飞道的博客

一种通用的载入本地数据集的方法

612人阅读  评论(0)

1.说明

1.1 数据集放置格式说明

数据集文件夹下的不同类别图片需要先进行整理,放在不同的子文件夹,放置格式如图所示:

这里只有2类,当然多个分类也行,这个对分类类别的数量没有要求。

1.2 函数引用说明

在其他程序中引用这里的函数,引用方法如下:

import sys      #绝对路径引用,不然引用load_data会报错
#load_data所在程序路径
sys.path.append(r'E:\Pycharm\project\yeah&ok\load_data')	
from load_data import load_data_func,test_image,augment

一般只需要引用load_data_func和test_image即可。

1.3 加载数据集程序中函数的使用方法说明

经过载入数据集程序的处理后,加载数据集就很简单了,加载方法如下:

ata_dir = 'E:\Pycharm\project\yeah&ok\dataset'
Batch_size = 32     #批处理尺寸
train_dataset,test_dataset = load_data_func(data_dir,batchsize=Batch_size)
test_image(train_dataset)	#显示9张图像

然后就能继续进行网络结构的搭建,进行训练等步骤了。

2.配置库文件(开始)

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import random
import tensorflow_datasets as tfds

3.主函数

主要函数:功能是输入数据集路径与批处理大小,返回训练集与测试集。

def load_data_func(data_dir,batch_size):
    data_root = pathlib.Path(data_dir)  #读取路径,创建path对象
    print(data_dir)
    print(data_root)
    all_image_path = list(data_root.glob('*/*'))    #*/*是获取文件夹下的所有文件及其子文件
    print(all_image_path)
    all_image_path = [str(path) for path in all_image_path] #获取所有图片的完整路径
    print(all_image_path)
    random.shuffle(all_image_path)  #打乱

    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())    #获取图像文件夹名字
    label_to_index = dict((name, index) for index, name in enumerate(label_names))  #创建字典对象,设置图像名称的映射为整数
    print(label_to_index)   #OK:0,Yeah:1
	# 获取所有图像对应的标签
    all_image_label = [label_to_index[pathlib.Path(p).parent.name]for p in all_image_path]  #获取每个图象的父类名称,并变成数值,0101...
    print(len(all_image_label))	#显示获取的数据量
    index_to_label = dict((v,k) for k,v in label_to_index.items())  #获取数值对应的标签名字,以备后用

    image_patn = all_image_path[5]
    image_show = (1 + load_preprocess_image(image_patn)) / 2.  # 要变成image/255.才能正常显示
    plt.imshow(image_show)  # 这里是测试图片能不能正常显示
    plt.show()

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
    image_dataset = path_ds.map(load_preprocess_image)  # 这里才是把所有图片提取出来,前面的都是路径

    label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
    dataset = tf.data.Dataset.zip((image_dataset, label_dataset))  # 做成数据集,zip将label和image对应起来
    image_count = len(all_image_path)  # 数据集的数量
    test_count = int(image_count * 0.2)
    train_count = image_count - test_count
    print(test_count, train_count)

    train_dataset = dataset.skip(test_count)  # 跳过test_count构成数据集
    test_dataset = dataset.take(test_count)  # 取test_count构成数据集
    BATCH_SIZE = batch_size # buffer_size = train_count
    train_dataset = train_dataset.shuffle(buffer_size=150).repeat(3).batch(BATCH_SIZE) 
    # 数据集数量不够则加个.repeat()
    test_dataset = test_dataset.batch(BATCH_SIZE)
    # 数据增强,OK,之前打乱过了,只需要对训练集数据增强
    train_dataset = train_dataset.map(augment)
    return train_dataset,test_dataset

4.从路径提取图片,并进行归一化处理

def load_preprocess_image(img_path):
    img_raw = tf.io.read_file(img_path)           #读取路径
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)   #解码图片 decode_image通用,但不会返回shape,改成对应的格式
    img_tensor = tf.image.resize(img_tensor,[160,160])      #改变图片大小
    img_tensor = tf.cast(img_tensor, tf.float32)  #转换数据类型
    img = img_tensor/127.5-1                   #标准化,归一化
    return img

5.对图片进行数据增强的函数

根据需要选择。

def augment(image,label):
    #随机进行水平翻转
    image = tf.image.random_flip_left_right(image)
    #随机设置对比度
    image = tf.image.random_contrast(image,lower=0.0,upper=1.0)
    #垂直翻转
    image = tf.image.random_flip_up_down(image)
    #设置亮度
    image = tf.image.random_brightness(image,max_delta=0.5)
    #设置色度
    image = tf.image.random_hue(image,max_delta=0.3)
    #设置饱和度
    image = tf.image.random_saturation(image,lower=0.3,upper=0.5)
    return image,label

6.显示9张图片,可以用来看数据增强后图片效果

这个函数会比较耗费时间,不需要每次都调用它。

def test_image(train_dataset):
    #用一次就行了
    plt.figure(figsize=(12,12))
    for batch in tfds.as_numpy(train_dataset):  #这里耗时间很久。。尽量不用
        for i in range(9):
            image, label = (1+batch[0][i])/2., batch[1][i]   #image前面进行了归一化,因此这里要先恢复过来,才能正常显示图像
            plt.subplot(3,3,i+1)
            plt.imshow(image)
            plt.grid(False)
        break
    plt.show()

转载:https://blog.csdn.net/weixin_45371989/article/details/106319320
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场