失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset

PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset

时间:2019-04-03 04:13:49

相关推荐

PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset

文章目录

一、PyTorch数据读取机制Dataloader

一、PyTorch数据读取机制Dataloader

PyTorch数据读取在Dataloader模块下,Dataloader又可以分为DataSet与Sampler。Sampler模块的功能是生成索引(样本序号);DataSet是依据索引读取Img、Lable。我们主要学习Dataloader与Dataset。

torch.utils.data.DataLoader()

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)

功能:构建可迭代的数据装载器

dataset: Dataset类,决定数据从哪读取及如何读取batch_size :批大小num_works:是否多进程读取数据shuffle:每个epoch是否乱序drop_last :当样本数不能被batchsize整除时,是否舍弃最后一批数据

Epoch:所有训练样本都已输入到模型中,称为一个Epoch

Iteration:一批样本输入到模型中,称之为一个lteration

Batchsize:批大小,决定一个Epoch有多少个lteration

样本总数:80,Batchsize : 8

1 Epoch = 10 lteration

样本总数:87, Batchsize: 8

1 Epoch = 10 lteration ? drop_last = True

1 Epoch = 11 lteration drop_last = False

torch.utils.data.Dataset()

class Dataset(object):def __getitem__(self,index):raise NotImplementedErrordef __add__(self, other) :return ConcatDataset([self, other])

功能: Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()

getitem:接收一个索引,返回一个样本

数据读取流程如下:

for i, data in enumerate(train_loader):==># 判断是单进程还是多进程def __iter__(self):# 单进程if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)# 多进程else:return _MultiProcessingDataLoaderIter(self)==># 以单进程为例class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_SingleProcessDataLoaderIter, self).__init__(loader)assert self.timeout == 0assert self.num_workers == 0self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)# 这个函数告诉我们每个iteration中读哪些数据def __next__(self):# index = self._next_index() # may raise StopIterationdata = self.dataset_fetcher.fetch(index) # may raise StopIterationif self.pin_memory:data = _utils.pin_memory.pin_memory(data)return datanext = __next__ # Python 2 compatibility==>def _next_index(self):return next(self.sampler_iter) # may raise StopIteration==># 利用sampler输出的index来进行采样def __iter__(self):batch = []# for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batch==>class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:# 这一步实现了正式的数据读取data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]return self.collate_fn(data)==>class RMBDataset(Dataset):def __init__(self, data_dir, transform=None):"""rmb面额分类任务的Dataset:param data_dir: str, 数据集所在路径:param transform: torch.transform,数据预处理"""self.label_name = {"1": 0, "100": 1}self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本self.transform = transformdef __getitem__(self, index):# 根据索引index获得数据与标签path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')# 0~255if self.transform is not None:img = self.transform(img) # 在这里做transform,转为tensor等等return img, labeldef __len__(self):return len(self.data_info)@staticmethoddef get_img_info(data_dir):data_info = list()# 遍历一个目录内,各个子目录与子文件for root, dirs, _ in os.walk(data_dir):# 遍历类别for sub_dir in dirs:img_names = os.listdir(os.path.join(root, sub_dir))img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))# 遍历图片for i in range(len(img_names)):img_name = img_names[i]path_img = os.path.join(root, sub_dir, img_name)label = rmb_label[sub_dir]data_info.append((path_img, int(label)))return data_info==>class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]# 数据的整理器,将读取到的数据整理成batch的形式return self.collate_fn(data)==>for i, data in enumerate(train_loader):# forward# data由两个Tensor组成inputs, labels = data

数据整理器将数据由下面的形式:

转化为batch形式:

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!

如果觉得《PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。