小言_互联网的博客

PyTorch~自定义数据读取

355人阅读  评论(0)

这次是PyTorch的自定义数据读取pipeline模板和相关trciks以及如何优化数据读取的pipeline等。

因为有torch也放人工智能模块了~

从PyTorch的数据对象类Dataset开始。Dataset在PyTorch中的模块位于utils.data下。

from torch.utils.data import Dataset

围绕Dataset对象分别从原始模板、torchvision的transforms模块、使用pandas来辅助读取、torch内置数据划分功能和DataLoader来展开阐述。

Dataset原始模板

PyTorch官方为我们提供了自定义数据读取的标准化代码代码模块,作为一个读取框架,我们这里称之为原始模板。其代码结构如下:


  
  1. from torch.utils.data import Dataset
  2. class  CustomDataset(Dataset):
  3.      def  __init__( self, ...):
  4.          # stuff
  5.         
  6.      def  __getitem__( self, index):
  7.          # stuff
  8.          return (img, label)
  9.         
  10.      def  __len__( self):
  11.          # return examples size
  12.          return count

根据这个标准化的代码模板,我们只需要根据自己的数据读取任务,分别往__init__()、__getitem__()和__len__()三个方法里添加读取逻辑即可。作为PyTorch范式下的数据读取以及为了后续的data loader,三个方法缺一不可。其中:

  • __init__()函数用于初始化数据读取逻辑,比如读取包含标签和图片地址的csv文件、定义transform组合等。

  • __getitem__()函数用来返回数据和标签。目的上是为了能够被后续的dataloader所调用。

  • __len__()函数则用于返回样本数量。

现在我们往这个框架里填几行代码来形成一个简单的数字案例。创建一个从1到100的数字例子:


  
  1. from torch.utils.data  import Dataset
  2. class  CustomDataset( Dataset):
  3.      def  __init__( self):
  4.         self.samples =  list( range( 1101))
  5.      def  __len__( self):
  6.          return  len(self.samples)
  7.      def  __getitem__( self, idx):
  8.          return self.samples[idx]
  9.         
  10. if __name__ ==  '__main__':
  11.     dataset = CustomDataset()
  12.      print( len(dataset))
  13.      print(dataset[ 50])
  14.      print(dataset[ 1: 100])

 添加torchvision.transforms

然后我们来看如何从内存中读取数据以及如何在读取过程中嵌入torchvision中的transforms功能。torchvision是一个独立于torch的关于数据、模型和一些图像增强操作的辅助库。主要包括datasets默认数据集模块、models经典模型模块、transforms图像增强模块以及utils模块等。在使用torch读取数据的时候,一般会搭配上transforms模块对数据进行一些处理和增强工作。

添加了tranforms之后的读取模块可以改写为:


  
  1. from torch.utils.data  import Dataset
  2. from torchvision  import transforms  as T
  3. class  CustomDataset( Dataset):
  4.      def  __init__( self, ...):
  5.          # stuff
  6.         ...
  7.          # compose the transforms methods
  8.         self.transform = T.Compose([T.CenterCrop( 100),
  9.                                 T.ToTensor()])
  10.         
  11.      def  __getitem__( self, index):
  12.          # stuff
  13.         ...
  14.         data =  # Some data read from a file or image
  15.          # execute the transform
  16.         data = self.transform(data)  
  17.          return (img, label)
  18.         
  19.      def  __len__( self):
  20.          # return examples size
  21.          return count
  22.         
  23. if __name__ ==  '__main__':
  24.      # Call the dataset
  25.     custom_dataset = CustomDataset(...)

可以看到,我们使用了Compose方法来把各种数据处理方法聚合到一起进行定义数据转换方法。通常作为初始化方法放在__init__()函数下。我们以猫狗图像数据为例进行说明。

定义数据读取方法如下:


  
  1. class  DogCat( Dataset):    
  2.      def  __init__( self, root, transforms=None, train=True, val=False):
  3.          """
  4.         get images and execute transforms.
  5.         """
  6.         self.val = val
  7.         imgs = [os.path.join(root, img)  for img  in os.listdir(root)]
  8.          # train: Cats_Dogs/trainset/cat.1.jpg
  9.          # val: Cats_Dogs/valset/cat.10004.jpg
  10.         imgs =  sorted(imgs, key= lambda x: x.split( '.')[- 2])
  11.         self.imgs = imgs         
  12.          if transforms  is  None:
  13.              # normalize      
  14.             normalize = T.Normalize(mean = [ 0.4850.4560.406],
  15.                                      std = [ 0.2290.2240.225])
  16.              # trainset and valset have different data transform 
  17.              # trainset need data augmentation but valset don't.
  18.              # valset
  19.              if self.val:
  20.                 self.transforms = T.Compose([
  21.                     T.Resize( 224),
  22.                     T.CenterCrop( 224),
  23.                     T.ToTensor(),
  24.                     normalize
  25.                 ])
  26.              # trainset
  27.              else:
  28.                 self.transforms = T.Compose([
  29.                     T.Resize( 256),
  30.                     T.RandomResizedCrop( 224),
  31.                     T.RandomHorizontalFlip(),
  32.                     T.ToTensor(),
  33.                     normalize
  34.                 ])
  35.                        
  36.      def  __getitem__( self, index):
  37.          """
  38.         return data and label
  39.         """
  40.         img_path = self.imgs[index]
  41.         label =  1  if  'dog'  in img_path.split( '/')[- 1else  0
  42.         data = Image. open(img_path)
  43.         data = self.transforms(data)
  44.          return data, label
  45.   
  46.      def  __len__( self):
  47.          """
  48.         return images size.
  49.         """
  50.          return  len(self.imgs)
  51. if __name__ ==  "__main__":
  52.     train_dataset = DogCat( './Cats_Dogs/trainset/', train= True)
  53.      print( len(train_dataset))
  54.      print(train_dataset[ 0])

因为这个数据集已经分好了训练集和验证集,所以在读取和transforms的时候需要进行区分。运行示例如下:

与pandas一起使用

很多时候数据的目录地址和标签都是通过csv文件给出的。如下所示:

此时在数据读取的pipeline中我们需要在__init__()方法中利用pandas把csv文件中包含的图片地址和标签融合进去。相应的数据读取pipeline模板可以改写为:


  
  1. class  CustomDatasetFromCSV( Dataset):
  2.      def  __init__( self, csv_path):
  3.          """
  4.         Args:
  5.             csv_path (string): path to csv file
  6.             transform: pytorch transforms for transforms and tensor conversion
  7.         """
  8.          # Transforms
  9.         self.to_tensor = transforms.ToTensor()
  10.          # Read the csv file
  11.         self.data_info = pd.read_csv(csv_path, header= None)
  12.          # First column contains the image paths
  13.         self.image_arr = np.asarray(self.data_info.iloc[:,  0])
  14.          # Second column is the labels
  15.         self.label_arr = np.asarray(self.data_info.iloc[:,  1])
  16.          # Calculate len
  17.         self.data_len =  len(self.data_info.index)
  18.      def  __getitem__( self, index):
  19.          # Get image name from the pandas df
  20.         single_image_name = self.image_arr[index]
  21.          # Open image
  22.         img_as_img = Image. open(single_image_name)
  23.          # Transform image to tensor
  24.         img_as_tensor = self.to_tensor(img_as_img)
  25.          # Get label of the image based on the cropped pandas column
  26.         single_image_label = self.label_arr[index]
  27.          return (img_as_tensor, single_image_label)
  28.      def  __len__( self):
  29.          return self.data_len
  30. if __name__ ==  "__main__":
  31.      # Call dataset
  32.     dataset =  CustomDatasetFromCSV( './labels.csv')

以mnist_label.csv文件为示例:


  
  1. from torch.utils.data  import Dataset
  2. from torch.utils.data  import DataLoader
  3. from torchvision  import transforms  as T
  4. from PIL  import Image
  5. import os
  6. import numpy  as np
  7. import pandas  as pd
  8. class  CustomDatasetFromCSV( Dataset):
  9.      def  __init__( self, csv_path):
  10.          """
  11.         Args:
  12.             csv_path (string): path to csv file            
  13.             transform: pytorch transforms for transforms and tensor conversion
  14.         """
  15.          # Transforms
  16.         self.to_tensor = T.ToTensor()
  17.          # Read the csv file
  18.         self.data_info = pd.read_csv(csv_path, header= None)
  19.          # First column contains the image paths
  20.         self.image_arr = np.asarray(self.data_info.iloc[:,  0])
  21.          # Second column is the labels
  22.         self.label_arr = np.asarray(self.data_info.iloc[:,  1])
  23.          # Third column is for an operation indicator
  24.         self.operation_arr = np.asarray(self.data_info.iloc[:,  2])
  25.          # Calculate len
  26.         self.data_len =  len(self.data_info.index)
  27.      def  __getitem__( self, index):
  28.          # Get image name from the pandas df
  29.         single_image_name = self.image_arr[index]
  30.          # Open image
  31.         img_as_img = Image. open(single_image_name)
  32.          # Check if there is an operation
  33.         some_operation = self.operation_arr[index]
  34.          # If there is an operation
  35.          if some_operation:
  36.              # Do some operation on image
  37.              # ...
  38.              # ...
  39.              pass
  40.          # Transform image to tensor
  41.         img_as_tensor = self.to_tensor(img_as_img)
  42.          # Get label of the image based on the cropped pandas column
  43.         single_image_label = self.label_arr[index]
  44.          return (img_as_tensor, single_image_label)
  45.      def  __len__( self):
  46.          return self.data_len
  47. if __name__ ==  "__main__":
  48.     transform = T.Compose([T.ToTensor()])
  49.     dataset = CustomDatasetFromCSV( './mnist_labels.csv')
  50.      print( len(dataset))
  51.      print(dataset[ 5])

运行示例如下:

训练集验证集划分

一般来说,为了模型训练的稳定,我们需要对数据划分训练集和验证集。torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。

以kaggle的花朵数据为例:  whaosoft aiot http://143ai.com


  
  1. from torch.utils.data  import DataLoader
  2. from torchvision.datasets  import ImageFolder
  3. from torchvision  import transforms  as T
  4. from torch.utils.data  import random_split
  5. transform = T.Compose([
  6.     T.Resize(( 224224)),
  7.     T.RandomHorizontalFlip(),
  8.     T.ToTensor()
  9.  ])
  10. dataset = ImageFolder( './flowers_photos', transform=transform)
  11. print(dataset.class_to_idx)
  12. trainset, valset = random_split(dataset, 
  13.                 [ int( len(dataset)* 0.7),  len(dataset)- int( len(dataset)* 0.7)])
  14. trainloader = DataLoader(dataset=trainset, batch_size= 32, shuffle= True, num_workers= 1)
  15. for i, (img, label)  in  enumerate(trainloader):
  16.     img, label = img.numpy(), label.numpy()
  17.      print(img, label)
  18. valloader = DataLoader(dataset=valset, batch_size= 32, shuffle= True, num_workers= 1)
  19. for i, (img, label)  in  enumerate(trainloader):
  20.     img, label = img.numpy(), label.numpy()
  21.      print(img.shape, label)

这里使用了ImageFolder模块,可以直接读取各标签对应的文件夹,部分运行示例如下: 

使用DataLoader

dataset方法写好之后,我们还需要使用DataLoader将其逐个喂给模型。上一节的数据划分我们已经用到了DataLoader函数。从本质上来讲,DataLoader只是调用了__getitem__()方法并按批次返回数据和标签。使用方法如下:


  
  1. from torch.utils.data  import DataLoader
  2. from torchvision  import transforms  as T
  3. if __name__ ==  "__main__":
  4.      # Define transforms
  5.     transformations = T.Compose([T.ToTensor()])
  6.      # Define custom dataset
  7.     dataset = CustomDatasetFromCSV( './labels.csv')
  8.      # Define data loader
  9.     data_loader = DataLoader(dataset=dataset, batch_size= 10, shuffle= True)
  10.      for images, labels  in data_loader:
  11.          # Feed the data to the model

以上就是PyTorch读取数据的Pipeline主要方法和流程。基于Dataset对象的基本框架不变,具体细节可自定义化调整。


转载:https://blog.csdn.net/qq_29788741/article/details/128244886
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场