失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > PyTorch框架学习七——自定义transforms方法

PyTorch框架学习七——自定义transforms方法

时间:2020-07-04 00:45:08

相关推荐

PyTorch框架学习七——自定义transforms方法

PyTorch框架学习七——自定义transforms方法

一、自定义transforms注意要素二、自定义transforms步骤三、自定义transforms实例:椒盐噪声

虽然前面的笔记介绍了很多PyTorch给出的transforms方法,也非常有用,但是也有可能在具体的问题中需要开发者自定义transforms方法,这次笔记就介绍如何自定义transforms方法。

ps:本次笔记中使用的原始图像出自上次笔记:/qq_40467656/article/details/107958492

一、自定义transforms注意要素

从数据读取机制DataLoader我们知道了transforms的内部工作原理,是在Compose类的__call__函数定义并实现的:

class Compose(object):"""Composes several transforms together.Args:transforms (list of ``Transform`` objects): list of transforms to compose.Example:>>> pose([>>>transforms.CenterCrop(10),>>>transforms.ToTensor(),>>> ])"""def __init__(self, transforms):self.transforms = transformsdef __call__(self, img):for t in self.transforms:img = t(img)return img

由此出发,可以看出自定义transforms需要注意两个要素:

仅接收一个参数img,并返回一个参数img;transforms之间要注意上下游的输入与输出的格式匹配。

二、自定义transforms步骤

首先,自定义的transforms的输入参数可能不只img一个,如概率p等等,但是原来的代码只允许接收一个参数返回一个参数,所以可以在原来的基础上改进:

class YourTransforms(object):def __init__(self, ...): # ...是要传入的多个参数# 对多参数进行传入# 如 self.p = p 传入概率# ...def __call__(self, img): # __call__函数还是只有一个参数传入# 该自定义transforms方法的具体实现过程# ...return img

步骤如下:

自定义一个类YourTransforms,结构类似Compose类__init__函数作为多参数传入的地方__call__函数具体实现自定义的transforms方法

三、自定义transforms实例:椒盐噪声

椒盐噪声:又称为脉冲噪声,是一种随机出现的白点或黑点,白点被称为盐噪声,黑点被称为椒噪声,其与信噪比(SNR)息息相关。

此外,我们还想加入概率p这个参数,实现随机添加椒盐噪声。

仿照实现步骤,先写出其实现的大致框架:

class AddPepperNoise(object):def __init__(self, snr, p): # snr, p 是要传入的多个参数self.snr = snrself.p = p def __call__(self, img): # __call__函数还是只有一个参数传入'''添加椒盐噪声的具体实现过程'''return img

完整实现代码:

class AddPepperNoise(object):"""增加椒盐噪声Args:snr (float): Signal Noise Ratep (float): 概率值,依概率执行该操作"""def __init__(self, snr, p=0.9):assert isinstance(snr, float) and (isinstance(p, float))self.snr = snrself.p = pdef __call__(self, img):"""Args:img (PIL Image): PIL ImageReturns:PIL Image: PIL image."""if random.uniform(0, 1) < self.p:img_ = np.array(img).copy()h, w, c = img_.shapesignal_pct = self.snrnoise_pct = (1 - self.snr)mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])mask = np.repeat(mask, c, axis=2)img_[mask == 1] = 255 # 盐噪声img_[mask == 2] = 0# 椒噪声return Image.fromarray(img_.astype('uint8')).convert('RGB')else:return img

添加椒盐噪声之后:

如果觉得《PyTorch框架学习七——自定义transforms方法》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。