更新
本文进入热榜收到了不少关注,所以将本文的代码放在了GitHub上,jupyter的,有需要的自取。
同时也欢迎查看后续更新:
pytorch DataLoader(2): Dataset,DataLoader自定义训练数据_opencv,skimage,PIL接口
前置知识
在使用pytorch进行dataload,transform之前,需要了解一些数据的知识,许多人使用不同的接口因为不熟悉犯了一些错误。在这里对一些常用的OpenCV,PIL,skimage进行了一些总结,以及pytorchvision.transorforms的一些简单使用。
import cv2
from PIL import Image
from skimage import io, transform, color
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
img_path = 'data/1803151818-00000065.jpg'
alpha_path = 'data/1803151818-00000065.png'
常用接口
1.1 OpenCV
# 默认彩图
img_cv2 = cv2.imread(img_path)
# 灰度图
img_cv2_gray = cv2.imread(alpha_path,0)
print(img_cv2.shape)
# (250, 250, 3) (H,W,C)
type(img_cv2)
# numpy.ndarray
1.2 PIL.Image
# 默认彩图
img_pil = Image.open(img_path)
# 灰度图
img_pil_gray = Image.open(alpha_path).convert('L') # 打开图片并转成灰度图
print(img_pil.size)
# (250, 250)
print(np.array(img_pil).shape) # PIL没有shape属性,需要转成 numpy.ndarray
#(250, 250, 3)
type(img_pil)
# PIL.JpegImagePlugin.JpegImageFile HWC
1.3 skimage1
# 默认彩图
img_skimage = io.imread(img_path)
# 灰度图
img_skimage_gray = io.imread(alpha_path,-1)
print(img_skimage.shape)
# (250, 250, 3)
type(img_skimage)
# numpy.ndarray
# imageio.core.util.Array
(800, 600, 3)
numpy.ndarray
1.4 小结
- OpenCV读进来的是numpy数组,是uint8类型,0-255范围,图像形状是(H,W,C),读入的顺序是BGR,这点需要注意
- PIL是有自己的数据结构的,类型是<class ‘PIL.Image.Image’>;但是可以转换成numpy数组,转换后的数组为unit8,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
- skimage读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
- matplotlib读取进来的图片是numpy数组,是unit8类型,0-255范围,图像形状是(H,W,C),读入的顺序是RGB
名称 | type | 数据类型 | 读入图像格式 | 数据形状 | 能否通过transforms转换 |
---|---|---|---|---|---|
opencv | numpy.ndarray | uint8类型,0-255范围 | BGR | H×W×C | 否 |
PIL | PIL.Image.Image | RGB | H×W×C | 是 | |
skimage | numpy.ndarray | uint8类型,0-255范围 | RGB | H×W×C | 否 |
#cv2
# cv2 BGR-->RGB 两种方法
#img_cv2 = img_cv2[:,:,::-1]
img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
plt.subplot(1,4,1)
plt.title('cv2')
plt.imshow(img_cv2)
#PIL
plt.subplot(1,4,2)
plt.title('PIL')
plt.imshow(img_pil)
#PIL
plt.subplot(1,4,3)
plt.title('skimage')
plt.imshow(img_skimage)
#plt
img = plt.imread(img_path)
plt.subplot(1,4,4)
plt.title('plt')
plt.imshow(img_pil)
#show
plt.show()
2. 相互转换
2.1 opencv <—> pil
img_cv = cv2.imread(img_path)
img_pil = Image.open(img_path)
img_skimage = io.imread(img_path)
# opencv -> pil
img_pil = Image.fromarray(cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB))
# pil -> opencv
img_cv = cv2.cvtColor(np.asarray(img_pil),cv2.COLOR_RGB2BGR)
2.2 skimage <—> pil
# skimage -> pil
img_pil = Image.fromarray(img_skimage)
# pil -> skimage
img_pil = np.array(img_skimage)
2.3 skimage <—> opencv
# opencv -> skimage
img_skimage = cv2.cvtColor(img_cv,cv2.COLOR_BGR2RGB)
# skimage -> opencv
from skimage import img_as_ubyte
cv_image = img_as_ubyte(img_skimage)
3. transforms, tensor转换
为了方便进行图像数据的操作,pytorch团队提供了一个torchvision.transforms包,我们可以用transforms进行以下操作:
- PIL.Image / numpy.ndarray与Tensor的相互转化;
- 归一化;
- 对PIL.Image进行裁剪、缩放等操作。
注意1: transforms.ToTensor()
可以将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到[0, 1.0],但是transforms的其他操作只能对PIL读入的数据操作,所以使用transforms.Compose()
将这些操作组合到一起的如果有其他操作则只能输入PIL数据。
transforms包含多种图像操作的函数,可以单独使用,也可以通过transforms.Compose([function1, function2,……functionN])操作。
注意2:Tensor的形状是[C,H,W],而cv2,plt,PIL,skimage形状都是[H,W,C]
3.1 H×W×C ——> C×H×W
img_cv2.transpose(2,0,1).shape
# (3,250, 250)
img_skimage.transpose(2,0,1).shape
# (3,250, 250)
(3, 800, 600)
3.2 toTensor
- PIL.Image / numpy.ndarray --> Tensor: train 数据读取
- Tensor --> PIL.Image / numpy.ndarray: inference 数据输出。
我们可以使用 transforms.ToTensor() 将 PIL.Image/numpy.ndarray 数据进转化为torch.FloatTensor,并归一化到[0, 1.0]:
- 取值范围为[0, 255]的PIL.Image,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloatTensor;
- 形状为[H, W, C]的numpy.ndarray,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloatTensor;
- 而
transforms.ToPILImage
则是将Tensor或numpy.ndarray转化为PIL.Image。如果,我们要将Tensor转化为numpy,只需要使用 .numpy() 即可。
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
img_path = 'data/1803151818-00000065.jpg'
# transforms.ToTensor()
transform1 = transforms.Compose([
transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] and convert [H,W,C] to [C,H,W]
])
img = plt.imread(img_path)
print('plt',img.shape) #(H,W,C)
img = transform1(img)
print(img.shape) #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255 #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8') #convert Float to Int
print(img_arr.shape) #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0)) #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)
img = cv2.imread(img_path)
#img = img[:,:,::-1] ### ValueError???
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
print('plt',img.shape) #(H,W,C)
img = transform1(img)
print(img.shape) #torch.Size([C,H,W])
# 转化为numpy.ndarray并显示
img_arr = img.numpy() * 255 #use np.numpy(): convert Tensor to numpy
img_arr = img_arr.astype('uint8') #convert Float to Int
print(img_arr.shape) #[C,H,W]
img_new = np.transpose(img_arr, (1, 2, 0)) #use np.transpose() convert [C,H,W] to [H,W,C]
plt.imshow(img_new)
plt.show()
plt (800, 600, 3)
torch.Size([3, 800, 600])
(3, 800, 600)
3.3 Normalize
c h a n n e l = c h a n n e l − m e a n s t d channel = \frac{channel - mean}{std} channel=stdchannel−mean进行规范化。(是对tensor进行归一化,所以需要放在transforms.ToTensor()之后)
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
# 这两组值是 ImageNet数据集大样本统计得出的
#归一化
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
]
)
3.4 compose
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# 先转PIL 再进入Compose 进行数据增强
all_transforms = transforms.Compose([
transforms.Resize(256),
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(), # 对PIL.Image图片进行操作
transforms.ToTensor(),
normalize])
# 或者ToTensor之后 再转PIL
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop((300,300)),
])
img = Image.open(img_path).convert('RGB')
img2 = transform2(img)
img2.show()
Reference:
数据来源:爱分割 github
https://blog.csdn.net/tsq292978891/article/details/78767326
转载:https://blog.csdn.net/Mao_Jonah/article/details/117228926