1.说明
1.1 数据集放置格式说明
数据集文件夹下的不同类别图片需要先进行整理,放在不同的子文件夹,放置格式如图所示:
这里只有2类,当然多个分类也行,这个对分类类别的数量没有要求。
1.2 函数引用说明
在其他程序中引用这里的函数,引用方法如下:
import sys #绝对路径引用,不然引用load_data会报错
#load_data所在程序路径
sys.path.append(r'E:\Pycharm\project\yeah&ok\load_data')
from load_data import load_data_func,test_image,augment
一般只需要引用load_data_func和test_image即可。
1.3 加载数据集程序中函数的使用方法说明
经过载入数据集程序的处理后,加载数据集就很简单了,加载方法如下:
ata_dir = 'E:\Pycharm\project\yeah&ok\dataset'
Batch_size = 32 #批处理尺寸
train_dataset,test_dataset = load_data_func(data_dir,batchsize=Batch_size)
test_image(train_dataset) #显示9张图像
然后就能继续进行网络结构的搭建,进行训练等步骤了。
2.配置库文件(开始)
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import random
import tensorflow_datasets as tfds
3.主函数
主要函数:功能是输入数据集路径与批处理大小,返回训练集与测试集。
def load_data_func(data_dir,batch_size):
data_root = pathlib.Path(data_dir) #读取路径,创建path对象
print(data_dir)
print(data_root)
all_image_path = list(data_root.glob('*/*')) #*/*是获取文件夹下的所有文件及其子文件
print(all_image_path)
all_image_path = [str(path) for path in all_image_path] #获取所有图片的完整路径
print(all_image_path)
random.shuffle(all_image_path) #打乱
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir()) #获取图像文件夹名字
label_to_index = dict((name, index) for index, name in enumerate(label_names)) #创建字典对象,设置图像名称的映射为整数
print(label_to_index) #OK:0,Yeah:1
# 获取所有图像对应的标签
all_image_label = [label_to_index[pathlib.Path(p).parent.name]for p in all_image_path] #获取每个图象的父类名称,并变成数值,0101...
print(len(all_image_label)) #显示获取的数据量
index_to_label = dict((v,k) for k,v in label_to_index.items()) #获取数值对应的标签名字,以备后用
image_patn = all_image_path[5]
image_show = (1 + load_preprocess_image(image_patn)) / 2. # 要变成image/255.才能正常显示
plt.imshow(image_show) # 这里是测试图片能不能正常显示
plt.show()
path_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
image_dataset = path_ds.map(load_preprocess_image) # 这里才是把所有图片提取出来,前面的都是路径
label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
dataset = tf.data.Dataset.zip((image_dataset, label_dataset)) # 做成数据集,zip将label和image对应起来
image_count = len(all_image_path) # 数据集的数量
test_count = int(image_count * 0.2)
train_count = image_count - test_count
print(test_count, train_count)
train_dataset = dataset.skip(test_count) # 跳过test_count构成数据集
test_dataset = dataset.take(test_count) # 取test_count构成数据集
BATCH_SIZE = batch_size # buffer_size = train_count
train_dataset = train_dataset.shuffle(buffer_size=150).repeat(3).batch(BATCH_SIZE)
# 数据集数量不够则加个.repeat()
test_dataset = test_dataset.batch(BATCH_SIZE)
# 数据增强,OK,之前打乱过了,只需要对训练集数据增强
train_dataset = train_dataset.map(augment)
return train_dataset,test_dataset
4.从路径提取图片,并进行归一化处理
def load_preprocess_image(img_path):
img_raw = tf.io.read_file(img_path) #读取路径
img_tensor = tf.image.decode_jpeg(img_raw,channels=3) #解码图片 decode_image通用,但不会返回shape,改成对应的格式
img_tensor = tf.image.resize(img_tensor,[160,160]) #改变图片大小
img_tensor = tf.cast(img_tensor, tf.float32) #转换数据类型
img = img_tensor/127.5-1 #标准化,归一化
return img
5.对图片进行数据增强的函数
根据需要选择。
def augment(image,label):
#随机进行水平翻转
image = tf.image.random_flip_left_right(image)
#随机设置对比度
image = tf.image.random_contrast(image,lower=0.0,upper=1.0)
#垂直翻转
image = tf.image.random_flip_up_down(image)
#设置亮度
image = tf.image.random_brightness(image,max_delta=0.5)
#设置色度
image = tf.image.random_hue(image,max_delta=0.3)
#设置饱和度
image = tf.image.random_saturation(image,lower=0.3,upper=0.5)
return image,label
6.显示9张图片,可以用来看数据增强后图片效果
这个函数会比较耗费时间,不需要每次都调用它。
def test_image(train_dataset):
#用一次就行了
plt.figure(figsize=(12,12))
for batch in tfds.as_numpy(train_dataset): #这里耗时间很久。。尽量不用
for i in range(9):
image, label = (1+batch[0][i])/2., batch[1][i] #image前面进行了归一化,因此这里要先恢复过来,才能正常显示图像
plt.subplot(3,3,i+1)
plt.imshow(image)
plt.grid(False)
break
plt.show()
转载:https://blog.csdn.net/weixin_45371989/article/details/106319320
查看评论