失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 2.PyTorch的Dataset和DataLoader

2.PyTorch的Dataset和DataLoader

时间:2023-01-25 08:56:14

相关推荐

2.PyTorch的Dataset和DataLoader

#! /p/543481537

2.PyTorch的Dataset和DataLoader Datasets & DataLoader 构建自己的数据集使用DataLoaders为训练准备数据DataLoader深入剖析

2.PyTorch的Dataset和DataLoader

本次笔记记录如何构建数据集,如何搭建minibatch,以及DataLoader源码剖析.

Datasets & DataLoader

torch.utils.data.Dataset:

处理单个训练样本,从磁盘读取训练数据(特征、标签),预处理变成{x,y}训练对。torch.utils.data.DataLoader:

得到单个训练样本后,变成minibatch的形式(将数据打乱,保存在gpu中等)

#使用Torchvision导入内置数据集import torchfrom torch.utils.data import Datasetfrom torchvision import datasetsfrom torchvision.transforms import ToTensorimport matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())

构建自己的数据集

继承抽象类Dataset,并自己实现以下三个函数:

__init__

初始化操作(文件路径、transform······)__len__

在DataLoader中需要用到Dataset的长度,在Init中读取数据,直接用len()返回即可__getitem__

对Dataset索引返回单个训练样本,返回x,y训练对(有标签时)

import osimport pandas as pdfrom torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])#每张图片路径image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)#预处理if self.target_transform:label = self.target_transform(label)#预处理labelreturn image, label

这种dataset叫做map-style datasets

读流式数据可以用iterable-style dataset

使用DataLoaders为训练准备数据

使用Minibatch的形式训练数据,在每个epoch都reshuffle数据从而降低模型过拟合的可能性,还通过python的multiprocessing多进程加载数据,使得读数据的过程不影响gpu训练,实现低延迟。(有的数据集读取的过程贼慢,比方说TextVQA这种,很影响GPU训练的效率)

DataLoader返回的是一个列表。

from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)#test不需要梯度下降,只要前向传递,不需要shuffletest_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

展示一张图片与标签,

下面第一行代码的next,iter方法后面解释

# Display image and label.train_features, train_labels = next(iter(train_dataloader))print(f"Feature batch shape: {train_features.size()}")print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()#灰度图片要变成三通道label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"Label: {label}")

DataLoader深入剖析

官网的教程过于简单,下面深入剖析DataLoader,

包括sampler,collate_fn处理等,还有一些排序方法(比如让一个minibatch中样本长度差不多,这样在进行padding等操作的时候就会简单很多)。

主要解析三部分代码:

torch/utils/data/dataloader.pytorch/utils/data/sampler.pytorch/utils/data/_utils/collate.py

首先看dataloader

'''dataset: Dataset[T_co]:传入一个实例化的dataset对象shuffle:每个epoch后打乱数据,一般在训练集上用sampler/batch_sampler:自定义采样方式,显然与shuffle冲突num_workers:进程数量,多进程读数据,0代表只用一个主进程pin_memory:把tensor保存在gpu上,不用重复保存,对于效率影响有待考究drop_last:样本数量不是batch_size整数倍时,把最后一部分数据丢掉collate_fn:自定义聚集函数,对小批次数据进行再次处理,输入一个Batch,输出一个batchtimeout:是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。'''def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,shuffle: bool = False, sampler: Optional[Sampler[int]] = None,batch_sampler: Optional[Sampler[Sequence[int]]] = None,num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,pin_memory: bool = False, drop_last: bool = False,timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,multiprocessing_context=None, generator=None,*, prefetch_factor: int = 2,persistent_workers: bool = False):torch._C._log_api_usage_once("python.data_loader")if num_workers < 0:raise ValueError('num_workers option should be non-negative; ''use num_workers=0 to disable multiprocessing.')if timeout < 0:raise ValueError('timeout option should be non-negative')if num_workers == 0 and prefetch_factor != 2:raise ValueError('prefetch_factor option could only be specified in multiprocessing.''let num_workers > 0 to enable multiprocessing.')assert prefetch_factor > 0if persistent_workers and num_workers == 0:raise ValueError('persistent_workers option needs num_workers > 0')#设置成员变量self.dataset = datasetself.num_workers = num_workersself.prefetch_factor = prefetch_factorself.pin_memory = pin_memoryself.timeout = timeoutself.worker_init_fn = worker_init_fnself.multiprocessing_context = multiprocessing_context#后续代码太长,这里不一一展示,后面不算很难

后续代码太长,这里不一一展示,init函数主要做三件事:

构建sampler构建batchsampler构建collate_fn

这里说几点重要的,方便放矢地看源码:

1.构建sampler

1.1.shuffle内部通过RandomSampler的具体实现,重点看__iter__.

def __iter__(self) -> Iterator[int]:n = len(self.data_source)if self.generator is None:generator = torch.Generator()generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))else:generator = self.generatorif self.replacement:for _ in range(self.num_samples // 32):yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()else:yield from torch.randperm(n, generator=generator).tolist()

1.2.shuffle = None通过SequentialSampler实现.

class SequentialSampler(Sampler[int]):r"""Samples elements sequentially, always in the same order.Args:data_source (Dataset): dataset to sample from"""data_source: Sizeddef __init__(self, data_source: Sized) -> None:self.data_source = data_sourcedef __iter__(self) -> Iterator[int]:return iter(range(len(self.data_source)))def __len__(self) -> int:return len(self.data_source)

2.我们不使用batchsampler而仅仅是设定batch_size,内部其实还是通过BatchSampler实现的.

if batch_size is not None and batch_sampler is None:# auto_collation without custom batch_samplerbatch_sampler = BatchSampler(sampler, batch_size, drop_last)

这里插入说一下python的迭代器和生成器,方便下面的理解

'''迭代是Python最强大的功能之一,是访问集合元素的一种方式。迭代器是一个可以记住遍历的位置的对象。迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。迭代器有两个基本的方法:iter() 和 next()。字符串,列表或元组对象都可用于创建迭代器:'''>>> list=[1,2,3,4]>>> it = iter(list) # 创建迭代器对象>>> print (next(it)) # 输出迭代器的下一个元素1>>> print (next(it))2#迭代器对象可以使用常规for语句进行遍历:list=[1,2,3,4]it = iter(list) # 创建迭代器对象for x in it:print (x, end=" ")'''在 Python 中,使用了 yield 的函数被称为生成器(generator)。跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。调用一个生成器函数,返回的是一个迭代器对象。以下实例使用 yield 实现斐波那契数列:'''import sysdef fibonacci(n): # 生成器函数 - 斐波那契a, b, counter = 0, 1, 0while True:if (counter > n): returnyield aa, b = b, a + bcounter += 1f = fibonacci(10) # f 是一个迭代器,由生成器返回生成while True:try:print (next(f), end=" ")except StopIteration:sys.exit()

回归BatchSampler实现:

def __iter__(self) -> Iterator[List[int]]:batch = []for idx in self.sampler:batch.append(idx)f len(batch) == self.batch_size:yield batch #生成器,每次next返回一个batchbatch = []if len(batch) > 0 and not self.drop_last:#drop_last的原理yield batch

3.collate_fn自定义聚集函数

if collate_fn is None:if self._auto_collation: #batch_sample is not None ->truecollate_fn = _utils.collate.default_collate #输入输出都是batch,做了一些数据类型的转换else:collate_fn = _utils.collate.default_convertself.collate_fn = collate_fn

4.迭代器iter()可以把dataloader变成一个迭代器

iter->_get_iterator()

_get_iterator()->_SingleProcessDataLoaderIter()or_MultiProcessingDataLoaderIter()

_SingleProcessDataLoaderIter()继承自BaseDataLoaderIter.

我们发现_SingleProcessDataLoaderIter()_next_data函数没有在子类中被调用,猜测可以在父类BaseDataLoaderIter找到调用.

经查看,父类__next__中确实调用了_next_data,这就可以解释了代码的内部原理。

如果觉得《2.PyTorch的Dataset和DataLoader》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。