├── README.md ├── datasets └── datasets_path ├── finetune.py ├── finetune_utils.py ├── graphadapter.py ├── preprocess.py ├── pretrain.py ├── pretrain_utils.py └── prompt_config.py /README.md: -------------------------------------------------------------------------------- 1 | # Can GNN be Good Adapter for LLMs? 2 | This repository is an implementation of GraphAdapter - [Can GNN be Good Adapter for LLMs?](https://arxiv.org/abs/2402.12984) in WWW 2024. 3 | 4 | ## Requirements 5 | * python = 3.8 6 | * numpy >= 1.19.5 7 | * pytorch = 1 .10.2 8 | * pyg = 2.3.1 9 | * transformers >= 4.28.1 10 | 11 | For the largest dataset Arxiv, 300G storage is required 12 | ## How to use our code 13 | The datasets this paper used can be downloaded from [here](https://drive.google.com/drive/folders/13fqwSfY5utv8HibtEoLIAGk7k85W7b2d?usp=sharing), please download them and put them in datasets to unzip. 14 | 15 | 16 | ### Step 1. Preprocess data for training 17 | ``` 18 | python3 preprocess.py --dataset_name instagram --gpu 0 --plm_path llama2_path --type pretrain 19 | ``` 20 | The preprocess.py will load the textual data of Instagram, and next transform them to token embedding by Llama 2, which will be saved into saving_path. The saved embeddings will used in the training of GraphAdapter. 21 | 22 | 23 | ### Step 2. Training GraphAdapter 24 | ``` 25 | python3 pretrain.py --dataset_name instagram --hiddensize_gnn 64 --hiddensize_fusion 64 --learning_ratio 5e-4 --batch_size 32 --max_epoch 15 --save_path your_model_save_path 26 | ``` 27 | 28 | ### Step 3. Finetuning for downstream task 29 | 30 | GraphAdapter requires prompt embedding for finetuning, 31 | 32 | ``` 33 | python3 preprocess.py --dataset_name instagram --gpu 0 --plm_path llama2_path --type prompt 34 | 35 | ``` 36 | After preprocessing the dataset, now you can finetune to downstream tasks. 37 | ``` 38 | python3 finetune.py --dataset_name instagram --gpu 0 --metric roc --save_path your_model_save_path 39 | ``` 40 | Note: keep your_model_save_path consistent in both pretrain.py and finetune.py. 41 | 42 | ## Citation 43 | If you find our work or dataset useful, please consider citing our work: 44 | ``` 45 | @article{huang2024can, 46 | title={Can GNN be Good Adapter for LLMs?}, 47 | author={Huang, Xuanwen and Han, Kaiqiao and Yang, Yang and Bao, Dezheng and Tao, Quanjin and Chai, Ziwei and Zhu, Qi}, 48 | journal={WWW}, 49 | year={2024} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /datasets/datasets_path: -------------------------------------------------------------------------------- 1 | path of datasets 2 | 3 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from finetune_utils import finetune,load_data_with_prompt_embedding,set_random_seed 2 | import numpy as np 3 | import argparse 4 | import torch 5 | def run_exp(args): 6 | acc_ls = [] 7 | for split in range(0,5): 8 | data=load_data_with_prompt_embedding(args.dataset_name,10,10,split) 9 | print("class_num:", data.y.max()+1) 10 | for i in range(1): 11 | acc = finetune(data,args) 12 | acc_ls.append(acc) 13 | print(np.mean(acc_ls),np.std(acc_ls)) 14 | return acc_ls 15 | 16 | if __name__ == "__main__": 17 | set_random_seed(0) 18 | parser = argparse.ArgumentParser('finetuning GraphAdapter') 19 | parser.add_argument('--dataset_name', type=str, help='dataset to be used', default='instagram', 20 | choices=['arxiv', 'instagram', 'reddit']) 21 | parser.add_argument('--step', type=int, default=20, help='epoch of saved graphadapter') 22 | parser.add_argument('--load_from_pretrain', type=int, default=1, help='whether using pretrained model',choices=[0,1]) 23 | parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate') 24 | parser.add_argument('--metric', type=str, help='metric used for evaluation', default='roc', 25 | choices=['roc', 'acc']) 26 | parser.add_argument('--save_path', type=str, default='./save_models/reddit/128_128_SAGE_2_32_0.0005_0.001_50_10/', help='path of saving embedding') 27 | parser.add_argument('--gpu', type=int, default=0, help='number of gpu to use') 28 | 29 | args = parser.parse_args() 30 | args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu' 31 | 32 | acc_ls = run_exp(args) -------------------------------------------------------------------------------- /finetune_utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_auc_score,accuracy_score,average_precision_score 2 | from transformers import activations 3 | import torch_geometric as tg 4 | from torch_geometric.nn.conv import MessagePassing 5 | from torch_geometric.nn import GATConv,GATv2Conv,SuperGATConv,ResGatedGraphConv,GCN2Conv,GatedGraphConv,SAGEConv 6 | from torch_geometric.data import Data 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | #from utils import load_data 11 | import pandas as pd 12 | import argparse 13 | import joblib 14 | import random 15 | from graphadapter import LinearHead 16 | def set_random_seed(seed: int = 0): 17 | """ 18 | set random seed 19 | :param seed: int, random seed 20 | :return: 21 | """ 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed_all(seed) 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = False 29 | 30 | def get_mask(seed): 31 | #seed=4 32 | np.random.seed(seed) 33 | randint = np.random.randint(0,100,(n,)) 34 | train_mask = torch.tensor((randint=i)&(randint=j)&(randint<100))).bool() 37 | return train_mask,val_mask,test_mask 38 | 39 | 40 | def get_mask(n,i,j,seed): 41 | np.random.seed(seed) 42 | randint = np.random.randint(0,100,(n,)) 43 | train_mask = torch.tensor((randint=i)&(randint=j)&(randint<100))).bool() 46 | return train_mask,val_mask,test_mask 47 | 48 | def normal(x): 49 | x = (x-x.mean(dim=0).view(1,-1))/x.std(dim=0).view(1,-1) 50 | return x 51 | def load_data_with_prompt_embedding(dataname,train_ratio,val_ratio,split): ## 52 | if (train_ratio>=100) or (val_ratio+train_ratio>=100): 53 | raise "train or validation ratio out of 100" 54 | x = np.load(f'./token_embedding/{dataname}/sentence_embeddings.npy') 55 | edge_index = np.load('./datasets/'+dataname+'/edge_index.npy') 56 | y = np.load('./datasets/'+dataname+'/y.npy') 57 | x = torch.tensor(x).float() 58 | y = torch.tensor(y).long() 59 | edge_index = torch.tensor(edge_index).T 60 | edge_index = tg.utils.to_undirected(edge_index) 61 | edge_index = tg.utils.add_self_loops(edge_index)[0] 62 | edge_index = tg.utils.sort_edge_index(edge_index) 63 | data = Data() 64 | data.x = x.float() 65 | data.y = y 66 | if(dataname != 'arxiv'): 67 | train_mask,val_mask,test_mask = get_mask(x.shape[0],train_ratio,train_ratio+val_ratio,split) 68 | else: 69 | train_mask = np.load('./datasets/'+dataname+'/train.npy') 70 | val_mask = np.load('./datasets/'+dataname+'/vaild.npy') 71 | test_mask = np.load('./datasets/'+dataname+'/test.npy') 72 | data.edge_index = edge_index 73 | data.train_mask = train_mask 74 | data.val_mask = val_mask 75 | data.test_mask = test_mask 76 | return data 77 | def evaluate(out,label,mask,metric = 'acc'): 78 | if metric == 'roc': 79 | py = out[:,1][mask].cpu().numpy() 80 | #val = (out[data.val_mask]==data.y[data.val_mask]).sum() 81 | # print(data.y) 82 | gy = label[mask].cpu().long().numpy() 83 | val = roc_auc_score(gy,py) 84 | return val 85 | elif metric == 'acc': 86 | py = out.max(dim=1)[1][mask].cpu().numpy() 87 | #val = (out[data.val_mask]==data.y[data.val_mask]).sum() 88 | # print(data.y) 89 | gy = label[mask].cpu().long().numpy() 90 | val = accuracy_score(gy,py) 91 | return val 92 | elif metric == 'ap': 93 | py = out[:,1][mask].cpu().numpy() 94 | #val = (out[data.val_mask]==data.y[data.val_mask]).sum() 95 | # print(data.y) 96 | gy = label[mask].cpu().long().numpy() 97 | val = average_precision_score(gy,py) 98 | return val 99 | 100 | def finetune(data,args): 101 | model=None 102 | device = args.device 103 | pretrain_args = joblib.load(f'{args.save_path}model_args.pkl') 104 | model = LinearHead(data.x.shape[1],int(data.y.max())+1,pretrain_args) 105 | 106 | if(args.load_from_pretrain==True): 107 | print("load model from save path") 108 | model.ga.load_state_dict(torch.load(f'{args.save_path}save_model_{args.step}.pkl',map_location='cpu')) 109 | 110 | prompt_x = np.load('./prompt_embedding/'+args.dataset_name+'/prompt_embedding.npy') 111 | #prompt_x = np.load('./token_embedding/'+args.dataset_name+'/sentence_embeddings.npy') 112 | prompt_x = torch.tensor(prompt_x).float().to(device) 113 | 114 | optimizer = torch.optim.AdamW([ 115 | {"params":model.lin.parameters(),"lr":args.learning_rate,'weight_decay':1e-3}, 116 | {"params":model.ga.parameters(),"lr":args.learning_rate,'weight_decay':1e-3},], 117 | ) 118 | 119 | data = data.to(device) 120 | model = model.to(device) 121 | 122 | loss=None 123 | val_acc = 0 124 | test = 0 125 | class_weight = torch.tensor([1,1.0]).to(device) 126 | for i in range(350): 127 | model.train() 128 | model.ga.train() 129 | optimizer.zero_grad() 130 | out,gate = model(data.x,data.edge_index,prompt_x) 131 | loss = F.nll_loss(out[data.train_mask],data.y[data.train_mask],weight=class_weight) 132 | loss.backward() 133 | optimizer.step() 134 | with torch.no_grad(): 135 | model.eval() 136 | model.ga.eval() 137 | out,eval_gate= model(data.x,data.edge_index,prompt_x) 138 | val = evaluate(out,data.y,data.val_mask,args.metric) 139 | if(val>=val_acc): 140 | test = evaluate(out,data.y,data.test_mask,args.metric) 141 | tr = evaluate(out,data.y,data.train_mask,args.metric) 142 | print(f'best {args.metric} in epoch {i}: train:{tr:.4f},valid:{val:.4f},test:{test:.4f}') 143 | val_acc=val 144 | duration=0 145 | print('final_loss',loss.item()) 146 | model.eval() 147 | return test -------------------------------------------------------------------------------- /graphadapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch_sparse 5 | 6 | import numpy as np 7 | 8 | 9 | 10 | from typing import Union, Tuple, Optional 11 | import torch_geometric as tg 12 | from transformers import activations 13 | from torch_geometric.nn.conv import MessagePassing 14 | from torch_geometric.nn import GATConv,GATv2Conv,SuperGATConv,ResGatedGraphConv,GCN2Conv,GatedGraphConv 15 | from torch_geometric.nn import GCNConv 16 | from torch_geometric.nn import SAGEConv 17 | from torch_geometric.nn import GINConv 18 | from torch_geometric.nn import APPNP 19 | from torch.nn import Linear 20 | from torch_geometric.nn.inits import glorot, zeros 21 | from torch_geometric.typing import ( 22 | Adj, 23 | OptTensor, 24 | PairTensor, 25 | SparseTensor, 26 | torch_sparse, 27 | Tensor 28 | ) 29 | from torch_geometric.utils import ( 30 | add_self_loops, 31 | is_torch_sparse_tensor, 32 | remove_self_loops, 33 | softmax, 34 | ) 35 | 36 | class LlamaRMSNorm(nn.Module): 37 | def __init__(self, hidden_size, eps=1e-6): 38 | """ 39 | LlamaRMSNorm is equivalent to T5LayerNorm 40 | """ 41 | super().__init__() 42 | self.weight = nn.Parameter(torch.ones(hidden_size)) 43 | self.variance_epsilon = eps 44 | 45 | def forward(self, hidden_states): 46 | input_dtype = hidden_states.dtype 47 | hidden_states = hidden_states.to(torch.float32) 48 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 49 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 50 | return self.weight * hidden_states.to(input_dtype) 51 | 52 | 53 | class MLPBlock(MessagePassing): 54 | def __init__(self,in_shape, hiddensize=64, num_layers=2, batch_norm=True, **kwargs): 55 | super().__init__(aggr='mean',**kwargs) 56 | self.num_layers=num_layers 57 | self.lin1 = torch.nn.Linear(in_shape,hiddensize*2,bias=False) 58 | self.lin2 = torch.nn.Linear(hiddensize*2,hiddensize,bias=False) 59 | self.bn_first1 = LlamaRMSNorm(hiddensize) 60 | self.ACT2FN = activations.ACT2FN['silu'] 61 | self.GNN1 = torch.nn.Linear(hiddensize,hiddensize,bias=False) 62 | self.is_gnn = is_gnn 63 | def forward(self, x,edge_index): 64 | x = (self.lin1(x)) 65 | x = self.ACT2FN(x) 66 | x = self.bn_first1(self.lin2(x)) 67 | x = self.ACT2FN(self.GNN1(x)) 68 | return x 69 | def message(self, x_j,index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: 70 | return x_j 71 | 72 | 73 | class SAGEBlock(MessagePassing): 74 | def __init__(self,in_shape, hiddensize=64, num_layers=2, batch_norm=True, is_gnn = True, **kwargs): 75 | super().__init__(aggr='mean',**kwargs) 76 | self.num_layers=num_layers 77 | self.lin1 = torch.nn.Linear(in_shape,hiddensize*2,bias=False) 78 | self.lin2 = torch.nn.Linear(hiddensize*2,hiddensize,bias=False) 79 | self.bn_first1 = LlamaRMSNorm(hiddensize) 80 | self.ACT2FN = activations.ACT2FN['silu'] 81 | self.GNN1 = torch.nn.Linear(hiddensize,hiddensize,bias=False) 82 | self.is_gnn = is_gnn 83 | def forward(self, x,edge_index): 84 | x = (self.lin1(x)) 85 | x = self.ACT2FN(x) 86 | x = self.bn_first1(self.lin2(x)) 87 | if(self.is_gnn==True): 88 | x = self.propagate(edge_index, x=x,size=None) 89 | x = self.ACT2FN(self.GNN1(x)) 90 | if(self.is_gnn==True): 91 | x = self.propagate(edge_index, x=x,size=None) 92 | return x 93 | def message(self, x_j,index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: 94 | return x_j 95 | 96 | class GATBlock(MessagePassing): 97 | def __init__(self,in_shape, hiddensize=64, num_layers=2, **kwargs): 98 | super().__init__(aggr='sum',**kwargs) 99 | self.num_layers=num_layers 100 | self.lin1 = torch.nn.Linear(in_shape,hiddensize*2,bias=False) 101 | self.lin2 = torch.nn.Linear(hiddensize*2,hiddensize,bias=False) 102 | self.bn_first1 = LlamaRMSNorm(hiddensize) 103 | self.bn_first2 = LlamaRMSNorm(hiddensize) 104 | self.bn_first3 = LlamaRMSNorm(hiddensize) 105 | 106 | 107 | 108 | self.ACT2FN = activations.ACT2FN['silu'] 109 | self.GNN1 = torch.nn.Linear(hiddensize,hiddensize,bias=False) 110 | self.GNN2 = torch.nn.Linear(hiddensize,hiddensize,bias=False) 111 | 112 | self.heads = 8 113 | self.out_channels = hiddensize//self.heads 114 | self.is_gnn = is_gnn 115 | #self.att = torch.nn.Parameter(torch.randn(1, self.heads, 2 * self.out_channels)) 116 | #self.att1 = torch.nn.Parameter(torch.randn(1, self.heads, 2 * self.out_channels)) 117 | self.att_l = torch.nn.Linear(hiddensize,hiddensize) 118 | self.att_r = torch.nn.Linear(hiddensize,hiddensize) 119 | 120 | self.att_l1 = torch.nn.Linear(hiddensize,hiddensize) 121 | self.att_r1 = torch.nn.Linear(hiddensize,hiddensize) 122 | self.sqrt=1/np.sqrt(self.out_channels) 123 | def forward(self, x,edge_index): 124 | x = (self.lin1(x)) 125 | x = self.ACT2FN(x) 126 | x = self.bn_first1(self.lin2(x)) 127 | if(self.is_gnn==True): 128 | x = self.propagate(edge_index, x=x,size=None,layer=0) 129 | x = self.bn_first2(self.ACT2FN(self.GNN1(x))) 130 | if(self.is_gnn==True): 131 | x = self.propagate(edge_index, x=x,size=None,layer=1) 132 | x = self.bn_first3(self.GNN2(x)) 133 | 134 | return x 135 | def message(self,x_i, x_j,index: Tensor, ptr: OptTensor, size_i: Optional[int],layer) -> Tensor: 136 | 137 | if(layer==0): 138 | x_i = self.att_l(x_i) 139 | x_j = self.att_r(x_j) 140 | else: 141 | x_i = self.att_l1(x_i) 142 | x_j = self.att_r1(x_j) 143 | x_j = x_j.view(-1, self.heads, self.out_channels) 144 | x_i = x_i.view(-1, self.heads, self.out_channels) 145 | alpha = (x_i*x_j).sum(dim=-1)*self.sqrt 146 | alpha = softmax(alpha, index, ptr, size_i) 147 | return (x_j * alpha.view(-1, self.heads, 1)).view(-1, self.heads * self.out_channels) 148 | 149 | class FusionBlock(MessagePassing): 150 | def __init__(self, gnn_size, llm_size, hidden_size,is_pretraining, **kwargs): 151 | super().__init__(aggr='mean',**kwargs) 152 | self.hidden_size = hidden_size 153 | self.llm_size = llm_size 154 | self.gnn_size = gnn_size 155 | self.prompt_lin = torch.nn.Linear(llm_size,hidden_size,bias=False) 156 | self.g_lin = torch.nn.Linear(hidden_size,hidden_size,bias=False) 157 | self.fuse1 = torch.nn.Linear(hidden_size*2,hidden_size*10,bias=False) 158 | self.fuse2 = torch.nn.Linear(hidden_size*10,hidden_size,bias=False) 159 | self.extend = torch.nn.Linear(hidden_size,llm_size,bias=False) 160 | self.ACT2FN = activations.ACT2FN['silu'] 161 | self.is_pretraining = is_pretraining 162 | def forward(self, x,node_ids, prompt): 163 | node_ids = node_ids.view(-1) 164 | token = self.prompt_lin(prompt) 165 | 166 | out = x[node_ids] 167 | out = self.g_lin(out) 168 | out = torch.cat((out,token),dim=1) 169 | out = self.ACT2FN(self.fuse1(out)) 170 | out = self.fuse2(out) 171 | if(self.is_pretraining): 172 | out = self.extend(out) 173 | return out 174 | 175 | def message(self, x_j, k: OptTensor,v, 176 | index: Tensor, ptr: OptTensor,q, 177 | size_i: Optional[int]) -> Tensor: 178 | v = v 179 | return v 180 | 181 | 182 | class GraphAdapter(torch.nn.Module): 183 | def __init__(self,llm_shape, hiddensize_gnn=64, hiddensize_fusion = 64, num_layers=2, GNN_type='SAGE', is_pretraining=True): 184 | super(GraphAdapter,self).__init__() 185 | if(GNN_type == 'SAGE'): 186 | self.graph_encode = SAGEBlock(llm_shape, hiddensize = hiddensize_gnn, num_layers=num_layers) 187 | elif(GNN_type == 'GAT'): 188 | self.graph_encode = GATBlock(llm_shape, hiddensize = hiddensize_gnn, num_layers=num_layers) 189 | elif(GNN_type == 'MLP'): 190 | self.graph_encode = MLPBlock(llm_shape, hiddensize = hiddensize_gnn, num_layers=num_layers) 191 | else: 192 | raise "GNN_type should be SAGE, GAT, MLP" 193 | self.fuse_model = FusionBlock(hiddensize_gnn, llm_shape, hiddensize_fusion,is_pretraining) 194 | def forward(self, x,edge_index,node_ids=None,prompt=None): 195 | gx = self.graph_encode(x,edge_index) 196 | out = self.fuse_model(gx,node_ids,prompt) 197 | return out 198 | 199 | 200 | ##used for downstream task 201 | class LinearHead(torch.nn.Module): 202 | def __init__(self, x_shape, y_shape, pretrain_args): 203 | super(LinearHead,self).__init__() 204 | self.ga = GraphAdapter(llm_shape = x_shape, hiddensize_gnn = pretrain_args.hiddensize_gnn, hiddensize_fusion = pretrain_args.hiddensize_fusion, GNN_type=pretrain_args.GNN_type, num_layers=pretrain_args.num_layers,is_pretraining=False) 205 | 206 | 207 | self.lin = torch.nn.Linear(pretrain_args.hiddensize_fusion,y_shape) 208 | self.lin.weight = torch.nn.Parameter((self.lin.weight-self.lin.weight.mean()/self.lin.weight.std())*0.1210,requires_grad=True) ## since 209 | def forward(self, x,edge_index,prompt_embedding): 210 | x = self.ga(x,edge_index,torch.arange(len(x)),prompt_embedding) 211 | x = self.lin(x) 212 | return F.log_softmax(x,dim=1),x -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import tables 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import tqdm 7 | from torch import nn 8 | import torch.nn.utils.rnn as rnn_utils 9 | from torch.utils.data import DataLoader 10 | import torch.utils.data as data_ 11 | import torch.nn.functional as F 12 | from torch.nn.utils.rnn import pad_sequence 13 | import random 14 | import transformers 15 | from transformers import AutoTokenizer, LlamaForCausalLM,AutoModelForCausalLM,LlamaTokenizer 16 | import os 17 | import shutil 18 | import joblib 19 | from prompt_config import get_template_by_dataset 20 | 21 | class RawTextData(data_.Dataset): 22 | def __init__(self, text,node_id): 23 | self.text = text 24 | self.node_id = node_id 25 | def __len__(self): 26 | return len(self.text) 27 | def __getitem__(self, idx): 28 | return (self.text[idx],self.node_id[idx]) 29 | 30 | def pretrain_collate_fn(data_tuple): 31 | 32 | seq = [torch.tensor(sq[0]) for sq in data_tuple] 33 | node_id = [sq[1] for sq in data_tuple] 34 | seq = pad_sequence(seq, batch_first=True, padding_value=tokenizer.pad_token_id) 35 | node_id = torch.tensor(node_id).view(-1,1) 36 | node_id = node_id.repeat(1,seq.shape[1]) 37 | return seq, node_id 38 | 39 | def build_pretrain_data_by_tables(model,tokenizer,x_text,save_path,template_l_id,device,args): 40 | 41 | template_l_id = tokenizer.encode(template_l)[0:] 42 | template_l_id = torch.tensor(template_l_id).view(1,-1) 43 | 44 | token_embedding_path = save_path+'token_embeddings.h5' 45 | f = tables.open_file(token_embedding_path, mode='w') 46 | atom = tables.Float16Atom() 47 | array_c = f.create_earray(f.root, 'data', atom, (0, 5120)) 48 | f.close() 49 | 50 | sentence_embedding_path = save_path+'sentence_embeddings.h5' 51 | f = tables.open_file(sentence_embedding_path, mode='w') 52 | atom = tables.Float16Atom() 53 | array_c = f.create_earray(f.root, 'data', atom, (0, 5120)) 54 | f.close() 55 | 56 | token_node_ids_path = save_path+'token_node_ids.h5' 57 | f = tables.open_file(token_node_ids_path, mode='w') 58 | atom = tables.IntAtom() 59 | array_c = f.create_earray(f.root, 'data', atom, (0, 1)) 60 | f.close() 61 | 62 | token_label_path = save_path+'token_labels.h5' 63 | f = tables.open_file(token_label_path, mode='w') 64 | atom = tables.IntAtom() 65 | array_c = f.create_earray(f.root, 'data', atom, (0, 1)) 66 | f.close() 67 | 68 | model.to(device) 69 | feature_ls=[] 70 | test_max = 0 71 | for text in list(x_text): 72 | feature_ls.append(text) 73 | print('total node: ', len(feature_ls)) 74 | 75 | feature_ls_ids = [] 76 | for f in tqdm.tqdm(feature_ls): 77 | feature_ls_ids.append(tokenizer(f,padding=True,truncation=True)['input_ids']) 78 | nodedata_ = RawTextData(feature_ls_ids,list(range(len(feature_ls)))) 79 | node_data_loader = DataLoader(nodedata_, batch_size=args.batch_size, shuffle=False,collate_fn=pretrain_collate_fn) 80 | token_node_ids_ls = [] 81 | labels_ls = [] 82 | embeddings_ls = [] 83 | word_num_ls = [] 84 | cls_embeddings_ls = [] 85 | 86 | for i in range(1): 87 | train_position = [] 88 | for (text,node_id) in tqdm.tqdm(node_data_loader): 89 | with torch.no_grad(): 90 | mlm_text_id, labels = text, text[..., 1:].contiguous() 91 | 92 | #print(labels) 93 | mlm_text_id = mlm_text_id[:,1:] 94 | labels = labels[:,1:] 95 | node_id = node_id[:,1:] 96 | 97 | prompt_l = template_l_id.repeat(mlm_text_id.shape[0],1)#.to(device) 98 | prompt_labels = torch.zeros_like(prompt_l) 99 | node_id = torch.cat((prompt_labels-1,node_id),dim=1) 100 | mlm_text_id = torch.cat((prompt_l,mlm_text_id),dim=1) 101 | labels = torch.cat((prompt_labels,labels),dim=1) 102 | 103 | attention_mask = (mlm_text_id != tokenizer.pad_token_id).long()#.half() 104 | 105 | mlm_text_id = mlm_text_id.to(device) 106 | attention_mask = attention_mask.to(device) 107 | embeddings = model.model(input_ids=mlm_text_id, attention_mask=attention_mask)[0] 108 | embedding_dim = embeddings.shape[-1] 109 | prompt_last_position = attention_mask.sum(dim=1)-1 110 | cls_embedding = embeddings.gather(1,prompt_last_position.view(-1,1,1).repeat(1,1,embedding_dim)).view(-1,embedding_dim) 111 | #cls_embeddings_ls.append(cls_embedding.to('cpu')) 112 | 113 | batch_cls_embedding = cls_embedding.to('cpu').numpy() 114 | 115 | embeddings = embeddings[:, :-1, :].contiguous() 116 | node_id = node_id[...,:-1] 117 | num = (labels!=-0).sum(dim=1) 118 | token_node_ids = [] 119 | 120 | 121 | node_ids = node_id[labels!=0].view(-1,1).to('cpu').numpy() 122 | 123 | token_node_ids_ls.append(node_id[labels!=0]) 124 | embeddings = embeddings[labels!=0,:].to('cpu').numpy() 125 | labels = labels[labels!=0].view(-1,1).to('cpu').numpy() 126 | 127 | 128 | f = tables.open_file(token_embedding_path, mode='a') 129 | f.root.data.append(embeddings) 130 | f.close() 131 | 132 | f = tables.open_file(sentence_embedding_path, mode='a') 133 | f.root.data.append(batch_cls_embedding) 134 | f.close() 135 | 136 | 137 | f = tables.open_file(token_node_ids_path, mode='a') 138 | f.root.data.append(node_ids) 139 | f.close() 140 | 141 | 142 | f = tables.open_file(token_label_path, mode='a') 143 | f.root.data.append(labels) 144 | f.close() 145 | return token_embedding_path,sentence_embedding_path,token_node_ids_path,token_label_path 146 | 147 | def convert_tables_to_npy(save_path): 148 | token_embedding_path = save_path+'token_embeddings.h5' 149 | token_node_ids_path = save_path+'token_node_ids.h5' 150 | token_label_path = save_path+'token_labels.h5' 151 | sentence_embedding_path = save_path+'sentence_embeddings.h5' 152 | 153 | token_node_ids = tables.open_file(token_node_ids_path, mode='r+').root.data.read() 154 | np.save(save_path+'token_node_ids.npy',token_node_ids[:,0]) 155 | 156 | token_labels = tables.open_file(token_label_path, mode='r+').root.data.read() 157 | np.save(save_path+'token_labels.npy',token_labels[:,0]) 158 | 159 | token_embeddings = tables.open_file(token_embedding_path, mode='r+').root.data.read() 160 | np.save(save_path+'token_embeddings.npy',token_embeddings) 161 | 162 | sentence_embeddings = tables.open_file(sentence_embedding_path, mode='r+').root.data.read() 163 | np.save(save_path+'sentence_embeddings.npy',sentence_embeddings) 164 | return True 165 | 166 | 167 | def get_prompt_embedding(model,tokenizer,x,template_l,template_r,device,args=None): 168 | feature_ls=[] 169 | for text in list(x): 170 | feature_ls.append(text) 171 | feature_ls_ids = [] 172 | for f in feature_ls: 173 | feature_ls_ids.append(tokenizer(template_l+f+template_r,padding=True,truncation=True)['input_ids']) 174 | nodedata_ = RawTextData(feature_ls_ids,list(range(len(feature_ls)))) 175 | node_data_loader = DataLoader(nodedata_, batch_size=args.batch_size, shuffle=False,collate_fn=pretrain_collate_fn) 176 | prompt_embeddings_ls = [] 177 | embedding_dim=model.config.hidden_size 178 | for i in range(1): 179 | train_position = [] 180 | for (text,node_id) in tqdm.tqdm(node_data_loader): 181 | with torch.no_grad(): 182 | text_id, labels = text[:,:], text[:, :] 183 | 184 | attention_mask = (text_id != tokenizer.pad_token_id).long() 185 | text_id = text_id.to(device) 186 | attention_mask = attention_mask.to(device) 187 | output = model.model(input_ids=text_id, attention_mask=attention_mask)[0] 188 | 189 | embeddings = output[..., :-1, :].contiguous() 190 | labels = labels[..., 1:].long() 191 | 192 | prompt_last_position = attention_mask.sum(dim=1)-1 193 | 194 | prompt_embedding = output.gather(1,prompt_last_position.view(-1,1,1).repeat(1,1,embedding_dim)).view(-1,embedding_dim) 195 | prompt_embeddings_ls.append(prompt_embedding.to('cpu')) 196 | prompt_embedding = torch.cat(prompt_embeddings_ls,dim=0) 197 | prompt_embedding = prompt_embedding.numpy() 198 | return prompt_embedding 199 | 200 | def save_lm_head(model): 201 | lm_head_path = "./pretrain_models/head/" 202 | if os.path.exists(lm_head_path): 203 | shutil.rmtree(lm_head_path, True) 204 | os.makedirs(lm_head_path) 205 | joblib.dump(model.lm_head.to('cpu'),open(f'{lm_head_path}lm_head.pkl','wb')) 206 | 207 | if __name__ == "__main__": 208 | parser = argparse.ArgumentParser('preprocess text-attributed graph by LLMs to gain the token embedding') 209 | parser.add_argument('--dataset_name', type=str, help='dataset to be used', default='instagram', 210 | choices=['arxiv', 'instagram', 'reddit']) 211 | parser.add_argument('--batch_size', type=int, default=2, help='batch size of llama 2') 212 | parser.add_argument('--plm_path', type=str, default='/data/pretrain_models/llama-2-13b-hf', help='path of llama 2') 213 | parser.add_argument('--gpu', type=int, default=0, help='number of gpu to use') 214 | parser.add_argument('--pretrain_save_path', type=str, default='./token_embedding/', help='path of saving pretrain data') 215 | parser.add_argument('--prompt_save_path', type=str, default='./prompt_embedding/', help='path of saving prompt embedding') 216 | parser.add_argument('--type',type=str,default='all',help='preprocess type',choices = ['pretrain','prompt','all','convert']) 217 | args = parser.parse_args() 218 | 219 | args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu' 220 | device = args.device 221 | save_path = args.pretrain_save_path+args.dataset_name+'/' 222 | 223 | ##load Llama 2 224 | if(args.type != 'convert'): 225 | model = AutoModelForCausalLM.from_pretrained(args.plm_path,low_cpu_mem_usage=True,torch_dtype=torch.float16).to(device) 226 | tokenizer = AutoTokenizer.from_pretrained(args.plm_path,use_fast=False) 227 | tokenizer.pad_token='[PAD]' # for batch preprocess 228 | 229 | save_lm_head(model) 230 | 231 | x_text = np.load(f'./datasets/{args.dataset_name}/x_text.npy') 232 | 233 | if(args.type == 'pretrain') or (args.type=='all'): 234 | 235 | if os.path.exists(save_path): 236 | shutil.rmtree(save_path, True) 237 | os.makedirs(save_path) 238 | 239 | template_l,template_r = get_template_by_dataset(args.dataset_name) 240 | print("template_l:",template_l) 241 | print() 242 | print("template_r",template_r) 243 | token_embedding_path,sentence_embedding_path,token_node_ids_path,token_label_path = build_pretrain_data_by_tables(model,tokenizer,x_text,save_path,template_l,args.device,args) 244 | convert_tables_to_npy(save_path) 245 | 246 | if(args.type == 'convert') or (args.type!='pretrain'): 247 | ## if out-of-memory, and the .h5 data have be saved, consider covert-only to transform .h5 to .npy 248 | convert_tables_to_npy(save_path) 249 | 250 | if(args.type == 'prompt') or (args.type=='all'): 251 | save_path = args.prompt_save_path+args.dataset_name+'/' 252 | template_l,template_r = get_template_by_dataset(args.dataset_name) 253 | if os.path.exists(save_path): 254 | shutil.rmtree(save_path, True) 255 | os.makedirs(save_path) 256 | prompt_embedding = get_prompt_embedding(model,tokenizer,x_text,template_l,template_r,args.device,args) 257 | np.save(f'{save_path}/prompt_embedding.npy',prompt_embedding) 258 | 259 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pretrain_utils import pretrain_graph_adapter 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser(description='graph adapter') 6 | parser.add_argument('--device', type=int, default=0) 7 | parser.add_argument('--dataset_name', type=str, help='dataset to be used', default='instagram', choices=['arxiv', 'instagram', 'reddit']) 8 | parser.add_argument('--max_epoch', type=int, default=50) 9 | parser.add_argument('--hiddensize_gnn', type=int, default=128) 10 | parser.add_argument('--hiddensize_fusion', type=int, default=128) 11 | parser.add_argument('--num_layers', type=int, default=2) 12 | parser.add_argument('--learning_ratio', type=float, default=1e-4) 13 | parser.add_argument('--weight_decay', type=float, default=1e-3) 14 | parser.add_argument('--num_warmup_steps', type=int, default=10) 15 | parser.add_argument('--batch_size', type=int, default=64) 16 | parser.add_argument('--lm_head_path', type=str, default=f'./pretrain_models/head/lm_head.pkl') 17 | parser.add_argument('--GNN_type', type=str, default='SAGE', choices = ['SAGE','GAT','MLP']) 18 | args = parser.parse_args() 19 | pretrain_graph_adapter(args) -------------------------------------------------------------------------------- /pretrain_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch_geometric as tg 4 | from torch_geometric.data import InMemoryDataset, download_url, Data 5 | from sklearn.metrics import roc_auc_score,accuracy_score 6 | from transformers import get_scheduler 7 | import joblib 8 | import tqdm 9 | import numpy as np 10 | from sklearn.model_selection import train_test_split 11 | import random 12 | from torch.utils.data import DataLoader 13 | import torch.utils.data as data_ 14 | import pandas as pd 15 | import torch.nn.functional as F 16 | import logging 17 | import os 18 | 19 | from torch.autograd import Variable 20 | 21 | from graphadapter import GraphAdapter 22 | 23 | def load_pretrain_graph(dataset_name): 24 | x = np.load(f'./token_embedding/{dataset_name}/sentence_embeddings.npy') 25 | edge_index = np.load(f'./datasets/{dataset_name}/edge_index.npy') 26 | x = torch.tensor(x).float() 27 | edge_index = torch.tensor(edge_index).T 28 | edge_index = tg.utils.to_undirected(edge_index) 29 | edge_index = tg.utils.add_self_loops(edge_index)[0] 30 | edge_index = tg.utils.sort_edge_index(edge_index) 31 | data = Data() 32 | data.x = x.float() 33 | data.edge_index = edge_index 34 | return data 35 | 36 | def load_llm_data(dataset_name = 'instagram'): 37 | token_labels = np.load(f'./token_embedding/{dataset_name}/token_labels.npy') 38 | token_embeddings = np.load(f'./token_embedding/{dataset_name}/token_embeddings.npy') 39 | token_node_ids = np.load(f'./token_embedding/{dataset_name}/token_node_ids.npy') 40 | return token_labels,token_embeddings,token_node_ids 41 | 42 | def get_node_level_token(token_node_ids,token_embeddings,token_labels): 43 | node_token_embeddings=[] 44 | node_token_labels=[] 45 | token_node_ids = token_node_ids.astype(int) 46 | token_labels = token_labels.astype(int) 47 | global node_num 48 | node_num = token_node_ids.max()+1 49 | for i in range(node_num): 50 | node_token_embeddings.append([]) 51 | node_token_labels.append([]) 52 | for node_ids,embed,label in tqdm.tqdm(zip(token_node_ids,token_embeddings,token_labels)): 53 | node_token_embeddings[node_ids].append(embed) 54 | node_token_labels[node_ids].append(label) 55 | return node_token_embeddings,node_token_labels 56 | 57 | def split_pretrain_data(token_labels,token_embeddings,token_node_ids): 58 | y_data = pd.DataFrame() 59 | y = token_labels 60 | node_token_ids = [] 61 | for i in range(token_node_ids.max()+1): 62 | node_token_ids.append([]) 63 | token_number=0 64 | for ids in token_node_ids: 65 | node_token_ids[ids].append(token_number) 66 | token_number+=1 67 | X_train = [] 68 | X_test = [] 69 | for e in node_token_ids: 70 | seq_size = len(e) 71 | if(seq_size<2): 72 | continue 73 | l = 0 74 | mid = int(seq_size*0.9) 75 | r = seq_size 76 | if(mid==r): 77 | mid-=1 78 | for i in range(l,mid): 79 | X_train.append(e[i]) 80 | for i in range(mid,r): 81 | X_test.append(e[i]) 82 | X_train = np.array(X_train) 83 | X_test = np.array(X_test) 84 | X_train = X_train.reshape(len(X_train)) 85 | X_test = X_test.reshape(len(X_test)) 86 | 87 | train_token_node_ids = token_node_ids[X_train] 88 | train_token_embeddings = token_embeddings[X_train] 89 | train_token_labels = token_labels[X_train] 90 | 91 | 92 | test_token_node_ids = token_node_ids[X_test] 93 | test_token_embeddings = token_embeddings[X_test] 94 | test_token_labels = token_labels[X_test] 95 | 96 | 97 | train_node_token_embeddings, train_node_token_labels = get_node_level_token(train_token_node_ids, train_token_embeddings,train_token_labels) 98 | eval_node_token_embeddings, eval_node_token_labels = get_node_level_token(test_token_node_ids, test_token_embeddings,test_token_labels) 99 | return train_node_token_embeddings, train_node_token_labels, eval_node_token_embeddings, eval_node_token_labels 100 | 101 | 102 | def load_pretrain_head(lm_head_path = f'./pretrain_models/head/lm_head.pkl'): 103 | try: 104 | pretrain_head = joblib.load(lm_head_path) 105 | except: 106 | raise "lm lead not be found, please see details of preprocess.py" 107 | pretrain_head = pretrain_head.float() 108 | for e in pretrain_head.parameters(): 109 | e.requires_grad=False 110 | return pretrain_head 111 | 112 | 113 | class PretrainData(data_.Dataset): 114 | def __init__(self, node_ids,edge_index,node_token_embeddings,node_token_labels,node_token_weight): 115 | self.node_ids = node_ids 116 | edge_index = edge_index.numpy() 117 | self.neighbor = [] 118 | for i in range(len(node_ids)): 119 | self.neighbor.append([]) 120 | for e in edge_index.T: 121 | self.neighbor[e[1]].append(e[0]) 122 | self.node_token_embeddings = node_token_embeddings 123 | self.node_token_labels = node_token_labels 124 | self.node_token_weight = node_token_weight 125 | def __len__(self): 126 | return len(self.node_ids) 127 | def __getitem__(self, idx): 128 | return (self.node_ids[idx],self.neighbor[idx],self.node_token_embeddings[idx],self.node_token_labels[idx],self.node_token_weight[idx]) 129 | 130 | def pretrain_collate_fn(node_embdding): 131 | i = 0 132 | token_embedding = [] 133 | token_ids = [] 134 | 135 | neighbor_ids = [] 136 | node_ids = [] 137 | token_labels=[] 138 | weight = [] 139 | for node_id,neighbor,node_token,node_labels,node_token_weight in node_embdding: 140 | node_ids+=len(node_token)*[node_id] 141 | weight += list(node_token_weight/np.sum(node_token_weight)) # node level normalize token weights 142 | token_embedding+=node_token 143 | token_labels+=node_labels 144 | token_embedding = np.array(token_embedding) 145 | node_ids = np.array(node_ids) 146 | token_labels = np.array(token_labels) 147 | weight = np.array(weight) 148 | return node_ids,token_embedding,token_labels,weight#,neg_token_embedding 149 | 150 | 151 | def get_node_token_weight(x): 152 | x_map = {} 153 | for e in x: 154 | x_map[e]=0 155 | for e in x: 156 | x_map[e]+=1 157 | node_token_num = [] 158 | for e in x: 159 | node_token_num.append(1/x_map[e]) ## keep token class balance 160 | return node_token_num 161 | 162 | 163 | class LabelSmoothing(torch.nn.Module): 164 | def __init__(self, size, smoothing=0.0): 165 | # using label smoothing can improve the robustness of GraphAdapter 166 | super(LabelSmoothing, self).__init__() 167 | self.criterion = torch.nn.KLDivLoss(reduction='none') 168 | self.confidence = 1.0 - smoothing 169 | self.smoothing = smoothing 170 | self.size = size 171 | self.true_dist = None 172 | 173 | def forward(self, x, target): 174 | assert x.size(1) == self.size 175 | true_dist = x.data.clone() 176 | true_dist.fill_(self.smoothing / (self.size - 1)) 177 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 178 | self.true_dist = true_dist 179 | 180 | return self.criterion(x, Variable(true_dist, requires_grad=False)).sum(dim=1) 181 | 182 | 183 | def pretrain_graph_adapter(args): 184 | 185 | dataset_name = args.dataset_name 186 | hiddensize_gnn = args.hiddensize_gnn 187 | hiddensize_fusion = args.hiddensize_fusion 188 | num_layers = args.num_layers 189 | batch_size = args.batch_size 190 | learning_ratio= args.learning_ratio 191 | weight_decay = args.weight_decay 192 | max_epoch = args.max_epoch 193 | num_warmup_steps = args.num_warmup_steps 194 | device = args.device 195 | GNN_type = args.GNN_type 196 | 197 | num_training_steps = args.max_epoch 198 | 199 | global eval_node_token_embeddings 200 | global eval_node_token_labels 201 | global train_node_token_embeddings 202 | global train_node_token_labels 203 | 204 | device = torch.device(device) 205 | 206 | save_path = f'./save_models/{dataset_name}/{hiddensize_gnn}_{hiddensize_fusion}_{GNN_type}_{num_layers}_{batch_size}_{learning_ratio}_{weight_decay}_{max_epoch}_{num_warmup_steps}/' 207 | 208 | if not os.path.exists(save_path): 209 | os.makedirs(save_path) 210 | joblib.dump(args,f'{save_path}model_args.pkl') 211 | 212 | logger = logging.getLogger() 213 | 214 | file_fmt = "%(asctime)s - %(levelname)s - %(message)s" 215 | logging.basicConfig(level=logging.DEBUG, format=file_fmt, filename=f"{save_path}log.txt", filemode="a") 216 | console_handler = logging.StreamHandler() 217 | console_handler.setLevel(level=logging.DEBUG) 218 | console_fmt = "%(asctime)s - %(levelname)s - %(message)s" 219 | fmt1 = logging.Formatter(fmt=console_fmt) 220 | console_handler.setFormatter(fmt=fmt1) 221 | logger.addHandler(console_handler) 222 | 223 | logging.info(f'save_path:{save_path}') 224 | logging.info('load_pretrain_data...') 225 | token_labels,token_embeddings,token_node_ids = load_llm_data(dataset_name = dataset_name) 226 | logging.info(f"load load llm pretrain data, dataset_name:{dataset_name}") 227 | 228 | train_node_token_embeddings, train_node_token_labels,eval_node_token_embeddings, eval_node_token_labels = split_pretrain_data(token_labels,token_embeddings,token_node_ids) 229 | pretrain_head = load_pretrain_head(args.lm_head_path) 230 | logging.info('load_graph_adapter...') 231 | 232 | train_node_token_weight = [] 233 | for e in train_node_token_labels: 234 | x = get_node_token_weight(torch.tensor(e).numpy()) 235 | train_node_token_weight.append(np.array(x)) 236 | eval_node_token_weight = [] 237 | eval_node_token_unique_token = [] 238 | 239 | for e in eval_node_token_labels: 240 | x = get_node_token_weight(torch.tensor(e).numpy()) 241 | eval_node_token_weight.append(np.array(x)) 242 | 243 | 244 | 245 | data= load_pretrain_graph(dataset_name) 246 | logging.info('load_data...OK') 247 | train_data = PretrainData(list(range(data.x.shape[0])),data.edge_index,train_node_token_embeddings,train_node_token_labels,train_node_token_weight) 248 | eval_data = PretrainData(list(range(data.x.shape[0])),data.edge_index,eval_node_token_embeddings,eval_node_token_labels,eval_node_token_weight) 249 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,collate_fn=pretrain_collate_fn, num_workers=16) 250 | eval_loader = DataLoader(eval_data, batch_size=batch_size*5, shuffle=False,collate_fn=pretrain_collate_fn, num_workers=16) 251 | logging.info('data_loader...OK') 252 | 253 | 254 | loss_function = LabelSmoothing(32000, 0.1) # The number of categories is the number of vocabulary lists in LLM 255 | model = GraphAdapter(llm_shape = data.x.shape[1],hiddensize_gnn = hiddensize_gnn, hiddensize_fusion = hiddensize_fusion, num_layers=num_layers,GNN_type=GNN_type,is_pretraining=True) 256 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_ratio, weight_decay=weight_decay) 257 | lr_scheduler = get_scheduler( 258 | name="linear", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps 259 | ) 260 | model = model.to(device) 261 | data = data.to(device) 262 | pretrain_head = pretrain_head.to(device) 263 | for epoch in range(max_epoch): 264 | total_loss = [] 265 | model.train() 266 | for node_ids,token_embedding,token_labels,weights in tqdm.tqdm(train_loader): 267 | optimizer.zero_grad() 268 | node_ids = torch.tensor(node_ids).view(-1,1).to(device) 269 | token_embedding = torch.tensor(token_embedding).float().to(device) 270 | token_labels = torch.tensor(token_labels).to(device) 271 | 272 | weights = torch.tensor(weights).view(-1,1).to(device) 273 | out1 = model(data.x,data.edge_index,node_ids,token_embedding) 274 | original_y = F.softmax(pretrain_head(token_embedding),dim=1).detach() 275 | 276 | #out2 = F.log_softmax(pretrain_head(out2),dim=1) 277 | pred_y = F.softmax(pretrain_head(out1),dim=1) 278 | pred_y = torch.log((original_y+pred_y)/2) 279 | loss0 = loss_function(pred_y,token_labels) 280 | loss0 = loss0.view(-1,1) 281 | loss0 = loss0*weights 282 | loss0 = loss0.sum()/batch_size 283 | loss = loss0 284 | loss.backward() 285 | optimizer.step() 286 | total_loss += [loss.item()*batch_size] 287 | lr_scheduler.step() 288 | total_eval_loss = [] 289 | 290 | with torch.no_grad(): 291 | model.eval() 292 | for node_ids,token_embedding,token_labels,weights in tqdm.tqdm(eval_loader): 293 | node_ids = torch.tensor(node_ids).view(-1,1).to(device) 294 | token_embedding = torch.tensor(token_embedding).float().to(device) 295 | token_labels = torch.tensor(token_labels).to(device) 296 | weights = torch.tensor(weights).view(-1,1).to(device) 297 | out1 = model(data.x,data.edge_index,node_ids,token_embedding) 298 | pred_y = F.softmax(pretrain_head(out1),dim=1) 299 | original_y = F.softmax(pretrain_head(token_embedding),dim=1) 300 | pred_y = torch.log((original_y+pred_y)/2) 301 | loss = loss_function(pred_y,token_labels) 302 | loss = loss.view(-1,1) 303 | loss = loss*weights 304 | loss = loss.sum() 305 | total_eval_loss += [loss.item()] 306 | 307 | logging.info(f'epoch: {epoch} , loss: {np.sum(total_loss)/data.x.shape[0]}, eval loss: {np.sum(total_eval_loss)/data.x.shape[0]}') 308 | torch.save(model.state_dict(),save_path+f'save_model_{epoch}.pkl') 309 | -------------------------------------------------------------------------------- /prompt_config.py: -------------------------------------------------------------------------------- 1 | def get_template_by_dataset(dataset_name): 2 | ## a prompt can be described as template_l+text_data+template_r 3 | ## template_l would used in pretrain 4 | ## template_r used for infer on downstream task 5 | ## pretrain and downstream task share same template_l 6 | if(dataset_name=='arxiv'): 7 | template_l = "Here is a paper published on arXiv. The abstract reads as follows: \n\n" 8 | template_r = ".\n\nQuestion: Based on the abstract above, this paper is published on \"___\" subject on Arxiv.\nAnswer: \"" 9 | elif(dataset_name == 'instagram'): 10 | template_l = "Here are an account on Instagram, and its personal profile reads: \n\n" 11 | template_r = ".\n\nQuestion: Based on the profile provided , this account is a \"___\" (answer in one word) account on Instagram.\nAnswer: \"" 12 | elif(dataset_name == 'reddit'): 13 | template_l = "It is a user on Reddit, and his last 3 posts are: \n\n" 14 | template_r = ".\n\nQuestion: Based on the given posts, the style of this user is \"___\" (answer in one word).\nAnswer: \"" 15 | else: 16 | raise "template of this dataset are not registered, please modifing the prompt_config.py" 17 | 18 | return template_l,template_r 19 | --------------------------------------------------------------------------------