小言_互联网的博客

PyG (PyTorch Geometric) 异质图神经网络HGNN

791人阅读  评论(0)

诸神缄默不语-个人CSDN博文目录
PyTorch Geometric (PyG) 包文档与官方代码示例学习笔记(持续更新ing…)

本文介绍使用PyG实现异质图神经网络(HGNN)的相关操作。

本文主要参考PyG文档异质图部分:Heterogeneous Graph Learning — pytorch_geometric documentation
相关官方代码示例:https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero

注意:①很多操作可以不使用PyG的HeteroData对象就直接实现。②部分数据集无法直接通过大陆网络下载的解决方式可参考我之前写的博文:PyG (PyTorch Geometric) Dropbox系图数据集无法下载的解决方案(AMiner, DBLP, IMDB, LastFM)(持续更新ing…) ③我用的是pip安装的2.2.0版本torch-geometric,部分较早的版本可能不支持T.AddMetaPaths对象的drop_orig_edge_types属性

1. 示例数据集介绍

ogbn-mag异质图的schema:

共有1,939,743个节点,21,111,007条边。
数据集的原始任务是节点分类,预测paper的venue(会议或期刊)。

在PyG中的调用方法(原始数据中只有paper节点的特征,在这里是用preprocess属性增加了其他节点通过图结构获取到的特征):

from torch_geometric.datasets import OGB_MAG

dataset = OGB_MAG(root='./data', preprocess='metapath2vec')
#preprocess也可以用TransE等
data = dataset[0]

2. HeteroData对象

from torch_geometric.data import HeteroData

data = HeteroData()

data['paper'].x = ... # [num_papers, num_features_paper]
data['author'].x = ... # [num_authors, num_features_author]
data['institution'].x = ... # [num_institutions, num_features_institution]
data['field_of_study'].x = ... # [num_field, num_features_field]

data['paper', 'cites', 'paper'].edge_index = ... # [2, num_edges_cites]
data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges_writes]
data['author', 'affiliated_with', 'institution'].edge_index = ... # [2, num_edges_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_index = ... # [2, num_edges_topic]

data['paper', 'cites', 'paper'].edge_attr = ... # [num_edges_cites, num_features_cites]
data['author', 'writes', 'paper'].edge_attr = ... # [num_edges_writes, num_features_writes]
data['author', 'affiliated_with', 'institution'].edge_attr = ... # [num_edges_affiliated, num_features_affiliated]
data['paper', 'has_topic', 'field_of_study'].edge_attr = ... # [num_edges_topic, num_features_topic]

 

节点类型用字符串切片,边类型用字符串三元组切片

data.{attribute_name}_dict提取对应的类名和值。这个可以作为GNN模型的传入项:

model = HeteroGNN(...)

output = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)

以在第一节中介绍的ogbn-mag数据为例,data对象打印出来就是这样的:

HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 5416271] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] }
)

 

3. 可用方法

  1. 切片获取节点/边对象,返回字典形式,键是对象的属性(如x),值是属性值
paper_node_data = data['paper']
cites_edge_data = data['paper', 'cites', 'paper']
  1. 如果边类型或节点对类型可以唯一确定一种边,那这样也可以:
cites_edge_data = data['paper', 'paper']
cites_edge_data = data['cites']
  1. 增删属性、节点类型、边类型:
data['paper'].year = ...    # Setting a new paper attribute
del data['field_of_study']  # Deleting 'field_of_study' node type
del data['has_topic']       # Deleting 'has_topic' edge type
  1. metadata:一个有两个元素的元组,元素是列表。第一个列表的元素是节点类型,第二个列表的元素的边类型。
node_types, edge_types = data.metadata()
print(node_types)
['paper', 'author', 'institution']
print(edge_types)
[('paper', 'cites', 'paper'),
('author', 'writes', 'paper'),
('author', 'affiliated_with', 'institution')]
  1. 转换设备:
data = data.to('cuda:0')
data = data.cpu()
  1. 图中是否有孤立点、自环、图是否无向、转换为同质图(注意:1. 我测试了一下,如果部分节点没有特征,如在构建ogbn-mag图时没有使用preprocess入参,则转换为同质图时所有节点仍然都具有x特征(虽然在异质图中只有paper节点有特征),但本来没有标签的异质图中节点对应的特征值是NaN或者-1(具体是哪个我没看出来应该怎么判断,反正两种情况都是可能发生的)2. 原图中对应的节点/边类型会变成node_typeedge_type对应的数字)
data.has_isolated_nodes()
data.has_self_loops()
data.is_undirected()
homogeneous_data = data.to_homogeneous()
  1. 使用torch_geometric.transforms对异质图对象进行转换(很多类似同质图上的操作):
data = T.ToUndirected()(data)
data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)

将异质图转换为无向图:增加反向边,使信息传播可以在各边上双向进行;如有必要还会增加反向边类型

示例:

import torch_geometric.transforms as T

T.ToUndirected()(data)

输出:

HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 10792672] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] },
  (institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] },
  (paper, rev_writes, author)={ edge_index=[2, 7145660] },
  (field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] }
)

 

增加自环:每种边都会加

4. 节点表征类

lazy initialization:

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

注意事项:一是检查张量的dtype(edge_indextorch.longx要统一(一般就统一成torch.float)),二是检查edge_index没有超界(我常犯的错误是出现以节点数为索引的节点)

问题一会出现的bug可参考:AssertionError when implementing heterogenous GNN · Discussion #5175 · pyg-team/pytorch_geometric

问题二自查可参考:

for edge_type in data.edge_types:
    src, _, dst = edge_type
    assert data[edge_type].edge_index[0].max() < data[src].num_nodes
    assert data[edge_type].edge_index[1].max() < data[dst].num_nodes

解决方式可参考我之前写的博文:RuntimeError: CUDA error: device-side assert triggered

4.1 将同质图GNN直接转换为异质图GNN

也就是直接正常定义GNN模型(有些同质图GNN无法应用于异质图),转换为异质图GNN就是在每种边类型上运行一个同质图GNN模型的实例

torch_geometric.nn.to_hetero()
torch_geometric.nn.to_hetero_with_bases()

示例代码:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

 

in_channels输入tuple形式,是为了二部图的信息传播(事实上我也没搞懂这是啥意思),事实上在本例中用int输入也可以:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(-1, hidden_channels)
        self.conv2 = SAGEConv(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

 

如果没有转换为无向图的话,由于author节点没有入边,就会导致NotImplementedError问题。报的警告是:
env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py:145: UserWarning: There exist node types ({'author'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behaviour. warnings.warn(

带可学习skip-connections(就是每一层卷完的结果再加上输入的线性转换结果)的版本:

from torch_geometric.nn import GATConv, Linear, to_hetero

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(-1, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x

model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

 

可参考的训练用代码:

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['paper'].train_mask
    loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

4.2 使用HeteroConv定义GNN

可以给不同的边定义不同的GNN算子

torch_geometric.nn.conv.HeteroConv
文档:https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.HeteroConv

示例代码:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
   
                ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
                ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
                ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {
   key: x.relu() for key, x in x_dict.items()}
        return self.lin(x_dict['author'])

model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes,
                  num_layers=2)

 

4.3 使用已有或手写的异质图算子

以HGT模型为例:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import HGTConv, Linear

dataset = OGB_MAG(root='/data/pyg_data',preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x).relu_()

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict['author'])

model = HGT(hidden_channels=64, out_channels=dataset.num_classes,
            num_heads=2, num_layers=2)

 

5. 节点分类任务

5.1 whole-batch

5.1.1 使用已有的异质图算子

5.1.1.1 HAN

参考https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/han_imdb.py

from typing import Dict, List, Union

import torch
import torch.nn.functional as F
from torch import nn

import torch_geometric.transforms as T
from torch_geometric.datasets import IMDB
from torch_geometric.nn import HANConv

metapaths = [[('movie', 'actor'), ('actor', 'movie')],
             [('movie', 'director'), ('director', 'movie')]]
transform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True,
                           drop_unconnected_node_types=True)
dataset = IMDB('/data/pyg_data/IMDB', transform=transform)
data = dataset[0]
print(data)

class HAN(nn.Module):
    def __init__(self, in_channels: Union[int, Dict[str, int]],
                 out_channels: int, hidden_channels=128, heads=8):
        super().__init__()
        self.han_conv = HANConv(in_channels, hidden_channels, heads=heads,
                                dropout=0.6, metadata=data.metadata())
        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        out = self.han_conv(x_dict, edge_index_dict)
        out = self.lin(out['movie'])
        return out

model = HAN(in_channels=-1, out_channels=3)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

def train() -> float:
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['movie'].train_mask
    loss = F.cross_entropy(out[mask], data['movie'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test() -> List[float]:
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['movie'][split]
        acc = (pred[mask] == data['movie'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

best_val_acc = 0
start_patience = patience = 100
for epoch in range(1, 200):

    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {
     epoch:03d}, Loss: {
     loss:.4f}, Train: {
     train_acc:.4f}, '
          f'Val: {
     val_acc:.4f}, Test: {
     test_acc:.4f}')

    if best_val_acc <= val_acc:
        patience = start_patience
        best_val_acc = val_acc
    else:
        patience -= 1

    if patience <= 0:
        print('Stopping training as validation accuracy did not improve '
              f'for {
     start_patience} epochs')
        break

 

输出:

HeteroData(
  metapath_dict={
    (movie, metapath_0, movie)=[2],
    (movie, metapath_1, movie)=[2]
  },
  movie={
    x=[4278, 3066],
    y=[4278],
    train_mask=[4278],
    val_mask=[4278],
    test_mask=[4278]
  },
  (movie, metapath_0, movie)={ edge_index=[2, 85358] },
  (movie, metapath_1, movie)={ edge_index=[2, 17446] }
)
Epoch: 001, Loss: 1.1020, Train: 0.5125, Val: 0.4100, Test: 0.3890
Epoch: 002, Loss: 1.0783, Train: 0.5575, Val: 0.4075, Test: 0.3813
Epoch: 003, Loss: 1.0498, Train: 0.6350, Val: 0.4325, Test: 0.4112
Epoch: 004, Loss: 1.0205, Train: 0.7075, Val: 0.4850, Test: 0.4448
Epoch: 005, Loss: 0.9788, Train: 0.7375, Val: 0.5050, Test: 0.4669
Epoch: 006, Loss: 0.9410, Train: 0.7600, Val: 0.5225, Test: 0.4796
Epoch: 007, Loss: 0.8921, Train: 0.7750, Val: 0.5375, Test: 0.4937
Epoch: 008, Loss: 0.8517, Train: 0.8000, Val: 0.5475, Test: 0.5003
Epoch: 009, Loss: 0.7975, Train: 0.8175, Val: 0.5475, Test: 0.5135
Epoch: 010, Loss: 0.7488, Train: 0.8475, Val: 0.5525, Test: 0.5216
Epoch: 011, Loss: 0.7133, Train: 0.8625, Val: 0.5575, Test: 0.5308
Epoch: 012, Loss: 0.6626, Train: 0.8875, Val: 0.5700, Test: 0.5443
Epoch: 013, Loss: 0.6171, Train: 0.9050, Val: 0.5900, Test: 0.5552
Epoch: 014, Loss: 0.5769, Train: 0.9225, Val: 0.5925, Test: 0.5710
Epoch: 015, Loss: 0.5236, Train: 0.9375, Val: 0.5900, Test: 0.5785
Epoch: 016, Loss: 0.4929, Train: 0.9425, Val: 0.5925, Test: 0.5851
Epoch: 017, Loss: 0.4456, Train: 0.9375, Val: 0.5925, Test: 0.5868
Epoch: 018, Loss: 0.4266, Train: 0.9375, Val: 0.5825, Test: 0.5909
Epoch: 019, Loss: 0.3856, Train: 0.9425, Val: 0.5900, Test: 0.5926
Epoch: 020, Loss: 0.3525, Train: 0.9425, Val: 0.5900, Test: 0.5909
Epoch: 021, Loss: 0.3250, Train: 0.9450, Val: 0.5975, Test: 0.5897
Epoch: 022, Loss: 0.2900, Train: 0.9500, Val: 0.6050, Test: 0.5831
Epoch: 023, Loss: 0.2754, Train: 0.9525, Val: 0.6075, Test: 0.5825
Epoch: 024, Loss: 0.2603, Train: 0.9500, Val: 0.6075, Test: 0.5802
Epoch: 025, Loss: 0.2436, Train: 0.9500, Val: 0.6050, Test: 0.5739
Epoch: 026, Loss: 0.2251, Train: 0.9525, Val: 0.6000, Test: 0.5722
Epoch: 027, Loss: 0.2156, Train: 0.9500, Val: 0.6000, Test: 0.5733
Epoch: 028, Loss: 0.2077, Train: 0.9525, Val: 0.5950, Test: 0.5702
Epoch: 029, Loss: 0.1806, Train: 0.9550, Val: 0.5900, Test: 0.5699
Epoch: 030, Loss: 0.1942, Train: 0.9675, Val: 0.5975, Test: 0.5707
Epoch: 031, Loss: 0.1899, Train: 0.9750, Val: 0.6050, Test: 0.5693
Epoch: 032, Loss: 0.1879, Train: 0.9800, Val: 0.6050, Test: 0.5687
Epoch: 033, Loss: 0.1759, Train: 0.9825, Val: 0.6000, Test: 0.5684
Epoch: 034, Loss: 0.1706, Train: 0.9825, Val: 0.5950, Test: 0.5670
Epoch: 035, Loss: 0.1678, Train: 0.9800, Val: 0.5925, Test: 0.5656
Epoch: 036, Loss: 0.1655, Train: 0.9750, Val: 0.5950, Test: 0.5647
Epoch: 037, Loss: 0.1561, Train: 0.9750, Val: 0.6025, Test: 0.5656
Epoch: 038, Loss: 0.1588, Train: 0.9775, Val: 0.6025, Test: 0.5644
Epoch: 039, Loss: 0.1502, Train: 0.9750, Val: 0.6025, Test: 0.5644
Epoch: 040, Loss: 0.1535, Train: 0.9775, Val: 0.6000, Test: 0.5638
Epoch: 041, Loss: 0.1502, Train: 0.9800, Val: 0.6000, Test: 0.5633
Epoch: 042, Loss: 0.1638, Train: 0.9800, Val: 0.6000, Test: 0.5621
Epoch: 043, Loss: 0.1530, Train: 0.9800, Val: 0.6000, Test: 0.5624
Epoch: 044, Loss: 0.1566, Train: 0.9800, Val: 0.5975, Test: 0.5624
Epoch: 045, Loss: 0.1578, Train: 0.9800, Val: 0.6150, Test: 0.5610
Epoch: 046, Loss: 0.1441, Train: 0.9800, Val: 0.6150, Test: 0.5615
Epoch: 047, Loss: 0.1430, Train: 0.9825, Val: 0.6175, Test: 0.5604
Epoch: 048, Loss: 0.1389, Train: 0.9875, Val: 0.6150, Test: 0.5578
Epoch: 049, Loss: 0.1396, Train: 0.9875, Val: 0.6200, Test: 0.5566
Epoch: 050, Loss: 0.1547, Train: 0.9875, Val: 0.6150, Test: 0.5610
Epoch: 051, Loss: 0.1471, Train: 0.9875, Val: 0.6125, Test: 0.5644
Epoch: 052, Loss: 0.1398, Train: 0.9900, Val: 0.6150, Test: 0.5647
Epoch: 053, Loss: 0.1393, Train: 0.9875, Val: 0.6125, Test: 0.5644
Epoch: 054, Loss: 0.1542, Train: 0.9850, Val: 0.6075, Test: 0.5638
Epoch: 055, Loss: 0.1435, Train: 0.9875, Val: 0.6150, Test: 0.5627
Epoch: 056, Loss: 0.1338, Train: 0.9850, Val: 0.6225, Test: 0.5633
Epoch: 057, Loss: 0.1311, Train: 0.9875, Val: 0.6125, Test: 0.5618
Epoch: 058, Loss: 0.1353, Train: 0.9900, Val: 0.6150, Test: 0.5592
Epoch: 059, Loss: 0.1308, Train: 0.9900, Val: 0.6050, Test: 0.5581
Epoch: 060, Loss: 0.1369, Train: 0.9900, Val: 0.6100, Test: 0.5584
Epoch: 061, Loss: 0.1303, Train: 0.9900, Val: 0.6075, Test: 0.5581
Epoch: 062, Loss: 0.1279, Train: 0.9900, Val: 0.6025, Test: 0.5604
Epoch: 063, Loss: 0.1355, Train: 0.9875, Val: 0.6025, Test: 0.5621
Epoch: 064, Loss: 0.1184, Train: 0.9925, Val: 0.6075, Test: 0.5664
Epoch: 065, Loss: 0.1291, Train: 0.9925, Val: 0.6025, Test: 0.5690
Epoch: 066, Loss: 0.1242, Train: 0.9900, Val: 0.6000, Test: 0.5676
Epoch: 067, Loss: 0.1238, Train: 0.9900, Val: 0.6025, Test: 0.5670
Epoch: 068, Loss: 0.1121, Train: 0.9900, Val: 0.6025, Test: 0.5656
Epoch: 069, Loss: 0.1126, Train: 0.9900, Val: 0.6050, Test: 0.5635
Epoch: 070, Loss: 0.1208, Train: 0.9900, Val: 0.6050, Test: 0.5612
Epoch: 071, Loss: 0.1059, Train: 0.9900, Val: 0.6075, Test: 0.5589
Epoch: 072, Loss: 0.1098, Train: 0.9900, Val: 0.6025, Test: 0.5581
Epoch: 073, Loss: 0.1198, Train: 0.9950, Val: 0.5950, Test: 0.5598
Epoch: 074, Loss: 0.1214, Train: 0.9925, Val: 0.5925, Test: 0.5621
Epoch: 075, Loss: 0.1016, Train: 0.9925, Val: 0.5950, Test: 0.5601
Epoch: 076, Loss: 0.1145, Train: 0.9950, Val: 0.6000, Test: 0.5621
Epoch: 077, Loss: 0.1148, Train: 0.9950, Val: 0.6000, Test: 0.5615
Epoch: 078, Loss: 0.1135, Train: 0.9925, Val: 0.5975, Test: 0.5612
Epoch: 079, Loss: 0.1104, Train: 0.9925, Val: 0.6000, Test: 0.5624
Epoch: 080, Loss: 0.1108, Train: 0.9900, Val: 0.6050, Test: 0.5572
Epoch: 081, Loss: 0.0916, Train: 0.9900, Val: 0.6050, Test: 0.5561
Epoch: 082, Loss: 0.1275, Train: 0.9900, Val: 0.6025, Test: 0.5581
Epoch: 083, Loss: 0.0970, Train: 1.0000, Val: 0.6025, Test: 0.5607
Epoch: 084, Loss: 0.0923, Train: 1.0000, Val: 0.6025, Test: 0.5592
Epoch: 085, Loss: 0.1089, Train: 1.0000, Val: 0.6025, Test: 0.5598
Epoch: 086, Loss: 0.1032, Train: 1.0000, Val: 0.6025, Test: 0.5598
Epoch: 087, Loss: 0.0983, Train: 1.0000, Val: 0.6000, Test: 0.5615
Epoch: 088, Loss: 0.0982, Train: 1.0000, Val: 0.5950, Test: 0.5615
Epoch: 089, Loss: 0.0849, Train: 1.0000, Val: 0.5925, Test: 0.5607
Epoch: 090, Loss: 0.0982, Train: 0.9975, Val: 0.5900, Test: 0.5610
Epoch: 091, Loss: 0.1133, Train: 1.0000, Val: 0.5950, Test: 0.5650
Epoch: 092, Loss: 0.0890, Train: 1.0000, Val: 0.5950, Test: 0.5664
Epoch: 093, Loss: 0.0935, Train: 1.0000, Val: 0.6000, Test: 0.5658
Epoch: 094, Loss: 0.0935, Train: 1.0000, Val: 0.6050, Test: 0.5673
Epoch: 095, Loss: 0.1027, Train: 1.0000, Val: 0.6075, Test: 0.5681
Epoch: 096, Loss: 0.0914, Train: 0.9975, Val: 0.6000, Test: 0.5679
Epoch: 097, Loss: 0.0908, Train: 0.9975, Val: 0.5900, Test: 0.5690
Epoch: 098, Loss: 0.1003, Train: 1.0000, Val: 0.5900, Test: 0.5667
Epoch: 099, Loss: 0.0835, Train: 1.0000, Val: 0.5875, Test: 0.5670
Epoch: 100, Loss: 0.0968, Train: 1.0000, Val: 0.5900, Test: 0.5670
Epoch: 101, Loss: 0.0868, Train: 1.0000, Val: 0.5900, Test: 0.5679
Epoch: 102, Loss: 0.0906, Train: 1.0000, Val: 0.6000, Test: 0.5681
Epoch: 103, Loss: 0.0967, Train: 1.0000, Val: 0.5975, Test: 0.5681
Epoch: 104, Loss: 0.0983, Train: 1.0000, Val: 0.5925, Test: 0.5699
Epoch: 105, Loss: 0.0775, Train: 1.0000, Val: 0.5975, Test: 0.5681
Epoch: 106, Loss: 0.0840, Train: 1.0000, Val: 0.5950, Test: 0.5664
Epoch: 107, Loss: 0.0962, Train: 1.0000, Val: 0.5950, Test: 0.5633
Epoch: 108, Loss: 0.0900, Train: 1.0000, Val: 0.5950, Test: 0.5621
Epoch: 109, Loss: 0.0831, Train: 1.0000, Val: 0.5975, Test: 0.5644
Epoch: 110, Loss: 0.0844, Train: 1.0000, Val: 0.5950, Test: 0.5653
Epoch: 111, Loss: 0.1017, Train: 0.9975, Val: 0.5925, Test: 0.5667
Epoch: 112, Loss: 0.0833, Train: 0.9975, Val: 0.5950, Test: 0.5661
Epoch: 113, Loss: 0.0840, Train: 0.9975, Val: 0.5875, Test: 0.5670
Epoch: 114, Loss: 0.0809, Train: 0.9975, Val: 0.5900, Test: 0.5664
Epoch: 115, Loss: 0.0854, Train: 0.9975, Val: 0.5950, Test: 0.5673
Epoch: 116, Loss: 0.0896, Train: 0.9975, Val: 0.5975, Test: 0.5687
Epoch: 117, Loss: 0.0999, Train: 1.0000, Val: 0.5975, Test: 0.5664
Epoch: 118, Loss: 0.0890, Train: 1.0000, Val: 0.5950, Test: 0.5667
Epoch: 119, Loss: 0.0780, Train: 1.0000, Val: 0.5900, Test: 0.5658
Epoch: 120, Loss: 0.0751, Train: 1.0000, Val: 0.5875, Test: 0.5670
Epoch: 121, Loss: 0.0693, Train: 1.0000, Val: 0.5950, Test: 0.5661
Epoch: 122, Loss: 0.0822, Train: 1.0000, Val: 0.5975, Test: 0.5664
Epoch: 123, Loss: 0.0782, Train: 1.0000, Val: 0.5925, Test: 0.5635
Epoch: 124, Loss: 0.0791, Train: 1.0000, Val: 0.5950, Test: 0.5627
Epoch: 125, Loss: 0.0958, Train: 1.0000, Val: 0.6000, Test: 0.5644
Epoch: 126, Loss: 0.0764, Train: 1.0000, Val: 0.5950, Test: 0.5650
Epoch: 127, Loss: 0.0878, Train: 1.0000, Val: 0.5900, Test: 0.5650
Epoch: 128, Loss: 0.0679, Train: 1.0000, Val: 0.5900, Test: 0.5641
Epoch: 129, Loss: 0.0791, Train: 1.0000, Val: 0.5900, Test: 0.5647
Epoch: 130, Loss: 0.0809, Train: 1.0000, Val: 0.5900, Test: 0.5647
Epoch: 131, Loss: 0.0740, Train: 1.0000, Val: 0.5850, Test: 0.5661
Epoch: 132, Loss: 0.0694, Train: 1.0000, Val: 0.5825, Test: 0.5647
Epoch: 133, Loss: 0.0859, Train: 1.0000, Val: 0.5875, Test: 0.5633
Epoch: 134, Loss: 0.0833, Train: 0.9975, Val: 0.5875, Test: 0.5638
Epoch: 135, Loss: 0.0797, Train: 1.0000, Val: 0.5900, Test: 0.5656
Epoch: 136, Loss: 0.0867, Train: 1.0000, Val: 0.5950, Test: 0.5696
Epoch: 137, Loss: 0.0811, Train: 1.0000, Val: 0.5975, Test: 0.5696
Epoch: 138, Loss: 0.0710, Train: 1.0000, Val: 0.5925, Test: 0.5713
Epoch: 139, Loss: 0.0603, Train: 1.0000, Val: 0.5950, Test: 0.5722
Epoch: 140, Loss: 0.0776, Train: 1.0000, Val: 0.5925, Test: 0.5719
Epoch: 141, Loss: 0.0705, Train: 1.0000, Val: 0.5975, Test: 0.5679
Epoch: 142, Loss: 0.0775, Train: 1.0000, Val: 0.5950, Test: 0.5679
Epoch: 143, Loss: 0.0700, Train: 1.0000, Val: 0.5975, Test: 0.5696
Epoch: 144, Loss: 0.0829, Train: 1.0000, Val: 0.5975, Test: 0.5727
Epoch: 145, Loss: 0.0697, Train: 1.0000, Val: 0.6000, Test: 0.5727
Epoch: 146, Loss: 0.0697, Train: 1.0000, Val: 0.6025, Test: 0.5750
Epoch: 147, Loss: 0.0706, Train: 1.0000, Val: 0.6075, Test: 0.5727
Epoch: 148, Loss: 0.0723, Train: 1.0000, Val: 0.5975, Test: 0.5690
Epoch: 149, Loss: 0.0771, Train: 1.0000, Val: 0.5950, Test: 0.5696
Epoch: 150, Loss: 0.0650, Train: 1.0000, Val: 0.6025, Test: 0.5699
Epoch: 151, Loss: 0.0802, Train: 1.0000, Val: 0.5950, Test: 0.5676
Epoch: 152, Loss: 0.0687, Train: 1.0000, Val: 0.5925, Test: 0.5710
Epoch: 153, Loss: 0.0705, Train: 1.0000, Val: 0.5925, Test: 0.5704
Epoch: 154, Loss: 0.0831, Train: 1.0000, Val: 0.5925, Test: 0.5696
Epoch: 155, Loss: 0.0714, Train: 1.0000, Val: 0.5900, Test: 0.5690
Epoch: 156, Loss: 0.0662, Train: 1.0000, Val: 0.5850, Test: 0.5635
Stopping training as validation accuracy did not improve for 100 epochs

 

5.1.1.2 HGT

参考https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hgt_dblp.py

示例代码:

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HGTConv, Linear

dataset = DBLP('/data/pyg_data/DBLP', transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
   
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return self.lin(x_dict['author'])

model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask], data['author'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['author'][split]
        acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {
     epoch:03d}, Loss: {
     loss:.4f}, Train: {
     train_acc:.4f}, '
          f'Val: {
     val_acc:.4f}, Test: {
     test_acc:.4f}')

 

输出:

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1]
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)
Epoch: 001, Loss: 1.3967, Train: 0.2550, Val: 0.2700, Test: 0.2539
Epoch: 002, Loss: 1.3708, Train: 0.6750, Val: 0.4825, Test: 0.5272
Epoch: 003, Loss: 1.3459, Train: 0.6200, Val: 0.4525, Test: 0.5302
Epoch: 004, Loss: 1.3173, Train: 0.5675, Val: 0.4300, Test: 0.4992
Epoch: 005, Loss: 1.2809, Train: 0.5550, Val: 0.4150, Test: 0.4836
Epoch: 006, Loss: 1.2323, Train: 0.5825, Val: 0.4175, Test: 0.4814
Epoch: 007, Loss: 1.1662, Train: 0.6575, Val: 0.4325, Test: 0.5026
Epoch: 008, Loss: 1.0761, Train: 0.7475, Val: 0.4775, Test: 0.5511
Epoch: 009, Loss: 0.9564, Train: 0.8425, Val: 0.5575, Test: 0.6098
Epoch: 010, Loss: 0.8064, Train: 0.9400, Val: 0.6275, Test: 0.6831
Epoch: 011, Loss: 0.6348, Train: 0.9750, Val: 0.7125, Test: 0.7446
Epoch: 012, Loss: 0.4627, Train: 0.9950, Val: 0.7225, Test: 0.7768
Epoch: 013, Loss: 0.3151, Train: 0.9975, Val: 0.7275, Test: 0.7842
Epoch: 014, Loss: 0.2002, Train: 0.9975, Val: 0.7200, Test: 0.7835
Epoch: 015, Loss: 0.1142, Train: 0.9950, Val: 0.7200, Test: 0.7706
Epoch: 016, Loss: 0.0614, Train: 0.9950, Val: 0.7250, Test: 0.7596
Epoch: 017, Loss: 0.0336, Train: 1.0000, Val: 0.7125, Test: 0.7633
Epoch: 018, Loss: 0.0163, Train: 1.0000, Val: 0.6950, Test: 0.7676
Epoch: 019, Loss: 0.0068, Train: 1.0000, Val: 0.7150, Test: 0.7688
Epoch: 020, Loss: 0.0033, Train: 1.0000, Val: 0.7175, Test: 0.7630
Epoch: 021, Loss: 0.0022, Train: 1.0000, Val: 0.7200, Test: 0.7630
Epoch: 022, Loss: 0.0016, Train: 1.0000, Val: 0.7175, Test: 0.7587
Epoch: 023, Loss: 0.0010, Train: 1.0000, Val: 0.7175, Test: 0.7624
Epoch: 024, Loss: 0.0006, Train: 1.0000, Val: 0.7300, Test: 0.7621
Epoch: 025, Loss: 0.0003, Train: 1.0000, Val: 0.7300, Test: 0.7599
Epoch: 026, Loss: 0.0002, Train: 1.0000, Val: 0.7325, Test: 0.7571
Epoch: 027, Loss: 0.0002, Train: 1.0000, Val: 0.7425, Test: 0.7608
Epoch: 028, Loss: 0.0002, Train: 1.0000, Val: 0.7425, Test: 0.7633
Epoch: 029, Loss: 0.0002, Train: 1.0000, Val: 0.7475, Test: 0.7624
Epoch: 030, Loss: 0.0003, Train: 1.0000, Val: 0.7525, Test: 0.7636
Epoch: 031, Loss: 0.0005, Train: 1.0000, Val: 0.7475, Test: 0.7667
Epoch: 032, Loss: 0.0006, Train: 1.0000, Val: 0.7475, Test: 0.7663
Epoch: 033, Loss: 0.0007, Train: 1.0000, Val: 0.7550, Test: 0.7673
Epoch: 034, Loss: 0.0008, Train: 1.0000, Val: 0.7600, Test: 0.7682
Epoch: 035, Loss: 0.0008, Train: 1.0000, Val: 0.7625, Test: 0.7703
Epoch: 036, Loss: 0.0008, Train: 1.0000, Val: 0.7650, Test: 0.7722
Epoch: 037, Loss: 0.0008, Train: 1.0000, Val: 0.7650, Test: 0.7759
Epoch: 038, Loss: 0.0008, Train: 1.0000, Val: 0.7625, Test: 0.7722
Epoch: 039, Loss: 0.0009, Train: 1.0000, Val: 0.7550, Test: 0.7756
Epoch: 040, Loss: 0.0010, Train: 1.0000, Val: 0.7550, Test: 0.7734
Epoch: 041, Loss: 0.0010, Train: 1.0000, Val: 0.7525, Test: 0.7749
Epoch: 042, Loss: 0.0011, Train: 1.0000, Val: 0.7475, Test: 0.7743
Epoch: 043, Loss: 0.0011, Train: 1.0000, Val: 0.7500, Test: 0.7753
Epoch: 044, Loss: 0.0011, Train: 1.0000, Val: 0.7525, Test: 0.7746
Epoch: 045, Loss: 0.0012, Train: 1.0000, Val: 0.7500, Test: 0.7749
Epoch: 046, Loss: 0.0012, Train: 1.0000, Val: 0.7550, Test: 0.7762
Epoch: 047, Loss: 0.0013, Train: 1.0000, Val: 0.7575, Test: 0.7792
Epoch: 048, Loss: 0.0015, Train: 1.0000, Val: 0.7550, Test: 0.7808
Epoch: 049, Loss: 0.0016, Train: 1.0000, Val: 0.7525, Test: 0.7783
Epoch: 050, Loss: 0.0016, Train: 1.0000, Val: 0.7575, Test: 0.7808
Epoch: 051, Loss: 0.0016, Train: 1.0000, Val: 0.7600, Test: 0.7811
Epoch: 052, Loss: 0.0016, Train: 1.0000, Val: 0.7625, Test: 0.7842
Epoch: 053, Loss: 0.0017, Train: 1.0000, Val: 0.7600, Test: 0.7823
Epoch: 054, Loss: 0.0018, Train: 1.0000, Val: 0.7600, Test: 0.7835
Epoch: 055, Loss: 0.0019, Train: 1.0000, Val: 0.7600, Test: 0.7808
Epoch: 056, Loss: 0.0019, Train: 1.0000, Val: 0.7575, Test: 0.7820
Epoch: 057, Loss: 0.0019, Train: 1.0000, Val: 0.7600, Test: 0.7832
Epoch: 058, Loss: 0.0020, Train: 1.0000, Val: 0.7625, Test: 0.7848
Epoch: 059, Loss: 0.0021, Train: 1.0000, Val: 0.7625, Test: 0.7845
Epoch: 060, Loss: 0.0021, Train: 1.0000, Val: 0.7625, Test: 0.7839
Epoch: 061, Loss: 0.0022, Train: 1.0000, Val: 0.7650, Test: 0.7826
Epoch: 062, Loss: 0.0023, Train: 1.0000, Val: 0.7700, Test: 0.7826
Epoch: 063, Loss: 0.0023, Train: 1.0000, Val: 0.7700, Test: 0.7848
Epoch: 064, Loss: 0.0024, Train: 1.0000, Val: 0.7700, Test: 0.7820
Epoch: 065, Loss: 0.0025, Train: 1.0000, Val: 0.7700, Test: 0.7839
Epoch: 066, Loss: 0.0025, Train: 1.0000, Val: 0.7675, Test: 0.7826
Epoch: 067, Loss: 0.0026, Train: 1.0000, Val: 0.7675, Test: 0.7832
Epoch: 068, Loss: 0.0026, Train: 1.0000, Val: 0.7650, Test: 0.7854
Epoch: 069, Loss: 0.0027, Train: 1.0000, Val: 0.7650, Test: 0.7863
Epoch: 070, Loss: 0.0027, Train: 1.0000, Val: 0.7625, Test: 0.7866
Epoch: 071, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7860
Epoch: 072, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 073, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 074, Loss: 0.0028, Train: 1.0000, Val: 0.7625, Test: 0.7860
Epoch: 075, Loss: 0.0029, Train: 1.0000, Val: 0.7625, Test: 0.7854
Epoch: 076, Loss: 0.0029, Train: 1.0000, Val: 0.7650, Test: 0.7863
Epoch: 077, Loss: 0.0029, Train: 1.0000, Val: 0.7600, Test: 0.7866
Epoch: 078, Loss: 0.0029, Train: 1.0000, Val: 0.7625, Test: 0.7875
Epoch: 079, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7872
Epoch: 080, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7885
Epoch: 081, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7897
Epoch: 082, Loss: 0.0030, Train: 1.0000, Val: 0.7600, Test: 0.7894
Epoch: 083, Loss: 0.0030, Train: 1.0000, Val: 0.7600, Test: 0.7897
Epoch: 084, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7900
Epoch: 085, Loss: 0.0030, Train: 1.0000, Val: 0.7625, Test: 0.7897
Epoch: 086, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7903
Epoch: 087, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7906
Epoch: 088, Loss: 0.0031, Train: 1.0000, Val: 0.7650, Test: 0.7909
Epoch: 089, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 090, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 091, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915
Epoch: 092, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 093, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 094, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 095, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 096, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7909
Epoch: 097, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7906
Epoch: 098, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 099, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7912
Epoch: 100, Loss: 0.0031, Train: 1.0000, Val: 0.7675, Test: 0.7915

 

5.1.2 使用HeteroConv

5.1.2.1 GraphSAGE

参考https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hetero_conv_dblp.py

(对原数据集中没有特征的节点,用[1.]作为初始特征)

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HeteroConv, Linear, SAGEConv

# We initialize conference node features with a single one-vector as feature:
dataset = DBLP('/data/pyg_data/DBLP', transform=T.Constant(node_types='conference'))
data = dataset[0]
print(data)

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
   
                edge_type: SAGEConv((-1, -1), hidden_channels)
                for edge_type in metadata[1]
            })
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {
   key: F.leaky_relu(x) for key, x in x_dict.items()}
        return self.lin(x_dict['author'])

model = HeteroGNN(data.metadata(), hidden_channels=64, out_channels=4,
                  num_layers=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data, model = data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask], data['author'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['author'][split]
        acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()
    print(f'Epoch: {
     epoch:03d}, Loss: {
     loss:.4f}, Train: {
     train_acc:.4f}, '
          f'Val: {
     val_acc:.4f}, Test: {
     test_acc:.4f}')

 

输出:

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057]
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1]
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)
Epoch: 001, Loss: 1.3721, Train: 0.4550, Val: 0.3450, Test: 0.3819
Epoch: 002, Loss: 1.2867, Train: 0.6050, Val: 0.4800, Test: 0.5333
Epoch: 003, Loss: 1.1778, Train: 0.7175, Val: 0.5325, Test: 0.5941
Epoch: 004, Loss: 1.0368, Train: 0.8350, Val: 0.6050, Test: 0.6788
Epoch: 005, Loss: 0.8729, Train: 0.8950, Val: 0.6725, Test: 0.7228
Epoch: 006, Loss: 0.6991, Train: 0.9350, Val: 0.7025, Test: 0.7479
Epoch: 007, Loss: 0.5314, Train: 0.9600, Val: 0.7350, Test: 0.7765
Epoch: 008, Loss: 0.3831, Train: 0.9825, Val: 0.7525, Test: 0.8010
Epoch: 009, Loss: 0.2585, Train: 0.9900, Val: 0.7800, Test: 0.8189
Epoch: 010, Loss: 0.1632, Train: 0.9975, Val: 0.8025, Test: 0.8284
Epoch: 011, Loss: 0.0988, Train: 0.9975, Val: 0.8100, Test: 0.8293
Epoch: 012, Loss: 0.0578, Train: 1.0000, Val: 0.8025, Test: 0.8290
Epoch: 013, Loss: 0.0324, Train: 1.0000, Val: 0.8150, Test: 0.8296
Epoch: 014, Loss: 0.0178, Train: 1.0000, Val: 0.8075, Test: 0.8281
Epoch: 015, Loss: 0.0100, Train: 1.0000, Val: 0.8050, Test: 0.8281
Epoch: 016, Loss: 0.0060, Train: 1.0000, Val: 0.8050, Test: 0.8262
Epoch: 017, Loss: 0.0039, Train: 1.0000, Val: 0.8025, Test: 0.8235
Epoch: 018, Loss: 0.0027, Train: 1.0000, Val: 0.8100, Test: 0.8232
Epoch: 019, Loss: 0.0020, Train: 1.0000, Val: 0.8125, Test: 0.8235
Epoch: 020, Loss: 0.0017, Train: 1.0000, Val: 0.8125, Test: 0.8238
Epoch: 021, Loss: 0.0015, Train: 1.0000, Val: 0.8175, Test: 0.8268
Epoch: 022, Loss: 0.0014, Train: 1.0000, Val: 0.8175, Test: 0.8271
Epoch: 023, Loss: 0.0013, Train: 1.0000, Val: 0.8125, Test: 0.8274
Epoch: 024, Loss: 0.0014, Train: 1.0000, Val: 0.8150, Test: 0.8284
Epoch: 025, Loss: 0.0014, Train: 1.0000, Val: 0.8175, Test: 0.8281
Epoch: 026, Loss: 0.0015, Train: 1.0000, Val: 0.8175, Test: 0.8268
Epoch: 027, Loss: 0.0017, Train: 1.0000, Val: 0.8225, Test: 0.8225
Epoch: 028, Loss: 0.0019, Train: 1.0000, Val: 0.8250, Test: 0.8216
Epoch: 029, Loss: 0.0021, Train: 1.0000, Val: 0.8250, Test: 0.8198
Epoch: 030, Loss: 0.0024, Train: 1.0000, Val: 0.8225, Test: 0.8195
Epoch: 031, Loss: 0.0026, Train: 1.0000, Val: 0.8200, Test: 0.8192
Epoch: 032, Loss: 0.0029, Train: 1.0000, Val: 0.8225, Test: 0.8189
Epoch: 033, Loss: 0.0032, Train: 1.0000, Val: 0.8175, Test: 0.8185
Epoch: 034, Loss: 0.0035, Train: 1.0000, Val: 0.8200, Test: 0.8185
Epoch: 035, Loss: 0.0037, Train: 1.0000, Val: 0.8200, Test: 0.8176
Epoch: 036, Loss: 0.0038, Train: 1.0000, Val: 0.8225, Test: 0.8185
Epoch: 037, Loss: 0.0039, Train: 1.0000, Val: 0.8200, Test: 0.8176
Epoch: 038, Loss: 0.0041, Train: 1.0000, Val: 0.8175, Test: 0.8192
Epoch: 039, Loss: 0.0043, Train: 1.0000, Val: 0.8175, Test: 0.8204
Epoch: 040, Loss: 0.0044, Train: 1.0000, Val: 0.8150, Test: 0.8189
Epoch: 041, Loss: 0.0045, Train: 1.0000, Val: 0.8150, Test: 0.8173
Epoch: 042, Loss: 0.0046, Train: 1.0000, Val: 0.8175, Test: 0.8179
Epoch: 043, Loss: 0.0047, Train: 1.0000, Val: 0.8150, Test: 0.8170
Epoch: 044, Loss: 0.0047, Train: 1.0000, Val: 0.8175, Test: 0.8185
Epoch: 045, Loss: 0.0047, Train: 1.0000, Val: 0.8125, Test: 0.8195
Epoch: 046, Loss: 0.0047, Train: 1.0000, Val: 0.8150, Test: 0.8192
Epoch: 047, Loss: 0.0047, Train: 1.0000, Val: 0.8125, Test: 0.8182
Epoch: 048, Loss: 0.0047, Train: 1.0000, Val: 0.8075, Test: 0.8167
Epoch: 049, Loss: 0.0047, Train: 1.0000, Val: 0.8050, Test: 0.8158
Epoch: 050, Loss: 0.0047, Train: 1.0000, Val: 0.8000, Test: 0.8167
Epoch: 051, Loss: 0.0047, Train: 1.0000, Val: 0.8050, Test: 0.8170
Epoch: 052, Loss: 0.0047, Train: 1.0000, Val: 0.8100, Test: 0.8152
Epoch: 053, Loss: 0.0046, Train: 1.0000, Val: 0.8075, Test: 0.8149
Epoch: 054, Loss: 0.0046, Train: 1.0000, Val: 0.8075, Test: 0.8133
Epoch: 055, Loss: 0.0046, Train: 1.0000, Val: 0.8100, Test: 0.8139
Epoch: 056, Loss: 0.0046, Train: 1.0000, Val: 0.8100, Test: 0.8152
Epoch: 057, Loss: 0.0045, Train: 1.0000, Val: 0.8050, Test: 0.8149
Epoch: 058, Loss: 0.0045, Train: 1.0000, Val: 0.8025, Test: 0.8146
Epoch: 059, Loss: 0.0045, Train: 1.0000, Val: 0.8100, Test: 0.8139
Epoch: 060, Loss: 0.0044, Train: 1.0000, Val: 0.8125, Test: 0.8149
Epoch: 061, Loss: 0.0044, Train: 1.0000, Val: 0.8100, Test: 0.8149
Epoch: 062, Loss: 0.0043, Train: 1.0000, Val: 0.8025, Test: 0.8130
Epoch: 063, Loss: 0.0043, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 064, Loss: 0.0042, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 065, Loss: 0.0042, Train: 1.0000, Val: 0.8100, Test: 0.8127
Epoch: 066, Loss: 0.0041, Train: 1.0000, Val: 0.8100, Test: 0.8130
Epoch: 067, Loss: 0.0041, Train: 1.0000, Val: 0.8050, Test: 0.8121
Epoch: 068, Loss: 0.0040, Train: 1.0000, Val: 0.8050, Test: 0.8124
Epoch: 069, Loss: 0.0040, Train: 1.0000, Val: 0.8100, Test: 0.8118
Epoch: 070, Loss: 0.0039, Train: 1.0000, Val: 0.8075, Test: 0.8115
Epoch: 071, Loss: 0.0038, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 072, Loss: 0.0038, Train: 1.0000, Val: 0.8075, Test: 0.8109
Epoch: 073, Loss: 0.0037, Train: 1.0000, Val: 0.8100, Test: 0.8099
Epoch: 074, Loss: 0.0037, Train: 1.0000, Val: 0.8100, Test: 0.8109
Epoch: 075, Loss: 0.0036, Train: 1.0000, Val: 0.8050, Test: 0.8106
Epoch: 076, Loss: 0.0036, Train: 1.0000, Val: 0.8075, Test: 0.8115
Epoch: 077, Loss: 0.0036, Train: 1.0000, Val: 0.8100, Test: 0.8112
Epoch: 078, Loss: 0.0035, Train: 1.0000, Val: 0.8075, Test: 0.8112
Epoch: 079, Loss: 0.0035, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 080, Loss: 0.0035, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 081, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8115
Epoch: 082, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 083, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8109
Epoch: 084, Loss: 0.0034, Train: 1.0000, Val: 0.8050, Test: 0.8112
Epoch: 085, Loss: 0.0033, Train: 1.0000, Val: 0.8050, Test: 0.8106
Epoch: 086, Loss: 0.0033, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 087, Loss: 0.0033, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 088, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8099
Epoch: 089, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 090, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8103
Epoch: 091, Loss: 0.0032, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 092, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8106
Epoch: 093, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8109
Epoch: 094, Loss: 0.0031, Train: 1.0000, Val: 0.8025, Test: 0.8115
Epoch: 095, Loss: 0.0031, Train: 1.0000, Val: 0.8000, Test: 0.8121
Epoch: 096, Loss: 0.0030, Train: 1.0000, Val: 0.8025, Test: 0.8124
Epoch: 097, Loss: 0.0030, Train: 1.0000, Val: 0.8000, Test: 0.8130
Epoch: 098, Loss: 0.0030, Train: 1.0000, Val: 0.8025, Test: 0.8127
Epoch: 099, Loss: 0.0030, Train: 1.0000, Val: 0.7975, Test: 0.8124
Epoch: 100, Loss: 0.0030, Train: 1.0000, Val: 0.8000, Test: 0.8127

 

5.2 mini-batch

可用的DataLoader:
https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.NeighborLoader
https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.HGTLoader

跟同质图一样,还挺方便的,就直接返回HeteroData对象

建立DataLoader的代码模板:

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader

transform = T.ToUndirected()  # Add reverse edge types.
data = OGB_MAG(root='./data', preprocess='metapath2vec', transform=transform)[0]

train_loader = NeighborLoader(
    data,
    # Sample 15 neighbors for each node and each edge type for 2 iterations:
    num_neighbors=[15] * 2,
    # Use a batch size of 128 for sampling training nodes of type "paper":
    batch_size=128,
    input_nodes=('paper', data['paper'].train_mask),
)

batch = next(iter(train_loader))

 

可以使用更细粒度的邻居数控制:num_neighbors = {key: [15] * 2 for key in data.edge_types}
就是这个batch_size是说用于计算这么多节点嵌入,需要用整个batch(前batch_size个嵌入就是这些要的嵌入)

训练的代码模板:

def train():
    model.train()

    total_examples = total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to('cuda:0')
        batch_size = batch['paper'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)
        loss = F.cross_entropy(out['paper'][:batch_size],
                               batch['paper'].y[:batch_size])
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples

 

直接使用NeighborLoader多进程会出现这个奇怪的问题,所以建议用单进程:Heterogenous graph, use NeighborLoader with num_workers>0, and stucks after many epochs · Issue #5348 · pyg-team/pytorch_geometric
如果想要获得mini-batch节点对应原图中的索引,可以参考这个discussion:I wonder how to use NeighborLoader correctly? · Discussion #3409 · pyg-team/pytorch_geometric

示例代码(参考https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py):

import argparse
import os.path as osp

import torch
import torch.nn.functional as F
from torch.nn import ReLU
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import HGTLoader, NeighborLoader
from torch_geometric.nn import Linear, SAGEConv, Sequential, to_hetero

parser = argparse.ArgumentParser()
parser.add_argument('--use_hgt_loader', action='store_true')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.ToUndirected(merge=True)
dataset = OGB_MAG('/data/pyg_data', preprocess='metapath2vec', transform=transform)

# Already send node features/labels to GPU for faster access during sampling:
data = dataset[0].to(device, 'x', 'y')

train_input_nodes = ('paper', data['paper'].train_mask)
val_input_nodes = ('paper', data['paper'].val_mask)
kwargs = {
   'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}

if not args.use_hgt_loader:
    train_loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True,
                                  input_nodes=train_input_nodes, **kwargs)
    val_loader = NeighborLoader(data, num_neighbors=[10] * 2,
                                input_nodes=val_input_nodes, **kwargs)
else:
    train_loader = HGTLoader(data, num_samples=[1024] * 4, shuffle=True,
                             input_nodes=train_input_nodes, **kwargs)
    val_loader = HGTLoader(data, num_samples=[1024] * 4,
                           input_nodes=val_input_nodes, **kwargs)

model = Sequential('x, edge_index', [
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (Linear(-1, dataset.num_classes), 'x -> x'),
])
model = to_hetero(model, data.metadata(), aggr='sum').to(device)


@torch.no_grad()
def init_params():
    # Initialize lazy parameters via forwarding a single batch to the model:
    batch = next(iter(train_loader))
    batch = batch.to(device, 'edge_index')
    model(batch.x_dict, batch.edge_index_dict)


def train():
    model.train()

    total_examples = total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        batch = batch.to(device, 'edge_index')
        batch_size = batch['paper'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
        loss = F.cross_entropy(out, batch['paper'].y[:batch_size])
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples


@torch.no_grad()
def test(loader):
    model.eval()

    total_examples = total_correct = 0
    for batch in tqdm(loader):
        batch = batch.to(device, 'edge_index')
        batch_size = batch['paper'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
        pred = out.argmax(dim=-1)

        total_examples += batch_size
        total_correct += int((pred == batch['paper'].y[:batch_size]).sum())

    return total_correct / total_examples


init_params()  # Initialize parameters.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 21):
    loss = train()
    val_acc = test(val_loader)
    print(f'Epoch: {
     epoch:02d}, Loss: {
     loss:.4f}, Val: {
     val_acc:.4f}')

 

NeighborLoader最后一个epoch的输出结果:Epoch: 20, Loss: 1.9040, Val: 0.4445
HGTLoader最后一个epoch的输出结果:Epoch: 20, Loss: 2.0077, Val: 0.4271
(因为有进度条,所以太长了,所以就不放全部输出了)

示例代码注意事项:PyG的数据对象用to可以单独挑出一些属性转换设备

6. 链路预测任务

6.1 transductive

6.1.1 GraphSAGE编码+MLP解码+预测用户打分

参考:https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hetero_link_pred.py

  1. 用GraphSAGE转换为异质图GNN,做节点编码
  2. 对节点对表征的解码(也就是得到节点对链路预测得分的过程):将节点对特征concat后,通过2层MLP
  3. 不算是标准的链路预测任务,因为这个任务是预测已知节点对的['user','rates','movie']得分(取值范围是0-5的离散整数),所以是用回归任务来做的(损失函数是加权MSE:因为6种打分之间不平衡)
  4. 在测试时把预测结果截断到0-5之间再计算RMSE值,作为输出指标
  5. 使用MovieLens数据集。原数据集中的节点有两种
    1. 电影节点,仅有文本特征(标题),代码中用SentenceTransformer模型进行句子表征,注意这个model_name属性如果直接用本地模型会出现问题,解决方式就是粗暴的直接用本地路径跑一次,然后把存储后的对象改成模型名,然后就用模型名直接调用。不太好解释,直接看这个issue吧(我提了个PR,但是不知道为啥作者没有merge,所以还需手动修改):Unable to process movie_lens dataset with local directory transformers model · Issue #5500 · pyg-team/pytorch_geometric
    2. 用户节点:没有特征,代码中用独热编码作为初始节点特征
    3. 转换为无向图(产生逆向边)
  6. 数据分割:8-1-1随机划分边,节点不变,训练集和验证集图中用的边相同,测试集用的边在训练集基础上增加验证集计算指标用的边。因为是已知节点对的回归任务,所以不需要负边
import argparse

import torch
import torch.nn.functional as F
from torch.nn import Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import MovieLens
from torch_geometric.nn import SAGEConv, to_hetero

parser = argparse.ArgumentParser()
parser.add_argument('--use_weighted_loss', action='store_true',
                    help='Whether to use weighted MSE loss.')
args = parser.parse_args()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

dataset = MovieLens('/data/pyg_data/MovieLens', model_name='all-MiniLM-L6-v2')
data = dataset[0].to(device)

# Add user node features for message passing:
data['user'].x = torch.eye(data['user'].num_nodes, device=device)
del data['user'].num_nodes

# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:
data = T.ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label  # Remove "reverse" label.

# Perform a link-level split into training, validation, and test edges:
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'movie')],
    rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)

# We have an unbalanced dataset with many labels for rating 3 and 4, and very
# few for 0 and 1. Therefore we use a weighted MSE loss.
if args.use_weighted_loss:
    weight = torch.bincount(train_data['user', 'movie'].edge_label)
    weight = weight.max() / weight
else:
    weight = None


def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()


class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)


model = Model(hidden_channels=32).to(device)

# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
    model.encoder(train_data.x_dict, train_data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def train():
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['user', 'movie'].edge_label_index)
    target = train_data['user', 'movie'].edge_label
    loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['user', 'movie'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)


for epoch in range(1, 301):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {
     epoch:03d}, Loss: {
     loss:.4f}, Train: {
     train_rmse:.4f}, '
          f'Val: {
     val_rmse:.4f}, Test: {
     test_rmse:.4f}')

 

输出:

HeteroData(
  movie={ x=[9742, 404] },
  user={ x=[610, 610] },
  (user, rates, movie)={
    edge_index=[2, 100836],
    edge_label=[100836]
  },
  (movie, rev_rates, user)={ edge_index=[2, 100836] }
)
Epoch: 001, Loss: 11.1455, Train: 3.0880, Val: 3.0996, Test: 3.0917
Epoch: 002, Loss: 9.5358, Train: 2.6658, Val: 2.6792, Test: 2.6712
Epoch: 003, Loss: 7.1066, Train: 1.8713, Val: 1.8877, Test: 1.8804
Epoch: 004, Loss: 3.5019, Train: 1.1067, Val: 1.0977, Test: 1.1063
Epoch: 005, Loss: 1.2249, Train: 1.9740, Val: 1.9311, Test: 1.9472
Epoch: 006, Loss: 5.5210, Train: 1.6758, Val: 1.6408, Test: 1.6595
Epoch: 007, Loss: 2.8131, Train: 1.0975, Val: 1.0894, Test: 1.0977
Epoch: 008, Loss: 1.2045, Train: 1.2613, Val: 1.2734, Test: 1.2708
Epoch: 009, Loss: 1.5908, Train: 1.5404, Val: 1.5562, Test: 1.5505
Epoch: 010, Loss: 2.3730, Train: 1.6708, Val: 1.6869, Test: 1.6805
Epoch: 011, Loss: 2.7915, Train: 1.6564, Val: 1.6724, Test: 1.6661
Epoch: 012, Loss: 2.7436, Train: 1.5262, Val: 1.5418, Test: 1.5362
Epoch: 013, Loss: 2.3292, Train: 1.3166, Val: 1.3301, Test: 1.3266
Epoch: 014, Loss: 1.7333, Train: 1.1141, Val: 1.1205, Test: 1.1219
Epoch: 015, Loss: 1.2412, Train: 1.0907, Val: 1.0819, Test: 1.0912
Epoch: 016, Loss: 1.1896, Train: 1.2743, Val: 1.2524, Test: 1.2672
Epoch: 017, Loss: 1.6238, Train: 1.3905, Val: 1.3643, Test: 1.3808
Epoch: 018, Loss: 1.9334, Train: 1.3026, Val: 1.2797, Test: 1.2951
Epoch: 019, Loss: 1.6968, Train: 1.1325, Val: 1.1190, Test: 1.1307
Epoch: 020, Loss: 1.2826, Train: 1.0549, Val: 1.0534, Test: 1.0595
Epoch: 021, Loss: 1.1128, Train: 1.1005, Val: 1.1076, Test: 1.1089
Epoch: 022, Loss: 1.2111, Train: 1.1792, Val: 1.1902, Test: 1.1889
Epoch: 023, Loss: 1.3905, Train: 1.2222, Val: 1.2345, Test: 1.2323
Epoch: 024, Loss: 1.4938, Train: 1.2101, Val: 1.2222, Test: 1.2202
Epoch: 025, Loss: 1.4643, Train: 1.1525, Val: 1.1631, Test: 1.1623
Epoch: 026, Loss: 1.3283, Train: 1.0802, Val: 1.0871, Test: 1.0889
Epoch: 027, Loss: 1.1668, Train: 1.0385, Val: 1.0392, Test: 1.0446
Epoch: 028, Loss: 1.0784, Train: 1.0565, Val: 1.0501, Test: 1.0592
Epoch: 029, Loss: 1.1163, Train: 1.1062, Val: 1.0946, Test: 1.1060
Epoch: 030, Loss: 1.2236, Train: 1.1266, Val: 1.1135, Test: 1.1257
Epoch: 031, Loss: 1.2693, Train: 1.0953, Val: 1.0846, Test: 1.0959
Epoch: 032, Loss: 1.1998, Train: 1.0460, Val: 1.0404, Test: 1.0495
Epoch: 033, Loss: 1.0942, Train: 1.0229, Val: 1.0234, Test: 1.0295
Epoch: 034, Loss: 1.0464, Train: 1.0346, Val: 1.0399, Test: 1.0434
Epoch: 035, Loss: 1.0705, Train: 1.0578, Val: 1.0659, Test: 1.0677
Epoch: 036, Loss: 1.1189, Train: 1.0684, Val: 1.0775, Test: 1.0787
Epoch: 037, Loss: 1.1414, Train: 1.0578, Val: 1.0665, Test: 1.0681
Epoch: 038, Loss: 1.1190, Train: 1.0331, Val: 1.0400, Test: 1.0428
Epoch: 039, Loss: 1.0672, Train: 1.0112, Val: 1.0150, Test: 1.0196
Epoch: 040, Loss: 1.0224, Train: 1.0071, Val: 1.0070, Test: 1.0139
Epoch: 041, Loss: 1.0143, Train: 1.0195, Val: 1.0159, Test: 1.0246
Epoch: 042, Loss: 1.0395, Train: 1.0297, Val: 1.0244, Test: 1.0339
Epoch: 043, Loss: 1.0602, Train: 1.0227, Val: 1.0182, Test: 1.0274
Epoch: 044, Loss: 1.0460, Train: 1.0051, Val: 1.0033, Test: 1.0113
Epoch: 045, Loss: 1.0103, Train: 0.9932, Val: 0.9948, Test: 1.0011
Epoch: 046, Loss: 0.9864, Train: 0.9937, Val: 0.9984, Test: 1.0032
Epoch: 047, Loss: 0.9875, Train: 1.0002, Val: 1.0069, Test: 1.0105
Epoch: 048, Loss: 1.0004, Train: 1.0024, Val: 1.0099, Test: 1.0131
Epoch: 049, Loss: 1.0048, Train: 0.9962, Val: 1.0033, Test: 1.0069
Epoch: 050, Loss: 0.9925, Train: 0.9855, Val: 0.9912, Test: 0.9956
Epoch: 051, Loss: 0.9712, Train: 0.9780, Val: 0.9815, Test: 0.9871
Epoch: 052, Loss: 0.9565, Train: 0.9782, Val: 0.9792, Test: 0.9862
Epoch: 053, Loss: 0.9568, Train: 0.9818, Val: 0.9813, Test: 0.9891
Epoch: 054, Loss: 0.9640, Train: 0.9807, Val: 0.9801, Test: 0.9880
Epoch: 055, Loss: 0.9619, Train: 0.9737, Val: 0.9744, Test: 0.9818
Epoch: 056, Loss: 0.9481, Train: 0.9670, Val: 0.9697, Test: 0.9762
Epoch: 057, Loss: 0.9351, Train: 0.9652, Val: 0.9699, Test: 0.9754
Epoch: 058, Loss: 0.9315, Train: 0.9664, Val: 0.9725, Test: 0.9773
Epoch: 059, Loss: 0.9338, Train: 0.9662, Val: 0.9729, Test: 0.9775
Epoch: 060, Loss: 0.9335, Train: 0.9626, Val: 0.9690, Test: 0.9739
Epoch: 061, Loss: 0.9265, Train: 0.9575, Val: 0.9629, Test: 0.9684
Epoch: 062, Loss: 0.9167, Train: 0.9542, Val: 0.9584, Test: 0.9646
Epoch: 063, Loss: 0.9106, Train: 0.9539, Val: 0.9568, Test: 0.9637
Epoch: 064, Loss: 0.9099, Train: 0.9539, Val: 0.9561, Test: 0.9634
Epoch: 065, Loss: 0.9099, Train: 0.9516, Val: 0.9541, Test: 0.9614
Epoch: 066, Loss: 0.9056, Train: 0.9479, Val: 0.9514, Test: 0.9582
Epoch: 067, Loss: 0.8985, Train: 0.9453, Val: 0.9500, Test: 0.9563
Epoch: 068, Loss: 0.8935, Train: 0.9446, Val: 0.9504, Test: 0.9563
Epoch: 069, Loss: 0.8922, Train: 0.9442, Val: 0.9507, Test: 0.9563
Epoch: 070, Loss: 0.8914, Train: 0.9424, Val: 0.9490, Test: 0.9546
Epoch: 071, Loss: 0.8882, Train: 0.9397, Val: 0.9459, Test: 0.9517
Epoch: 072, Loss: 0.8831, Train: 0.9376, Val: 0.9432, Test: 0.9491
Epoch: 073, Loss: 0.8791, Train: 0.9367, Val: 0.9417, Test: 0.9480
Epoch: 074, Loss: 0.8775, Train: 0.9359, Val: 0.9408, Test: 0.9471
Epoch: 075, Loss: 0.8760, Train: 0.9341, Val: 0.9395, Test: 0.9456
Epoch: 076, Loss: 0.8726, Train: 0.9319, Val: 0.9383, Test: 0.9441
Epoch: 077, Loss: 0.8685, Train: 0.9307, Val: 0.9382, Test: 0.9435
Epoch: 078, Loss: 0.8662, Train: 0.9302, Val: 0.9386, Test: 0.9434
Epoch: 079, Loss: 0.8653, Train: 0.9289, Val: 0.9376, Test: 0.9422
Epoch: 080, Loss: 0.8630, Train: 0.9270, Val: 0.9353, Test: 0.9399
Epoch: 081, Loss: 0.8593, Train: 0.9257, Val: 0.9335, Test: 0.9381
Epoch: 082, Loss: 0.8570, Train: 0.9251, Val: 0.9326, Test: 0.9372
Epoch: 083, Loss: 0.8560, Train: 0.9239, Val: 0.9317, Test: 0.9361
Epoch: 084, Loss: 0.8538, Train: 0.9223, Val: 0.9308, Test: 0.9349
Epoch: 085, Loss: 0.8507, Train: 0.9211, Val: 0.9307, Test: 0.9344
Epoch: 086, Loss: 0.8486, Train: 0.9205, Val: 0.9307, Test: 0.9341
Epoch: 087, Loss: 0.8473, Train: 0.9192, Val: 0.9296, Test: 0.9329
Epoch: 088, Loss: 0.8450, Train: 0.9178, Val: 0.9280, Test: 0.9312
Epoch: 089, Loss: 0.8425, Train: 0.9169, Val: 0.9269, Test: 0.9300
Epoch: 090, Loss: 0.8408, Train: 0.9160, Val: 0.9261, Test: 0.9291
Epoch: 091, Loss: 0.8393, Train: 0.9148, Val: 0.9255, Test: 0.9282
Epoch: 092, Loss: 0.8370, Train: 0.9136, Val: 0.9253, Test: 0.9275
Epoch: 093, Loss: 0.8348, Train: 0.9127, Val: 0.9253, Test: 0.9272
Epoch: 094, Loss: 0.8333, Train: 0.9118, Val: 0.9249, Test: 0.9266
Epoch: 095, Loss: 0.8315, Train: 0.9106, Val: 0.9240, Test: 0.9254
Epoch: 096, Loss: 0.8294, Train: 0.9097, Val: 0.9231, Test: 0.9243
Epoch: 097, Loss: 0.8277, Train: 0.9089, Val: 0.9226, Test: 0.9237
Epoch: 098, Loss: 0.8263, Train: 0.9079, Val: 0.9224, Test: 0.9232
Epoch: 099, Loss: 0.8246, Train: 0.9070, Val: 0.9225, Test: 0.9230
Epoch: 100, Loss: 0.8230, Train: 0.9063, Val: 0.9227, Test: 0.9230
Epoch: 101, Loss: 0.8217, Train: 0.9056, Val: 0.9227, Test: 0.9227
Epoch: 102, Loss: 0.8204, Train: 0.9048, Val: 0.9223, Test: 0.9222
Epoch: 103, Loss: 0.8190, Train: 0.9042, Val: 0.9219, Test: 0.9216
Epoch: 104, Loss: 0.8178, Train: 0.9036, Val: 0.9217, Test: 0.9213
Epoch: 105, Loss: 0.8168, Train: 0.9030, Val: 0.9218, Test: 0.9212
Epoch: 106, Loss: 0.8157, Train: 0.9024, Val: 0.9220, Test: 0.9212
Epoch: 107, Loss: 0.8146, Train: 0.9019, Val: 0.9223, Test: 0.9212
Epoch: 108, Loss: 0.8137, Train: 0.9013, Val: 0.9224, Test: 0.9210
Epoch: 109, Loss: 0.8127, Train: 0.9008, Val: 0.9222, Test: 0.9205
Epoch: 110, Loss: 0.8117, Train: 0.9003, Val: 0.9220, Test: 0.9201
Epoch: 111, Loss: 0.8108, Train: 0.8998, Val: 0.9220, Test: 0.9198
Epoch: 112, Loss: 0.8100, Train: 0.8993, Val: 0.9221, Test: 0.9196
Epoch: 113, Loss: 0.8091, Train: 0.8989, Val: 0.9223, Test: 0.9196
Epoch: 114, Loss: 0.8083, Train: 0.8985, Val: 0.9225, Test: 0.9195
Epoch: 115, Loss: 0.8076, Train: 0.8981, Val: 0.9224, Test: 0.9193
Epoch: 116, Loss: 0.8069, Train: 0.8977, Val: 0.9222, Test: 0.9189
Epoch: 117, Loss: 0.8063, Train: 0.8974, Val: 0.9220, Test: 0.9186
Epoch: 118, Loss: 0.8057, Train: 0.8971, Val: 0.9219, Test: 0.9184
Epoch: 119, Loss: 0.8051, Train: 0.8967, Val: 0.9219, Test: 0.9184
Epoch: 120, Loss: 0.8045, Train: 0.8965, Val: 0.9220, Test: 0.9184
Epoch: 121, Loss: 0.8040, Train: 0.8962, Val: 0.9220, Test: 0.9183
Epoch: 122, Loss: 0.8035, Train: 0.8959, Val: 0.9218, Test: 0.9181
Epoch: 123, Loss: 0.8030, Train: 0.8956, Val: 0.9216, Test: 0.9178
Epoch: 124, Loss: 0.8026, Train: 0.8954, Val: 0.9214, Test: 0.9176
Epoch: 125, Loss: 0.8022, Train: 0.8952, Val: 0.9214, Test: 0.9176
Epoch: 126, Loss: 0.8017, Train: 0.8950, Val: 0.9214, Test: 0.9176
Epoch: 127, Loss: 0.8013, Train: 0.8947, Val: 0.9214, Test: 0.9175
Epoch: 128, Loss: 0.8009, Train: 0.8945, Val: 0.9213, Test: 0.9174
Epoch: 129, Loss: 0.8006, Train: 0.8943, Val: 0.9211, Test: 0.9172
Epoch: 130, Loss: 0.8002, Train: 0.8941, Val: 0.9208, Test: 0.9171
Epoch: 131, Loss: 0.7999, Train: 0.8939, Val: 0.9207, Test: 0.9169
Epoch: 132, Loss: 0.7995, Train: 0.8938, Val: 0.9207, Test: 0.9169
Epoch: 133, Loss: 0.7992, Train: 0.8936, Val: 0.9207, Test: 0.9169
Epoch: 134, Loss: 0.7989, Train: 0.8934, Val: 0.9206, Test: 0.9168
Epoch: 135, Loss: 0.7986, Train: 0.8932, Val: 0.9204, Test: 0.9166
Epoch: 136, Loss: 0.7983, Train: 0.8931, Val: 0.9202, Test: 0.9165
Epoch: 137, Loss: 0.7980, Train: 0.8929, Val: 0.9201, Test: 0.9164
Epoch: 138, Loss: 0.7977, Train: 0.8928, Val: 0.9201, Test: 0.9164
Epoch: 139, Loss: 0.7974, Train: 0.8926, Val: 0.9201, Test: 0.9164
Epoch: 140, Loss: 0.7972, Train: 0.8925, Val: 0.9200, Test: 0.9163
Epoch: 141, Loss: 0.7969, Train: 0.8923, Val: 0.9199, Test: 0.9162
Epoch: 142, Loss: 0.7967, Train: 0.8922, Val: 0.9197, Test: 0.9161
Epoch: 143, Loss: 0.7965, Train: 0.8921, Val: 0.9196, Test: 0.9159
Epoch: 144, Loss: 0.7962, Train: 0.8919, Val: 0.9198, Test: 0.9161
Epoch: 145, Loss: 0.7960, Train: 0.8918, Val: 0.9194, Test: 0.9158
Epoch: 146, Loss: 0.7958, Train: 0.8917, Val: 0.9193, Test: 0.9157
Epoch: 147, Loss: 0.7956, Train: 0.8916, Val: 0.9193, Test: 0.9157
Epoch: 148, Loss: 0.7954, Train: 0.8915, Val: 0.9192, Test: 0.9156
Epoch: 149, Loss: 0.7952, Train: 0.8914, Val: 0.9191, Test: 0.9156
Epoch: 150, Loss: 0.7950, Train: 0.8913, Val: 0.9190, Test: 0.9155
Epoch: 151, Loss: 0.7949, Train: 0.8912, Val: 0.9190, Test: 0.9155
Epoch: 152, Loss: 0.7947, Train: 0.8911, Val: 0.9189, Test: 0.9155
Epoch: 153, Loss: 0.7945, Train: 0.8910, Val: 0.9189, Test: 0.9154
Epoch: 154, Loss: 0.7944, Train: 0.8909, Val: 0.9189, Test: 0.9154
Epoch: 155, Loss: 0.7942, Train: 0.8909, Val: 0.9188, Test: 0.9153
Epoch: 156, Loss: 0.7941, Train: 0.8908, Val: 0.9188, Test: 0.9153
Epoch: 157, Loss: 0.7939, Train: 0.8907, Val: 0.9188, Test: 0.9152
Epoch: 158, Loss: 0.7938, Train: 0.8906, Val: 0.9188, Test: 0.9152
Epoch: 159, Loss: 0.7936, Train: 0.8905, Val: 0.9188, Test: 0.9151
Epoch: 160, Loss: 0.7935, Train: 0.8905, Val: 0.9188, Test: 0.9151
Epoch: 161, Loss: 0.7934, Train: 0.8904, Val: 0.9188, Test: 0.9151
Epoch: 162, Loss: 0.7932, Train: 0.8903, Val: 0.9188, Test: 0.9151
Epoch: 163, Loss: 0.7931, Train: 0.8903, Val: 0.9188, Test: 0.9151
Epoch: 164, Loss: 0.7930, Train: 0.8902, Val: 0.9188, Test: 0.9150
Epoch: 165, Loss: 0.7929, Train: 0.8901, Val: 0.9189, Test: 0.9150
Epoch: 166, Loss: 0.7928, Train: 0.8901, Val: 0.9189, Test: 0.9150
Epoch: 167, Loss: 0.7926, Train: 0.8900, Val: 0.9189, Test: 0.9150
Epoch: 168, Loss: 0.7925, Train: 0.8899, Val: 0.9189, Test: 0.9150
Epoch: 169, Loss: 0.7924, Train: 0.8899, Val: 0.9189, Test: 0.9150
Epoch: 170, Loss: 0.7923, Train: 0.8898, Val: 0.9190, Test: 0.9150
Epoch: 171, Loss: 0.7922, Train: 0.8898, Val: 0.9190, Test: 0.9149
Epoch: 172, Loss: 0.7921, Train: 0.8897, Val: 0.9190, Test: 0.9149
Epoch: 173, Loss: 0.7920, Train: 0.8896, Val: 0.9190, Test: 0.9148
Epoch: 174, Loss: 0.7919, Train: 0.8896, Val: 0.9190, Test: 0.9148
Epoch: 175, Loss: 0.7918, Train: 0.8895, Val: 0.9190, Test: 0.9148
Epoch: 176, Loss: 0.7917, Train: 0.8895, Val: 0.9190, Test: 0.9147
Epoch: 177, Loss: 0.7916, Train: 0.8894, Val: 0.9190, Test: 0.9147
Epoch: 178, Loss: 0.7916, Train: 0.8894, Val: 0.9190, Test: 0.9147
Epoch: 179, Loss: 0.7915, Train: 0.8893, Val: 0.9191, Test: 0.9147
Epoch: 180, Loss: 0.7914, Train: 0.8893, Val: 0.9191, Test: 0.9147
Epoch: 181, Loss: 0.7913, Train: 0.8893, Val: 0.9191, Test: 0.9147
Epoch: 182, Loss: 0.7912, Train: 0.8892, Val: 0.9191, Test: 0.9146
Epoch: 183, Loss: 0.7911, Train: 0.8892, Val: 0.9191, Test: 0.9146
Epoch: 184, Loss: 0.7910, Train: 0.8891, Val: 0.9191, Test: 0.9146
Epoch: 185, Loss: 0.7910, Train: 0.8891, Val: 0.9192, Test: 0.9146
Epoch: 186, Loss: 0.7909, Train: 0.8890, Val: 0.9192, Test: 0.9146
Epoch: 187, Loss: 0.7908, Train: 0.8890, Val: 0.9192, Test: 0.9146
Epoch: 188, Loss: 0.7907, Train: 0.8889, Val: 0.9192, Test: 0.9145
Epoch: 189, Loss: 0.7907, Train: 0.8889, Val: 0.9192, Test: 0.9145
Epoch: 190, Loss: 0.7906, Train: 0.8889, Val: 0.9192, Test: 0.9145
Epoch: 191, Loss: 0.7905, Train: 0.8888, Val: 0.9192, Test: 0.9145
Epoch: 192, Loss: 0.7905, Train: 0.8888, Val: 0.9192, Test: 0.9144
Epoch: 193, Loss: 0.7904, Train: 0.8888, Val: 0.9192, Test: 0.9144
Epoch: 194, Loss: 0.7903, Train: 0.8887, Val: 0.9192, Test: 0.9144
Epoch: 195, Loss: 0.7903, Train: 0.8887, Val: 0.9192, Test: 0.9144
Epoch: 196, Loss: 0.7902, Train: 0.8886, Val: 0.9192, Test: 0.9144
Epoch: 197, Loss: 0.7901, Train: 0.8886, Val: 0.9193, Test: 0.9143
Epoch: 198, Loss: 0.7901, Train: 0.8886, Val: 0.9193, Test: 0.9143
Epoch: 199, Loss: 0.7900, Train: 0.8885, Val: 0.9193, Test: 0.9143
Epoch: 200, Loss: 0.7899, Train: 0.8885, Val: 0.9193, Test: 0.9143
Epoch: 201, Loss: 0.7899, Train: 0.8885, Val: 0.9192, Test: 0.9143
Epoch: 202, Loss: 0.7898, Train: 0.8884, Val: 0.9193, Test: 0.9143
Epoch: 203, Loss: 0.7898, Train: 0.8884, Val: 0.9193, Test: 0.9143
Epoch: 204, Loss: 0.7897, Train: 0.8884, Val: 0.9193, Test: 0.9143
Epoch: 205, Loss: 0.7896, Train: 0.8883, Val: 0.9193, Test: 0.9143
Epoch: 206, Loss: 0.7896, Train: 0.8883, Val: 0.9193, Test: 0.9142
Epoch: 207, Loss: 0.7895, Train: 0.8883, Val: 0.9193, Test: 0.9143
Epoch: 208, Loss: 0.7895, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 209, Loss: 0.7894, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 210, Loss: 0.7894, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 211, Loss: 0.7893, Train: 0.8882, Val: 0.9193, Test: 0.9142
Epoch: 212, Loss: 0.7893, Train: 0.8881, Val: 0.9193, Test: 0.9142
Epoch: 213, Loss: 0.7892, Train: 0.8881, Val: 0.9193, Test: 0.9142
Epoch: 214, Loss: 0.7892, Train: 0.8881, Val: 0.9193, Test: 0.9142
Epoch: 215, Loss: 0.7891, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 216, Loss: 0.7891, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 217, Loss: 0.7890, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 218, Loss: 0.7890, Train: 0.8880, Val: 0.9194, Test: 0.9141
Epoch: 219, Loss: 0.7889, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 220, Loss: 0.7889, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 221, Loss: 0.7888, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 222, Loss: 0.7888, Train: 0.8879, Val: 0.9194, Test: 0.9141
Epoch: 223, Loss: 0.7887, Train: 0.8878, Val: 0.9194, Test: 0.9141
Epoch: 224, Loss: 0.7887, Train: 0.8878, Val: 0.9195, Test: 0.9141
Epoch: 225, Loss: 0.7887, Train: 0.8878, Val: 0.9194, Test: 0.9141
Epoch: 226, Loss: 0.7886, Train: 0.8878, Val: 0.9194, Test: 0.9141
Epoch: 227, Loss: 0.7886, Train: 0.8877, Val: 0.9195, Test: 0.9141
Epoch: 228, Loss: 0.7885, Train: 0.8877, Val: 0.9195, Test: 0.9141
Epoch: 229, Loss: 0.7885, Train: 0.8877, Val: 0.9195, Test: 0.9140
Epoch: 230, Loss: 0.7884, Train: 0.8877, Val: 0.9195, Test: 0.9140
Epoch: 231, Loss: 0.7884, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 232, Loss: 0.7884, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 233, Loss: 0.7883, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 234, Loss: 0.7883, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 235, Loss: 0.7882, Train: 0.8876, Val: 0.9195, Test: 0.9140
Epoch: 236, Loss: 0.7882, Train: 0.8875, Val: 0.9196, Test: 0.9140
Epoch: 237, Loss: 0.7882, Train: 0.8875, Val: 0.9196, Test: 0.9140
Epoch: 238, Loss: 0.7881, Train: 0.8875, Val: 0.9195, Test: 0.9140
Epoch: 239, Loss: 0.7881, Train: 0.8875, Val: 0.9196, Test: 0.9140
Epoch: 240, Loss: 0.7880, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 241, Loss: 0.7880, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 242, Loss: 0.7880, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 243, Loss: 0.7879, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 244, Loss: 0.7879, Train: 0.8874, Val: 0.9196, Test: 0.9140
Epoch: 245, Loss: 0.7879, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 246, Loss: 0.7878, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 247, Loss: 0.7878, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 248, Loss: 0.7877, Train: 0.8873, Val: 0.9196, Test: 0.9140
Epoch: 249, Loss: 0.7877, Train: 0.8873, Val: 0.9196, Test: 0.9139
Epoch: 250, Loss: 0.7877, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 251, Loss: 0.7876, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 252, Loss: 0.7876, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 253, Loss: 0.7876, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 254, Loss: 0.7875, Train: 0.8872, Val: 0.9196, Test: 0.9140
Epoch: 255, Loss: 0.7875, Train: 0.8871, Val: 0.9196, Test: 0.9139
Epoch: 256, Loss: 0.7875, Train: 0.8871, Val: 0.9196, Test: 0.9140
Epoch: 257, Loss: 0.7874, Train: 0.8871, Val: 0.9196, Test: 0.9140
Epoch: 258, Loss: 0.7874, Train: 0.8871, Val: 0.9196, Test: 0.9139
Epoch: 259, Loss: 0.7874, Train: 0.8871, Val: 0.9196, Test: 0.9139
Epoch: 260, Loss: 0.7873, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 261, Loss: 0.7873, Train: 0.8870, Val: 0.9196, Test: 0.9140
Epoch: 262, Loss: 0.7873, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 263, Loss: 0.7872, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 264, Loss: 0.7872, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 265, Loss: 0.7872, Train: 0.8870, Val: 0.9196, Test: 0.9139
Epoch: 266, Loss: 0.7871, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 267, Loss: 0.7871, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 268, Loss: 0.7871, Train: 0.8869, Val: 0.9196, Test: 0.9139
Epoch: 269, Loss: 0.7871, Train: 0.8869, Val: 0.9196, Test: 0.9139
Epoch: 270, Loss: 0.7870, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 271, Loss: 0.7870, Train: 0.8869, Val: 0.9197, Test: 0.9139
Epoch: 272, Loss: 0.7870, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 273, Loss: 0.7869, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 274, Loss: 0.7869, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 275, Loss: 0.7869, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 276, Loss: 0.7868, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 277, Loss: 0.7868, Train: 0.8868, Val: 0.9197, Test: 0.9139
Epoch: 278, Loss: 0.7868, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 279, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 280, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 281, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 282, Loss: 0.7867, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 283, Loss: 0.7866, Train: 0.8867, Val: 0.9197, Test: 0.9139
Epoch: 284, Loss: 0.7866, Train: 0.8866, Val: 0.9196, Test: 0.9139
Epoch: 285, Loss: 0.7866, Train: 0.8866, Val: 0.9196, Test: 0.9139
Epoch: 286, Loss: 0.7865, Train: 0.8866, Val: 0.9197, Test: 0.9139
Epoch: 287, Loss: 0.7865, Train: 0.8866, Val: 0.9196, Test: 0.9138
Epoch: 288, Loss: 0.7865, Train: 0.8866, Val: 0.9197, Test: 0.9139
Epoch: 289, Loss: 0.7865, Train: 0.8866, Val: 0.9196, Test: 0.9139
Epoch: 290, Loss: 0.7864, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 291, Loss: 0.7864, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 292, Loss: 0.7864, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 293, Loss: 0.7863, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 294, Loss: 0.7863, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 295, Loss: 0.7863, Train: 0.8865, Val: 0.9196, Test: 0.9138
Epoch: 296, Loss: 0.7863, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 297, Loss: 0.7862, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 298, Loss: 0.7862, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 299, Loss: 0.7862, Train: 0.8864, Val: 0.9196, Test: 0.9138
Epoch: 300, Loss: 0.7861, Train: 0.8864, Val: 0.9195, Test: 0.9138

 

6.1.2 二部图预测用户打分

参考:https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/bipartite_sage.py

跟上一节的主要区别在于分别用2种模型来编码2种节点

  1. 使用MovieLens数据集。其他介绍见上一节
    1. 电影节点,处理方式同上一节
    2. 用户节点:没有特征,代码中用随机初始化(torch.nn.Embedding)的方式获得初始特征
    3. 转换为无向图(产生逆向边)
    4. 通过metapath-based neighborhood生成电影节点之间的边(没看懂这里为什么用了gcn_norm得到边权重后用这个权重来筛选,这个权重在字面意义上应该是两个节点的度数相乘的归一化,意思是度数高的节点对更重要、所以建立这些边?)
  2. 数据分割:同上一节
  3. 电影-电影用2层GraphSAGE+MLP表征,用户节点用3个GraphSAGE(分别在每一个入边类型上应用)表征
  4. 链路预测解码方式同上一节
import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import MovieLens
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.conv.gcn_conv import gcn_norm

dataset = MovieLens('/data/pyg_data/MovieLens',model_name='all-MiniLM-L6-v2')
data = dataset[0]
data['user'].x=torch.arange(data['user'].num_nodes)
data['user','movie'].edge_label = data['user', 'movie'].edge_label.float()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:
data = T.ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label  # Remove "reverse" label.

# Perform a link-level split into training, validation, and test edges:
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'movie')],
    rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)

# Generate the co-occurence matrix of movies<>movies:
metapath = [('movie', 'rev_rates', 'user'), ('user', 'rates', 'movie')]
train_data = T.AddMetaPaths(metapaths=[metapath])(train_data)

# Apply normalization to filter the metapath:
_, edge_weight = gcn_norm(
    train_data['movie', 'movie'].edge_index,
    num_nodes=train_data['movie'].num_nodes,
    add_self_loops=False,
)
edge_index = train_data['movie', 'movie'].edge_index[:, edge_weight > 0.002]

train_data['movie', 'metapath_0', 'movie'].edge_index = edge_index
val_data['movie', 'metapath_0', 'movie'].edge_index = edge_index
test_data['movie', 'metapath_0', 'movie'].edge_index = edge_index


class MovieGNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = SAGEConv(-1, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return self.lin(x)


class UserGNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), hidden_channels)
        self.conv3 = SAGEConv((-1, -1), hidden_channels)
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        movie_x = self.conv1(
            x_dict['movie'],
            edge_index_dict[('movie', 'metapath_0', 'movie')],
        ).relu()

        user_x = self.conv2(
            (x_dict['movie'], x_dict['user']),
            edge_index_dict[('movie', 'rev_rates', 'user')],
        ).relu()

        user_x = self.conv3(
            (movie_x, user_x),
            edge_index_dict[('movie', 'rev_rates', 'user')],
        ).relu()

        return self.lin(user_x)


class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)

    def forward(self, z_src, z_dst, edge_label_index):
        row, col = edge_label_index
        z = torch.cat([z_src[row], z_dst[col]], dim=-1)

        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)


class Model(torch.nn.Module):
    def __init__(self, num_users, hidden_channels, out_channels):
        super().__init__()
        self.user_emb = Embedding(num_users, hidden_channels)
        self.user_encoder = UserGNNEncoder(hidden_channels, out_channels)
        self.movie_encoder = MovieGNNEncoder(hidden_channels, out_channels)
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = {
   }
        x_dict['user'] = self.user_emb(x_dict['user'])
        z_dict['user'] = self.user_encoder(x_dict, edge_index_dict)
        z_dict['movie'] = self.movie_encoder(
            x_dict['movie'],
            edge_index_dict[('movie', 'metapath_0', 'movie')],
        )
        return self.decoder(z_dict['user'], z_dict['movie'], edge_label_index)


model = Model(data['user'].num_nodes, hidden_channels=64, out_channels=64)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(
        train_data.x_dict,
        train_data.edge_index_dict,
        train_data['user', 'movie'].edge_label_index,
    )
    loss = F.mse_loss(out, train_data['user', 'movie'].edge_label)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(data):
    model.eval()
    out = model(
        data.x_dict,
        data.edge_index_dict,
        data['user', 'movie'].edge_label_index,
    ).clamp(min=0, max=5)
    rmse = F.mse_loss(out, data['user', 'movie'].edge_label).sqrt()
    return float(rmse)


for epoch in range(1, 701):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {
     epoch:04d}, Loss: {
     loss:.4f}, Train: {
     train_rmse:.4f}, '
          f'Val: {
     val_rmse:.4f}, Test: {
     test_rmse:.4f}')

 

输出:

Epoch: 0001, Loss: 11.8958, Train: 3.4372, Val: 3.4460, Test: 3.4240
Epoch: 0002, Loss: 11.8144, Train: 3.4252, Val: 3.4339, Test: 3.4120
Epoch: 0003, Loss: 11.7317, Train: 3.4129, Val: 3.4216, Test: 3.3997
Epoch: 0004, Loss: 11.6476, Train: 3.4003, Val: 3.4091, Test: 3.3872
Epoch: 0005, Loss: 11.5624, Train: 3.3876, Val: 3.3963, Test: 3.3745
Epoch: 0006, Loss: 11.4759, Train: 3.3747, Val: 3.3834, Test: 3.3616
Epoch: 0007, Loss: 11.3884, Train: 3.3615, Val: 3.3702, Test: 3.3485
Epoch: 0008, Loss: 11.2998, Train: 3.3481, Val: 3.3568, Test: 3.3351
Epoch: 0009, Loss: 11.2096, Train: 3.3342, Val: 3.3429, Test: 3.3213
Epoch: 0010, Loss: 11.1170, Train: 3.3198, Val: 3.3285, Test: 3.3069
Epoch: 0011, Loss: 11.0213, Train: 3.3048, Val: 3.3135, Test: 3.2919
Epoch: 0012, Loss: 10.9215, Train: 3.2888, Val: 3.2976, Test: 3.2760
Epoch: 0013, Loss: 10.8164, Train: 3.2719, Val: 3.2806, Test: 3.2591
Epoch: 0014, Loss: 10.7052, Train: 3.2537, Val: 3.2625, Test: 3.2410
Epoch: 0015, Loss: 10.5868, Train: 3.2342, Val: 3.2429, Test: 3.2216
Epoch: 0016, Loss: 10.4600, Train: 3.2131, Val: 3.2219, Test: 3.2005
Epoch: 0017, Loss: 10.3240, Train: 3.1903, Val: 3.1991, Test: 3.1778
Epoch: 0018, Loss: 10.1780, Train: 3.1656, Val: 3.1745, Test: 3.1533
Epoch: 0019, Loss: 10.0213, Train: 3.1389, Val: 3.1478, Test: 3.1267
Epoch: 0020, Loss: 9.8530, Train: 3.1100, Val: 3.1189, Test: 3.0979
Epoch: 0021, Loss: 9.6721, Train: 3.0785, Val: 3.0875, Test: 3.0666
Epoch: 0022, Loss: 9.4775, Train: 3.0444, Val: 3.0533, Test: 3.0326
Epoch: 0023, Loss: 9.2682, Train: 3.0072, Val: 3.0162, Test: 2.9956
Epoch: 0024, Loss: 9.0433, Train: 2.9668, Val: 2.9759, Test: 2.9554
Epoch: 0025, Loss: 8.8020, Train: 2.9230, Val: 2.9322, Test: 2.9118
Epoch: 0026, Loss: 8.5438, Train: 2.8755, Val: 2.8848, Test: 2.8646
Epoch: 0027, Loss: 8.2685, Train: 2.8241, Val: 2.8334, Test: 2.8134
Epoch: 0028, Loss: 7.9754, Train: 2.7684, Val: 2.7779, Test: 2.7580
Epoch: 0029, Loss: 7.6640, Train: 2.7082, Val: 2.7178, Test: 2.6981
Epoch: 0030, Loss: 7.3342, Train: 2.6431, Val: 2.6528, Test: 2.6334
Epoch: 0031, Loss: 6.9859, Train: 2.5728, Val: 2.5827, Test: 2.5635
Epoch: 0032, Loss: 6.6194, Train: 2.4970, Val: 2.5071, Test: 2.4882
Epoch: 0033, Loss: 6.2352, Train: 2.4155, Val: 2.4257, Test: 2.4071
Epoch: 0034, Loss: 5.8348, Train: 2.3281, Val: 2.3384, Test: 2.3201
Epoch: 0035, Loss: 5.4198, Train: 2.2345, Val: 2.2449, Test: 2.2270
Epoch: 0036, Loss: 4.9928, Train: 2.1347, Val: 2.1453, Test: 2.1278
Epoch: 0037, Loss: 4.5570, Train: 2.0289, Val: 2.0396, Test: 2.0225
Epoch: 0038, Loss: 4.1166, Train: 1.9177, Val: 1.9284, Test: 1.9118
Epoch: 0039, Loss: 3.6775, Train: 1.8019, Val: 1.8125, Test: 1.7965
Epoch: 0040, Loss: 3.2467, Train: 1.6832, Val: 1.6937, Test: 1.6783
Epoch: 0041, Loss: 2.8331, Train: 1.5643, Val: 1.5744, Test: 1.5598
Epoch: 0042, Loss: 2.4470, Train: 1.4493, Val: 1.4586, Test: 1.4449
Epoch: 0043, Loss: 2.1004, Train: 1.3440, Val: 1.3521, Test: 1.3395
Epoch: 0044, Loss: 1.8063, Train: 1.2560, Val: 1.2624, Test: 1.2511
Epoch: 0045, Loss: 1.5776, Train: 1.1939, Val: 1.1978, Test: 1.1881
Epoch: 0046, Loss: 1.4255, Train: 1.1646, Val: 1.1654, Test: 1.1574
Epoch: 0047, Loss: 1.3563, Train: 1.1696, Val: 1.1671, Test: 1.1609
Epoch: 0048, Loss: 1.3679, Train: 1.2025, Val: 1.1970, Test: 1.1923
Epoch: 0049, Loss: 1.4462, Train: 1.2497, Val: 1.2420, Test: 1.2385
Epoch: 0050, Loss: 1.5637, Train: 1.2956, Val: 1.2867, Test: 1.2837
Epoch: 0051, Loss: 1.6858, Train: 1.3288, Val: 1.3193, Test: 1.3166
Epoch: 0052, Loss: 1.7801, Train: 1.3444, Val: 1.3345, Test: 1.3320
Epoch: 0053, Loss: 1.8263, Train: 1.3421, Val: 1.3321, Test: 1.3296
Epoch: 0054, Loss: 1.8189, Train: 1.3240, Val: 1.3141, Test: 1.3115
Epoch: 0055, Loss: 1.7654, Train: 1.2938, Val: 1.2843, Test: 1.2814
Epoch: 0056, Loss: 1.6807, Train: 1.2566, Val: 1.2477, Test: 1.2443
Epoch: 0057, Loss: 1.5817, Train: 1.2177, Val: 1.2097, Test: 1.2056
Epoch: 0058, Loss: 1.4835, Train: 1.1819, Val: 1.1752, Test: 1.1703
Epoch: 0059, Loss: 1.3973, Train: 1.1530, Val: 1.1477, Test: 1.1418
Epoch: 0060, Loss: 1.3294, Train: 1.1323, Val: 1.1285, Test: 1.1216
Epoch: 0061, Loss: 1.2822, Train: 1.1200, Val: 1.1176, Test: 1.1098
Epoch: 0062, Loss: 1.2544, Train: 1.1149, Val: 1.1136, Test: 1.1051
Epoch: 0063, Loss: 1.2429, Train: 1.1151, Val: 1.1149, Test: 1.1056
Epoch: 0064, Loss: 1.2434, Train: 1.1186, Val: 1.1192, Test: 1.1093
Epoch: 0065, Loss: 1.2512, Train: 1.1235, Val: 1.1248, Test: 1.1144
Epoch: 0066, Loss: 1.2622, Train: 1.1283, Val: 1.1301, Test: 1.1193
Epoch: 0067, Loss: 1.2732, Train: 1.1320, Val: 1.1341, Test: 1.1230
Epoch: 0068, Loss: 1.2814, Train: 1.1338, Val: 1.1361, Test: 1.1247
Epoch: 0069, Loss: 1.2855, Train: 1.1334, Val: 1.1357, Test: 1.1242
Epoch: 0070, Loss: 1.2845, Train: 1.1308, Val: 1.1330, Test: 1.1215
Epoch: 0071, Loss: 1.2786, Train: 1.1262, Val: 1.1282, Test: 1.1167
Epoch: 0072, Loss: 1.2683, Train: 1.1200, Val: 1.1218, Test: 1.1104
Epoch: 0073, Loss: 1.2545, Train: 1.1129, Val: 1.1143, Test: 1.1030
Epoch: 0074, Loss: 1.2385, Train: 1.1052, Val: 1.1062, Test: 1.0951
Epoch: 0075, Loss: 1.2215, Train: 1.0977, Val: 1.0982, Test: 1.0873
Epoch: 0076, Loss: 1.2049, Train: 1.0908, Val: 1.0907, Test: 1.0801
Epoch: 0077, Loss: 1.1898, Train: 1.0849, Val: 1.0843, Test: 1.0740
Epoch: 0078, Loss: 1.1771, Train: 1.0804, Val: 1.0792, Test: 1.0692
Epoch: 0079, Loss: 1.1672, Train: 1.0772, Val: 1.0755, Test: 1.0658
Epoch: 0080, Loss: 1.1604, Train: 1.0753, Val: 1.0731, Test: 1.0637
Epoch: 0081, Loss: 1.1564, Train: 1.0745, Val: 1.0718, Test: 1.0627
Epoch: 0082, Loss: 1.1546, Train: 1.0744, Val: 1.0713, Test: 1.0623
Epoch: 0083, Loss: 1.1543, Train: 1.0746, Val: 1.0712, Test: 1.0624
Epoch: 0084, Loss: 1.1547, Train: 1.0747, Val: 1.0711, Test: 1.0624
Epoch: 0085, Loss: 1.1550, Train: 1.0745, Val: 1.0708, Test: 1.0622
Epoch: 0086, Loss: 1.1546, Train: 1.0738, Val: 1.0700, Test: 1.0614
Epoch: 0087, Loss: 1.1531, Train: 1.0725, Val: 1.0687, Test: 1.0601
Epoch: 0088, Loss: 1.1503, Train: 1.0707, Val: 1.0670, Test: 1.0583
Epoch: 0089, Loss: 1.1464, Train: 1.0685, Val: 1.0649, Test: 1.0561
Epoch: 0090, Loss: 1.1417, Train: 1.0661, Val: 1.0627, Test: 1.0536
Epoch: 0091, Loss: 1.1365, Train: 1.0636, Val: 1.0604, Test: 1.0512
Epoch: 0092, Loss: 1.1312, Train: 1.0613, Val: 1.0584, Test: 1.0489
Epoch: 0093, Loss: 1.1263, Train: 1.0592, Val: 1.0565, Test: 1.0469
Epoch: 0094, Loss: 1.1219, Train: 1.0574, Val: 1.0551, Test: 1.0452
Epoch: 0095, Loss: 1.1182, Train: 1.0560, Val: 1.0539, Test: 1.0439
Epoch: 0096, Loss: 1.1152, Train: 1.0549, Val: 1.0530, Test: 1.0428
Epoch: 0097, Loss: 1.1128, Train: 1.0540, Val: 1.0523, Test: 1.0420
Epoch: 0098, Loss: 1.1110, Train: 1.0533, Val: 1.0517, Test: 1.0412
Epoch: 0099, Loss: 1.1094, Train: 1.0526, Val: 1.0512, Test: 1.0406
Epoch: 0100, Loss: 1.1079, Train: 1.0518, Val: 1.0505, Test: 1.0399
Epoch: 0101, Loss: 1.1064, Train: 1.0511, Val: 1.0498, Test: 1.0391
Epoch: 0102, Loss: 1.1047, Train: 1.0502, Val: 1.0490, Test: 1.0382
Epoch: 0103, Loss: 1.1029, Train: 1.0492, Val: 1.0480, Test: 1.0372
Epoch: 0104, Loss: 1.1008, Train: 1.0481, Val: 1.0469, Test: 1.0362
Epoch: 0105, Loss: 1.0985, Train: 1.0470, Val: 1.0458, Test: 1.0350
Epoch: 0106, Loss: 1.0962, Train: 1.0459, Val: 1.0446, Test: 1.0339
Epoch: 0107, Loss: 1.0938, Train: 1.0447, Val: 1.0434, Test: 1.0327
Epoch: 0108, Loss: 1.0915, Train: 1.0437, Val: 1.0422, Test: 1.0317
Epoch: 0109, Loss: 1.0892, Train: 1.0427, Val: 1.0412, Test: 1.0306
Epoch: 0110, Loss: 1.0872, Train: 1.0418, Val: 1.0402, Test: 1.0297
Epoch: 0111, Loss: 1.0853, Train: 1.0409, Val: 1.0393, Test: 1.0289
Epoch: 0112, Loss: 1.0835, Train: 1.0402, Val: 1.0384, Test: 1.0281
Epoch: 0113, Loss: 1.0820, Train: 1.0395, Val: 1.0377, Test: 1.0274
Epoch: 0114, Loss: 1.0805, Train: 1.0388, Val: 1.0369, Test: 1.0267
Epoch: 0115, Loss: 1.0790, Train: 1.0381, Val: 1.0362, Test: 1.0260
Epoch: 0116, Loss: 1.0776, Train: 1.0374, Val: 1.0355, Test: 1.0253
Epoch: 0117, Loss: 1.0761, Train: 1.0367, Val: 1.0348, Test: 1.0246
Epoch: 0118, Loss: 1.0747, Train: 1.0359, Val: 1.0341, Test: 1.0239
Epoch: 0119, Loss: 1.0731, Train: 1.0352, Val: 1.0334, Test: 1.0232
Epoch: 0120, Loss: 1.0716, Train: 1.0344, Val: 1.0327, Test: 1.0224
Epoch: 0121, Loss: 1.0700, Train: 1.0336, Val: 1.0320, Test: 1.0217
Epoch: 0122, Loss: 1.0684, Train: 1.0329, Val: 1.0314, Test: 1.0210
Epoch: 0123, Loss: 1.0669, Train: 1.0322, Val: 1.0307, Test: 1.0203
Epoch: 0124, Loss: 1.0654, Train: 1.0315, Val: 1.0301, Test: 1.0196
Epoch: 0125, Loss: 1.0639, Train: 1.0308, Val: 1.0295, Test: 1.0189
Epoch: 0126, Loss: 1.0625, Train: 1.0301, Val: 1.0289, Test: 1.0183
Epoch: 0127, Loss: 1.0612, Train: 1.0295, Val: 1.0283, Test: 1.0177
Epoch: 0128, Loss: 1.0598, Train: 1.0289, Val: 1.0277, Test: 1.0171
Epoch: 0129, Loss: 1.0586, Train: 1.0282, Val: 1.0272, Test: 1.0165
Epoch: 0130, Loss: 1.0573, Train: 1.0276, Val: 1.0266, Test: 1.0159
Epoch: 0131, Loss: 1.0560, Train: 1.0270, Val: 1.0260, Test: 1.0153
Epoch: 0132, Loss: 1.0548, Train: 1.0264, Val: 1.0254, Test: 1.0147
Epoch: 0133, Loss: 1.0535, Train: 1.0258, Val: 1.0248, Test: 1.0141
Epoch: 0134, Loss: 1.0522, Train: 1.0252, Val: 1.0242, Test: 1.0135
Epoch: 0135, Loss: 1.0510, Train: 1.0246, Val: 1.0236, Test: 1.0129
Epoch: 0136, Loss: 1.0497, Train: 1.0240, Val: 1.0230, Test: 1.0123
Epoch: 0137, Loss: 1.0485, Train: 1.0234, Val: 1.0224, Test: 1.0117
Epoch: 0138, Loss: 1.0473, Train: 1.0228, Val: 1.0218, Test: 1.0112
Epoch: 0139, Loss: 1.0461, Train: 1.0222, Val: 1.0213, Test: 1.0106
Epoch: 0140, Loss: 1.0449, Train: 1.0216, Val: 1.0207, Test: 1.0100
Epoch: 0141, Loss: 1.0437, Train: 1.0211, Val: 1.0201, Test: 1.0095
Epoch: 0142, Loss: 1.0426, Train: 1.0205, Val: 1.0196, Test: 1.0090
Epoch: 0143, Loss: 1.0414, Train: 1.0200, Val: 1.0191, Test: 1.0084
Epoch: 0144, Loss: 1.0403, Train: 1.0194, Val: 1.0185, Test: 1.0079
Epoch: 0145, Loss: 1.0392, Train: 1.0189, Val: 1.0180, Test: 1.0074
Epoch: 0146, Loss: 1.0381, Train: 1.0183, Val: 1.0175, Test: 1.0068
Epoch: 0147, Loss: 1.0370, Train: 1.0178, Val: 1.0170, Test: 1.0063
Epoch: 0148, Loss: 1.0359, Train: 1.0172, Val: 1.0165, Test: 1.0058
Epoch: 0149, Loss: 1.0348, Train: 1.0167, Val: 1.0159, Test: 1.0053
Epoch: 0150, Loss: 1.0337, Train: 1.0162, Val: 1.0154, Test: 1.0047
Epoch: 0151, Loss: 1.0326, Train: 1.0156, Val: 1.0149, Test: 1.0042
Epoch: 0152, Loss: 1.0315, Train: 1.0151, Val: 1.0144, Test: 1.0037
Epoch: 0153, Loss: 1.0304, Train: 1.0146, Val: 1.0140, Test: 1.0032
Epoch: 0154, Loss: 1.0294, Train: 1.0141, Val: 1.0135, Test: 1.0027
Epoch: 0155, Loss: 1.0283, Train: 1.0135, Val: 1.0130, Test: 1.0022
Epoch: 0156, Loss: 1.0273, Train: 1.0130, Val: 1.0125, Test: 1.0017
Epoch: 0157, Loss: 1.0262, Train: 1.0125, Val: 1.0120, Test: 1.0012
Epoch: 0158, Loss: 1.0252, Train: 1.0120, Val: 1.0115, Test: 1.0008
Epoch: 0159, Loss: 1.0242, Train: 1.0115, Val: 1.0110, Test: 1.0003
Epoch: 0160, Loss: 1.0231, Train: 1.0110, Val: 1.0105, Test: 0.9998
Epoch: 0161, Loss: 1.0221, Train: 1.0105, Val: 1.0101, Test: 0.9993
Epoch: 0162, Loss: 1.0211, Train: 1.0100, Val: 1.0096, Test: 0.9988
Epoch: 0163, Loss: 1.0201, Train: 1.0095, Val: 1.0091, Test: 0.9984
Epoch: 0164, Loss: 1.0191, Train: 1.0090, Val: 1.0086, Test: 0.9979
Epoch: 0165, Loss: 1.0181, Train: 1.0085, Val: 1.0082, Test: 0.9974
Epoch: 0166, Loss: 1.0171, Train: 1.0080, Val: 1.0077, Test: 0.9970
Epoch: 0167, Loss: 1.0162, Train: 1.0076, Val: 1.0072, Test: 0.9965
Epoch: 0168, Loss: 1.0152, Train: 1.0071, Val: 1.0068, Test: 0.9960
Epoch: 0169, Loss: 1.0142, Train: 1.0066, Val: 1.0063, Test: 0.9956
Epoch: 0170, Loss: 1.0132, Train: 1.0061, Val: 1.0058, Test: 0.9951
Epoch: 0171, Loss: 1.0123, Train: 1.0056, Val: 1.0054, Test: 0.9947
Epoch: 0172, Loss: 1.0113, Train: 1.0052, Val: 1.0049, Test: 0.9942
Epoch: 0173, Loss: 1.0104, Train: 1.0047, Val: 1.0045, Test: 0.9938
Epoch: 0174, Loss: 1.0094, Train: 1.0042, Val: 1.0040, Test: 0.9933
Epoch: 0175, Loss: 1.0084, Train: 1.0037, Val: 1.0036, Test: 0.9929
Epoch: 0176, Loss: 1.0075, Train: 1.0033, Val: 1.0031, Test: 0.9924
Epoch: 0177, Loss: 1.0066, Train: 1.0028, Val: 1.0027, Test: 0.9920
Epoch: 0178, Loss: 1.0056, Train: 1.0023, Val: 1.0022, Test: 0.9916
Epoch: 0179, Loss: 1.0047, Train: 1.0019, Val: 1.0018, Test: 0.9911
Epoch: 0180, Loss: 1.0038, Train: 1.0014, Val: 1.0014, Test: 0.9907
Epoch: 0181, Loss: 1.0028, Train: 1.0010, Val: 1.0009, Test: 0.9903
Epoch: 0182, Loss: 1.0019, Train: 1.0005, Val: 1.0005, Test: 0.9898
Epoch: 0183, Loss: 1.0010, Train: 1.0001, Val: 1.0001, Test: 0.9894
Epoch: 0184, Loss: 1.0001, Train: 0.9996, Val: 0.9996, Test: 0.9890
Epoch: 0185, Loss: 0.9992, Train: 0.9991, Val: 0.9992, Test: 0.9885
Epoch: 0186, Loss: 0.9983, Train: 0.9987, Val: 0.9988, Test: 0.9881
Epoch: 0187, Loss: 0.9974, Train: 0.9982, Val: 0.9983, Test: 0.9877
Epoch: 0188, Loss: 0.9965, Train: 0.9978, Val: 0.9979, Test: 0.9873
Epoch: 0189, Loss: 0.9956, Train: 0.9974, Val: 0.9975, Test: 0.9868
Epoch: 0190, Loss: 0.9947, Train: 0.9969, Val: 0.9971, Test: 0.9864
Epoch: 0191, Loss: 0.9938, Train: 0.9965, Val: 0.9966, Test: 0.9860
Epoch: 0192, Loss: 0.9930, Train: 0.9960, Val: 0.9962, Test: 0.9856
Epoch: 0193, Loss: 0.9921, Train: 0.9956, Val: 0.9958, Test: 0.9852
Epoch: 0194, Loss: 0.9912, Train: 0.9952, Val: 0.9954, Test: 0.9848
Epoch: 0195, Loss: 0.9903, Train: 0.9947, Val: 0.9950, Test: 0.9844
Epoch: 0196, Loss: 0.9895, Train: 0.9943, Val: 0.9946, Test: 0.9840
Epoch: 0197, Loss: 0.9886, Train: 0.9939, Val: 0.9942, Test: 0.9836
Epoch: 0198, Loss: 0.9878, Train: 0.9934, Val: 0.9937, Test: 0.9832
Epoch: 0199, Loss: 0.9869, Train: 0.9930, Val: 0.9933, Test: 0.9828
Epoch: 0200, Loss: 0.9861, Train: 0.9926, Val: 0.9929, Test: 0.9824
Epoch: 0201, Loss: 0.9852, Train: 0.9922, Val: 0.9925, Test: 0.9820
Epoch: 0202, Loss: 0.9844, Train: 0.9917, Val: 0.9921, Test: 0.9816
Epoch: 0203, Loss: 0.9836, Train: 0.9913, Val: 0.9917, Test: 0.9812
Epoch: 0204, Loss: 0.9827, Train: 0.9909, Val: 0.9913, Test: 0.9808
Epoch: 0205, Loss: 0.9819, Train: 0.9905, Val: 0.9909, Test: 0.9804
Epoch: 0206, Loss: 0.9811, Train: 0.9901, Val: 0.9905, Test: 0.9801
Epoch: 0207, Loss: 0.9802, Train: 0.9897, Val: 0.9902, Test: 0.9797
Epoch: 0208, Loss: 0.9794, Train: 0.9892, Val: 0.9898, Test: 0.9793
Epoch: 0209, Loss: 0.9786, Train: 0.9888, Val: 0.9894, Test: 0.9789
Epoch: 0210, Loss: 0.9778, Train: 0.9884, Val: 0.9890, Test: 0.9786
Epoch: 0211, Loss: 0.9770, Train: 0.9880, Val: 0.9886, Test: 0.9782
Epoch: 0212, Loss: 0.9762, Train: 0.9876, Val: 0.9882, Test: 0.9778
Epoch: 0213, Loss: 0.9754, Train: 0.9872, Val: 0.9878, Test: 0.9775
Epoch: 0214, Loss: 0.9746, Train: 0.9868, Val: 0.9875, Test: 0.9771
Epoch: 0215, Loss: 0.9738, Train: 0.9864, Val: 0.9871, Test: 0.9767
Epoch: 0216, Loss: 0.9730, Train: 0.9860, Val: 0.9867, Test: 0.9764
Epoch: 0217, Loss: 0.9722, Train: 0.9856, Val: 0.9863, Test: 0.9760
Epoch: 0218, Loss: 0.9715, Train: 0.9852, Val: 0.9860, Test: 0.9757
Epoch: 0219, Loss: 0.9707, Train: 0.9848, Val: 0.9856, Test: 0.9753
Epoch: 0220, Loss: 0.9699, Train: 0.9845, Val: 0.9852, Test: 0.9749
Epoch: 0221, Loss: 0.9691, Train: 0.9841, Val: 0.9849, Test: 0.9746
Epoch: 0222, Loss: 0.9684, Train: 0.9837, Val: 0.9845, Test: 0.9742
Epoch: 0223, Loss: 0.9676, Train: 0.9833, Val: 0.9841, Test: 0.9739
Epoch: 0224, Loss: 0.9669, Train: 0.9829, Val: 0.9838, Test: 0.9735
Epoch: 0225, Loss: 0.9661, Train: 0.9825, Val: 0.9834, Test: 0.9732
Epoch: 0226, Loss: 0.9654, Train: 0.9822, Val: 0.9831, Test: 0.9729
Epoch: 0227, Loss: 0.9646, Train: 0.9818, Val: 0.9827, Test: 0.9725
Epoch: 0228, Loss: 0.9639, Train: 0.9814, Val: 0.9823, Test: 0.9722
Epoch: 0229, Loss: 0.9632, Train: 0.9810, Val: 0.9820, Test: 0.9718
Epoch: 0230, Loss: 0.9625, Train: 0.9807, Val: 0.9816, Test: 0.9715
Epoch: 0231, Loss: 0.9617, Train: 0.9803, Val: 0.9813, Test: 0.9712
Epoch: 0232, Loss: 0.9610, Train: 0.9800, Val: 0.9810, Test: 0.9709
Epoch: 0233, Loss: 0.9603, Train: 0.9796, Val: 0.9806, Test: 0.9705
Epoch: 0234, Loss: 0.9596, Train: 0.9792, Val: 0.9803, Test: 0.9702
Epoch: 0235, Loss: 0.9589, Train: 0.9789, Val: 0.9799, Test: 0.9699
Epoch: 0236, Loss: 0.9582, Train: 0.9785, Val: 0.9796, Test: 0.9696
Epoch: 0237, Loss: 0.9575, Train: 0.9782, Val: 0.9793, Test: 0.9692
Epoch: 0238, Loss: 0.9568, Train: 0.9778, Val: 0.9789, Test: 0.9689
Epoch: 0239, Loss: 0.9561, Train: 0.9775, Val: 0.9786, Test: 0.9686
Epoch: 0240, Loss: 0.9554, Train: 0.9771, Val: 0.9783, Test: 0.9683
Epoch: 0241, Loss: 0.9548, Train: 0.9768, Val: 0.9780, Test: 0.9680
Epoch: 0242, Loss: 0.9541, Train: 0.9764, Val: 0.9777, Test: 0.9677
Epoch: 0243, Loss: 0.9534, Train: 0.9761, Val: 0.9773, Test: 0.9674
Epoch: 0244, Loss: 0.9528, Train: 0.9758, Val: 0.9770, Test: 0.9671
Epoch: 0245, Loss: 0.9521, Train: 0.9754, Val: 0.9767, Test: 0.9668
Epoch: 0246, Loss: 0.9515, Train: 0.9751, Val: 0.9764, Test: 0.9665
Epoch: 0247, Loss: 0.9508, Train: 0.9748, Val: 0.9761, Test: 0.9662
Epoch: 0248, Loss: 0.9502, Train: 0.9744, Val: 0.9758, Test: 0.9659
Epoch: 0249, Loss: 0.9495, Train: 0.9741, Val: 0.9755, Test: 0.9656
Epoch: 0250, Loss: 0.9489, Train: 0.9738, Val: 0.9752, Test: 0.9653
Epoch: 0251, Loss: 0.9483, Train: 0.9735, Val: 0.9749, Test: 0.9650
Epoch: 0252, Loss: 0.9476, Train: 0.9731, Val: 0.9746, Test: 0.9648
Epoch: 0253, Loss: 0.9470, Train: 0.9728, Val: 0.9743, Test: 0.9645
Epoch: 0254, Loss: 0.9464, Train: 0.9725, Val: 0.9740, Test: 0.9642
Epoch: 0255, Loss: 0.9458, Train: 0.9722, Val: 0.9737, Test: 0.9639
Epoch: 0256, Loss: 0.9452, Train: 0.9719, Val: 0.9734, Test: 0.9637
Epoch: 0257, Loss: 0.9446, Train: 0.9716, Val: 0.9731, Test: 0.9634
Epoch: 0258, Loss: 0.9440, Train: 0.9713, Val: 0.9729, Test: 0.9631
Epoch: 0259, Loss: 0.9434, Train: 0.9710, Val: 0.9726, Test: 0.9629
Epoch: 0260, Loss: 0.9428, Train: 0.9707, Val: 0.9723, Test: 0.9626
Epoch: 0261, Loss: 0.9422, Train: 0.9704, Val: 0.9720, Test: 0.9623
Epoch: 0262, Loss: 0.9416, Train: 0.9701, Val: 0.9718, Test: 0.9621
Epoch: 0263, Loss: 0.9410, Train: 0.9698, Val: 0.9715, Test: 0.9618
Epoch: 0264, Loss: 0.9405, Train: 0.9695, Val: 0.9712, Test: 0.9616
Epoch: 0265, Loss: 0.9399, Train: 0.9692, Val: 0.9710, Test: 0.9613
Epoch: 0266, Loss: 0.9393, Train: 0.9689, Val: 0.9707, Test: 0.9610
Epoch: 0267, Loss: 0.9388, Train: 0.9686, Val: 0.9704, Test: 0.9608
Epoch: 0268, Loss: 0.9382, Train: 0.9683, Val: 0.9702, Test: 0.9606
Epoch: 0269, Loss: 0.9376, Train: 0.9680, Val: 0.9699, Test: 0.9603
Epoch: 0270, Loss: 0.9371, Train: 0.9678, Val: 0.9697, Test: 0.9601
Epoch: 0271, Loss: 0.9365, Train: 0.9675, Val: 0.9694, Test: 0.9598
Epoch: 0272, Loss: 0.9360, Train: 0.9672, Val: 0.9692, Test: 0.9596
Epoch: 0273, Loss: 0.9355, Train: 0.9669, Val: 0.9689, Test: 0.9594
Epoch: 0274, Loss: 0.9349, Train: 0.9666, Val: 0.9687, Test: 0.9591
Epoch: 0275, Loss: 0.9344, Train: 0.9664, Val: 0.9684, Test: 0.9589
Epoch: 0276, Loss: 0.9338, Train: 0.9661, Val: 0.9682, Test: 0.9587
Epoch: 0277, Loss: 0.9333, Train: 0.9658, Val: 0.9679, Test: 0.9584
Epoch: 0278, Loss: 0.9328, Train: 0.9655, Val: 0.9677, Test: 0.9582
Epoch: 0279, Loss: 0.9323, Train: 0.9653, Val: 0.9674, Test: 0.9580
Epoch: 0280, Loss: 0.9317, Train: 0.9650, Val: 0.9672, Test: 0.9578
Epoch: 0281, Loss: 0.9312, Train: 0.9647, Val: 0.9670, Test: 0.9575
Epoch: 0282, Loss: 0.9307, Train: 0.9645, Val: 0.9667, Test: 0.9573
Epoch: 0283, Loss: 0.9302, Train: 0.9642, Val: 0.9665, Test: 0.9571
Epoch: 0284, Loss: 0.9297, Train: 0.9639, Val: 0.9662, Test: 0.9569
Epoch: 0285, Loss: 0.9292, Train: 0.9637, Val: 0.9660, Test: 0.9566
Epoch: 0286, Loss: 0.9287, Train: 0.9634, Val: 0.9658, Test: 0.9564
Epoch: 0287, Loss: 0.9282, Train: 0.9632, Val: 0.9656, Test: 0.9562
Epoch: 0288, Loss: 0.9277, Train: 0.9629, Val: 0.9653, Test: 0.9560
Epoch: 0289, Loss: 0.9272, Train: 0.9627, Val: 0.9651, Test: 0.9558
Epoch: 0290, Loss: 0.9267, Train: 0.9624, Val: 0.9649, Test: 0.9556
Epoch: 0291, Loss: 0.9262, Train: 0.9621, Val: 0.9647, Test: 0.9553
Epoch: 0292, Loss: 0.9257, Train: 0.9619, Val: 0.9644, Test: 0.9551
Epoch: 0293, Loss: 0.9253, Train: 0.9617, Val: 0.9642, Test: 0.9549
Epoch: 0294, Loss: 0.9248, Train: 0.9614, Val: 0.9640, Test: 0.9547
Epoch: 0295, Loss: 0.9243, Train: 0.9612, Val: 0.9638, Test: 0.9545
Epoch: 0296, Loss: 0.9238, Train: 0.9609, Val: 0.9636, Test: 0.9543
Epoch: 0297, Loss: 0.9233, Train: 0.9607, Val: 0.9634, Test: 0.9541
Epoch: 0298, Loss: 0.9229, Train: 0.9604, Val: 0.9631, Test: 0.9539
Epoch: 0299, Loss: 0.9224, Train: 0.9602, Val: 0.9629, Test: 0.9537
Epoch: 0300, Loss: 0.9219, Train: 0.9599, Val: 0.9627, Test: 0.9535
Epoch: 0301, Loss: 0.9215, Train: 0.9597, Val: 0.9625, Test: 0.9533
Epoch: 0302, Loss: 0.9210, Train: 0.9595, Val: 0.9623, Test: 0.9531
Epoch: 0303, Loss: 0.9206, Train: 0.9592, Val: 0.9621, Test: 0.9529
Epoch: 0304, Loss: 0.9201, Train: 0.9590, Val: 0.9619, Test: 0.9527
Epoch: 0305, Loss: 0.9196, Train: 0.9587, Val: 0.9617, Test: 0.9525
Epoch: 0306, Loss: 0.9192, Train: 0.9585, Val: 0.9615, Test: 0.9523
Epoch: 0307, Loss: 0.9187, Train: 0.9583, Val: 0.9613, Test: 0.9521
Epoch: 0308, Loss: 0.9183, Train: 0.9580, Val: 0.9611, Test: 0.9519
Epoch: 0309, Loss: 0.9178, Train: 0.9578, Val: 0.9609, Test: 0.9517
Epoch: 0310, Loss: 0.9174, Train: 0.9576, Val: 0.9607, Test: 0.9515
Epoch: 0311, Loss: 0.9169, Train: 0.9573, Val: 0.9605, Test: 0.9513
Epoch: 0312, Loss: 0.9165, Train: 0.9571, Val: 0.9603, Test: 0.9511
Epoch: 0313, Loss: 0.9160, Train: 0.9569, Val: 0.9601, Test: 0.9510
Epoch: 0314, Loss: 0.9156, Train: 0.9566, Val: 0.9599, Test: 0.9508
Epoch: 0315, Loss: 0.9152, Train: 0.9564, Val: 0.9597, Test: 0.9506
Epoch: 0316, Loss: 0.9147, Train: 0.9562, Val: 0.9595, Test: 0.9504
Epoch: 0317, Loss: 0.9143, Train: 0.9559, Val: 0.9593, Test: 0.9502
Epoch: 0318, Loss: 0.9138, Train: 0.9557, Val: 0.9591, Test: 0.9500
Epoch: 0319, Loss: 0.9134, Train: 0.9555, Val: 0.9589, Test: 0.9498
Epoch: 0320, Loss: 0.9130, Train: 0.9553, Val: 0.9587, Test: 0.9496
Epoch: 0321, Loss: 0.9125, Train: 0.9550, Val: 0.9585, Test: 0.9495
Epoch: 0322, Loss: 0.9121, Train: 0.9548, Val: 0.9584, Test: 0.9493
Epoch: 0323, Loss: 0.9117, Train: 0.9546, Val: 0.9582, Test: 0.9491
Epoch: 0324, Loss: 0.9112, Train: 0.9544, Val: 0.9580, Test: 0.9489
Epoch: 0325, Loss: 0.9108, Train: 0.9541, Val: 0.9578, Test: 0.9487
Epoch: 0326, Loss: 0.9104, Train: 0.9539, Val: 0.9576, Test: 0.9485
Epoch: 0327, Loss: 0.9099, Train: 0.9537, Val: 0.9574, Test: 0.9483
Epoch: 0328, Loss: 0.9095, Train: 0.9535, Val: 0.9572, Test: 0.9481
Epoch: 0329, Loss: 0.9091, Train: 0.9532, Val: 0.9571, Test: 0.9479
Epoch: 0330, Loss: 0.9087, Train: 0.9530, Val: 0.9569, Test: 0.9478
Epoch: 0331, Loss: 0.9082, Train: 0.9528, Val: 0.9567, Test: 0.9476
Epoch: 0332, Loss: 0.9078, Train: 0.9526, Val: 0.9565, Test: 0.9474
Epoch: 0333, Loss: 0.9074, Train: 0.9523, Val: 0.9563, Test: 0.9472
Epoch: 0334, Loss: 0.9070, Train: 0.9521, Val: 0.9561, Test: 0.9470
Epoch: 0335, Loss: 0.9065, Train: 0.9519, Val: 0.9560, Test: 0.9469
Epoch: 0336, Loss: 0.9061, Train: 0.9517, Val: 0.9558, Test: 0.9467
Epoch: 0337, Loss: 0.9057, Train: 0.9515, Val: 0.9556, Test: 0.9465
Epoch: 0338, Loss: 0.9053, Train: 0.9512, Val: 0.9554, Test: 0.9463
Epoch: 0339, Loss: 0.9049, Train: 0.9510, Val: 0.9552, Test: 0.9461
Epoch: 0340, Loss: 0.9045, Train: 0.9508, Val: 0.9550, Test: 0.9459
Epoch: 0341, Loss: 0.9040, Train: 0.9506, Val: 0.9549, Test: 0.9458
Epoch: 0342, Loss: 0.9036, Train: 0.9504, Val: 0.9547, Test: 0.9456
Epoch: 0343, Loss: 0.9032, Train: 0.9501, Val: 0.9545, Test: 0.9454
Epoch: 0344, Loss: 0.9028, Train: 0.9499, Val: 0.9543, Test: 0.9452
Epoch: 0345, Loss: 0.9024, Train: 0.9497, Val: 0.9542, Test: 0.9451
Epoch: 0346, Loss: 0.9020, Train: 0.9495, Val: 0.9540, Test: 0.9449
Epoch: 0347, Loss: 0.9016, Train: 0.9493, Val: 0.9538, Test: 0.9447
Epoch: 0348, Loss: 0.9011, Train: 0.9491, Val: 0.9536, Test: 0.9445
Epoch: 0349, Loss: 0.9007, Train: 0.9489, Val: 0.9535, Test: 0.9444
Epoch: 0350, Loss: 0.9003, Train: 0.9486, Val: 0.9533, Test: 0.9442
Epoch: 0351, Loss: 0.8999, Train: 0.9484, Val: 0.9531, Test: 0.9440
Epoch: 0352, Loss: 0.8995, Train: 0.9482, Val: 0.9529, Test: 0.9439
Epoch: 0353, Loss: 0.8991, Train: 0.9480, Val: 0.9527, Test: 0.9437
Epoch: 0354, Loss: 0.8987, Train: 0.9478, Val: 0.9525, Test: 0.9435
Epoch: 0355, Loss: 0.8983, Train: 0.9475, Val: 0.9524, Test: 0.9433
Epoch: 0356, Loss: 0.8979, Train: 0.9473, Val: 0.9522, Test: 0.9432
Epoch: 0357, Loss: 0.8974, Train: 0.9471, Val: 0.9520, Test: 0.9430
Epoch: 0358, Loss: 0.8970, Train: 0.9469, Val: 0.9519, Test: 0.9428
Epoch: 0359, Loss: 0.8966, Train: 0.9467, Val: 0.9517, Test: 0.9427
Epoch: 0360, Loss: 0.8962, Train: 0.9464, Val: 0.9515, Test: 0.9425
Epoch: 0361, Loss: 0.8958, Train: 0.9462, Val: 0.9513, Test: 0.9423
Epoch: 0362, Loss: 0.8954, Train: 0.9460, Val: 0.9512, Test: 0.9421
Epoch: 0363, Loss: 0.8949, Train: 0.9458, Val: 0.9510, Test: 0.9420
Epoch: 0364, Loss: 0.8945, Train: 0.9456, Val: 0.9508, Test: 0.9418
Epoch: 0365, Loss: 0.8941, Train: 0.9453, Val: 0.9506, Test: 0.9416
Epoch: 0366, Loss: 0.8937, Train: 0.9451, Val: 0.9504, Test: 0.9414
Epoch: 0367, Loss: 0.8933, Train: 0.9449, Val: 0.9503, Test: 0.9412
Epoch: 0368, Loss: 0.8928, Train: 0.9447, Val: 0.9501, Test: 0.9411
Epoch: 0369, Loss: 0.8924, Train: 0.9444, Val: 0.9499, Test: 0.9409
Epoch: 0370, Loss: 0.8920, Train: 0.9442, Val: 0.9497, Test: 0.9407
Epoch: 0371, Loss: 0.8915, Train: 0.9440, Val: 0.9496, Test: 0.9405
Epoch: 0372, Loss: 0.8911, Train: 0.9438, Val: 0.9494, Test: 0.9403
Epoch: 0373, Loss: 0.8907, Train: 0.9435, Val: 0.9492, Test: 0.9401
Epoch: 0374, Loss: 0.8902, Train: 0.9433, Val: 0.9490, Test: 0.9399
Epoch: 0375, Loss: 0.8898, Train: 0.9431, Val: 0.9488, Test: 0.9398
Epoch: 0376, Loss: 0.8894, Train: 0.9428, Val: 0.9486, Test: 0.9396
Epoch: 0377, Loss: 0.8889, Train: 0.9426, Val: 0.9484, Test: 0.9394
Epoch: 0378, Loss: 0.8885, Train: 0.9424, Val: 0.9483, Test: 0.9392
Epoch: 0379, Loss: 0.8880, Train: 0.9421, Val: 0.9481, Test: 0.9390
Epoch: 0380, Loss: 0.8876, Train: 0.9419, Val: 0.9479, Test: 0.9388
Epoch: 0381, Loss: 0.8871, Train: 0.9416, Val: 0.9477, Test: 0.9386
Epoch: 0382, Loss: 0.8867, Train: 0.9414, Val: 0.9475, Test: 0.9384
Epoch: 0383, Loss: 0.8862, Train: 0.9412, Val: 0.9473, Test: 0.9382
Epoch: 0384, Loss: 0.8858, Train: 0.9409, Val: 0.9471, Test: 0.9380
Epoch: 0385, Loss: 0.8853, Train: 0.9407, Val: 0.9470, Test: 0.9378
Epoch: 0386, Loss: 0.8849, Train: 0.9404, Val: 0.9468, Test: 0.9377
Epoch: 0387, Loss: 0.8844, Train: 0.9402, Val: 0.9466, Test: 0.9375
Epoch: 0388, Loss: 0.8840, Train: 0.9399, Val: 0.9464, Test: 0.9373
Epoch: 0389, Loss: 0.8835, Train: 0.9397, Val: 0.9462, Test: 0.9371
Epoch: 0390, Loss: 0.8830, Train: 0.9394, Val: 0.9460, Test: 0.9369
Epoch: 0391, Loss: 0.8826, Train: 0.9392, Val: 0.9458, Test: 0.9367
Epoch: 0392, Loss: 0.8821, Train: 0.9390, Val: 0.9456, Test: 0.9365
Epoch: 0393, Loss: 0.8816, Train: 0.9387, Val: 0.9454, Test: 0.9363
Epoch: 0394, Loss: 0.8812, Train: 0.9384, Val: 0.9452, Test: 0.9361
Epoch: 0395, Loss: 0.8807, Train: 0.9382, Val: 0.9450, Test: 0.9359
Epoch: 0396, Loss: 0.8802, Train: 0.9379, Val: 0.9448, Test: 0.9357
Epoch: 0397, Loss: 0.8798, Train: 0.9377, Val: 0.9447, Test: 0.9355
Epoch: 0398, Loss: 0.8793, Train: 0.9374, Val: 0.9445, Test: 0.9353
Epoch: 0399, Loss: 0.8788, Train: 0.9372, Val: 0.9443, Test: 0.9351
Epoch: 0400, Loss: 0.8783, Train: 0.9369, Val: 0.9441, Test: 0.9349
Epoch: 0401, Loss: 0.8778, Train: 0.9367, Val: 0.9439, Test: 0.9347
Epoch: 0402, Loss: 0.8774, Train: 0.9364, Val: 0.9437, Test: 0.9345
Epoch: 0403, Loss: 0.8769, Train: 0.9361, Val: 0.9435, Test: 0.9344
Epoch: 0404, Loss: 0.8764, Train: 0.9359, Val: 0.9433, Test: 0.9342
Epoch: 0405, Loss: 0.8759, Train: 0.9356, Val: 0.9431, Test: 0.9340
Epoch: 0406, Loss: 0.8754, Train: 0.9354, Val: 0.9429, Test: 0.9338
Epoch: 0407, Loss: 0.8749, Train: 0.9351, Val: 0.9427, Test: 0.9336
Epoch: 0408, Loss: 0.8744, Train: 0.9348, Val: 0.9425, Test: 0.9333
Epoch: 0409, Loss: 0.8739, Train: 0.9346, Val: 0.9423, Test: 0.9331
Epoch: 0410, Loss: 0.8734, Train: 0.9343, Val: 0.9421, Test: 0.9329
Epoch: 0411, Loss: 0.8729, Train: 0.9340, Val: 0.9419, Test: 0.9327
Epoch: 0412, Loss: 0.8724, Train: 0.9337, Val: 0.9417, Test: 0.9325
Epoch: 0413, Loss: 0.8719, Train: 0.9335, Val: 0.9415, Test: 0.9323
Epoch: 0414, Loss: 0.8714, Train: 0.9332, Val: 0.9413, Test: 0.9321
Epoch: 0415, Loss: 0.8709, Train: 0.9329, Val: 0.9411, Test: 0.9319
Epoch: 0416, Loss: 0.8703, Train: 0.9326, Val: 0.9409, Test: 0.9317
Epoch: 0417, Loss: 0.8698, Train: 0.9324, Val: 0.9407, Test: 0.9315
Epoch: 0418, Loss: 0.8693, Train: 0.9321, Val: 0.9405, Test: 0.9312
Epoch: 0419, Loss: 0.8688, Train: 0.9318, Val: 0.9403, Test: 0.9310
Epoch: 0420, Loss: 0.8683, Train: 0.9315, Val: 0.9401, Test: 0.9308
Epoch: 0421, Loss: 0.8677, Train: 0.9312, Val: 0.9398, Test: 0.9306
Epoch: 0422, Loss: 0.8672, Train: 0.9310, Val: 0.9396, Test: 0.9303
Epoch: 0423, Loss: 0.8667, Train: 0.9307, Val: 0.9394, Test: 0.9301
Epoch: 0424, Loss: 0.8662, Train: 0.9304, Val: 0.9392, Test: 0.9299
Epoch: 0425, Loss: 0.8656, Train: 0.9301, Val: 0.9390, Test: 0.9297
Epoch: 0426, Loss: 0.8651, Train: 0.9298, Val: 0.9388, Test: 0.9295
Epoch: 0427, Loss: 0.8646, Train: 0.9295, Val: 0.9386, Test: 0.9293
Epoch: 0428, Loss: 0.8640, Train: 0.9292, Val: 0.9384, Test: 0.9290
Epoch: 0429, Loss: 0.8635, Train: 0.9289, Val: 0.9382, Test: 0.9288
Epoch: 0430, Loss: 0.8629, Train: 0.9286, Val: 0.9380, Test: 0.9285
Epoch: 0431, Loss: 0.8623, Train: 0.9283, Val: 0.9378, Test: 0.9283
Epoch: 0432, Loss: 0.8618, Train: 0.9280, Val: 0.9376, Test: 0.9281
Epoch: 0433, Loss: 0.8612, Train: 0.9277, Val: 0.9373, Test: 0.9278
Epoch: 0434, Loss: 0.8607, Train: 0.9274, Val: 0.9371, Test: 0.9276
Epoch: 0435, Loss: 0.8601, Train: 0.9271, Val: 0.9369, Test: 0.9274
Epoch: 0436, Loss: 0.8595, Train: 0.9268, Val: 0.9367, Test: 0.9271
Epoch: 0437, Loss: 0.8590, Train: 0.9265, Val: 0.9365, Test: 0.9269
Epoch: 0438, Loss: 0.8584, Train: 0.9262, Val: 0.9363, Test: 0.9266
Epoch: 0439, Loss: 0.8578, Train: 0.9259, Val: 0.9361, Test: 0.9264
Epoch: 0440, Loss: 0.8573, Train: 0.9256, Val: 0.9358, Test: 0.9262
Epoch: 0441, Loss: 0.8567, Train: 0.9252, Val: 0.9356, Test: 0.9259
Epoch: 0442, Loss: 0.8561, Train: 0.9249, Val: 0.9354, Test: 0.9257
Epoch: 0443, Loss: 0.8555, Train: 0.9246, Val: 0.9352, Test: 0.9254
Epoch: 0444, Loss: 0.8549, Train: 0.9243, Val: 0.9350, Test: 0.9252
Epoch: 0445, Loss: 0.8543, Train: 0.9240, Val: 0.9348, Test: 0.9249
Epoch: 0446, Loss: 0.8537, Train: 0.9236, Val: 0.9346, Test: 0.9246
Epoch: 0447, Loss: 0.8531, Train: 0.9233, Val: 0.9344, Test: 0.9244
Epoch: 0448, Loss: 0.8525, Train: 0.9230, Val: 0.9342, Test: 0.9241
Epoch: 0449, Loss: 0.8519, Train: 0.9227, Val: 0.9339, Test: 0.9238
Epoch: 0450, Loss: 0.8513, Train: 0.9223, Val: 0.9337, Test: 0.9236
Epoch: 0451, Loss: 0.8507, Train: 0.9220, Val: 0.9335, Test: 0.9234
Epoch: 0452, Loss: 0.8501, Train: 0.9217, Val: 0.9332, Test: 0.9231
Epoch: 0453, Loss: 0.8495, Train: 0.9213, Val: 0.9330, Test: 0.9228
Epoch: 0454, Loss: 0.8489, Train: 0.9210, Val: 0.9328, Test: 0.9226
Epoch: 0455, Loss: 0.8483, Train: 0.9207, Val: 0.9325, Test: 0.9223
Epoch: 0456, Loss: 0.8476, Train: 0.9203, Val: 0.9323, Test: 0.9220
Epoch: 0457, Loss: 0.8470, Train: 0.9200, Val: 0.9320, Test: 0.9218
Epoch: 0458, Loss: 0.8464, Train: 0.9196, Val: 0.9318, Test: 0.9215
Epoch: 0459, Loss: 0.8458, Train: 0.9193, Val: 0.9316, Test: 0.9212
Epoch: 0460, Loss: 0.8451, Train: 0.9189, Val: 0.9313, Test: 0.9210
Epoch: 0461, Loss: 0.8445, Train: 0.9186, Val: 0.9311, Test: 0.9207
Epoch: 0462, Loss: 0.8438, Train: 0.9182, Val: 0.9309, Test: 0.9205
Epoch: 0463, Loss: 0.8432, Train: 0.9179, Val: 0.9306, Test: 0.9202
Epoch: 0464, Loss: 0.8426, Train: 0.9175, Val: 0.9304, Test: 0.9199
Epoch: 0465, Loss: 0.8419, Train: 0.9172, Val: 0.9302, Test: 0.9197
Epoch: 0466, Loss: 0.8413, Train: 0.9168, Val: 0.9300, Test: 0.9194
Epoch: 0467, Loss: 0.8406, Train: 0.9165, Val: 0.9297, Test: 0.9191
Epoch: 0468, Loss: 0.8400, Train: 0.9161, Val: 0.9295, Test: 0.9189
Epoch: 0469, Loss: 0.8393, Train: 0.9158, Val: 0.9293, Test: 0.9186
Epoch: 0470, Loss: 0.8386, Train: 0.9154, Val: 0.9291, Test: 0.9184
Epoch: 0471, Loss: 0.8380, Train: 0.9150, Val: 0.9289, Test: 0.9181
Epoch: 0472, Loss: 0.8373, Train: 0.9147, Val: 0.9286, Test: 0.9178
Epoch: 0473, Loss: 0.8367, Train: 0.9143, Val: 0.9284, Test: 0.9176
Epoch: 0474, Loss: 0.8360, Train: 0.9140, Val: 0.9283, Test: 0.9173
Epoch: 0475, Loss: 0.8353, Train: 0.9136, Val: 0.9281, Test: 0.9171
Epoch: 0476, Loss: 0.8347, Train: 0.9132, Val: 0.9279, Test: 0.9169
Epoch: 0477, Loss: 0.8340, Train: 0.9129, Val: 0.9277, Test: 0.9166
Epoch: 0478, Loss: 0.8334, Train: 0.9125, Val: 0.9275, Test: 0.9164
Epoch: 0479, Loss: 0.8327, Train: 0.9121, Val: 0.9273, Test: 0.9162
Epoch: 0480, Loss: 0.8320, Train: 0.9118, Val: 0.9271, Test: 0.9159
Epoch: 0481, Loss: 0.8314, Train: 0.9114, Val: 0.9269, Test: 0.9157
Epoch: 0482, Loss: 0.8307, Train: 0.9111, Val: 0.9267, Test: 0.9154
Epoch: 0483, Loss: 0.8300, Train: 0.9107, Val: 0.9266, Test: 0.9152
Epoch: 0484, Loss: 0.8294, Train: 0.9103, Val: 0.9264, Test: 0.9150
Epoch: 0485, Loss: 0.8287, Train: 0.9100, Val: 0.9262, Test: 0.9147
Epoch: 0486, Loss: 0.8281, Train: 0.9096, Val: 0.9260, Test: 0.9145
Epoch: 0487, Loss: 0.8274, Train: 0.9092, Val: 0.9259, Test: 0.9143
Epoch: 0488, Loss: 0.8268, Train: 0.9089, Val: 0.9257, Test: 0.9141
Epoch: 0489, Loss: 0.8261, Train: 0.9085, Val: 0.9255, Test: 0.9139
Epoch: 0490, Loss: 0.8255, Train: 0.9082, Val: 0.9254, Test: 0.9137
Epoch: 0491, Loss: 0.8248, Train: 0.9078, Val: 0.9252, Test: 0.9134
Epoch: 0492, Loss: 0.8241, Train: 0.9074, Val: 0.9250, Test: 0.9132
Epoch: 0493, Loss: 0.8235, Train: 0.9071, Val: 0.9249, Test: 0.9130
Epoch: 0494, Loss: 0.8228, Train: 0.9067, Val: 0.9247, Test: 0.9128
Epoch: 0495, Loss: 0.8222, Train: 0.9064, Val: 0.9245, Test: 0.9126
Epoch: 0496, Loss: 0.8215, Train: 0.9060, Val: 0.9244, Test: 0.9124
Epoch: 0497, Loss: 0.8209, Train: 0.9057, Val: 0.9242, Test: 0.9122
Epoch: 0498, Loss: 0.8203, Train: 0.9053, Val: 0.9241, Test: 0.9120
Epoch: 0499, Loss: 0.8196, Train: 0.9049, Val: 0.9239, Test: 0.9118
Epoch: 0500, Loss: 0.8190, Train: 0.9046, Val: 0.9238, Test: 0.9116
Epoch: 0501, Loss: 0.8183, Train: 0.9042, Val: 0.9236, Test: 0.9113
Epoch: 0502, Loss: 0.8177, Train: 0.9039, Val: 0.9235, Test: 0.9111
Epoch: 0503, Loss: 0.8170, Train: 0.9035, Val: 0.9233, Test: 0.9109
Epoch: 0504, Loss: 0.8164, Train: 0.9032, Val: 0.9232, Test: 0.9107
Epoch: 0505, Loss: 0.8157, Train: 0.9028, Val: 0.9231, Test: 0.9105
Epoch: 0506, Loss: 0.8151, Train: 0.9024, Val: 0.9230, Test: 0.9103
Epoch: 0507, Loss: 0.8144, Train: 0.9021, Val: 0.9228, Test: 0.9101
Epoch: 0508, Loss: 0.8138, Train: 0.9017, Val: 0.9227, Test: 0.9099
Epoch: 0509, Loss: 0.8131, Train: 0.9014, Val: 0.9225, Test: 0.9096
Epoch: 0510, Loss: 0.8125, Train: 0.9010, Val: 0.9224, Test: 0.9095
Epoch: 0511, Loss: 0.8119, Train: 0.9007, Val: 0.9222, Test: 0.9092
Epoch: 0512, Loss: 0.8112, Train: 0.9003, Val: 0.9222, Test: 0.9091
Epoch: 0513, Loss: 0.8106, Train: 0.9000, Val: 0.9220, Test: 0.9089
Epoch: 0514, Loss: 0.8100, Train: 0.8996, Val: 0.9219, Test: 0.9088
Epoch: 0515, Loss: 0.8094, Train: 0.8993, Val: 0.9217, Test: 0.9085
Epoch: 0516, Loss: 0.8087, Train: 0.8989, Val: 0.9216, Test: 0.9084
Epoch: 0517, Loss: 0.8081, Train: 0.8986, Val: 0.9215, Test: 0.9083
Epoch: 0518, Loss: 0.8075, Train: 0.8983, Val: 0.9213, Test: 0.9081
Epoch: 0519, Loss: 0.8069, Train: 0.8979, Val: 0.9212, Test: 0.9080
Epoch: 0520, Loss: 0.8062, Train: 0.8975, Val: 0.9210, Test: 0.9078
Epoch: 0521, Loss: 0.8056, Train: 0.8972, Val: 0.9208, Test: 0.9076
Epoch: 0522, Loss: 0.8051, Train: 0.8969, Val: 0.9208, Test: 0.9075
Epoch: 0523, Loss: 0.8045, Train: 0.8965, Val: 0.9206, Test: 0.9072
Epoch: 0524, Loss: 0.8038, Train: 0.8962, Val: 0.9204, Test: 0.9071
Epoch: 0525, Loss: 0.8032, Train: 0.8959, Val: 0.9204, Test: 0.9070
Epoch: 0526, Loss: 0.8027, Train: 0.8956, Val: 0.9202, Test: 0.9068
Epoch: 0527, Loss: 0.8021, Train: 0.8952, Val: 0.9202, Test: 0.9067
Epoch: 0528, Loss: 0.8015, Train: 0.8949, Val: 0.9200, Test: 0.9066
Epoch: 0529, Loss: 0.8009, Train: 0.8946, Val: 0.9199, Test: 0.9064
Epoch: 0530, Loss: 0.8003, Train: 0.8943, Val: 0.9198, Test: 0.9063
Epoch: 0531, Loss: 0.7998, Train: 0.8939, Val: 0.9196, Test: 0.9062
Epoch: 0532, Loss: 0.7992, Train: 0.8936, Val: 0.9195, Test: 0.9060
Epoch: 0533, Loss: 0.7986, Train: 0.8933, Val: 0.9194, Test: 0.9059
Epoch: 0534, Loss: 0.7980, Train: 0.8930, Val: 0.9193, Test: 0.9058
Epoch: 0535, Loss: 0.7974, Train: 0.8926, Val: 0.9192, Test: 0.9057
Epoch: 0536, Loss: 0.7969, Train: 0.8923, Val: 0.9190, Test: 0.9056
Epoch: 0537, Loss: 0.7963, Train: 0.8920, Val: 0.9190, Test: 0.9055
Epoch: 0538, Loss: 0.7957, Train: 0.8917, Val: 0.9189, Test: 0.9054
Epoch: 0539, Loss: 0.7951, Train: 0.8913, Val: 0.9188, Test: 0.9052
Epoch: 0540, Loss: 0.7946, Train: 0.8910, Val: 0.9187, Test: 0.9051
Epoch: 0541, Loss: 0.7940, Train: 0.8907, Val: 0.9185, Test: 0.9050
Epoch: 0542, Loss: 0.7934, Train: 0.8904, Val: 0.9185, Test: 0.9049
Epoch: 0543, Loss: 0.7929, Train: 0.8901, Val: 0.9183, Test: 0.9048
Epoch: 0544, Loss: 0.7923, Train: 0.8898, Val: 0.9183, Test: 0.9048
Epoch: 0545, Loss: 0.7917, Train: 0.8894, Val: 0.9182, Test: 0.9046
Epoch: 0546, Loss: 0.7912, Train: 0.8891, Val: 0.9181, Test: 0.9046
Epoch: 0547, Loss: 0.7906, Train: 0.8888, Val: 0.9180, Test: 0.9045
Epoch: 0548, Loss: 0.7901, Train: 0.8885, Val: 0.9180, Test: 0.9045
Epoch: 0549, Loss: 0.7895, Train: 0.8882, Val: 0.9179, Test: 0.9044
Epoch: 0550, Loss: 0.7890, Train: 0.8879, Val: 0.9178, Test: 0.9044
Epoch: 0551, Loss: 0.7884, Train: 0.8876, Val: 0.9177, Test: 0.9042
Epoch: 0552, Loss: 0.7879, Train: 0.8873, Val: 0.9176, Test: 0.9042
Epoch: 0553, Loss: 0.7874, Train: 0.8870, Val: 0.9175, Test: 0.9040
Epoch: 0554, Loss: 0.7869, Train: 0.8868, Val: 0.9175, Test: 0.9041
Epoch: 0555, Loss: 0.7864, Train: 0.8864, Val: 0.9173, Test: 0.9039
Epoch: 0556, Loss: 0.7859, Train: 0.8861, Val: 0.9173, Test: 0.9039
Epoch: 0557, Loss: 0.7853, Train: 0.8858, Val: 0.9171, Test: 0.9037
Epoch: 0558, Loss: 0.7847, Train: 0.8855, Val: 0.9170, Test: 0.9036
Epoch: 0559, Loss: 0.7842, Train: 0.8852, Val: 0.9170, Test: 0.9036
Epoch: 0560, Loss: 0.7837, Train: 0.8850, Val: 0.9168, Test: 0.9035
Epoch: 0561, Loss: 0.7832, Train: 0.8847, Val: 0.9169, Test: 0.9035
Epoch: 0562, Loss: 0.7827, Train: 0.8844, Val: 0.9167, Test: 0.9033
Epoch: 0563, Loss: 0.7822, Train: 0.8841, Val: 0.9166, Test: 0.9033
Epoch: 0564, Loss: 0.7816, Train: 0.8838, Val: 0.9165, Test: 0.9032
Epoch: 0565, Loss: 0.7811, Train: 0.8835, Val: 0.9163, Test: 0.9031
Epoch: 0566, Loss: 0.7806, Train: 0.8832, Val: 0.9163, Test: 0.9031
Epoch: 0567, Loss: 0.7802, Train: 0.8829, Val: 0.9162, Test: 0.9030
Epoch: 0568, Loss: 0.7797, Train: 0.8827, Val: 0.9162, Test: 0.9030
Epoch: 0569, Loss: 0.7792, Train: 0.8824, Val: 0.9160, Test: 0.9028
Epoch: 0570, Loss: 0.7787, Train: 0.8821, Val: 0.9160, Test: 0.9029
Epoch: 0571, Loss: 0.7782, Train: 0.8818, Val: 0.9158, Test: 0.9027
Epoch: 0572, Loss: 0.7776, Train: 0.8815, Val: 0.9157, Test: 0.9027
Epoch: 0573, Loss: 0.7771, Train: 0.8812, Val: 0.9157, Test: 0.9026
Epoch: 0574, Loss: 0.7766, Train: 0.8809, Val: 0.9155, Test: 0.9026
Epoch: 0575, Loss: 0.7761, Train: 0.8807, Val: 0.9155, Test: 0.9026
Epoch: 0576, Loss: 0.7757, Train: 0.8804, Val: 0.9154, Test: 0.9025
Epoch: 0577, Loss: 0.7752, Train: 0.8802, Val: 0.9154, Test: 0.9026
Epoch: 0578, Loss: 0.7748, Train: 0.8799, Val: 0.9152, Test: 0.9025
Epoch: 0579, Loss: 0.7743, Train: 0.8796, Val: 0.9153, Test: 0.9025
Epoch: 0580, Loss: 0.7738, Train: 0.8793, Val: 0.9150, Test: 0.9024
Epoch: 0581, Loss: 0.7733, Train: 0.8790, Val: 0.9150, Test: 0.9024
Epoch: 0582, Loss: 0.7728, Train: 0.8787, Val: 0.9149, Test: 0.9023
Epoch: 0583, Loss: 0.7723, Train: 0.8785, Val: 0.9148, Test: 0.9022
Epoch: 0584, Loss: 0.7718, Train: 0.8782, Val: 0.9148, Test: 0.9022
Epoch: 0585, Loss: 0.7713, Train: 0.8780, Val: 0.9147, Test: 0.9021
Epoch: 0586, Loss: 0.7709, Train: 0.8777, Val: 0.9147, Test: 0.9022
Epoch: 0587, Loss: 0.7705, Train: 0.8775, Val: 0.9146, Test: 0.9021
Epoch: 0588, Loss: 0.7701, Train: 0.8772, Val: 0.9146, Test: 0.9022
Epoch: 0589, Loss: 0.7696, Train: 0.8769, Val: 0.9144, Test: 0.9020
Epoch: 0590, Loss: 0.7691, Train: 0.8766, Val: 0.9143, Test: 0.9020
Epoch: 0591, Loss: 0.7685, Train: 0.8763, Val: 0.9142, Test: 0.9019
Epoch: 0592, Loss: 0.7681, Train: 0.8761, Val: 0.9142, Test: 0.9019
Epoch: 0593, Loss: 0.7676, Train: 0.8759, Val: 0.9142, Test: 0.9020
Epoch: 0594, Loss: 0.7672, Train: 0.8756, Val: 0.9140, Test: 0.9019
Epoch: 0595, Loss: 0.7668, Train: 0.8754, Val: 0.9141, Test: 0.9020
Epoch: 0596, Loss: 0.7664, Train: 0.8751, Val: 0.9139, Test: 0.9019
Epoch: 0597, Loss: 0.7660, Train: 0.8749, Val: 0.9139, Test: 0.9019
Epoch: 0598, Loss: 0.7655, Train: 0.8746, Val: 0.9137, Test: 0.9017
Epoch: 0599, Loss: 0.7650, Train: 0.8743, Val: 0.9136, Test: 0.9017
Epoch: 0600, Loss: 0.7645, Train: 0.8740, Val: 0.9135, Test: 0.9017
Epoch: 0601, Loss: 0.7640, Train: 0.8738, Val: 0.9134, Test: 0.9017
Epoch: 0602, Loss: 0.7637, Train: 0.8736, Val: 0.9135, Test: 0.9018
Epoch: 0603, Loss: 0.7633, Train: 0.8734, Val: 0.9133, Test: 0.9016
Epoch: 0604, Loss: 0.7629, Train: 0.8731, Val: 0.9133, Test: 0.9018
Epoch: 0605, Loss: 0.7624, Train: 0.8728, Val: 0.9131, Test: 0.9016
Epoch: 0606, Loss: 0.7620, Train: 0.8726, Val: 0.9131, Test: 0.9016
Epoch: 0607, Loss: 0.7615, Train: 0.8723, Val: 0.9130, Test: 0.9016
Epoch: 0608, Loss: 0.7610, Train: 0.8721, Val: 0.9129, Test: 0.9015
Epoch: 0609, Loss: 0.7606, Train: 0.8719, Val: 0.9129, Test: 0.9016
Epoch: 0610, Loss: 0.7602, Train: 0.8717, Val: 0.9127, Test: 0.9015
Epoch: 0611, Loss: 0.7599, Train: 0.8715, Val: 0.9128, Test: 0.9016
Epoch: 0612, Loss: 0.7595, Train: 0.8712, Val: 0.9126, Test: 0.9014
Epoch: 0613, Loss: 0.7592, Train: 0.8710, Val: 0.9127, Test: 0.9015
Epoch: 0614, Loss: 0.7587, Train: 0.8707, Val: 0.9125, Test: 0.9013
Epoch: 0615, Loss: 0.7582, Train: 0.8704, Val: 0.9124, Test: 0.9014
Epoch: 0616, Loss: 0.7578, Train: 0.8702, Val: 0.9123, Test: 0.9013
Epoch: 0617, Loss: 0.7573, Train: 0.8700, Val: 0.9122, Test: 0.9013
Epoch: 0618, Loss: 0.7570, Train: 0.8698, Val: 0.9122, Test: 0.9014
Epoch: 0619, Loss: 0.7566, Train: 0.8696, Val: 0.9121, Test: 0.9013
Epoch: 0620, Loss: 0.7563, Train: 0.8694, Val: 0.9122, Test: 0.9014
Epoch: 0621, Loss: 0.7559, Train: 0.8692, Val: 0.9119, Test: 0.9012
Epoch: 0622, Loss: 0.7556, Train: 0.8689, Val: 0.9120, Test: 0.9013
Epoch: 0623, Loss: 0.7551, Train: 0.8686, Val: 0.9117, Test: 0.9011
Epoch: 0624, Loss: 0.7546, Train: 0.8684, Val: 0.9117, Test: 0.9012
Epoch: 0625, Loss: 0.7542, Train: 0.8682, Val: 0.9116, Test: 0.9012
Epoch: 0626, Loss: 0.7538, Train: 0.8680, Val: 0.9115, Test: 0.9011
Epoch: 0627, Loss: 0.7535, Train: 0.8678, Val: 0.9116, Test: 0.9013
Epoch: 0628, Loss: 0.7531, Train: 0.8676, Val: 0.9114, Test: 0.9011
Epoch: 0629, Loss: 0.7528, Train: 0.8674, Val: 0.9115, Test: 0.9013
Epoch: 0630, Loss: 0.7524, Train: 0.8671, Val: 0.9113, Test: 0.9011
Epoch: 0631, Loss: 0.7520, Train: 0.8669, Val: 0.9113, Test: 0.9012
Epoch: 0632, Loss: 0.7516, Train: 0.8666, Val: 0.9112, Test: 0.9011
Epoch: 0633, Loss: 0.7512, Train: 0.8664, Val: 0.9112, Test: 0.9011
Epoch: 0634, Loss: 0.7508, Train: 0.8662, Val: 0.9111, Test: 0.9011
Epoch: 0635, Loss: 0.7504, Train: 0.8660, Val: 0.9110, Test: 0.9011
Epoch: 0636, Loss: 0.7500, Train: 0.8658, Val: 0.9111, Test: 0.9012
Epoch: 0637, Loss: 0.7497, Train: 0.8656, Val: 0.9110, Test: 0.9011
Epoch: 0638, Loss: 0.7494, Train: 0.8655, Val: 0.9111, Test: 0.9013
Epoch: 0639, Loss: 0.7491, Train: 0.8653, Val: 0.9110, Test: 0.9011
Epoch: 0640, Loss: 0.7489, Train: 0.8652, Val: 0.9112, Test: 0.9015
Epoch: 0641, Loss: 0.7487, Train: 0.8650, Val: 0.9110, Test: 0.9012
Epoch: 0642, Loss: 0.7483, Train: 0.8646, Val: 0.9109, Test: 0.9013
Epoch: 0643, Loss: 0.7477, Train: 0.8643, Val: 0.9107, Test: 0.9011
Epoch: 0644, Loss: 0.7471, Train: 0.8641, Val: 0.9107, Test: 0.9011
Epoch: 0645, Loss: 0.7468, Train: 0.8640, Val: 0.9108, Test: 0.9013
Epoch: 0646, Loss: 0.7466, Train: 0.8639, Val: 0.9106, Test: 0.9011
Epoch: 0647, Loss: 0.7464, Train: 0.8636, Val: 0.9107, Test: 0.9013
Epoch: 0648, Loss: 0.7459, Train: 0.8633, Val: 0.9105, Test: 0.9010
Epoch: 0649, Loss: 0.7454, Train: 0.8631, Val: 0.9104, Test: 0.9010
Epoch: 0650, Loss: 0.7450, Train: 0.8629, Val: 0.9105, Test: 0.9012
Epoch: 0651, Loss: 0.7448, Train: 0.8628, Val: 0.9104, Test: 0.9011
Epoch: 0652, Loss: 0.7446, Train: 0.8626, Val: 0.9105, Test: 0.9013
Epoch: 0653, Loss: 0.7442, Train: 0.8624, Val: 0.9103, Test: 0.9010
Epoch: 0654, Loss: 0.7438, Train: 0.8621, Val: 0.9102, Test: 0.9011
Epoch: 0655, Loss: 0.7433, Train: 0.8619, Val: 0.9102, Test: 0.9011
Epoch: 0656, Loss: 0.7430, Train: 0.8618, Val: 0.9101, Test: 0.9010
Epoch: 0657, Loss: 0.7427, Train: 0.8616, Val: 0.9102, Test: 0.9012
Epoch: 0658, Loss: 0.7425, Train: 0.8614, Val: 0.9101, Test: 0.9010
Epoch: 0659, Loss: 0.7422, Train: 0.8612, Val: 0.9102, Test: 0.9011
Epoch: 0660, Loss: 0.7418, Train: 0.8610, Val: 0.9100, Test: 0.9010
Epoch: 0661, Loss: 0.7414, Train: 0.8608, Val: 0.9100, Test: 0.9010
Epoch: 0662, Loss: 0.7410, Train: 0.8606, Val: 0.9100, Test: 0.9010
Epoch: 0663, Loss: 0.7407, Train: 0.8604, Val: 0.9100, Test: 0.9010
Epoch: 0664, Loss: 0.7404, Train: 0.8603, Val: 0.9101, Test: 0.9011
Epoch: 0665, Loss: 0.7401, Train: 0.8601, Val: 0.9100, Test: 0.9010
Epoch: 0666, Loss: 0.7399, Train: 0.8599, Val: 0.9101, Test: 0.9012
Epoch: 0667, Loss: 0.7396, Train: 0.8597, Val: 0.9099, Test: 0.9010
Epoch: 0668, Loss: 0.7392, Train: 0.8595, Val: 0.9100, Test: 0.9011
Epoch: 0669, Loss: 0.7388, Train: 0.8593, Val: 0.9098, Test: 0.9010
Epoch: 0670, Loss: 0.7384, Train: 0.8591, Val: 0.9098, Test: 0.9010
Epoch: 0671, Loss: 0.7381, Train: 0.8589, Val: 0.9098, Test: 0.9011
Epoch: 0672, Loss: 0.7378, Train: 0.8587, Val: 0.9098, Test: 0.9010
Epoch: 0673, Loss: 0.7375, Train: 0.8586, Val: 0.9098, Test: 0.9012
Epoch: 0674, Loss: 0.7372, Train: 0.8584, Val: 0.9097, Test: 0.9011
Epoch: 0675, Loss: 0.7370, Train: 0.8583, Val: 0.9099, Test: 0.9013
Epoch: 0676, Loss: 0.7368, Train: 0.8582, Val: 0.9098, Test: 0.9011
Epoch: 0677, Loss: 0.7366, Train: 0.8580, Val: 0.9099, Test: 0.9014
Epoch: 0678, Loss: 0.7363, Train: 0.8578, Val: 0.9097, Test: 0.9011
Epoch: 0679, Loss: 0.7359, Train: 0.8575, Val: 0.9097, Test: 0.9012
Epoch: 0680, Loss: 0.7354, Train: 0.8572, Val: 0.9096, Test: 0.9010
Epoch: 0681, Loss: 0.7350, Train: 0.8571, Val: 0.9096, Test: 0.9010
Epoch: 0682, Loss: 0.7347, Train: 0.8570, Val: 0.9097, Test: 0.9012
Epoch: 0683, Loss: 0.7345, Train: 0.8569, Val: 0.9096, Test: 0.9011
Epoch: 0684, Loss: 0.7344, Train: 0.8567, Val: 0.9098, Test: 0.9013
Epoch: 0685, Loss: 0.7341, Train: 0.8565, Val: 0.9095, Test: 0.9011
Epoch: 0686, Loss: 0.7337, Train: 0.8563, Val: 0.9096, Test: 0.9012
Epoch: 0687, Loss: 0.7333, Train: 0.8560, Val: 0.9095, Test: 0.9011
Epoch: 0688, Loss: 0.7329, Train: 0.8559, Val: 0.9094, Test: 0.9011
Epoch: 0689, Loss: 0.7326, Train: 0.8558, Val: 0.9096, Test: 0.9012
Epoch: 0690, Loss: 0.7324, Train: 0.8556, Val: 0.9095, Test: 0.9011
Epoch: 0691, Loss: 0.7322, Train: 0.8555, Val: 0.9097, Test: 0.9014
Epoch: 0692, Loss: 0.7320, Train: 0.8553, Val: 0.9095, Test: 0.9012
Epoch: 0693, Loss: 0.7317, Train: 0.8552, Val: 0.9097, Test: 0.9013
Epoch: 0694, Loss: 0.7314, Train: 0.8549, Val: 0.9094, Test: 0.9011
Epoch: 0695, Loss: 0.7310, Train: 0.8547, Val: 0.9094, Test: 0.9011
Epoch: 0696, Loss: 0.7306, Train: 0.8545, Val: 0.9093, Test: 0.9010
Epoch: 0697, Loss: 0.7303, Train: 0.8544, Val: 0.9093, Test: 0.9010
Epoch: 0698, Loss: 0.7300, Train: 0.8542, Val: 0.9093, Test: 0.9011
Epoch: 0699, Loss: 0.7298, Train: 0.8541, Val: 0.9093, Test: 0.9011
Epoch: 0700, Loss: 0.7296, Train: 0.8540, Val: 0.9095, Test: 0.9013

 

6.2 inductive

验证/测试时直接用新图就行。没有官方示例,有了再补。


转载:https://blog.csdn.net/PolarisRisingWar/article/details/128130870
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场