dataset 源代码

# -*- coding: utf-8 -*-
# @Time    : 2018/8/23 22:18
# @Author  : zhoujun
import pathlib
import numpy as np
from mxnet import image, nd, recordio
import cv2
from import Dataset, RecordFileDataset

[文档]class ImageDataset(Dataset):
[文档] def __init__(self, data_txt: str, data_shape: tuple, img_channel: int, num_label: int, alphabet: str, phase: str = 'train'): """ 数据集初始化 :param data_txt: 存储着图片路径和对于label的文件 :param data_shape: 图片的大小(h,w) :param img_channel: 图片通道数 :param num_label: 最大字符个数,应该和网络最终输出的序列宽度一样 :param alphabet: 字母表 """ super(ImageDataset, self).__init__() assert phase in ['train', 'test'] self.data_list = [] with open(data_txt, 'r', encoding='utf-8') as f: for line in f.readlines(): line = line.strip('\n').replace('.jpg ', '.jpg\t').split('\t') img_path = pathlib.Path(line[0]) if img_path.exists() and img_path.stat().st_size > 0 and line[1]: self.data_list.append((line[0], line[1])) self.img_h = data_shape[0] self.img_w = data_shape[1] self.img_channel = img_channel self.num_label = num_label self.alphabet = alphabet self.phase = phase self.label_dict = {} for i, char in enumerate(self.alphabet): self.label_dict[char] = i
def __getitem__(self, idx): img_path, label = self.data_list[idx] label = label.replace(' ', '') try: label = self.label_enocder(label) except Exception as e: print(img_path, label) img = self.pre_processing(img_path) return img, label def __len__(self): return len(self.data_list)
[文档] def label_enocder(self, label): """ 对label进行处理,将输入的label字符串转换成在字母表中的索引 :param label: label字符串 :return: 索引列表 """ tmp_label = nd.zeros(self.num_label, dtype=np.float32) - 1 for i, ch in enumerate(label): tmp_label[i] = self.label_dict[ch] return tmp_label
[文档] def pre_processing(self, img_path): """ 对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度 :param img_path: 图片地址 :return: """ data_augment = False if self.phase == 'train' and np.random.rand() > 0.5: data_augment = True if data_augment: img_h = 40 img_w = 340 else: img_h = self.img_h img_w = self.img_w img = image.imdecode(open(img_path, 'rb').read(), 1 if self.img_channel == 3 else 0) h, w = img.shape[:2] ratio_h = float(img_h) / h new_w = int(w * ratio_h) if new_w < img_w: img = image.imresize(img, w=new_w, h=img_h) step = nd.zeros((img_h, img_w - new_w, self.img_channel), dtype=img.dtype) img = nd.concat(img, step, dim=1) else: img = image.imresize(img, w=img_w, h=img_h) if data_augment: img, _ = image.random_crop(img, (self.img_w, self.img_h)) return img
[文档]class RecordDataset(RecordFileDataset): """ A dataset wrapping over a RecordIO file contraining images Each sample is an image and its corresponding label """
[文档] def __init__(self, filename, data_shape: tuple, img_channel: int, num_label: int): super(RecordDataset, self).__init__(filename) self.data_shape = data_shape self.img_channel = img_channel self.num_label = num_label
def __getitem__(self, idx): record = super(RecordDataset, self).__getitem__(idx) header, img = recordio.unpack(record) img = self.pre_processing(img) label = self.label_enocder(header.label) return img, label
[文档] def label_enocder(self, label): """ 对label进行处理,将输入的label字符串转换成在字母表中的索引 :param label: label字符串 :return: 索引列表 """ label = nd.array(label) tmp_label = nd.zeros(self.num_label - len(label), dtype=np.float32) - 1 label = nd.concat(label, tmp_label, dim=0) return label
[文档] def pre_processing(self, img): """ 对图片进行处理 :param img_path: 图片 :return: """ img = image.imdecode(img, 1 if self.img_channel == 3 else 0) h, w = img.shape[:2] ratio_h = float(self.data_shape[0]) / h new_w = int(w * ratio_h) if new_w < self.data_shape[1]: img = image.imresize(img, w=new_w, h=self.data_shape[0]) step = nd.zeros((self.data_shape[0], self.data_shape[1] - new_w, self.img_channel), dtype=img.dtype) img = nd.concat(img, step, dim=1) else: img = image.imresize(img, w=self.data_shape[1], h=self.data_shape[0]) return img
if __name__ == '__main__': import keys import time from import DataLoader from matplotlib import pyplot as plt from matplotlib.font_manager import FontProperties from import ToTensor from predict import decode font = FontProperties(fname=r"simsun.ttc", size=14) alphabet = keys.txt_alphabet dataset = ImageDataset('/data/zhy/crnn/Chinese_character/train2.txt', (32, 320), 3, 81, alphabet) # dataset = RecordDataset('/data1/zj/data/crnn/txt/val.rec', (32, 320), 3, 81) data_loader = DataLoader(dataset.transform_first(ToTensor()), 128, shuffle=True, num_workers=12) all_step = dataset.__len__() // 128 start = time.time() for i, (img, label) in enumerate(data_loader): cur_step = 0 * all_step + i if (i + 1) % 100 == 0: print(time.time() - start) start = time.time() print(cur_step) # start = time.time() # print(label.shape) # result = decode(label.asnumpy(), alphabet) # img1 = img[0].asnumpy().transpose(1, 2, 0) # print(result[0]) # plt.title(result[0], FontProperties=font) # plt.imshow(img1) # # break