失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > 【pytorch】【Dataset/DataLoader】制作数据集(一)

【pytorch】【Dataset/DataLoader】制作数据集(一)

时间:2023-12-18 21:31:03

相关推荐

【pytorch】【Dataset/DataLoader】制作数据集(一)

导入自定义数据

来源官方教程

数据下载链接:/tutorial/faces.zip

1、导入库

import matplotlib.pyplot as pltimport osimport numpy as npimport pandas as pdimport warningsimport torchsnooperfrom torch.utils.data import Dataset, DataLoaderimport pandas as pdimport torchfrom skimage import io, transformfrom torchvision import transforms, utilswarnings.filterwarnings("ignore")

2、查看单个样本

landmark_frame = pd.read_csv("faces\\face_landmarks.csv")n = 64image_name = landmark_frame.iloc[n, 0]landmark = landmark_frame.iloc[n, 1:]landmark = np.asarray(landmark)landmark = landmark.astype(float).reshape(-1, 2)def show_image(image_file, landmarks):image = io.imread(image_file)plt.imshow(image)plt.scatter(landmark[:, 0], landmark[:, 1], s=10, c='r', marker='.')plt.pause(0.001)plt.ion()image_file = os.path.join('faces', image_name)image = io.imread(image_file)plt.figure()show_image(image_file, landmark)plt.show()

知识点:

pandas:DataFrame.iloc[]#像数组一样访问元素

3、制作数据集

总结:

继承类:torch.utils.data.Dataset 初始化:init() 实现接口:len(self)、getitem(self, idx)

返回: 返回的类型类似于字典列表,可以通过方括号[]进行索引获得每条数据。 类似于data = [dict1,dict2,dict3],data[0]

此处是继承一个类,并且要实现其接口,接口必须要实现,通过接口使得这个类更加具有灵活 性,想返回什么样类型的,只要将其包装成字典就可以。

class FaceLandMarksDataset(Dataset):def __init__(self, csv_file, root_dir, transform=None):super(FaceLandMarksDataset, self).__init__()self.img_file = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.img_file)def __getitem__(self, item):if torch.is_tensor(item):#防止索引是tensoritem = item.tolist()img_name = self.img_file.iloc[item, 0]img_file = os.path.join(self.root_dir, img_name)land_marks = self.img_file.iloc[item, 1:]land_marks = np.asarray(land_marks).astype(float).reshape(-1, 2)img = io.imread(img_file)sample = {'image': img, 'landmarks': land_marks}if self.transform:sample = self.transform(sample)return sample

展示:

with torchsnooper.snoop():face_dataset = FaceLandMarksDataset("faces\\face_landmarks.csv", "faces\\")print(type(face_dataset[0]))fig = plt.figure()img, landma = face_dataset[0]['image'], face_dataset[0]['landmarks']show_image(img, landma)

知识点:

1)创建的Dataset类可以通过方括号索引直接获得。得到的是一个字典。

2)torchsnooper是调试助手,可通过pip安装,其作用是可以将其上下文中的变量形状,类型以及运行设备显示出来。

设置数据的增强转换:

1)图片缩放

class Rescale(object):def __init__(self, output_size):#输入两种格式,要么给出最短边的长度,等比例缩放,要么给出长宽assert isinstance(output_size, (int, tuple)) #不满足直接抛出异常self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]if isinstance(self.output_size, int):if h > w:new_h, new_w = self.output_size*h/w, self.output_sizeelse:new_h, new_w = self.output_size, self.output_size*w/helse:new_h, new_w = self.output_sizeimg = transform.resize(image, (new_h, new_w))landmarks = landmarks * [new_w/w, new_h/h] #点乘,利用broadcastreturn {"image":img, "landmarks":landmarks}

知识点:

1)assert 断言,必须满足,否则直接抛出异常。

2)_call_():在类中实现此函数,相当于重载(),使得可以直接通过Class()来调用此函数。对于可调用对象,实际上“名称()”可以理解为是“名称._call_()”的简写。链接:/view/2380.html

3)判断是否属于某类:isinstance()

4)*用于点乘,landmark是(31,2)

5)输入单个样本,输出认为单个样本

2)随机裁剪

class RandomCrop(object):def __init__(self, output_size):assert isinstance(output_size, (int, tuple))if isinstance(output_size, int):self.output_size = output_size, output_sizeelse:assert len(output_size) == 2self.output_size = output_sizedef __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']h, w = image.shape[:2]new_h, new_w = self.output_sizetop = np.random.randint(0, h-new_h)left = np.random.randint(0, w-new_w)image = image[top:top+new_w, left:left+new_w]landmarks = landmarks - [left, top]return {"image": image, "landmarks": landmarks}

3)转化为tensor

class ToTensor(object):def __call__(self, sample):image, landmarks = sample['image'], sample['landmarks']image = image.transpose((2, 0, 1))image = torch.from_numpy(image)landmarks = torch.from_numpy(landmarks)return {"image": image, "landmarks": landmarks}

知识点:

字典的传递时,保证接收的函数参数与字典的键一致,并且字典传递使用双星**;

注意skimage.io.imread()读入的数据是H×W×C,但是torch需要的是C×H×W,因此才有transpose那一句。

创建数据集实例:

数据图片增强,通过裁剪,缩放,旋转等操作可进行,而通过创建可调用的类,可以直接通过“名称()”形式进行调用,而pose([])可以直接将类进行综合,传入数据集中,对每个返回的数据进行处理之后吐出来。

transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv',root_dir='faces/',transform=pose([Rescale(256),RandomCrop(224),ToTensor()]))

4、创建DataLoader

制作数据集后,为了有助于进行训练,如设置批大小,是否打乱等,使用这样一个装饰器,将数据集实例传入,返回一个含有批量数据的迭代器。

dataloader = DataLoader(face_dataset, batch_size=2, shuffle=True, num_workers=0)

批量显示图片:

def show_batch(sample_batched):image_batch, landmark_batch = sample_batched['image'], sample_batched['landmarks']batch_size = len(image_batch)im_size = image_batch.size(2)grid = utils.make_grid(image_batch)plt.imshow(grid.numpy().transpose((1, 2, 0)))for i in range(batch_size):plt.scatter(landmark_batch[i, :, 0].numpy()+i*im_size, landmark_batch[i, :, 1].numpy(),s=10, marker='.', c='r')plt.title('Batch from dataloader')

知识点:

1)torchvision.utils.make_gride()将一批次的图片合成一个大的图,高度为图片的高度,但是长度是四个图片长度相加。因此在画landmarks时才需要在i*im_size。

2)grid后是tensor类型,但是show只能展示array的数据,因此将其转化,并且调整轴的维度。torch中是C×H×W,而画图时是H×W×C。

with torchsnooper.snoop():for i_batch, sample_batched in enumerate(dataloader):print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())if i_batch == 0:plt.figure()show_batch(sample_batched)plt.axis('off')plt.show()break

用到的库总结:

from torch.utils.data import Dataset, DataLoaderfrom skimage import io, transformfrom torchvision import transforms, utils

导入数据总结:

1)创建数据集类,继承自torch.utils.data.Dataset,并实现__init__()、__len__()、__getitem__()__init__()接收样本地址、标签地址、以及转化。__len__()返回样本个数__getitem__()返回字典,包含样本和标签2)自定义转换缩放、切割、转化为tensor继承自 object;实现函数__init__()、__call__(),call的 作用见前面对于实现的类,使用pose([])组合起来,传入数据集3)创建实例4)创建DataLoader,设置是否打乱、批量,线程

展示批量图片:

torchvision.utils.make_grid(),传入实例化的DataLoader(迭代器),生成的是tensor,一张大图片,整个图的高维图片的高,宽为批次图片宽度相加。在show前,进行两个转化,一个是转为ndarray,一个是通道转变。

实现细节注意:

1)图片通道:读取和show通道在后,而torch中通道在前。2)数据类型:标签读入时可能为str,要进行转化,ndarray和torch的转化。3)图像增强类:继承object,实现call,使用pose([])聚合。4)自定义Dataset类的实例,相当于样本标签字典组合的列表,可以通过方括号进行索引。5)DataLoader类的实例是迭代器,可通过for迭代访问,常用enumerate。6)作图时使用make_grid要注意,通道(channel)、类型(ndarray)。

如果觉得《【pytorch】【Dataset/DataLoader】制作数据集(一)》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。