CASIA手写体数据集HWDB gnt和dgrl格式解析

引言

  • 最近在做手写识别项目,网上找到的是用CASIA数据集来做模型测试,CASIA数据集网址
  • 需要对其数据中的gnt和dgrl格式进行解析,网上也找了很多现成的代码拿来用
  • 最终找到了这篇CASIA手写体数据集HWDB1.0 gnt和dgrl格式解析
  • 在处理gnt格式的代码上做了修改,增加了多进程处理的功能,可提高批量处理图片的效率

Gnt格式解析

import struct
from pathlib import Path
from PIL import Image
from multiprocessing import Pool



def process_gnt_file(gnt_paths):

    label_list = []
    for gnt_path in gnt_paths:
        count = 0
        print(f'gnt路径--->{gnt_path}')

        with open(str(gnt_path), 'rb') as f:
            while f.read(1) != "":
                f.seek(-1, 1)
                count += 1
                try:
                    # 按类型提取gnt格式文件中的数据
                    length_bytes = struct.unpack('<I', f.read(4))[0]

                    tag_code = f.read(2)

                    width = struct.unpack('<H', f.read(2))[0]

                    height = struct.unpack('<H', f.read(2))[0]

                    im = Image.new('RGB', (width, height))
                    img_array = im.load()  # 返回像素值
                    for x in range(height):
                        for y in range(width):
                            # 读取像素值
                            pixel = struct.unpack('<B', f.read(1))[0]
                            # 赋值
                            img_array[y, x] = (pixel, pixel, pixel)

                    filename = str(count) + '.png'

                    # 转换为中文的格式
                    tag_code = tag_code.decode('gbk').strip('\x00')
                    save_path = f'{save_dir}/zf_images_train/{gnt_path.stem}'
                    if not Path(save_path).exists():
                        Path(save_path).mkdir(parents=True, exist_ok=True)
                    im.save(f'{save_path}/{filename}')

                    # 保存格式为:文件路径/文件名 图片中的文字 : 1290-c/563.png	兼
                    label_list.append(f'{gnt_path.stem}/{filename}\t{tag_code}')
                except:
                    break

    return label_list


def write_txt(save_path: str, content: list, mode='w'):
    """
    将list内容写入txt中
    @param
    content: list格式内容
    save_path: 绝对路径str
    @return:None
    """
    with open(save_path, mode, encoding='utf-8') as f:
        for value in content:
            print(f'value--->{value}')
            f.write(value + '\n')



if __name__ == '__main__':
    # 数据读取路径
    path = '/local/nas/from-jfs/data/zf_gnt_train'

    # 文件的保存路径
    save_dir = '/local/nas/from-jfs/data/HWDB1'  

    gnt_paths = list(Path(path).iterdir())
    print(f'gnt_paths--->{gnt_paths}')

    # 将 gnt_paths 分割成子列表
    n = 10
    gnt_sublists = [gnt_paths[i:i + n] for i in range(0, len(gnt_paths), n)]
    
    # 多进程处理
    with Pool(10) as p:
        label_list = p.map(process_gnt_file, gnt_sublists)
        print(f'label_list--->{label_list}')

     # 合并子列表
    label_list = [item for sublist in label_list for item in sublist]
    print(f'label_list--->{label_list}')

    write_txt(f'{save_dir}/zf_gnt_train.txt', label_list)


Dgrl格式解析

import os
import struct
from pathlib import Path

import cv2 as cv
import numpy as np
from tqdm import tqdm


def read_from_dgrl(dgrl):
    if not os.path.exists(dgrl):
        print('DGRL not exis!')
        return

    dir_name, base_name = os.path.split(dgrl)
    label_dir = dir_name+'_label'
    image_dir = dir_name+'_images'
    if not os.path.exists(label_dir):
        os.makedirs(label_dir)
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    with open(dgrl, 'rb') as f:
        # 读取表头尺寸
        header_size = np.fromfile(f, dtype='uint8', count=4)
        header_size = sum([j << (i*8) for i, j in enumerate(header_size)])
        # print(header_size)

        # 读取表头剩下内容,提取 code_length
        header = np.fromfile(f, dtype='uint8', count=header_size-4)
        code_length = sum([j << (i*8) for i, j in enumerate(header[-4:-2])])
        # print(code_length)

        # 读取图像尺寸信息,提取图像中行数量
        image_record = np.fromfile(f, dtype='uint8', count=12)
        height = sum([j << (i*8) for i, j in enumerate(image_record[:4])])
        width = sum([j << (i*8) for i, j in enumerate(image_record[4:8])])
        line_num = sum([j << (i*8) for i, j in enumerate(image_record[8:])])
        # print('图像尺寸:')
        # print(height, width, line_num)

        # 读取每一行文字的信息
        for k in range(line_num):
            # print(k+1)

            # 读取该行的字符数量
            char_num = np.fromfile(f, dtype='uint8', count=4)
            char_num = sum([j << (i*8) for i, j in enumerate(char_num)])
            # print('字符数量:', char_num)

            # 读取该行的标注信息
            label = np.fromfile(f, dtype='uint8', count=code_length*char_num)
            label = [label[i] << (8*(i % code_length))
                     for i in range(code_length*char_num)]
            label = [sum(label[i*code_length:(i+1)*code_length])
                     for i in range(char_num)]
            label = [struct.pack('I', i).decode(
                'gbk', 'ignore')[0] for i in label]
            # print('合并前:', label)
            # 合并前: ['一', '种', '不', '正', '常', '的', '现', '象', '。', '企', '业', '家', '不', '管', '在', '经', '营', '、', '管', '理', '上', '存', '在', '什', '么']

            label = ''.join(label)
            # 去掉不可见字符 \x00,这一步不加的话后面保存的内容会出现看不见的问题
            label = ''.join(label.split(b'\x00'.decode()))
            # print('合并后:', label)
            # 合并后: 一种不正常的现象。企业家不管在经营、管理上存在什么

            # 读取该行的位置和尺寸
            pos_size = np.fromfile(f, dtype='uint8', count=16)
            y = sum([j << (i*8) for i, j in enumerate(pos_size[:4])])
            x = sum([j << (i*8) for i, j in enumerate(pos_size[4:8])])
            h = sum([j << (i*8) for i, j in enumerate(pos_size[8:12])])
            w = sum([j << (i*8) for i, j in enumerate(pos_size[12:])])
            # print(x, y, w, h)

            # 读取该行的图片
            bitmap = np.fromfile(f, dtype='uint8', count=h*w)
            bitmap = np.array(bitmap).reshape(h, w)

            # 保存信息
            label_file = os.path.join(
                label_dir, base_name.replace('.dgrl', '_'+str(k)+'.txt'))
            with open(label_file, 'w') as f1:
                f1.write(label)
            bitmap_file = os.path.join(
                image_dir, base_name.replace('.dgrl', '_'+str(k)+'.png'))
            cv.imwrite(bitmap_file, bitmap)


if __name__ == '__main__':
    dgrl_paths = Path('/local/nas/from-jfs/data/wbx_dgrl_train').iterdir()
    dgrl_paths = list(dgrl_paths)
    for dgrl_path in tqdm(dgrl_paths):
        read_from_dgrl(dgrl_path)


参考资料


使用 python 获取 CASIA 脱机和在线手写汉字库

python 获取 CASIA 脱机和在线手写汉字库 (三)
CASIA数据集网址
CASIA手写体数据集HWDB1.0 gnt和dgrl格式解析

猜你喜欢

转载自blog.csdn.net/modi88/article/details/130213007