├── 2020-1-1.txt ├── README.md ├── Task4-链路预测.py ├── Task7-cluster_gcn.py ├── Task8-GIN ├── gin_conv.py ├── gin_graph.py ├── gin_node.py ├── main.py └── pcqm4m_data.py └── Task9-graph_classification ├── gin_conv.py ├── gin_graph.py ├── gin_node.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # GNN_datawhale -------------------------------------------------------------------------------- /Task4-链路预测.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch_geometric.transforms as T 6 | from sklearn.metrics import roc_auc_score 7 | from torch_geometric.datasets import Planetoid 8 | from torch_geometric.nn import GCNConv 9 | from torch_geometric.utils import negative_sampling, train_test_split_edges 10 | 11 | dataset = 'Cora' 12 | dataset = Planetoid('G:/chenyu/GNN/dataset/Cora', name='Cora', transform=T.NormalizeFeatures())#包括数据集的下载,若root路径存在数据集则直接加载数据集 13 | data = dataset[0] #该数据集只有一个图len(dataset):1,在这里才调用transform函数 14 | data.train_mask = data.val_mask = data.test_mask = data.y = None 15 | data = train_test_split_edges(data) 16 | print(data) 17 | 18 | class Net(torch.nn.Module): 19 | def __init__(self, in_channels, out_channels): 20 | super(Net, self).__init__() 21 | self.conv1 = GCNConv(in_channels, 128) 22 | self.conv2 = GCNConv(128, out_channels) 23 | 24 | def encode(self, x, edge_index): 25 | x = self.conv1(x, edge_index) 26 | x = x.relu() 27 | return self.conv2(x, edge_index) 28 | 29 | def decode(self, z, pos_edge_index, neg_edge_index): 30 | edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # [2,E] 31 | return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) # *:element-wise乘法 32 | 33 | def decode_all(self, z): 34 | prob_adj = z @ z.t() # @:矩阵乘法,自动执行适合的矩阵乘法函数 35 | return (prob_adj > 0).nonzero(as_tuple=False).t() 36 | 37 | def forward(self, x, pos_edge_index, neg_edge_index): 38 | return decode(encode(x, pos_edge_index), pos_edge_index, neg_edge_index) 39 | 40 | def get_link_labels(pos_edge_index, neg_edge_index): 41 | num_links = pos_edge_index.size(1) + neg_edge_index.size(1) 42 | link_labels = torch.zeros(num_links, dtype=torch.float, device=device) 43 | link_labels[:pos_edge_index.size(1)] = 1. 44 | return link_labels 45 | 46 | 47 | def train(data, model, optimizer, criterion): 48 | model.train() 49 | 50 | neg_edge_index = negative_sampling( # 训练集负采样,每个epoch负采样样本可能不同 51 | edge_index=data.train_pos_edge_index, 52 | num_nodes=data.num_nodes, 53 | num_neg_samples=data.train_pos_edge_index.size(1)) 54 | 55 | optimizer.zero_grad() 56 | # link_logits = model(data.x, data.train_pos_edge_index, neg_edge_index) 57 | z = model.encode(data.x, data.train_pos_edge_index) 58 | link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) 59 | link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(data.x.device) # 训练集中正样本标签 60 | loss = criterion(link_logits, link_labels) 61 | loss.backward() 62 | optimizer.step() 63 | 64 | return loss 65 | 66 | @torch.no_grad() 67 | def mytest(data,model): 68 | model.eval() 69 | 70 | z = model.encode(data.x, data.train_pos_edge_index) 71 | 72 | results = [] 73 | for prefix in ['val', 'test']: 74 | pos_edge_index = data[f'{prefix}_pos_edge_index'] 75 | neg_edge_index = data[f'{prefix}_neg_edge_index'] 76 | link_logits = model.decode(z, pos_edge_index, neg_edge_index) 77 | link_probs = link_logits.sigmoid()#计算链路存在的概率 78 | link_labels = get_link_labels(pos_edge_index, neg_edge_index) 79 | results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) 80 | return results 81 | 82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 83 | model = Net(dataset.num_features, 64).to(device) 84 | data = data.to(device) 85 | optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) 86 | criterion = F.binary_cross_entropy_with_logits 87 | best_val_auc = test_auc = 0 88 | for epoch in range(1,101): 89 | loss=train(data,model,optimizer,criterion) 90 | val_auc,tmp_test_auc=mytest(data,model) 91 | if val_auc>best_val_auc: 92 | best_val_auc=val_auc 93 | test_auc=tmp_test_auc 94 | print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, Test: {test_auc:.4f}') 95 | #预测 96 | z=model.encode(data.x,data.train_pos_edge_index) 97 | final_edge_index=model.decode_all(z) -------------------------------------------------------------------------------- /Task7-cluster_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import ModuleList 4 | from tqdm import tqdm 5 | from torch_geometric.datasets import Reddit, Reddit2,PPI 6 | from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler 7 | from torch_geometric.nn import SAGEConv 8 | from typing import Optional, Tuple 9 | from torch_geometric.typing import Adj, OptTensor, PairTensor 10 | 11 | from torch import Tensor 12 | from torch.nn import Parameter 13 | from torch_scatter import scatter_add 14 | from torch_sparse import SparseTensor, matmul, fill_diag, sum, mul 15 | from torch_geometric.nn.conv import MessagePassing 16 | from torch_geometric.utils import add_remaining_self_loops 17 | from torch_geometric.utils.num_nodes import maybe_num_nodes 18 | from torch_geometric.nn.inits import glorot, zeros 19 | from torch_geometric.utils import add_self_loops, degree 20 | 21 | class GCNConv(MessagePassing): 22 | _cached_edge_index: Optional[Tuple[Tensor, Tensor]] 23 | _cached_adj_t: Optional[SparseTensor] 24 | 25 | def __init__(self, in_channels: int, out_channels: int, 26 | improved: bool = False, cached: bool = False, 27 | add_self_loops: bool = True, normalize: bool = True, 28 | bias: bool = True,lamb: float=1, **kwargs): 29 | 30 | kwargs.setdefault('aggr', 'add') 31 | super(GCNConv, self).__init__(**kwargs) 32 | 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.improved = improved 36 | self.cached = cached 37 | self.add_self_loops = add_self_loops 38 | self.normalize = normalize 39 | 40 | self._cached_edge_index = None 41 | self._cached_adj_t = None 42 | 43 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 44 | 45 | self.lamb=lamb 46 | if bias: 47 | self.bias = Parameter(torch.Tensor(out_channels)) 48 | else: 49 | self.register_parameter('bias', None) 50 | 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | glorot(self.weight) 55 | zeros(self.bias) 56 | self._cached_edge_index = None 57 | self._cached_adj_t = None 58 | 59 | def forward(self, x: Tensor, edge_index: Adj, 60 | edge_weight: OptTensor = None) -> Tensor: 61 | adjmat,_ = add_self_loops(edge_index=edge_index,edge_weight=edge_weight) 62 | row, col = adjmat 63 | adjmat = SparseTensor(row=adjmat[0], col=adjmat[1], value=torch.ones(adjmat.shape[1])) 64 | deg_norm = degree(col, adjmat.size(0), dtype=torch.long).pow(-1) 65 | A_hat=adjmat.to_dense() 66 | A_hat=deg_norm*(A_hat)+self.lamb*torch.diag(A_hat)*torch.eye(A_hat.shape[0]) 67 | A_hat=SparseTensor.from_dense(A_hat) 68 | 69 | x = x @ self.weight 70 | 71 | # propagate_type: (x: Tensor, edge_weight: OptTensor) 72 | out = self.propagate(A_hat, x=x) 73 | if self.bias is not None: 74 | out += self.bias 75 | return out 76 | 77 | def message_and_aggregate(self, A_hat: SparseTensor, x: Tensor) -> Tensor: 78 | return matmul(A_hat, x, reduce=self.aggr) 79 | 80 | def __repr__(self): 81 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 82 | self.out_channels) 83 | 84 | class Net(torch.nn.Module): 85 | def __init__(self, in_channels, out_channels): 86 | super(Net, self).__init__() 87 | self.convs = ModuleList( 88 | [GCNConv(in_channels, 128), 89 | GCNConv(128, out_channels)]) 90 | 91 | def forward(self, x, edge_index): 92 | for i, conv in enumerate(self.convs): 93 | x = conv(x, edge_index) 94 | if i != len(self.convs) - 1: 95 | x = F.relu(x) 96 | x = F.dropout(x, p=0.5, training=self.training) 97 | return F.log_softmax(x, dim=-1) 98 | 99 | def inference(self, x_all): 100 | pbar = tqdm(total=x_all.size(0) * len(self.convs)) 101 | pbar.set_description('Evaluating') 102 | 103 | # Compute representations of nodes layer by layer, using *all* 104 | # available edges. This leads to faster computation in contrast to 105 | # immediately computing the final representations of each batch. 106 | for i, conv in enumerate(self.convs): 107 | xs = [] 108 | for batch_size, n_id, adj in subgraph_loader: 109 | edge_index, _, size = adj.to(device) 110 | x = x_all[n_id].to(device) 111 | x_target = x[:size[1]] 112 | x = conv((x, x_target), edge_index) 113 | if i != len(self.convs) - 1: 114 | x = F.relu(x) 115 | xs.append(x.cpu()) 116 | 117 | pbar.update(batch_size) 118 | 119 | x_all = torch.cat(xs, dim=0) 120 | 121 | pbar.close() 122 | 123 | return x_all 124 | 125 | 126 | def train(): 127 | model.train() 128 | 129 | total_loss = total_nodes = 0 130 | for batch in train_loader: 131 | batch = batch.to(device) 132 | optimizer.zero_grad() 133 | out = model(batch.x, batch.edge_index) 134 | loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask]) 135 | loss.backward() 136 | optimizer.step() 137 | 138 | nodes = batch.train_mask.sum().item() 139 | total_loss += loss.item() * nodes 140 | total_nodes += nodes 141 | 142 | return total_loss / total_nodes 143 | 144 | 145 | @torch.no_grad() 146 | def test(): # Inference should be performed on the full graph. 147 | model.eval() 148 | 149 | out = model.inference(data.x) 150 | y_pred = out.argmax(dim=-1) 151 | 152 | accs = [] 153 | for mask in [data.train_mask, data.val_mask, data.test_mask]: 154 | correct = y_pred[mask].eq(data.y[mask]).sum().item() 155 | accs.append(correct / mask.sum().item()) 156 | return accs 157 | 158 | if __name__ == '__main__': 159 | # 需要下载 https://data.dgl.ai/dataset/reddit.zip 到 data/Reddit 文件夹下 160 | # dataset = Reddit('data/Reddit') 161 | # dataset = PPI('data/PPI') 162 | dataset = Reddit2('data/Reddit2') 163 | data = dataset[0] 164 | 165 | # 图聚类 166 | cluster_data = ClusterData(data, num_parts=500, recursive=False, 167 | save_dir=dataset.processed_dir) 168 | train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, 169 | num_workers=12) 170 | # 不聚类 171 | subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, 172 | shuffle=False, num_workers=12) 173 | 174 | device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu') 175 | model = Net(dataset.num_features, dataset.num_classes).to(device) 176 | optimizer = torch.optim.Adam(model.parameters(), lr=0.005) 177 | for epoch in range(1, 31): 178 | loss = train() 179 | if epoch % 5 == 0: 180 | train_acc, val_acc, test_acc = test() 181 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' 182 | f'Val: {val_acc:.4f}, test: {test_acc:.4f}') 183 | else: 184 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') 185 | -------------------------------------------------------------------------------- /Task8-GIN/gin_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import MessagePassing 4 | import torch.nn.functional as F 5 | from ogb.graphproppred.mol_encoder import BondEncoder 6 | 7 | 8 | ### GIN convolution along the graph structure 9 | class GINConv(MessagePassing): 10 | def __init__(self, emb_dim): 11 | ''' 12 | emb_dim (int): node embedding dimensionality 13 | ''' 14 | 15 | super(GINConv, self).__init__(aggr = "add") 16 | 17 | self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim)) 18 | self.eps = nn.Parameter(torch.Tensor([0])) 19 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 20 | 21 | def forward(self, x, edge_index, edge_attr): 22 | edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边嵌入 23 | out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) 24 | return out 25 | 26 | def message(self, x_j, edge_attr): 27 | return F.relu(x_j + edge_attr) 28 | 29 | def update(self, aggr_out): 30 | return aggr_out 31 | -------------------------------------------------------------------------------- /Task8-GIN/gin_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | from gin_node import GINNodeEmbedding 5 | 6 | 7 | class GINGraphPooling(nn.Module): 8 | 9 | def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"): 10 | """GIN Graph Pooling Module 11 | 12 | 此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)。 13 | 14 | Args: 15 | num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表示的维度,dimension of graph representation). 16 | num_layers (int, optional): number of GINConv layers. Defaults to 5. 17 | emb_dim (int, optional): dimension of node embedding. Defaults to 300. 18 | residual (bool, optional): adding residual connection or not. Defaults to False. 19 | drop_ratio (float, optional): dropout rate. Defaults to 0. 20 | JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last". 21 | graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum". 22 | 23 | Out: 24 | graph representation 25 | """ 26 | super(GINGraphPooling, self).__init__() 27 | 28 | self.num_layers = num_layers 29 | self.drop_ratio = drop_ratio 30 | self.JK = JK 31 | self.emb_dim = emb_dim 32 | self.num_tasks = num_tasks 33 | 34 | if self.num_layers < 2: 35 | raise ValueError("Number of GNN layers must be greater than 1.") 36 | 37 | self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual) 38 | 39 | # Pooling function to generate whole-graph embeddings 40 | if graph_pooling == "sum": 41 | self.pool = global_add_pool 42 | elif graph_pooling == "mean": 43 | self.pool = global_mean_pool 44 | elif graph_pooling == "max": 45 | self.pool = global_max_pool 46 | elif graph_pooling == "attention": 47 | self.pool = GlobalAttention(gate_nn=nn.Sequential( 48 | nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1))) 49 | elif graph_pooling == "set2set": 50 | self.pool = Set2Set(emb_dim, processing_steps=2) 51 | else: 52 | raise ValueError("Invalid graph pooling type.") 53 | 54 | if graph_pooling == "set2set": 55 | self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks) 56 | else: 57 | self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) 58 | 59 | def forward(self, batched_data): 60 | h_node = self.gnn_node(batched_data) 61 | 62 | h_graph = self.pool(h_node, batched_data.batch) 63 | output = self.graph_pred_linear(h_graph) 64 | 65 | if self.training: 66 | return output 67 | else: 68 | # At inference time, relu is applied to output to ensure positivity 69 | return torch.clamp(output, min=0, max=50) 70 | -------------------------------------------------------------------------------- /Task8-GIN/gin_node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.graphproppred.mol_encoder import AtomEncoder 3 | from gin_conv import GINConv 4 | import torch.nn.functional as F 5 | 6 | 7 | # GNN to generate node embedding 8 | class GINNodeEmbedding(torch.nn.Module): 9 | """ 10 | Output: 11 | node representations 12 | """ 13 | 14 | def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False): 15 | """GIN Node Embedding Module 16 | 采用多层GINConv实现图上结点的嵌入。 17 | """ 18 | 19 | super(GINNodeEmbedding, self).__init__() 20 | self.num_layers = num_layers 21 | self.drop_ratio = drop_ratio 22 | self.JK = JK 23 | # add residual connection or not 24 | self.residual = residual 25 | 26 | if self.num_layers < 2: 27 | raise ValueError("Number of GNN layers must be greater than 1.") 28 | 29 | self.atom_encoder = AtomEncoder(emb_dim) 30 | 31 | # List of GNNs 32 | self.convs = torch.nn.ModuleList() 33 | self.batch_norms = torch.nn.ModuleList() 34 | 35 | for layer in range(num_layers): 36 | self.convs.append(GINConv(emb_dim)) 37 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 38 | 39 | def forward(self, batched_data): 40 | x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr 41 | 42 | # computing input node embedding 43 | h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子嵌入 44 | for layer in range(self.num_layers): 45 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 46 | h = self.batch_norms[layer](h) 47 | if layer == self.num_layers - 1: 48 | # remove relu for the last layer 49 | h = F.dropout(h, self.drop_ratio, training=self.training) 50 | else: 51 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 52 | 53 | if self.residual: 54 | h += h_list[layer] 55 | 56 | h_list.append(h) 57 | 58 | # Different implementations of Jk-concat 59 | if self.JK == "last": 60 | node_representation = h_list[-1] 61 | elif self.JK == "sum": 62 | node_representation = 0 63 | for layer in range(self.num_layers + 1): 64 | node_representation += h_list[layer] 65 | 66 | return node_representation 67 | 68 | -------------------------------------------------------------------------------- /Task8-GIN/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from tqdm import tqdm 5 | from ogb.lsc import PCQM4MEvaluator 6 | from torch_geometric.data import DataLoader 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import StepLR 9 | 10 | from pcqm4m_data import MyPCQM4MDataset 11 | from gin_graph import GINGraphPooling 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | def train(model, device, loader, optimizer, criterion_fn): 17 | model.train() 18 | loss_accum = 0 19 | 20 | for step, batch in enumerate(tqdm(loader)): 21 | batch = batch.to(device) 22 | pred = model(batch).view(-1,) 23 | optimizer.zero_grad() 24 | loss = criterion_fn(pred, batch.y) 25 | loss.backward() 26 | optimizer.step() 27 | loss_accum += loss.detach().cpu().item() 28 | 29 | return loss_accum / (step + 1) 30 | 31 | 32 | def eval(model, device, loader, evaluator): 33 | model.eval() 34 | y_true = [] 35 | y_pred = [] 36 | 37 | with torch.no_grad(): 38 | for _, batch in enumerate(tqdm(loader)): 39 | batch = batch.to(device) 40 | pred = model(batch).view(-1,) 41 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 42 | y_pred.append(pred.detach().cpu()) 43 | 44 | y_true = torch.cat(y_true, dim=0) 45 | y_pred = torch.cat(y_pred, dim=0) 46 | input_dict = {"y_true": y_true, "y_pred": y_pred} 47 | return evaluator.eval(input_dict)["mae"] 48 | 49 | 50 | def test(model, device, loader): 51 | model.eval() 52 | y_pred = [] 53 | 54 | with torch.no_grad(): 55 | for _, batch in enumerate(loader): 56 | batch = batch.to(device) 57 | pred = model(batch).view(-1,) 58 | y_pred.append(pred.detach().cpu()) 59 | 60 | y_pred = torch.cat(y_pred, dim=0) 61 | return y_pred 62 | 63 | device=0 64 | num_layers=5 65 | graph_pooling='sum' 66 | emb_dim256, 67 | drop_ratio=0. 68 | save_test='store_true' 69 | batch_size=512 70 | epochs=100 71 | weight_decay=0.00001 72 | early_stop=10 73 | num_workers=4 74 | dataset_root="dataset" 75 | 76 | # automatic dataloading and splitting 77 | dataset = MyPCQM4MDataset(root=dataset_root) 78 | split_idx = dataset.get_idx_split() 79 | train_data = dataset[split_idx['train']] 80 | valid_data = dataset[split_idx['valid']] 81 | test_data = dataset[split_idx['test']] 82 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers) 83 | valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers) 84 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers) 85 | 86 | # automatic evaluator. takes dataset name as input 87 | evaluator = PCQM4MEvaluator() 88 | criterion_fn = torch.nn.MSELoss() 89 | 90 | model = GINGraphPooling(**nn_params).to(device) 91 | num_params = sum(p.numel() for p in model.parameters()) 92 | 93 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=weight_decay) 94 | scheduler = StepLR(optimizer, step_size=30, gamma=0.25) 95 | 96 | not_improved = 0 97 | best_valid_mae = 9999 98 | for epoch in range(1, epochs + 1): 99 | train_mae = train(model, device, train_loader, optimizer, criterion_fn) 100 | 101 | valid_mae = eval(model, device, valid_loader, evaluator) 102 | 103 | print({'Train': train_mae, 'Validation': valid_mae}) 104 | 105 | if valid_mae < best_valid_mae: 106 | best_valid_mae = valid_mae 107 | 108 | scheduler.step() 109 | print(f'Best validation MAE so far: {best_valid_mae}') 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /Task8-GIN/pcqm4m_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import pandas as pd 5 | import torch 6 | from ogb.utils import smiles2graph 7 | from ogb.utils.torch_util import replace_numpy_with_torchtensor 8 | from ogb.utils.url import download_url, extract_zip 9 | from rdkit import RDLogger 10 | from torch_geometric.data import Data, Dataset 11 | import shutil 12 | 13 | RDLogger.DisableLog('rdApp.*') 14 | 15 | 16 | class MyPCQM4MDataset(Dataset): 17 | 18 | def __init__(self, root): 19 | self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip' 20 | super(MyPCQM4MDataset, self).__init__(root) 21 | 22 | filepath = osp.join(root, 'raw/data.csv.gz') 23 | data_df = pd.read_csv(filepath) 24 | self.smiles_list = data_df['smiles'] 25 | self.homolumogap_list = data_df['homolumogap'] 26 | 27 | @property 28 | def raw_file_names(self): 29 | return 'data.csv.gz' 30 | 31 | def download(self): 32 | path = download_url(self.url, self.root) 33 | extract_zip(path, self.root) 34 | os.unlink(path) 35 | shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz')) 36 | 37 | def len(self): 38 | return len(self.smiles_list) 39 | 40 | def get(self, idx): 41 | smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx] 42 | graph = smiles2graph(smiles) 43 | assert(len(graph['edge_feat']) == graph['edge_index'].shape[1]) 44 | assert(len(graph['node_feat']) == graph['num_nodes']) 45 | 46 | x = torch.from_numpy(graph['node_feat']).to(torch.int64) 47 | edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64) 48 | edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64) 49 | y = torch.Tensor([homolumogap]) 50 | num_nodes = int(graph['num_nodes']) 51 | data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes) 52 | return data 53 | 54 | def get_idx_split(self): 55 | split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt'))) 56 | return split_dict 57 | 58 | 59 | if __name__ == "__main__": 60 | dataset = MyPCQM4MDataset('dataset') 61 | from torch_geometric.data import DataLoader 62 | from tqdm import tqdm 63 | dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4) 64 | for batch in tqdm(dataloader): 65 | pass 66 | -------------------------------------------------------------------------------- /Task9-graph_classification/gin_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import MessagePassing 4 | import torch.nn.functional as F 5 | from ogb.graphproppred.mol_encoder import BondEncoder 6 | 7 | 8 | ### GIN convolution along the graph structure 9 | class GINConv(MessagePassing): 10 | def __init__(self, emb_dim): 11 | ''' 12 | emb_dim (int): node embedding dimensionality 13 | ''' 14 | 15 | super(GINConv, self).__init__(aggr = "add") 16 | 17 | self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim)) 18 | self.eps = nn.Parameter(torch.Tensor([0])) 19 | self.bond_encoder = BondEncoder(emb_dim = emb_dim) 20 | 21 | def forward(self, x, edge_index, edge_attr): 22 | edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边嵌入 23 | out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) 24 | return out 25 | 26 | def message(self, x_j, edge_attr): 27 | return F.relu(x_j + edge_attr) 28 | 29 | def update(self, aggr_out): 30 | return aggr_out 31 | -------------------------------------------------------------------------------- /Task9-graph_classification/gin_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 4 | from gin_node import GINNodeEmbedding 5 | 6 | 7 | class GINGraphPooling(nn.Module): 8 | 9 | def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"): 10 | """GIN Graph Pooling Module 11 | 12 | 此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)。 13 | 14 | Args: 15 | num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表示的维度,dimension of graph representation). 16 | num_layers (int, optional): number of GINConv layers. Defaults to 5. 17 | emb_dim (int, optional): dimension of node embedding. Defaults to 300. 18 | residual (bool, optional): adding residual connection or not. Defaults to False. 19 | drop_ratio (float, optional): dropout rate. Defaults to 0. 20 | JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last". 21 | graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum". 22 | 23 | Out: 24 | graph representation 25 | """ 26 | super(GINGraphPooling, self).__init__() 27 | 28 | self.num_layers = num_layers 29 | self.drop_ratio = drop_ratio 30 | self.JK = JK 31 | self.emb_dim = emb_dim 32 | self.num_tasks = num_tasks 33 | 34 | if self.num_layers < 2: 35 | raise ValueError("Number of GNN layers must be greater than 1.") 36 | 37 | self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual) 38 | 39 | # Pooling function to generate whole-graph embeddings 40 | if graph_pooling == "sum": 41 | self.pool = global_add_pool 42 | elif graph_pooling == "mean": 43 | self.pool = global_mean_pool 44 | elif graph_pooling == "max": 45 | self.pool = global_max_pool 46 | elif graph_pooling == "attention": 47 | self.pool = GlobalAttention(gate_nn=nn.Sequential( 48 | nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1))) 49 | elif graph_pooling == "set2set": 50 | self.pool = Set2Set(emb_dim, processing_steps=2) 51 | else: 52 | raise ValueError("Invalid graph pooling type.") 53 | 54 | if graph_pooling == "set2set": 55 | self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks) 56 | else: 57 | self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks) 58 | 59 | def forward(self, batched_data): 60 | h_node = self.gnn_node(batched_data) 61 | 62 | h_graph = self.pool(h_node, batched_data.batch) 63 | output = self.graph_pred_linear(h_graph) 64 | 65 | if self.training: 66 | return output 67 | else: 68 | # At inference time, relu is applied to output to ensure positivity 69 | return torch.clamp(output, min=0, max=50) 70 | -------------------------------------------------------------------------------- /Task9-graph_classification/gin_node.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ogb.graphproppred.mol_encoder import AtomEncoder 3 | from gin_conv import GINConv 4 | import torch.nn.functional as F 5 | 6 | 7 | # GNN to generate node embedding 8 | class GINNodeEmbedding(torch.nn.Module): 9 | """ 10 | Output: 11 | node representations 12 | """ 13 | 14 | def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False): 15 | """GIN Node Embedding Module 16 | 采用多层GINConv实现图上结点的嵌入。 17 | """ 18 | 19 | super(GINNodeEmbedding, self).__init__() 20 | self.num_layers = num_layers 21 | self.drop_ratio = drop_ratio 22 | self.JK = JK 23 | # add residual connection or not 24 | self.residual = residual 25 | 26 | if self.num_layers < 2: 27 | raise ValueError("Number of GNN layers must be greater than 1.") 28 | 29 | self.atom_encoder = AtomEncoder(emb_dim) 30 | 31 | # List of GNNs 32 | self.convs = torch.nn.ModuleList() 33 | self.batch_norms = torch.nn.ModuleList() 34 | 35 | for layer in range(num_layers): 36 | self.convs.append(GINConv(emb_dim)) 37 | self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) 38 | 39 | def forward(self, batched_data): 40 | x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr 41 | 42 | # computing input node embedding 43 | h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子嵌入 44 | for layer in range(self.num_layers): 45 | h = self.convs[layer](h_list[layer], edge_index, edge_attr) 46 | h = self.batch_norms[layer](h) 47 | if layer == self.num_layers - 1: 48 | # remove relu for the last layer 49 | h = F.dropout(h, self.drop_ratio, training=self.training) 50 | else: 51 | h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) 52 | 53 | if self.residual: 54 | h += h_list[layer] 55 | 56 | h_list.append(h) 57 | 58 | # Different implementations of Jk-concat 59 | if self.JK == "last": 60 | node_representation = h_list[-1] 61 | elif self.JK == "sum": 62 | node_representation = 0 63 | for layer in range(self.num_layers + 1): 64 | node_representation += h_list[layer] 65 | 66 | return node_representation 67 | 68 | -------------------------------------------------------------------------------- /Task9-graph_classification/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import argparse 5 | from tqdm import tqdm 6 | from torch_geometric.data import DataLoader 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import StepLR 9 | 10 | from gin_graph import GINGraphPooling 11 | 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | def parse_args(): 15 | 16 | parser = argparse.ArgumentParser(description='Graph data miming with GNN') 17 | parser.add_argument('--task_name', type=str, default='GINGraphPooling', 18 | help='task name') 19 | parser.add_argument('--device', type=int, default=0, 20 | help='which gpu to use if any (default: 0)') 21 | parser.add_argument('--num_layers', type=int, default=5, 22 | help='number of GNN message passing layers (default: 5)') 23 | parser.add_argument('--graph_pooling', type=str, default='sum', 24 | help='graph pooling strategy mean or sum (default: sum)') 25 | parser.add_argument('--emb_dim', type=int, default=256, 26 | help='dimensionality of hidden units in GNNs (default: 256)') 27 | parser.add_argument('--drop_ratio', type=float, default=0., 28 | help='dropout ratio (default: 0.)') 29 | parser.add_argument('--save_test', action='store_true') 30 | parser.add_argument('--batch_size', type=int, default=512, 31 | help='input batch size for training (default: 512)') 32 | parser.add_argument('--epochs', type=int, default=100, 33 | help='number of epochs to train (default: 100)') 34 | parser.add_argument('--weight_decay', type=float, default=0.00001, 35 | help='weight decay') 36 | parser.add_argument('--early_stop', type=int, default=10, 37 | help='early stop (default: 10)') 38 | parser.add_argument('--num_workers', type=int, default=4, 39 | help='number of workers (default: 4)') 40 | parser.add_argument('--dataset_root', type=str, default="dataset", 41 | help='dataset root') 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | def prepartion(args): 48 | save_dir = os.path.join('saves', args.task_name) 49 | if os.path.exists(save_dir): 50 | for idx in range(1000): 51 | if not os.path.exists(save_dir + '=' + str(idx)): 52 | save_dir = save_dir + '=' + str(idx) 53 | break 54 | 55 | args.save_dir = save_dir 56 | os.makedirs(args.save_dir, exist_ok=True) 57 | args.device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 58 | args.output_file = open(os.path.join(args.save_dir, 'output'), 'a') 59 | print(args, file=args.output_file, flush=True) 60 | 61 | 62 | def train(model, device, loader, optimizer,scheduler, criterion_fn): 63 | model.train() 64 | loss_accum = 0 65 | 66 | for step, batch in enumerate(tqdm(loader)): 67 | batch = batch.to(device) 68 | pred = model(batch).view(-1,) 69 | optimizer.zero_grad() 70 | loss = criterion_fn(pred, batch.y.to(torch.float)) 71 | loss.backward() 72 | optimizer.step() 73 | scheduler.step() 74 | loss_accum += loss.detach().cpu().item() 75 | 76 | return loss_accum / (step + 1) 77 | 78 | 79 | def eval(model, device, loader): 80 | model.eval() 81 | y_true = [] 82 | y_pred = [] 83 | 84 | with torch.no_grad(): 85 | for _, batch in enumerate(tqdm(loader)): 86 | batch = batch.to(device) 87 | pred = model(batch).view(-1,) 88 | y_true.append(batch.y.view(pred.shape).detach().cpu()) 89 | y_pred.append(pred.detach().cpu()) 90 | 91 | y_true = torch.cat(y_true, dim=0) 92 | y_pred = torch.cat(y_pred, dim=0) 93 | input_dict = {"y_true": y_true, "y_pred": y_pred} 94 | ''' 95 | y_true: numpy.ndarray or torch.Tensor of shape (num_graphs,) 96 | y_pred: numpy.ndarray or torch.Tensor of shape (num_graphs,) 97 | y_true and y_pred need to be of the same type (either numpy.ndarray or torch.Tensor) 98 | ''' 99 | assert('y_pred' in input_dict) 100 | assert('y_true' in input_dict) 101 | 102 | y_pred, y_true = input_dict['y_pred'], input_dict['y_true'] 103 | 104 | assert((isinstance(y_true, np.ndarray) and isinstance(y_pred, np.ndarray)) 105 | or 106 | (isinstance(y_true, torch.Tensor) and isinstance(y_pred, torch.Tensor))) 107 | assert(y_true.shape == y_pred.shape) 108 | assert(len(y_true.shape) == 1) 109 | 110 | if isinstance(y_true, torch.Tensor): 111 | return {'mae': torch.mean(torch.abs(y_pred - y_true)).cpu().item()} 112 | else: 113 | return {'mae': float(np.mean(np.absolute(y_pred - y_true)))} 114 | 115 | def predict(model, device, loader): 116 | model.eval() 117 | y_pred = [] 118 | 119 | with torch.no_grad(): 120 | for _, batch in enumerate(loader): 121 | batch = batch.to(device) 122 | pred = model(batch).view(-1,) 123 | y_pred.append(pred.detach().cpu()) 124 | 125 | y_pred = torch.cat(y_pred, dim=0) 126 | return y_pred 127 | 128 | def save_test_submission(self, input_dict, dir_path): 129 | ''' 130 | save test submission file at dir_path 131 | ''' 132 | assert ('y_pred' in input_dict) 133 | y_pred = input_dict['y_pred'] 134 | 135 | if not osp.exists(dir_path): 136 | os.makedirs(dir_path) 137 | 138 | filename = osp.join(dir_path, 'y_pred') 139 | assert (isinstance(filename, str)) 140 | assert (isinstance(y_pred, np.ndarray) or isinstance(y_pred, torch.Tensor)) 141 | 142 | if isinstance(y_pred, torch.Tensor): 143 | y_pred = y_pred.numpy() 144 | y_pred = y_pred.astype(np.float32) 145 | np.savez_compressed(filename, y_pred=y_pred) 146 | 147 | def main(args): 148 | prepartion(args) 149 | nn_params = { 150 | 'num_layers': args.num_layers, 151 | 'emb_dim': args.emb_dim, 152 | 'drop_ratio': args.drop_ratio, 153 | 'graph_pooling': args.graph_pooling 154 | } 155 | 156 | # automatic dataloading and splitting 157 | from ogb.graphproppred.dataset_pyg import PygGraphPropPredDataset 158 | dataset = PygGraphPropPredDataset(name = 'ogbg-molhiv') 159 | 160 | 161 | split_idx = dataset.get_idx_split() 162 | train_data = dataset[split_idx['train']] 163 | valid_data = dataset[split_idx['valid']] 164 | test_data = dataset[split_idx['test']] 165 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 166 | valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 167 | test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 168 | 169 | # automatic evaluator. takes dataset name as input 170 | criterion_fn = torch.nn.MSELoss() 171 | 172 | device = args.device 173 | 174 | model = GINGraphPooling(**nn_params).to(device) 175 | 176 | num_params = sum(p.numel() for p in model.parameters()) 177 | print(f'#Params: {num_params}', file=args.output_file, flush=True) 178 | print(model, file=args.output_file, flush=True) 179 | 180 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay) 181 | scheduler = StepLR(optimizer, step_size=30, gamma=0.25) 182 | 183 | writer = SummaryWriter(log_dir=args.save_dir) 184 | not_improved = 0 185 | best_valid_mae = 9999 186 | for epoch in range(1, args.epochs + 1): 187 | print("=====Epoch {}".format(epoch), file=args.output_file, flush=True) 188 | print('Training...', file=args.output_file, flush=True) 189 | train_mae = train(model, device, train_loader, optimizer,scheduler, criterion_fn) 190 | 191 | print('Evaluating...', file=args.output_file, flush=True) 192 | valid_mae = eval(model, device, valid_loader)['mae'] 193 | 194 | print({'Train': train_mae, 'Validation': valid_mae}, file=args.output_file, flush=True) 195 | 196 | writer.add_scalar('valid/mae', valid_mae, epoch) 197 | writer.add_scalar('train/mae', train_mae, epoch) 198 | 199 | if valid_mae < best_valid_mae: 200 | best_valid_mae = valid_mae 201 | if args.save_test: 202 | print('Saving checkpoint...', file=args.output_file, flush=True) 203 | checkpoint = { 204 | 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 205 | 'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae, 'num_params': num_params 206 | } 207 | torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint.pt')) 208 | print('Predicting on test data...', file=args.output_file, flush=True) 209 | y_pred = predict(model, device, test_loader) 210 | print('Saving test submission file...', file=args.output_file, flush=True) 211 | save_test_submission({'y_pred': y_pred}, args.save_dir) 212 | 213 | not_improved = 0 214 | else: 215 | not_improved += 1 216 | if not_improved == args.early_stop: 217 | print(f"Have not improved for {not_improved} epoches.", file=args.output_file, flush=True) 218 | break 219 | 220 | scheduler.step() 221 | print(f'Best validation MAE so far: {best_valid_mae}', file=args.output_file, flush=True) 222 | 223 | writer.close() 224 | args.output_file.close() 225 | 226 | 227 | if __name__ == "__main__": 228 | args = parse_args() 229 | main(args) 230 | --------------------------------------------------------------------------------