1 引言
元学习是今年来新起的一种深度学习任务,它主要是想训练出具有强学习能力的神经网络。元学习领域一开始是一个小众的领域,之前很多年都没有很好的进展,直到Finn, C.在就读博士期间发表了一篇元学习的论文,也就是大名鼎鼎的MAML,它在回归,分类,强化学习三个任务上都达到了当时最好的性能。
2 数据集
Omniglot 一般会被戏称为 MNIST 的转置,大家可以想想为什么?Omniglot 数据集包含来自 5050 个不同字母的 16231623 个不同手写字符。每一个字符都是由 2020 个不同的人通过亚马逊的 Mechanical Turk 在线绘制的。
Omniglot 数据集总共包含 5050 个字母。我们通常将这些分成一组包含 3030 个字母的背景(background)集和一组包含 2020 个字母的评估(evaluation)集。
更具挑战性的表示学习任务是使用较小的背景集 “background small 1” 和 “background small 2”。每一个都只包含 55 个字母, 更类似于一个成年人在学习一般的字符时可能遇到的经验。
3 代码分段详解
3.1 数据预处理
import torch
import numpy as np
import os
import zipfile
root_path = './../datasets'
processed_folder = os.path.join(root_path)
zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot_standard.zip'), 'r')

# 数据预处理
root_dir = './../datasets/omniglot/python'
import torchvision.transforms as transforms
from PIL import Image
an example of img_items:
( '0709_17.png',
def find_classes(root_dir):
img_items = []
for (root, dirs, files) in os.walk(root_dir):
for file in files:
if (file.endswith("png")):
r = root.split('/')
img_items.append((file, r[-2] + "/" + r[-1], root))
print("== Found %d items " % len(img_items))
return img_items
## 构建一个词典{class:idx}
def index_classes(items):
class_idx = {}
count = 0
for item in items:
if item[1] not in class_idx:
class_idx[item[1]] = count
count += 1
print('== Found {} classes'.format(len(class_idx)))
return class_idx
img_items = find_classes(root_dir)
class_idx = index_classes(img_items)
temp = dict()
for imgname, classes, dirs in img_items:
img = '{}/{}'.format(dirs, imgname)
label = class_idx[classes]
transform = transforms.Compose([lambda img: Image.open(img).convert('L'),
lambda img: img.resize((28,28)),
lambda img: np.reshape(img, (28,28,1)),
lambda img: np.transpose(img, [2,0,1]),
lambda img: img/255.
img = transform(img)
if label in temp.keys():
temp[label] = [img]
print('begin to generate omniglot.npy')
## 移除标签信息,每个标签包含20个样本
img_list = []
for label, imgs in temp.items():
img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]
print('data shape:{}'.format(img_list.shape)) # (1623, 20, 1, 28, 28)
temp = []
np.save(os.path.join(root_dir, 'omniglot.npy'), img_list)
3.3 构造训练集和测试集
img_list = np.load(os.path.join(root_dir, 'omniglot.npy')) # (1623, 20, 1, 28, 28)
x_train = img_list[:1200]
x_test = img_list[1200:]
num_classes = img_list.shape[0]
datasets = {'train': x_train, 'test': x_test}
def next(mode='train'):
Gets next batch from the dataset with name.
:param mode: The name of the splitting (one of "train", "val", "test")
# update cache if indexes is larger than len(data_cache)
if indexes[mode] >= len(datasets_cache[mode]):
indexes[mode] = 0
datasets_cache[mode] = load_data_cache(datasets[mode])
next_batch = datasets_cache[mode][indexes[mode]]
indexes[mode] += 1
return next_batch
3.2 构造Base-Learner
if params is None:
params = self.vars
weight, bias = params[0], params[1] # 第1个CONV层
x = F.conv2d(x, weight, bias, stride = 2, padding = 2)
weight, bias = params[2], params[3] # 第1个BN层
running_mean, running_var = self.vars_bn[0], self.vars_bn[1]
x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)
x = F.max_pool2d(x,kernel_size=2) #第1个MAX_POOL层
x = F.relu(x, inplace = [True]) #第1个relu
CONV 层-> BN 层 -> POOL层 -> ReLU层,以上四个层组成一个块,然后将四个类似的块堆叠起来,结尾接一个Flatten层和一个Linear层。
class MetaLearner(nn.Module):
def __init__(self):
super(MetaLearner, self).__init__()
self.update_step = 5 ## task-level inner update steps
self.update_step_test = 5
self.net = BaseNet()
self.meta_lr = 2e-4
self.base_lr = 4 * 1e-2
self.inner_lr = 0.4
self.outer_lr = 1e-2
self.meta_optim = torch.optim.Adam(self.net.parameters(), lr = self.meta_lr)
def forward(self,x_spt, y_spt, x_qry, y_qry):
# 初始化
task_num, ways, shots, h, w = x_spt.size()
query_size = x_qry.size(1) # 75 = 15 * 5
loss_list_qry = [0 for _ in range(self.update_step + 1)]
correct_list = [0 for _ in range(self.update_step + 1)]
for i in range(task_num):
## 第0步更新
y_hat = self.net(x_spt[i], params = None, bn_training=True) # (ways * shots, ways)
loss = F.cross_entropy(y_hat, y_spt[i])
grad = torch.autograd.grad(loss, self.net.parameters())
tuples = zip(grad, self.net.parameters()) ## 将梯度和参数\theta一一对应起来
# fast_weights这一步相当于求了一个\theta - \alpha*\nabla(L)
fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
# 在query集上测试,计算准确率
# 这一步使用更新前的数据
with torch.no_grad():
y_hat = self.net(x_qry[i], self.net.parameters(), bn_training = True)
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_list_qry[0] += loss_qry
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)
correct = torch.eq(pred_qry, y_qry[i]).sum().item()
correct_list[0] += correct
# 使用更新后的数据在query集上测试。
with torch.no_grad():
y_hat = self.net(x_qry[i], fast_weights, bn_training = True)
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_list_qry[1] += loss_qry
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)
correct = torch.eq(pred_qry, y_qry[i]).sum().item()
correct_list[1] += correct
for k in range(1, self.update_step):
y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)
loss = F.cross_entropy(y_hat, y_spt[i])
grad = torch.autograd.grad(loss, fast_weights)
tuples = zip(grad, fast_weights)
fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_list_qry[k+1] += loss_qry
with torch.no_grad():
pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)
correct = torch.eq(pred_qry, y_qry[i]).sum().item()
correct_list[k+1] += correct
# print('hello')
loss_qry = loss_list_qry[-1] / task_num
self.meta_optim.zero_grad() # 梯度清零
accs = np.array(correct_list) / (query_size * task_num)
loss = np.array(loss_list_qry) / ( task_num)
return accs,loss
def finetunning(self, x_spt, y_spt, x_qry, y_qry):
assert len(x_spt.shape) == 4
query_size = x_qry.size(0)
correct_list = [0 for _ in range(self.update_step_test + 1)]
new_net = deepcopy(self.net)
y_hat = new_net(x_spt)
loss = F.cross_entropy(y_hat, y_spt)
grad = torch.autograd.grad(loss, new_net.parameters())
fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, new_net.parameters())))
# 在query集上测试,计算准确率
# 这一步使用更新前的数据
with torch.no_grad():
y_hat = new_net(x_qry, params = new_net.parameters(), bn_training = True)
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)
correct = torch.eq(pred_qry, y_qry).sum().item()
correct_list[0] += correct
# 使用更新后的数据在query集上测试。
with torch.no_grad():
y_hat = new_net(x_qry, params = fast_weights, bn_training = True)
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)
correct = torch.eq(pred_qry, y_qry).sum().item()
correct_list[1] += correct
for k in range(1, self.update_step_test):
y_hat = new_net(x_spt, params = fast_weights, bn_training=True)
loss = F.cross_entropy(y_hat, y_spt)
grad = torch.autograd.grad(loss, fast_weights)
fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, fast_weights)))
y_hat = new_net(x_qry, fast_weights, bn_training=True)
with torch.no_grad():
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)
correct = torch.eq(pred_qry, y_qry).sum().item()
correct_list[k+1] += correct
del new_net
accs = np.array(correct_list) / query_size
return accs
4 全部源码
5 实验结果