├── datasets.keep ├── inference.ipynb ├── datasets ├── TwiBot-20 │ └── put preprocessed dataset here.txt ├── Cresci-2015 │ └── put preprocessed dataset here.txt ├── Cresci-2017 │ └── put preprocessed dataset here.txt └── Midterm-2018 │ └── put preprocessed dataset here.txt ├── LM.py ├── README.md ├── model_building.py ├── RGT.py ├── parser_args.py ├── utils.py ├── dataloader.py ├── SimpleHGN.py ├── lmbot.yaml ├── main.py ├── GNNs.py └── trainer.py /datasets.keep: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/TwiBot-20/put preprocessed dataset here.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/Cresci-2015/put preprocessed dataset here.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/Cresci-2017/put preprocessed dataset here.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/Midterm-2018/put preprocessed dataset here.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LM.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForMaskedLM, AutoModel 2 | import torch.nn as nn 3 | from torch_geometric.nn.models import MLP 4 | import torch 5 | 6 | class LM_Model(nn.Module): 7 | def __init__(self, model_config): 8 | super().__init__() 9 | self.LM_model_name = model_config['lm_model'] 10 | if self.LM_model_name == 'deberta': 11 | self.LM = AutoModel.from_pretrained('microsoft/deberta-v3-base') 12 | elif self.LM_model_name == 'roberta': 13 | self.LM = AutoModel.from_pretrained('roberta-base') 14 | elif self.LM_model_name == 'bert': 15 | self.LM = AutoModel.from_pretrained('bert-base-uncased') 16 | elif self.LM_model_name == 'twhin-bert': 17 | self.LM = AutoModel.from_pretrained('Twitter/twhin-bert-base') 18 | elif self.LM_model_name == 'xlm-roberta': 19 | self.LM = AutoModel.from_pretrained('xlm-roberta-base') 20 | else: 21 | raise ValueError() 22 | 23 | self.classifier = MLP(in_channels=self.LM.config.hidden_size, hidden_channels=model_config['classifier_hidden_dim'], out_channels=2, num_layers=model_config['classifier_n_layers'], act=model_config['activation']) 24 | 25 | self.LM.config.hidden_dropout_prob = model_config['lm_dropout'] 26 | self.LM.attention_probs_dropout_prob = model_config['att_dropout'] 27 | 28 | def forward(self, tokenized_tensors): 29 | out = self.LM(output_hidden_states=True, **tokenized_tensors)['hidden_states'] 30 | embedding = out[-1].mean(dim=1) 31 | 32 | return embedding.detach(), self.classifier(embedding) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LMBot: Distilling Graph Knowledge into Language Model for Graph-less Deployment in Twitter Bot Detection (WSDM 2024) 2 | Official implementation of [LMBot: Distilling Graph Knowledge into Language Model for Graph-less Deployment in Twitter Bot Detection](https://arxiv.org/abs/2306.17408) 3 | 4 | ## Requirements 5 | Run following command to create environment for reproduction (for cuda 10.2): 6 | ``` 7 | conda env create -f lmbot.yaml 8 | conda activate lmbot 9 | pip install torch==1.12.0+cu102 torchvision==0.13.0+cu102 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu102 10 | ``` 11 | For ```pyg_lib```, ```torch_cluster```, ```torch_scatter```, ```torch_sparse``` and ```torch_spline_conv```, please download [here](https://data.pyg.org/whl/torch-1.12.0%2Bcu102.html) and install locally. 12 | ``` 13 | pip install pyg_lib-0.1.0+pt112cu102-cp39-cp39-linux_x86_64.whl torch_cluster-1.6.0+pt112cu102-cp39-cp39-linux_x86_64.whl torch_scatter-2.1.0+pt112cu102-cp39-cp39-linux_x86_64.whl torch_sparse-0.6.16+pt112cu102-cp39-cp39-linux_x86_64.whl torch_spline_conv-1.2.1+pt112cu102-cp39-cp39-linux_x86_64.whl 14 | ``` 15 | ## Data preperation 16 | Please download our preprocessed datasets [here](https://drive.google.com/drive/folders/1kbI3uJQCn3e8CN3d9iUeUNSIOuJCbDUj?usp=sharing) and put it in the ```datasets``` folder. 17 | 18 | ## Training 19 | Run the following commands to train on ```TwiBot-20```: 20 | ``` 21 | main.py --project_name lmbot --experiment_name TwiBot-20 --dataset TwiBot-20 --device 0 --LM_pretrain_epochs 4.5 --alpha 0.5 --max_iter 10 --batch_size_LM 32 --use_GNN 22 | ``` 23 | Run the following commands to train on ```Cresci-2015```: 24 | ``` 25 | main.py --project_name lmbot --experiment_name Cresci-2015 --dataset Cresci-2015 --device 0 --LM_pretrain_epochs 2.5 --alpha 0.5 --max_iter 10 --batch_size_LM 32 --use_GNN --LM_eval_patience 10 --hidden_dim 64 26 | ``` 27 | Run the following commands to train on ```Cresci-2017```: 28 | ``` 29 | main.py --project_name lmbot --experiment_name Cresci-2017 --dataset Cresci-2017 --device0 --LM_pretrain_epochs 3 --alpha 0.5 --max_iter 10 --batch_size_LM 32 --LM_eval_patience 20 30 | ``` 31 | Run the following commands to train on ```Midterm-2018```: 32 | ``` 33 | main.py --project_name lmbot --experiment_name Midterm-2018 --dataset Midterm-2018 --device 0 --LM_pretrain_epochs 2 --batch_size_LM 32 --LM_eval_patience 50 34 | ``` 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /model_building.py: -------------------------------------------------------------------------------- 1 | from LM import LM_Model 2 | from GNNs import RGCN, HGT, SimpleHGN, RGT 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | from transformers import AutoTokenizer 7 | 8 | import torch 9 | 10 | 11 | def build_LM_model(model_config): 12 | # build LM_model 13 | LM_model = LM_Model(model_config).to(model_config['device']) 14 | # bulid tokenizer 15 | LM_model_name = model_config['lm_model'].lower() 16 | 17 | if LM_model_name == 'deberta': 18 | LM_tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base') 19 | elif LM_model_name == 'roberta-f': 20 | LM_tokenizer = AutoTokenizer.from_pretrained('yzxjb/roberta-finetuned-20') 21 | elif LM_model_name == 'roberta': 22 | LM_tokenizer = AutoTokenizer.from_pretrained('roberta-base') 23 | elif LM_model_name == 'bert': 24 | LM_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 25 | elif LM_model_name == 'twhin-bert': 26 | LM_tokenizer = AutoTokenizer.from_pretrained('Twitter/twhin-bert-base') 27 | elif LM_model_name == 'xlm-roberta': 28 | LM_tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base') 29 | if LM_model_name != 'roberta-f': 30 | special_tokens_dict = {'additional_special_tokens': ['DESCRIPTION:','METADATA:','TWEET:']} 31 | LM_tokenizer.add_special_tokens(special_tokens_dict) 32 | tokens_list = ["@USER", '#HASHTAG', "HTTPURL", 'EMOJI', 'RT', 'None'] 33 | LM_tokenizer.add_tokens(tokens_list) 34 | LM_model.LM.resize_token_embeddings(len(LM_tokenizer)) 35 | print('Information about LM model:') 36 | print('total params:', sum(p.numel() for p in LM_model.parameters())) 37 | return LM_model, LM_tokenizer 38 | 39 | 40 | def build_GNN_model(model_config): 41 | # build GNN_model 42 | GNN_model_name = model_config['GNN_model'].lower() 43 | if GNN_model_name == 'rgcn': 44 | GNN_model = RGCN(model_config).to(model_config['device']) 45 | elif GNN_model_name == 'rgt': 46 | GNN_model = RGT(model_config).to(model_config['device']) 47 | elif GNN_model_name == 'simplehgn': 48 | GNN_model = SimpleHGN(model_config).to(model_config['device']) 49 | elif GNN_model_name == 'hgt': 50 | GNN_model = HGT(model_config).to(model_config['device']) 51 | 52 | else: 53 | raise ValueError('') 54 | 55 | print('Information about GNN model:') 56 | print(GNN_model) 57 | print('total params:', sum(p.numel() for p in GNN_model.parameters())) 58 | 59 | 60 | return GNN_model 61 | 62 | 63 | -------------------------------------------------------------------------------- /RGT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import TransformerConv 4 | 5 | 6 | class SemanticAttention(nn.Module): 7 | def __init__(self, in_size, num_head, out_size, hidden_size=128): 8 | super(SemanticAttention, self).__init__() 9 | self.num_head = num_head 10 | self.att_layers = nn.ModuleList() 11 | # multi-head attention 12 | for i in range(num_head): 13 | self.att_layers.append( 14 | nn.Sequential( 15 | nn.Linear(in_size, hidden_size), 16 | nn.Tanh(), 17 | nn.Linear(hidden_size, 1, bias=False)) 18 | ) 19 | 20 | def forward(self, z, return_beta): 21 | w = self.att_layers[0](z).mean(0) 22 | beta = torch.softmax(w, dim=0) 23 | if return_beta==True: 24 | print(beta) 25 | beta = beta.expand((z.shape[0],) + beta.shape) 26 | output = (beta * z).sum(1) 27 | 28 | for i in range(1, self.num_head): 29 | w = self.att_layers[i](z).mean(0) 30 | beta = torch.softmax(w, dim=0) 31 | if return_beta == True: 32 | print(beta) 33 | beta = beta.expand((z.shape[0],) + beta.shape) 34 | temp = (beta * z).sum(1) 35 | output += temp 36 | # print('pre_feature',pre_features.size()) 37 | return output / self.num_head 38 | 39 | class RGTLayer(nn.Module): 40 | def __init__(self, num_edge_type, in_size, out_size, layer_num_heads, semantic_head, dropout): 41 | super(RGTLayer, self).__init__() 42 | self.gated = nn.Sequential( 43 | nn.Linear(in_size + out_size, in_size), 44 | nn.Sigmoid() 45 | ) 46 | 47 | self.activation = nn.ELU() 48 | self.gat_layers = nn.ModuleList() 49 | for i in range(int(num_edge_type)): 50 | self.gat_layers.append(TransformerConv(in_channels=in_size, out_channels=out_size, heads=layer_num_heads, dropout=dropout, concat=False)) 51 | 52 | self.semantic_attention = SemanticAttention(in_size=out_size, num_head = semantic_head, out_size = out_size) 53 | 54 | def forward(self, features, edge_index_list, beta = False, agg = None): 55 | 56 | u = self.gat_layers[0](features, edge_index_list[0].squeeze(0)).flatten(1) #.unsqueeze(1) 57 | a = self.gated(torch.cat((u, features), dim = 1)) 58 | 59 | semantic_embeddings = (torch.mul(torch.tanh(u), a) + torch.mul(features, (1-a))).unsqueeze(1) 60 | 61 | for i in range(1,len(edge_index_list)): 62 | 63 | u = self.gat_layers[i](features, edge_index_list[i].squeeze(0)).flatten(1) 64 | a = self.gated(torch.cat((u, features), dim = 1)) 65 | output = torch.mul(torch.tanh(u), a) + torch.mul(features, (1-a)) 66 | semantic_embeddings=torch.cat((semantic_embeddings,output.unsqueeze(1)), dim = 1) 67 | 68 | if agg == 'max': 69 | return semantic_embeddings.max(dim = 1)[0] 70 | if agg == 'min': 71 | return semantic_embeddings.min(dim = 1)[0] 72 | if agg == 'sum': 73 | return semantic_embeddings.sum(1) 74 | if agg == 'mean': 75 | return semantic_embeddings.mean(1) 76 | else: 77 | return self.semantic_attention(semantic_embeddings, return_beta = beta) -------------------------------------------------------------------------------- /parser_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parser_args(): 5 | parser = argparse.ArgumentParser() 6 | # wandb argument 7 | parser.add_argument('--project_name', type=str) 8 | parser.add_argument('--experiment_name', type=str) 9 | # dataset argument 10 | 11 | parser.add_argument('--batch_size_LM', type=int, default=32) 12 | parser.add_argument('--batch_size_GNN', type=int, default=300000) 13 | parser.add_argument('--batch_size_MLP', type=int, default=300000) 14 | parser.add_argument('--raw_data_filepath', type=str, default='./data/raw/') 15 | parser.add_argument('--is_processed', type=bool, default=True) 16 | parser.add_argument('--reset_split', type=str, default='1,1,8') 17 | 18 | # model argument 19 | 20 | parser.add_argument('--LM_model', type=str, default='roberta') 21 | parser.add_argument('--max_length', type=int, default=512) 22 | parser.add_argument('--optimizer_LM', type=str, default='adamw') 23 | parser.add_argument('--dropout', type=float, default=0.4) 24 | parser.add_argument('--LM_classifier_n_layers', type=int, default=2) 25 | parser.add_argument('--LM_classifier_hidden_dim', type=int, default=128) 26 | parser.add_argument('--LM_dropout', type=float, default=0.1) 27 | parser.add_argument('--LM_att_dropout', type=float, default=0.1) 28 | parser.add_argument('--label_smoothing_factor', type=float, default=0) 29 | parser.add_argument('--warmup', type=float, default=0.6) 30 | 31 | parser.add_argument('--use_GNN', action='store_true') 32 | parser.add_argument('--GNN_model', type=str, default='rgcn') 33 | parser.add_argument('--n_layers', type=int, default=2) 34 | parser.add_argument('--hidden_dim', type=int, default=128) 35 | parser.add_argument('--n_relations', type=int, default=2) 36 | parser.add_argument('--activation', type=str, default='leakyrelu') 37 | parser.add_argument('--optimizer_GNN', type=str, default='adamw') 38 | parser.add_argument('--GNN_dropout', type=float, default=0.4) 39 | parser.add_argument('--att_heads', type=int, default=8) 40 | parser.add_argument('--SimpleHGN_att_res', type=float, default=0.2) 41 | parser.add_argument('--RGT_semantic_heads', type=int, default=8) 42 | 43 | parser.add_argument('--MLP_n_layers', type=int, default=3) 44 | parser.add_argument('--MLP_hidden_dim', type=int, default=128) 45 | parser.add_argument('--optimizer_MLP', type=str, default='adamw') 46 | parser.add_argument('--MLP_dropout', type=float, default=0.4) 47 | 48 | 49 | 50 | # train evaluation test argument 51 | parser.add_argument('--seeds', type=str, default='1,2,3,4,5') 52 | parser.add_argument('--device', type=int, default=-1) 53 | parser.add_argument('--LM_pretrain_epochs', type=float, default=5) 54 | parser.add_argument('--MLP_KD_epochs', type=int, default=300) 55 | parser.add_argument('--LM_eval_patience', type=int, default=20) 56 | parser.add_argument('--LM_accumulation', type=int, default=1) 57 | parser.add_argument('--max_iters', type=int, default=10) 58 | parser.add_argument('--GNN_epochs_per_iter', type=int, default=200) 59 | parser.add_argument('--LM_epochs_per_iter', type=int, default=3) 60 | parser.add_argument('--MLP_epochs_per_iter', type=int, default=300) 61 | parser.add_argument('--temperature', type=float, default=3) 62 | parser.add_argument('--pl_ratio_LM', type=float, default=0.5) 63 | parser.add_argument('--pl_ratio_GNN', type=float, default=0) 64 | parser.add_argument('--pl_ratio_MLP', type=float, default=0.5) 65 | parser.add_argument('--alpha', type=float, default=0.7) 66 | parser.add_argument('--beta', type=float, default=0) 67 | parser.add_argument('--gamma', type=float, default=0.7) 68 | 69 | parser.add_argument('--lr_LM', type=float, default=1e-5) 70 | parser.add_argument('--weight_decay_LM', type=float, default=0.01) 71 | 72 | parser.add_argument('--lr_GNN', type=float, default=5e-4) 73 | parser.add_argument('--weight_decay_GNN', type=float, default=1e-5) 74 | 75 | parser.add_argument('--lr_MLP', type=float, default=5e-4) 76 | parser.add_argument('--weight_decay_MLP', type=float, default=1e-5) 77 | 78 | 79 | args = parser.parse_args() 80 | return args -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random, os 3 | import numpy as np 4 | import wandb 5 | import json 6 | from pathlib import Path 7 | 8 | def seed_setting(seed_number): 9 | random.seed(seed_number) 10 | os.environ['PYTHONHASHSEED'] = str(seed_number) 11 | np.random.seed(seed_number) 12 | torch.manual_seed(seed_number) 13 | torch.cuda.manual_seed(seed_number) 14 | torch.cuda.manual_seed_all(seed_number) 15 | torch.backends.cudnn.benchmark = False 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.enabled = False 18 | 19 | def setup_wandb(args, seed): 20 | run = wandb.init( 21 | project=args.project_name, 22 | name=args.experiment_name + f'_seed_{seed}', 23 | config=args 24 | ) 25 | return run 26 | 27 | 28 | def load_raw_data(dataset, use_GNN): 29 | data_filepath = f'./datasets/{dataset}/' 30 | print('Loading data...') 31 | train_idx = torch.load(data_filepath+'train_idx.pt') 32 | valid_idx = torch.load(data_filepath+'valid_idx.pt') 33 | test_idx = torch.load(data_filepath+'test_idx.pt') 34 | user_text = json.load(open(data_filepath+'norm_user_text.json')) 35 | labels = torch.load(data_filepath+'labels.pt') 36 | if use_GNN: 37 | edge_index = torch.load(data_filepath+'edge_index.pt') 38 | edge_type = torch.load(data_filepath+'edge_type.pt') 39 | return {'train_idx': train_idx, 40 | 'valid_idx': valid_idx, 41 | 'test_idx': test_idx, 42 | 'user_text': user_text, 43 | 'labels': labels, 44 | 'edge_index': edge_index, 45 | 'edge_type': edge_type} 46 | else: 47 | return {'train_idx': train_idx, 48 | 'valid_idx': valid_idx, 49 | 'test_idx': test_idx, 50 | 'user_text': user_text, 51 | 'labels': labels} 52 | 53 | 54 | def load_distilled_knowledge(from_which_model, intermediate_data_filepath, iter): 55 | if from_which_model == 'LM': 56 | embeddings = torch.load(intermediate_data_filepath / f'embeddings_iter_{iter}.pt') 57 | soft_labels = torch.load(intermediate_data_filepath / f'soft_labels_iter_{iter}.pt') 58 | return embeddings, soft_labels 59 | 60 | elif from_which_model == 'GNN': 61 | 62 | soft_labels = torch.load(intermediate_data_filepath / f'soft_labels_iter_{iter}.pt') 63 | return soft_labels 64 | 65 | elif from_which_model == 'MLP': 66 | soft_labels = torch.load(intermediate_data_filepath / f'soft_labels_iter_{iter}.pt') 67 | return soft_labels 68 | 69 | else: 70 | raise ValueError('"from_which_model" should be "LM", "GNN" or "MLP".') 71 | 72 | 73 | def prepare_path(experiment_name): 74 | experiment_path = Path(experiment_name) 75 | ckpt_filepath = experiment_path / 'checkpoints' 76 | MLP_ckpt_filepath = ckpt_filepath / 'MLP' 77 | LM_ckpt_filepath = ckpt_filepath / 'LM' 78 | GNN_ckpt_filepath = ckpt_filepath / 'GNN' 79 | MLP_KD_ckpt_filepath = Path('MLP_KD') 80 | LM_prt_ckpt_filepath = ckpt_filepath / 'LM_pretrain' 81 | GNN_prt_ckpt_filepath = ckpt_filepath / 'GNN_pretrain' 82 | LM_prt_ckpt_filepath.mkdir(exist_ok=True, parents=True) 83 | GNN_prt_ckpt_filepath.mkdir(exist_ok=True, parents=True) 84 | LM_ckpt_filepath.mkdir(exist_ok=True, parents=True) 85 | GNN_ckpt_filepath.mkdir(exist_ok=True, parents=True) 86 | MLP_KD_ckpt_filepath.mkdir(exist_ok=True, parents=True) 87 | MLP_ckpt_filepath.mkdir(exist_ok=True, parents=True) 88 | 89 | LM_intermediate_data_filepath = experiment_path / 'intermediate' / 'LM' 90 | GNN_intermediate_data_filepath = experiment_path / 'intermediate' / 'GNN' 91 | MLP_intermediate_data_filepath = experiment_path / 'intermediate' / 'MLP' 92 | LM_intermediate_data_filepath.mkdir(exist_ok=True, parents=True) 93 | GNN_intermediate_data_filepath.mkdir(exist_ok=True, parents=True) 94 | MLP_intermediate_data_filepath.mkdir(exist_ok=True, parents=True) 95 | 96 | return LM_prt_ckpt_filepath, GNN_prt_ckpt_filepath, MLP_KD_ckpt_filepath, LM_ckpt_filepath, GNN_ckpt_filepath, MLP_ckpt_filepath, LM_intermediate_data_filepath, GNN_intermediate_data_filepath, MLP_intermediate_data_filepath 97 | 98 | 99 | 100 | def reset_split(n_nodes, ratio): 101 | idx = torch.randperm(n_nodes) 102 | split = list(map(int, ratio.split(','))) 103 | train_ratio = split[0] / sum(split) 104 | valid_ratio = split[1] / sum(split) 105 | 106 | train_idx = idx[: int(train_ratio * n_nodes)] 107 | valid_idx = idx[int(train_ratio * n_nodes): int((train_ratio + valid_ratio) * n_nodes)] 108 | test_idx = idx[int((train_ratio + valid_ratio) * n_nodes):] 109 | return train_idx, valid_idx, test_idx 110 | 111 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import numpy as np 3 | import torch 4 | from torch_geometric.loader import NeighborLoader 5 | from torch_geometric.data import Data 6 | 7 | 8 | class LM_dataset(Dataset): 9 | def __init__(self, user_text: list, labels: torch.Tensor, is_pl: torch.LongTensor=None): 10 | super().__init__() 11 | self.user_text = user_text 12 | self.labels = labels 13 | self.is_pl = is_pl 14 | 15 | def __getitem__(self, index): 16 | if self.is_pl is None: 17 | text = self.user_text[index] 18 | label = self.labels[index] 19 | return text, label 20 | else: 21 | text = self.user_text[index] 22 | label = self.labels[index] 23 | is_pl = self.is_pl[index] 24 | return text, label, is_pl 25 | 26 | def __len__(self): 27 | return len(self.user_text) 28 | 29 | 30 | class MLP_dataset(Dataset): 31 | def __init__(self, LM_embeddings: torch.Tensor, labels: torch.Tensor, is_pl: torch.LongTensor=None): 32 | super().__init__() 33 | self.LM_embeddings = LM_embeddings 34 | self.labels = labels 35 | self.is_pl = is_pl 36 | 37 | def __getitem__(self, index): 38 | if self.is_pl is None: 39 | emb = self.LM_embeddings[index] 40 | label = self.labels[index] 41 | return emb, label 42 | else: 43 | emb = self.LM_embeddings[index] 44 | label = self.labels[index] 45 | is_pl = self.is_pl[index] 46 | return emb, label, is_pl 47 | 48 | def __len__(self): 49 | return self.LM_embeddings.shape[0] 50 | 51 | def build_LM_dataloader(dataloader_config, idx, user_seq, labels, mode, is_pl=None): 52 | batch_size = dataloader_config['batch_size'] 53 | 54 | if mode == 'train': 55 | user_text = [] 56 | for i in idx: 57 | user_text.append(user_seq[i.item()]) 58 | loader = DataLoader(dataset=LM_dataset(user_text, labels[idx], is_pl), batch_size=batch_size, shuffle=True) 59 | 60 | elif mode == 'pretrain': 61 | user_text = [] 62 | for i in idx: 63 | user_text.append(user_seq[i.item()]) 64 | loader = DataLoader(dataset=LM_dataset(user_text, labels[idx]), batch_size=batch_size, shuffle=True) 65 | 66 | 67 | elif mode == 'infer': 68 | loader = DataLoader(dataset=LM_dataset(user_seq, labels), batch_size=batch_size*5) 69 | 70 | 71 | elif mode == 'eval': 72 | user_text = [] 73 | for i in idx: 74 | user_text.append(user_seq[i.item()]) 75 | loader = DataLoader(dataset=LM_dataset(user_text, labels[idx]), batch_size=batch_size*5) 76 | 77 | else: 78 | raise ValueError('mode should be in ["train", "eval", "infer", "pretrain"].') 79 | 80 | return loader 81 | 82 | 83 | def build_GNN_dataloader(dataloader_config, idx, LM_embedding, labels, edge_index, edge_type, mode, is_pl=None): 84 | batch_size = dataloader_config['batch_size'] 85 | n_layers = dataloader_config['n_layers'] 86 | 87 | data = Data(x=LM_embedding, edge_index=edge_index, edge_type=edge_type, labels=labels) 88 | data.num_nodes = LM_embedding.shape[0] 89 | if mode == 'train' or mode == 'pretrain': 90 | data.is_pl = is_pl 91 | loader = NeighborLoader(data=data, num_neighbors=[-1]*n_layers, batch_size=batch_size, input_nodes=idx, shuffle=True) 92 | 93 | elif mode == 'eval': 94 | loader = NeighborLoader(data=data, num_neighbors=[-1]*n_layers, input_nodes=idx, batch_size=batch_size) 95 | 96 | elif mode == 'infer': 97 | loader = NeighborLoader(data=data, num_neighbors=[-1]*n_layers, batch_size=batch_size) 98 | else: 99 | raise ValueError('mode should be in ["train", "valid", "test", "infer"].') 100 | 101 | return loader 102 | 103 | 104 | def build_MLP_dataloader(dataloader_config, idx, LM_embeddings, labels, mode, is_pl=None): 105 | batch_size = dataloader_config['batch_size'] 106 | 107 | if mode == 'train': 108 | loader = DataLoader(dataset=MLP_dataset(LM_embeddings[idx], labels[idx], is_pl), batch_size=batch_size, shuffle=True) 109 | 110 | elif mode == 'pretrain': 111 | loader = DataLoader(dataset=MLP_dataset(LM_embeddings[idx], labels[idx]), batch_size=batch_size, shuffle=True) 112 | 113 | elif mode == 'eval': 114 | loader = DataLoader(dataset=MLP_dataset(LM_embeddings[idx], labels[idx]), batch_size=batch_size*10) 115 | 116 | elif mode == 'infer': 117 | loader = DataLoader(dataset=MLP_dataset(LM_embeddings, labels), batch_size=batch_size*10) 118 | 119 | else: 120 | raise ValueError('mode should be in ["train", "eval", "infer", "pretrain"].') 121 | 122 | return loader 123 | -------------------------------------------------------------------------------- /SimpleHGN.py: -------------------------------------------------------------------------------- 1 | # from torch_geometric.nn import MessagePassing 2 | # import torch 3 | # import torch.nn as nn 4 | # from torch.nn import Linear 5 | # from torch_geometric.utils import softmax 6 | # import torch.nn.functional as F 7 | 8 | 9 | 10 | # class SimpleHGNConv(MessagePassing): 11 | # def __init__(self, in_dim, hidden_dim, num_edge_type, edge_emb_dim, num_heads, beta=0, is_final=False): 12 | # ''' 13 | # if self.is_final: 14 | # out_dim = hidden_dim * num_heads 15 | # else: 16 | # out_dim = hidden_dim 17 | # ''' 18 | # super(SimpleHGNConv, self).__init__(aggr='add') 19 | # self.in_dim = in_dim 20 | # self.hidden_dim = hidden_dim 21 | # self.beta = beta 22 | # self.num_heads = num_heads 23 | # self.is_final = is_final 24 | 25 | # self.W = nn.Parameter(torch.empty((in_dim, num_heads*hidden_dim))) 26 | # self.W_r = nn.Parameter(torch.empty((edge_emb_dim, num_heads*hidden_dim))) 27 | # self.edge_emb = nn.Parameter(torch.empty((num_edge_type, edge_emb_dim))) 28 | # self.a = nn.Parameter(torch.empty((1, num_heads, 3*hidden_dim))) 29 | # self.W_res = nn.Parameter(torch.empty((in_dim, num_heads*hidden_dim))) 30 | 31 | # self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 32 | # self.elu = nn.ELU() 33 | 34 | # nn.init.xavier_uniform_(self.W, gain=1.414) 35 | # nn.init.xavier_uniform_(self.W_r, gain=1.414) 36 | # nn.init.xavier_uniform_(self.edge_emb, gain=1.414) 37 | # nn.init.xavier_uniform_(self.a, gain=1.414) 38 | # nn.init.xavier_uniform_(self.W_res, gain=1.414) 39 | 40 | # def forward(self, x, edge_index, edge_tpye, att_res=None, node_res=None): 41 | # ''' 42 | # x has shape (num_nodes, in_dim) 43 | # edge_index has shape (2, num_edges) 44 | # edge_tpye has shape (num_edges, ) 45 | # ''' 46 | # out = self.propagate(x=x, edge_index=edge_index, edge_tpye=edge_tpye, att_res=att_res) 47 | # if self.is_final: 48 | # out = out.view(-1, self.num_heads, self.hidden_dim) 49 | # out = self.elu(out.sum(dim=1) / self.num_heads) 50 | # out = F.normalize(out, dim=1) 51 | # else: 52 | # if node_res is not None: 53 | # out = self.elu(out + torch.matmul(node_res * self.W_res)) 54 | # return out, self.att.detach() 55 | 56 | # def message(self, x_i, x_j, edge_tpye, att_res, index, ptr, size_i): 57 | # v = torch.matmul(x_j, self.W) 58 | # k = torch.matmul(x_j, self.W).view(-1, self.num_heads, self.hidden_dim) 59 | # q = torch.matmul(x_i, self.W).view(-1, self.num_heads, self.hidden_dim) 60 | # ''' 61 | # q, k, v has shape (num_edges, num_heads, hidden_dim) 62 | # ''' 63 | # # print(k.shape) 64 | # att = self.leakyrelu((self.a * torch.cat([q, k, torch.matmul(self.edge_emb[edge_tpye], self.W_r).view(-1, self.num_heads, self.hidden_dim)], dim=-1)).sum(dim=-1)) 65 | # att = softmax(att, index, ptr, size_i) 66 | # # print(att.shape) 67 | # self.att = att 68 | # ''' 69 | # att has shape (num_edges, num_heads) 70 | # ''' 71 | # att = att if att_res is None else self.beta * att_res + (1 - self.beta) * att 72 | 73 | # out = att.repeat(1, self.hidden_dim) * v 74 | # # print(out.shape) 75 | # ''' 76 | # out has shape (num_edges, num_heads*hidden_dim) 77 | # ''' 78 | # return out 79 | 80 | # def update(self, aggr_out): 81 | # return aggr_out 82 | 83 | 84 | import torch 85 | from torch_geometric.nn import MessagePassing 86 | from torch_geometric.utils import softmax 87 | import torch.nn.functional as F 88 | 89 | class SimpleHGNConv(MessagePassing): 90 | def __init__(self, in_channels, out_channels, num_edge_type, rel_dim, beta=None, final_layer=False): 91 | super(SimpleHGNConv, self).__init__(aggr = "add", node_dim=0) 92 | self.W = torch.nn.Linear(in_channels, out_channels, bias=False) 93 | self.W_r = torch.nn.Linear(rel_dim, out_channels, bias=False) 94 | self.a = torch.nn.Linear(3*out_channels, 1, bias=False) 95 | self.W_res = torch.nn.Linear(in_channels, out_channels, bias=False) 96 | self.rel_emb = torch.nn.Embedding(num_edge_type, rel_dim) 97 | self.beta = beta 98 | self.leaky_relu = torch.nn.LeakyReLU(0.2) 99 | self.ELU = torch.nn.ELU() 100 | self.final = final_layer 101 | 102 | def init_weight(self): 103 | for m in self.modules(): 104 | if isinstance(m, torch.nn.Linear): 105 | torch.nn.init.xavier_uniform_(m.weight.data) 106 | 107 | def forward(self, x, edge_index, edge_type, pre_alpha=None): 108 | 109 | node_emb = self.propagate(x=x, edge_index=edge_index, edge_type=edge_type, pre_alpha=pre_alpha) 110 | output = node_emb + self.W_res(x) 111 | output = self.ELU(output) 112 | if self.final: 113 | output = F.normalize(output, dim=1) 114 | 115 | return output, self.alpha.detach() 116 | 117 | def message(self, x_i, x_j, edge_type, pre_alpha, index, ptr, size_i): 118 | out = self.W(x_j) 119 | rel_emb = self.rel_emb(edge_type) 120 | alpha = self.leaky_relu(self.a(torch.cat((self.W(x_i), self.W(x_j), self.W_r(rel_emb)), dim=1))) 121 | alpha = softmax(alpha, index, ptr, size_i) 122 | if pre_alpha is not None and self.beta is not None: 123 | self.alpha = alpha*(1-self.beta) + pre_alpha*(self.beta) 124 | else: 125 | self.alpha = alpha 126 | out = out * alpha.view(-1,1) 127 | return out 128 | 129 | def update(self, aggr_out): 130 | return aggr_out -------------------------------------------------------------------------------- /lmbot.yaml: -------------------------------------------------------------------------------- 1 | name: lmbot 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - asttokens=2.0.5=pyhd3eb1b0_0 9 | - backcall=0.2.0=pyhd3eb1b0_0 10 | - blas=1.0=mkl 11 | - bottleneck=1.3.5=py39h7deecbd_0 12 | - ca-certificates=2022.10.11=h06a4308_0 13 | - certifi=2022.12.7=py39h06a4308_0 14 | - debugpy=1.5.1=py39h295c915_0 15 | - decorator=5.1.1=pyhd3eb1b0_0 16 | - entrypoints=0.4=py39h06a4308_0 17 | - executing=0.8.3=pyhd3eb1b0_0 18 | - intel-openmp=2021.4.0=h06a4308_3561 19 | - ipykernel=6.15.2=py39h06a4308_0 20 | - ipython=8.7.0=py39h06a4308_0 21 | - jedi=0.18.1=py39h06a4308_1 22 | - jupyter_client=7.4.8=py39h06a4308_0 23 | - jupyter_core=4.11.2=py39h06a4308_0 24 | - ld_impl_linux-64=2.38=h1181459_1 25 | - libffi=3.4.2=h6a678d5_6 26 | - libgcc-ng=11.2.0=h1234567_1 27 | - libgomp=11.2.0=h1234567_1 28 | - libsodium=1.0.18=h7b6447c_0 29 | - libstdcxx-ng=11.2.0=h1234567_1 30 | - matplotlib-inline=0.1.6=py39h06a4308_0 31 | - mkl=2021.4.0=h06a4308_640 32 | - mkl-service=2.4.0=py39h7f8727e_0 33 | - mkl_fft=1.3.1=py39hd3c417c_0 34 | - mkl_random=1.2.2=py39h51133e4_0 35 | - ncurses=6.3=h5eee18b_3 36 | - nest-asyncio=1.5.5=py39h06a4308_0 37 | - numexpr=2.8.4=py39he184ba9_0 38 | - numpy-base=1.23.5=py39h31eccc5_0 39 | - openssl=1.1.1s=h7f8727e_0 40 | - packaging=22.0=py39h06a4308_0 41 | - parso=0.8.3=pyhd3eb1b0_0 42 | - pexpect=4.8.0=pyhd3eb1b0_3 43 | - pickleshare=0.7.5=pyhd3eb1b0_1003 44 | - pip=22.3.1=py39h06a4308_0 45 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 46 | - ptyprocess=0.7.0=pyhd3eb1b0_2 47 | - pure_eval=0.2.2=pyhd3eb1b0_0 48 | - pygments=2.11.2=pyhd3eb1b0_0 49 | - python=3.9.15=h7a1cb2a_2 50 | - python-dateutil=2.8.2=pyhd3eb1b0_0 51 | - pytz=2022.7=py39h06a4308_0 52 | - pyzmq=23.2.0=py39h6a678d5_0 53 | - readline=8.2=h5eee18b_0 54 | - setuptools=65.5.0=py39h06a4308_0 55 | - six=1.16.0=pyhd3eb1b0_1 56 | - sqlite=3.40.0=h5082296_0 57 | - stack_data=0.2.0=pyhd3eb1b0_0 58 | - tk=8.6.12=h1ccaba5_0 59 | - tornado=6.2=py39h5eee18b_0 60 | - traitlets=5.7.1=py39h06a4308_0 61 | - tzdata=2022g=h04d1e81_0 62 | - wcwidth=0.2.5=pyhd3eb1b0_0 63 | - wheel=0.37.1=pyhd3eb1b0_0 64 | - xz=5.2.8=h5eee18b_0 65 | - zeromq=4.3.4=h2531618_0 66 | - zlib=1.2.13=h5eee18b_0 67 | - pip: 68 | - absl-py==1.4.0 69 | - accelerate==0.21.0 70 | - aiohttp==3.8.4 71 | - aiosignal==1.3.1 72 | - alembic==1.10.2 73 | - appdirs==1.4.4 74 | - async-timeout==4.0.2 75 | - attrs==22.2.0 76 | - axial-positional-embedding==0.2.1 77 | - beautifulsoup4==4.12.2 78 | - braceexpand==0.1.7 79 | - cachetools==5.3.0 80 | - charset-normalizer==2.1.1 81 | - click==8.1.3 82 | - cmaes==0.9.1 83 | - cmake==3.27.2 84 | - colorlog==6.7.0 85 | - contourpy==1.0.7 86 | - cycler==0.11.0 87 | - datasets==2.11.0 88 | - dgl==1.0.1 89 | - dig==0.1.10 90 | - dill==0.3.6 91 | - docker-pycreds==0.4.0 92 | - docopt==0.6.2 93 | - docutils==0.20.1 94 | - einops==0.6.0 95 | - emoji==2.2.0 96 | - fastjsonschema==2.16.3 97 | - filelock==3.8.2 98 | - fonttools==4.39.4 99 | - frozenlist==1.3.3 100 | - fsspec==2023.3.0 101 | - future==0.18.3 102 | - futures==3.0.5 103 | - gitdb==4.0.10 104 | - gitpython==3.1.31 105 | - google-auth==2.16.2 106 | - google-auth-oauthlib==0.4.6 107 | - googledrivedownloader==0.4 108 | - greenlet==2.0.2 109 | - grpcio==1.51.3 110 | - huggingface-hub==0.16.4 111 | - idna==3.4 112 | - ijson==3.2.0.post0 113 | - imbalanced-learn==0.10.1 114 | - imblearn==0.0 115 | - importlib-metadata==6.0.0 116 | - importlib-resources==5.12.0 117 | - jinja2==3.1.2 118 | - joblib==1.2.0 119 | - jsonschema==4.17.3 120 | - kiwisolver==1.4.4 121 | - levenshtein==0.20.9 122 | - lightning-utilities==0.8.0 123 | - linformer==0.2.1 124 | - lit==16.0.6 125 | - littleutils==0.2.2 126 | - local-attention==1.5.7 127 | - mako==1.2.4 128 | - markdown==3.4.1 129 | - markupsafe==2.1.1 130 | - matplotlib==3.7.1 131 | - mpmath==1.3.0 132 | - multidict==6.0.4 133 | - multiprocess==0.70.14 134 | - mypy-extensions==1.0.0 135 | - natsort==8.4.0 136 | - nbformat==5.7.3 137 | - networkx==2.8.8 138 | - nltk==3.8.1 139 | - numpy==1.24.0 140 | - nvidia-cublas-cu11==11.10.3.66 141 | - nvidia-cuda-cupti-cu11==11.7.101 142 | - nvidia-cuda-nvrtc-cu11==11.7.99 143 | - nvidia-cuda-runtime-cu11==11.7.99 144 | - nvidia-cudnn-cu11==8.5.0.96 145 | - nvidia-cufft-cu11==10.9.0.58 146 | - nvidia-curand-cu11==10.2.10.91 147 | - nvidia-cusolver-cu11==11.4.0.1 148 | - nvidia-cusparse-cu11==11.7.4.91 149 | - nvidia-nccl-cu11==2.14.3 150 | - nvidia-nvtx-cu11==11.7.91 151 | - oauthlib==3.2.2 152 | - ogb==1.3.5 153 | - optuna==3.1.0 154 | - orjson==3.9.4 155 | - outdated==0.2.2 156 | - pandas==1.4.0 157 | - pathtools==0.1.2 158 | - peft==0.4.0 159 | - pillow==9.3.0 160 | - plotly==5.13.1 161 | - product-key-memory==0.1.10 162 | - protobuf==3.20.0 163 | - psutil==5.9.4 164 | - pyarrow==11.0.0 165 | - pyasn1==0.4.8 166 | - pyasn1-modules==0.2.8 167 | - pydeprecate==0.3.1 168 | - pyparsing==3.0.9 169 | - pyre-extensions==0.0.29 170 | - pyrsistent==0.19.3 171 | - pytorch-lightning==1.4.9 172 | - pyyaml==6.0 173 | - rapidfuzz==2.15.1 174 | - regex==2022.10.31 175 | - requests==2.28.2 176 | - requests-oauthlib==1.3.1 177 | - responses==0.18.0 178 | - rlpython==0.10.3 179 | - rsa==4.9 180 | - safetensors==0.3.2 181 | - scikit-learn==1.2.0 182 | - scipy==1.9.3 183 | - sentencepiece==0.1.97 184 | - sentry-sdk==1.17.0 185 | - setproctitle==1.3.2 186 | - smmap==5.0.0 187 | - soupsieve==2.4.1 188 | - sparselinear==0.0.5 189 | - sqlalchemy==2.0.5.post1 190 | - sympy==1.12 191 | - tenacity==8.2.2 192 | - tensorboard==2.12.0 193 | - tensorboard-data-server==0.7.0 194 | - tensorboard-plugin-wit==1.8.1 195 | - thop==0.1.1-2209072238 196 | - threadpoolctl==3.1.0 197 | - tokenizers==0.13.2 198 | - torch-geometric==2.2.0 199 | - torchmetrics==0.11.4 200 | - torchsummary==1.5.1 201 | - tqdm==4.64.1 202 | - transformers==4.29.0 203 | - triton==2.0.0 204 | - typing-extensions==4.4.0 205 | - typing-inspect==0.9.0 206 | - urllib3==1.26.13 207 | - wandb==0.14.0 208 | - webdataset==0.2.48 209 | - werkzeug==2.2.3 210 | - xxhash==3.2.0 211 | - yarl==1.8.2 212 | - zipp==3.15.0 213 | 214 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from parser_args import parser_args 2 | from utils import * 3 | from trainer import LM_Trainer, GNN_Trainer, MLP_Trainer 4 | import torch.nn as nn 5 | 6 | def main(args): 7 | 8 | for seed in list(map(int, args.seeds.strip().split(','))): 9 | 10 | seed_setting(seed) 11 | LM_prt_ckpt_filepath, GNN_prt_ckpt_filepath, MLP_KD_ckpt_filepath, LM_ckpt_filepath, GNN_ckpt_filepath, MLP_ckpt_filepath, LM_intermediate_data_filepath, GNN_intermediate_data_filepath, MLP_intermediate_data_filepath = prepare_path(args.experiment_name + f'_seed_{seed}') 12 | 13 | data = load_raw_data(args.dataset, args.use_GNN) 14 | 15 | run = setup_wandb(args, seed) 16 | 17 | if args.reset_split != '-1': 18 | train_idx, valid_idx, test_idx = reset_split(len(data['user_text']), args.reset_split) 19 | data['train_idx'], data['valid_idx'], data['test_idx'] = train_idx, valid_idx, test_idx 20 | 21 | LMTrainer = LM_Trainer( 22 | model_name=args.LM_model, 23 | classifier_n_layers=args.LM_classifier_n_layers, 24 | classifier_hidden_dim=args.LM_classifier_hidden_dim, 25 | device=args.device, 26 | pretrain_epochs=args.LM_pretrain_epochs, 27 | optimizer_name=args.optimizer_LM, 28 | lr=args.lr_LM, 29 | weight_decay=args.weight_decay_LM, 30 | dropout=args.dropout, 31 | att_dropout=args.LM_att_dropout, 32 | lm_dropout=args.LM_dropout, 33 | warmup=args.warmup, 34 | label_smoothing_factor=args.label_smoothing_factor, 35 | pl_weight=args.alpha, 36 | max_length=args.max_length, 37 | batch_size=args.batch_size_LM, 38 | grad_accumulation=args.LM_accumulation, 39 | lm_epochs_per_iter=args.LM_epochs_per_iter, 40 | temperature=args.temperature, 41 | pl_ratio=args.pl_ratio_LM, 42 | intermediate_data_filepath=LM_intermediate_data_filepath, 43 | ckpt_filepath=LM_ckpt_filepath, 44 | pretrain_ckpt_filepath=LM_prt_ckpt_filepath, 45 | raw_data_filepath=args.raw_data_filepath, 46 | train_idx=data['train_idx'], 47 | valid_idx=data['valid_idx'], 48 | test_idx=data['test_idx'], 49 | hard_labels=data['labels'], 50 | user_seq=data['user_text'], 51 | run=run, 52 | eval_patience=args.LM_eval_patience, 53 | activation=args.activation 54 | ) 55 | 56 | MLPTrainer = MLP_Trainer( 57 | device=args.device, 58 | optimizer_name=args.optimizer_MLP, 59 | lr=args.lr_MLP, 60 | weight_decay=args.weight_decay_MLP, 61 | dropout=args.MLP_dropout, 62 | pl_weight=args.gamma, 63 | batch_size=args.batch_size_MLP, 64 | n_layers=args.MLP_n_layers, 65 | hidden_dim=args.MLP_hidden_dim, 66 | activation=args.activation, 67 | glnn_epochs=args.MLP_KD_epochs, 68 | mlp_epochs_per_iter=args.MLP_epochs_per_iter, 69 | temperature=args.temperature, 70 | pl_ratio=args.pl_ratio_MLP, 71 | intermediate_data_filepath=MLP_intermediate_data_filepath, 72 | ckpt_filepath=MLP_ckpt_filepath, 73 | KD_ckpt_filepath=MLP_KD_ckpt_filepath, 74 | train_idx=data['train_idx'], 75 | valid_idx=data['valid_idx'], 76 | test_idx=data['test_idx'], 77 | hard_labels=data['labels'], 78 | run=run, 79 | seed=seed, 80 | use_gnn = args.use_GNN 81 | ) 82 | 83 | 84 | LMTrainer.build_model() 85 | LMTrainer.pretrain() 86 | MLPTrainer.build_model() 87 | 88 | if args.use_GNN: 89 | GNNTrainer = GNN_Trainer( 90 | model_name=args.GNN_model, 91 | device=args.device, 92 | optimizer_name=args.optimizer_GNN, 93 | lr=args.lr_GNN, 94 | weight_decay=args.weight_decay_GNN, 95 | dropout=args.GNN_dropout, 96 | pl_weight=args.beta, 97 | batch_size=args.batch_size_GNN, 98 | gnn_n_layers=args.n_layers, 99 | n_relations=args.n_relations, 100 | activation=args.activation, 101 | gnn_epochs_per_iter=args.GNN_epochs_per_iter, 102 | temperature=args.temperature, 103 | pl_ratio=args.pl_ratio_GNN, 104 | intermediate_data_filepath=GNN_intermediate_data_filepath, 105 | ckpt_filepath=GNN_ckpt_filepath, 106 | pretrain_ckpt_filepath=GNN_prt_ckpt_filepath, 107 | train_idx=data['train_idx'], 108 | valid_idx=data['valid_idx'], 109 | test_idx=data['test_idx'], 110 | hard_labels=data['labels'], 111 | edge_index=data['edge_index'], 112 | edge_type=data['edge_type'], 113 | run=run, 114 | SimpleHGN_att_res=args.SimpleHGN_att_res, 115 | att_heads=args.att_heads, 116 | RGT_semantic_heads=args.RGT_semantic_heads, 117 | gnn_hidden_dim=args.hidden_dim, 118 | lm_name = args.LM_model 119 | ) 120 | GNNTrainer.build_model() 121 | for iter in range(args.max_iters): 122 | print(f'------Iter: {iter}/{args.max_iters-1}------') 123 | 124 | embeddings_LM, soft_labels_LM = load_distilled_knowledge('LM', LM_intermediate_data_filepath, iter-1) 125 | flag = GNNTrainer.train(embeddings_LM, soft_labels_LM) 126 | GNNTrainer.infer(embeddings_LM) 127 | if flag: 128 | print(f'Early stop by GNN at iter {iter}!') 129 | break 130 | 131 | soft_labels_GNN = load_distilled_knowledge('GNN', GNN_intermediate_data_filepath, iter) 132 | flag = LMTrainer.train(soft_labels_GNN) 133 | LMTrainer.infer() 134 | if flag: 135 | print(f'Early stop by LM at iter {iter}!') 136 | break 137 | 138 | print(f'Best LM is iter {LMTrainer.best_iter} epoch {LMTrainer.best_epoch}!') 139 | LMTrainer.test() 140 | 141 | print(f'Best GNN is iter {GNNTrainer.best_iter} epoch {GNNTrainer.best_epoch}!') 142 | embeddings_LM = LMTrainer.load_embedding(GNNTrainer.best_iter-1) 143 | GNNTrainer.test(embeddings_LM) 144 | 145 | # soft_labels_GNN = GNNTrainer.load_soft_labels(GNNTrainer.best_iter) 146 | # MLPTrainer.KD_GLNN(embeddings_LM, soft_labels_GNN) 147 | 148 | GNNTrainer.save_results(args.experiment_name + f'_seed_{seed}/results_GNN.json') 149 | 150 | 151 | else: 152 | for iter in range(args.max_iters): 153 | print(f'------Iter: {iter}/{args.max_iters-1}------') 154 | 155 | embeddings_LM, soft_labels_LM = load_distilled_knowledge('LM', LM_intermediate_data_filepath, iter-1) 156 | flag = MLPTrainer.train(embeddings_LM, soft_labels_LM) 157 | MLPTrainer.infer(embeddings_LM) 158 | if flag: 159 | print(f'Early stop by MLP at iter {iter}!') 160 | break 161 | 162 | soft_labels_MLP = load_distilled_knowledge('MLP', MLP_intermediate_data_filepath, iter) 163 | flag = LMTrainer.train(soft_labels_MLP) 164 | LMTrainer.infer() 165 | if flag: 166 | print(f'Early stop by LM at iter {iter}!') 167 | break 168 | 169 | print(f'Best LM is iter {LMTrainer.best_iter} epoch {LMTrainer.best_epoch}!') 170 | LMTrainer.test() 171 | 172 | print(f'Best MLP is iter {MLPTrainer.best_iter} epoch {MLPTrainer.best_epoch}!') 173 | embeddings_LM = LMTrainer.load_embedding(MLPTrainer.best_iter-1) 174 | MLPTrainer.test(embeddings_LM) 175 | 176 | MLPTrainer.save_results(args.experiment_name + f'_seed_{seed}/results_MLP.json') 177 | LMTrainer.save_results(args.experiment_name + f'_seed_{seed}/results_LM.json') 178 | 179 | if __name__ == '__main__': 180 | args = parser_args() 181 | main(args) 182 | -------------------------------------------------------------------------------- /GNNs.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_geometric.nn import RGCNConv, HGTConv 3 | from SimpleHGN import SimpleHGNConv 4 | import torch 5 | from RGT import RGTLayer 6 | from torch_geometric.nn.models import MLP 7 | 8 | 9 | class RGCN(nn.Module): 10 | def __init__(self, model_config): 11 | super().__init__() 12 | self.hidden_dim = model_config['gnn_hidden_dim'] 13 | self.n_layers = model_config['gnn_n_layers'] 14 | self.convs = nn.ModuleList([]) 15 | self.linear_in = nn.Linear(model_config['lm_input_dim'], self.hidden_dim) 16 | 17 | for i in range(self.n_layers): 18 | self.convs.append(RGCNConv(self.hidden_dim, self.hidden_dim, model_config['n_relations'])) 19 | 20 | self.dropout = nn.Dropout(model_config['dropout']) 21 | 22 | self.activation_name = model_config['activation'].lower() 23 | if self.activation_name == 'leakyrelu': 24 | self.activation = nn.LeakyReLU() 25 | elif self.activation_name == 'relu': 26 | self.activation = nn.ReLU() 27 | elif self.activation_name == 'elu': 28 | self.activation = nn.ELU() 29 | else: 30 | raise ValueError('Please choose activation function from "leakyrelu", "relu" or "elu".') 31 | 32 | self.linear_pool = nn.Linear(self.hidden_dim, self.hidden_dim) 33 | self.linear_out = nn.Linear(self.hidden_dim, 2) 34 | 35 | def forward(self, x, edge_index, edge_type): 36 | x = self.linear_in(x) 37 | x = self.dropout(x) 38 | for i in range(self.n_layers): 39 | x = self.convs[i](x, edge_index, edge_type) 40 | x = self.activation(x) 41 | x = self.linear_pool(x) 42 | x = self.activation(x) 43 | x = self.dropout(x) 44 | return self.linear_out(x) 45 | 46 | 47 | class SimpleHGN(nn.Module): 48 | def __init__(self, model_config): 49 | super().__init__() 50 | self.hidden_dim = model_config['hidden_dim'] 51 | self.n_layers = model_config['n_layers'] 52 | self.heads = model_config['att_heads'] 53 | self.edge_res = model_config['SimpleHGN_att_res'] 54 | 55 | 56 | self.convs = nn.ModuleList([]) 57 | for i in range(self.n_layers): 58 | if i == 0: 59 | self.convs.append(SimpleHGNConv(768, self.hidden_dim, model_config['n_relations'], 32, beta=self.edge_res)) 60 | elif i == self.n_layers - 1: 61 | self.convs.append(SimpleHGNConv(self.hidden_dim, self.hidden_dim, model_config['n_relations'], 32, beta=self.edge_res, final_layer=True)) 62 | else: 63 | self.convs.append(SimpleHGNConv(self.hidden_dim, self.hidden_dim, model_config['n_relations'], 32, beta=self.edge_res)) 64 | 65 | self.dropout = nn.Dropout(model_config['dropout']) 66 | 67 | self.activation_name = model_config['activation'].lower() 68 | if self.activation_name == 'leakyrelu': 69 | self.activation = nn.LeakyReLU() 70 | elif self.activation_name == 'relu': 71 | self.activation = nn.ReLU() 72 | elif self.activation_name == 'elu': 73 | self.activation = nn.ELU() 74 | else: 75 | raise ValueError('Please choose activation function from "leakyrelu", "relu" or "elu".') 76 | 77 | self.linear_pool = nn.Linear(self.hidden_dim, self.hidden_dim) 78 | self.linear_out = nn.Linear(self.hidden_dim, 2) 79 | 80 | def forward(self, x, edge_index, edge_type): 81 | for i in range(self.n_layers): 82 | if i == 0: 83 | x, att_res = self.convs[i](x, edge_index, edge_type) 84 | x = self.dropout(x) 85 | else: 86 | x, att_res = self.convs[i](x, edge_index, edge_type, pre_alpha=att_res) 87 | x = self.dropout(x) 88 | x = self.linear_pool(x) 89 | x = self.dropout(x) 90 | x = self.activation(x) 91 | 92 | return self.linear_out(x) 93 | 94 | 95 | 96 | class HGT(nn.Module): 97 | def __init__(self, model_config): 98 | super().__init__() 99 | self.hidden_dim = model_config['gnn_hidden_dim'] 100 | self.n_layers = model_config['gnn_n_layers'] 101 | self.heads = model_config['att_heads'] 102 | self.metadata = (['user'], [('user', 'follower', 'user'), ('user', 'following', 'user')]) 103 | self.mlp = MLP_Model(model_config) 104 | 105 | 106 | self.convs = nn.ModuleList([]) 107 | for i in range(self.n_layers): 108 | self.convs.append(HGTConv(self.hidden_dim, self.hidden_dim, self.metadata, self.heads)) 109 | 110 | self.dropout = nn.Dropout(model_config['dropout']) 111 | 112 | self.activation_name = model_config['activation'].lower() 113 | if self.activation_name == 'leakyrelu': 114 | self.activation = nn.LeakyReLU() 115 | elif self.activation_name == 'relu': 116 | self.activation = nn.ReLU() 117 | elif self.activation_name == 'elu': 118 | self.activation = nn.ELU() 119 | else: 120 | raise ValueError('Please choose activation function from "leakyrelu", "relu" or "elu".') 121 | 122 | self.linear_pool = nn.Linear(self.hidden_dim, self.hidden_dim) 123 | self.linear_out = nn.Linear(self.hidden_dim, 2) 124 | 125 | def prepare_data_for_HGT(self, x, edge_index, edge_type): 126 | x_dict = {self.metadata[0][0]: x} 127 | edge_index_dict = {} 128 | for i in range(len(self.metadata[1])): 129 | 130 | edge_index_dict[self.metadata[1][i]] = edge_index[:, edge_type==i] 131 | return x_dict, edge_index_dict 132 | 133 | def forward(self, x1, x2, x3, edge_index, edge_type): 134 | x = self.mlp(x1, x2, x3) 135 | x = self.dropout(x) 136 | x_dict, edge_index_dict = self.prepare_data_for_HGT(x, edge_index, edge_type) 137 | for i in range(self.n_layers): 138 | x = self.convs[i](x_dict, edge_index_dict) 139 | x[self.metadata[0][0]] = self.activation(self.dropout(x[self.metadata[0][0]])) 140 | 141 | x = self.linear_pool(x[self.metadata[0][0]]) 142 | x = self.dropout(x) 143 | x = self.activation(x) 144 | 145 | return self.linear_out(x) 146 | 147 | 148 | 149 | class RGT(nn.Module): 150 | def __init__(self, model_config): 151 | super().__init__() 152 | self.hidden_dim = model_config['gnn_hidden_dim'] 153 | self.n_layers = model_config['gnn_n_layers'] 154 | self.n_relations = model_config['n_relations'] 155 | self.mlp = MLP_Model(model_config) 156 | self.convs = nn.ModuleList([]) 157 | for i in range(self.n_layers): 158 | self.convs.append(RGTLayer(self.n_relations, self.hidden_dim, self.hidden_dim, model_config['att_heads'], model_config['RGT_semantic_heads'], dropout=model_config['dropout'])) 159 | 160 | self.dropout = nn.Dropout(model_config['dropout']) 161 | 162 | self.activation_name = model_config['activation'].lower() 163 | if self.activation_name == 'leakyrelu': 164 | self.activation = nn.LeakyReLU() 165 | elif self.activation_name == 'relu': 166 | self.activation = nn.ReLU() 167 | elif self.activation_name == 'elu': 168 | self.activation = nn.ELU() 169 | else: 170 | raise ValueError('Please choose activation function from "leakyrelu", "relu" or "elu".') 171 | 172 | self.linear_pool = nn.Linear(self.hidden_dim, self.hidden_dim) 173 | self.linear_out = nn.Linear(self.hidden_dim, 2) 174 | 175 | def prepare_data_for_RGT(self, x, edge_index, edge_type): 176 | edge_index_list = [] 177 | for i in range(self.n_relations): 178 | edge_index_list.append(edge_index[:, edge_type==i]) 179 | return x, edge_index_list 180 | 181 | def forward(self, LM_embedding, x_numerical, x_categorical, edge_index, edge_type): 182 | x = self.mlp(LM_embedding, x_numerical, x_categorical) 183 | x, edge_index_list = self.prepare_data_for_RGT(x, edge_index, edge_type) 184 | # x = self.input_norm(x) 185 | for i in range(self.n_layers): 186 | x = self.convs[i](x, edge_index_list) 187 | x = self.activation(x) 188 | x = self.linear_pool(x) 189 | x = self.activation(x) 190 | x = self.dropout(x) 191 | return self.linear_out(x) 192 | 193 | 194 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import CrossEntropyLoss, KLDivLoss 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | from sklearn.metrics import f1_score, accuracy_score 6 | import numpy as np 7 | from model_building import build_LM_model, build_GNN_model 8 | from dataloader import build_LM_dataloader, build_GNN_dataloader, build_MLP_dataloader 9 | import os 10 | import json 11 | from pathlib import Path 12 | from transformers.optimization import get_cosine_schedule_with_warmup 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | from torch_geometric.nn.models import MLP 15 | 16 | class LM_Trainer: 17 | def __init__( 18 | self, 19 | model_name, 20 | classifier_n_layers, 21 | classifier_hidden_dim, 22 | device, 23 | pretrain_epochs, 24 | optimizer_name, 25 | lr, 26 | weight_decay, 27 | dropout, 28 | att_dropout, 29 | lm_dropout, 30 | activation, 31 | warmup, 32 | label_smoothing_factor, 33 | pl_weight, 34 | max_length, 35 | batch_size, 36 | grad_accumulation, 37 | lm_epochs_per_iter, 38 | temperature, 39 | pl_ratio, 40 | eval_patience, 41 | intermediate_data_filepath, 42 | ckpt_filepath, 43 | pretrain_ckpt_filepath, 44 | raw_data_filepath, 45 | train_idx, 46 | valid_idx, 47 | test_idx, 48 | hard_labels, 49 | user_seq, 50 | run): 51 | 52 | self.model_name = model_name 53 | self.device = device 54 | self.pretrain_epochs = pretrain_epochs 55 | self.optimizer_name = optimizer_name.lower() 56 | self.lr = lr 57 | self.weight_decay = weight_decay 58 | self.dropout = dropout 59 | self.att_dropout = att_dropout 60 | self.lm_dropout = lm_dropout 61 | self.warmup = warmup 62 | self.label_smoothing_factor = label_smoothing_factor 63 | self.pl_weight = pl_weight 64 | self.max_length = max_length 65 | self.batch_size = batch_size 66 | self.grad_accumulation = grad_accumulation 67 | self.lm_epochs_per_iter = lm_epochs_per_iter 68 | self.temperature = temperature 69 | self.pl_ratio = pl_ratio 70 | self.eval_patience = eval_patience 71 | self.intermediate_data_filepath = intermediate_data_filepath 72 | self.ckpt_filepath = ckpt_filepath 73 | self.pretrain_ckpt_filepath = pretrain_ckpt_filepath 74 | self.raw_data_filepath = Path(raw_data_filepath) 75 | self.train_idx = train_idx 76 | self.valid_idx = valid_idx 77 | self.test_idx = test_idx 78 | self.hard_labels = hard_labels 79 | self.user_seq = user_seq 80 | self.run = run 81 | self.do_mlm_task = False 82 | 83 | self.iter = 0 84 | self.best_iter = 0 85 | self.best_valid_acc = 0 86 | self.best_epoch = 0 87 | self.criterion = CrossEntropyLoss(label_smoothing=label_smoothing_factor) 88 | self.KD_criterion = KLDivLoss(log_target=False, reduction='batchmean') 89 | self.results = {} 90 | 91 | 92 | self.get_train_idx_all() 93 | self.pretrain_steps_per_epoch = self.train_idx.shape[0] // self.batch_size + 1 94 | self.pretrain_steps = int(self.pretrain_steps_per_epoch * self.pretrain_epochs) 95 | self.train_steps_per_iter = (self.train_idx_all.shape[0] // self.batch_size + 1) * self.lm_epochs_per_iter 96 | self.optimizer_args = dict(lr=lr, weight_decay=weight_decay) 97 | 98 | self.model_config = { 99 | 'lm_model': model_name, 100 | 'dropout': dropout, 101 | 'att_dropout': att_dropout, 102 | 'lm_dropout': self.lm_dropout, 103 | 'classifier_n_layers': classifier_n_layers, 104 | 'classifier_hidden_dim': classifier_hidden_dim, 105 | 'activation': activation, 106 | 'device': device, 107 | 'return_mlm_loss': True if self.do_mlm_task else False 108 | } 109 | 110 | self.dataloader_config = { 111 | 'batch_size': batch_size, 112 | 'pl_ratio': pl_ratio 113 | } 114 | 115 | 116 | 117 | 118 | def build_model(self): 119 | self.model, self.tokenizer = build_LM_model(self.model_config) 120 | self.DESCRIPTION_id = self.tokenizer.convert_tokens_to_ids('DESCRIPTION:') 121 | self.TWEET_id = self.tokenizer.convert_tokens_to_ids('TWEET:') 122 | self.METADATA_id = self.tokenizer.convert_tokens_to_ids('METADATA:') 123 | 124 | def get_optimizer(self, parameters): 125 | 126 | if self.optimizer_name == "adam": 127 | optimizer = torch.optim.Adam(parameters, **self.optimizer_args) 128 | elif self.optimizer_name == "adamw": 129 | optimizer = torch.optim.AdamW(parameters, **self.optimizer_args) 130 | elif self.optimizer_name == "adadelta": 131 | optimizer = torch.optim.Adadelta(parameters, **self.optimizer_args) 132 | elif self.optimizer_name == "radam": 133 | optimizer = torch.optim.RAdam(parameters, **self.optimizer_args) 134 | else: 135 | return NotImplementedError 136 | 137 | return optimizer 138 | 139 | def get_scheduler(self, optimizer, mode='train'): 140 | if mode == 'pretrain': 141 | return get_cosine_schedule_with_warmup(optimizer, self.pretrain_steps_per_epoch * self.warmup, self.pretrain_steps) 142 | else: 143 | return CosineAnnealingLR(optimizer, T_max=self.train_steps_per_iter, eta_min=0) 144 | 145 | def get_initial_embeddings(self): 146 | 147 | if not os.path.exists(self.raw_data_filepath / f'embeddings_{self.model_name.lower()}.pt'): 148 | print('Generating initial GNN embeddings...') 149 | self.infer(True) 150 | 151 | embeddings = torch.load(self.raw_data_filepath / f'embeddings_{self.model_name.lower()}.pt') 152 | return embeddings 153 | 154 | def pretrain(self): 155 | print('LM pretraining start!') 156 | optimizer = self.get_optimizer(self.model.parameters()) 157 | scheduler = self.get_scheduler(optimizer, 'pretrain') 158 | if os.listdir(self.pretrain_ckpt_filepath) and os.path.exists(self.intermediate_data_filepath / 'embeddings_iter_-1.pt'): 159 | print('Pretrain checkpoint exists, loading from checkpoint...') 160 | print('Please make sure you use the same parameter setting as the one of the pretrain checkpoint!') 161 | ckpt = torch.load(self.pretrain_ckpt_filepath / os.listdir(self.pretrain_ckpt_filepath)[0]) 162 | self.model.load_state_dict(ckpt['model']) 163 | # self.optimizer.load_state_dict(ckpt['optimizer']) 164 | # self.scheduler.load_state_dict(ckpt['scheduler']) 165 | # embeddings = torch.load(self.intermediate_data_filepath / 'embeddings_iter_-1.pt') 166 | test_acc, test_f1 = self.eval('test') 167 | self.results['pretrain accuracy'] = test_acc 168 | self.results['pretrain f1'] = test_f1 169 | 170 | else: 171 | step = 0 172 | valid_acc_best = 0 173 | valid_step_best = 0 174 | 175 | torch.save({'model': self.model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()}, self.pretrain_ckpt_filepath / 'best.pkl') 176 | 177 | train_loader = build_LM_dataloader(self.dataloader_config, self.train_idx, self.user_seq, self.hard_labels, 'pretrain') 178 | 179 | for epoch in range(int(self.pretrain_epochs)+1): 180 | self.model.train() 181 | print(f'------LM Pretraining Epoch: {epoch}/{int(self.pretrain_epochs)}------') 182 | for batch in tqdm(train_loader): 183 | step += 1 184 | if step >= self.pretrain_steps: 185 | break 186 | tokenized_tensors, labels, _ = self.batch_to_tensor(batch) 187 | 188 | _, output = self.model(tokenized_tensors) 189 | loss = self.criterion(output, labels) 190 | loss /= self.grad_accumulation 191 | loss.backward() 192 | self.run.log({'LM Pretrain Loss': loss.item()}) 193 | 194 | if step % self.grad_accumulation == 0: 195 | optimizer.step() 196 | optimizer.zero_grad() 197 | scheduler.step() 198 | 199 | if step % self.eval_patience == 0: 200 | valid_acc, valid_f1 = self.eval() 201 | 202 | print(f'LM Pretrain Valid Accuracy = {valid_acc}') 203 | print(f'LM Pretrain Valid F1 = {valid_f1}') 204 | self.run.log({'LM Pretrain Valid Accuracy': valid_acc}) 205 | self.run.log({'LM Pretrain Valid F1': valid_f1}) 206 | 207 | if valid_acc > valid_acc_best: 208 | valid_acc_best = valid_acc 209 | valid_step_best = step 210 | 211 | torch.save({'model': self.model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()}, self.pretrain_ckpt_filepath / 'best.pkl') 212 | 213 | 214 | print(f'The highest pretrain valid accuracy is {valid_acc_best}!') 215 | print(f'Load model from step {valid_step_best}') 216 | self.model.eval() 217 | all_outputs = [] 218 | all_labels = [] 219 | embeddings = [] 220 | infer_loader = build_LM_dataloader(self.dataloader_config, None, self.user_seq, self.hard_labels, mode='infer') 221 | with torch.no_grad(): 222 | ckpt = torch.load(self.pretrain_ckpt_filepath / 'best.pkl') 223 | self.model.load_state_dict(ckpt['model']) 224 | optimizer.load_state_dict(ckpt['optimizer']) 225 | scheduler.load_state_dict(ckpt['scheduler']) 226 | for batch in tqdm(infer_loader): 227 | tokenized_tensors, labels, _ = self.batch_to_tensor(batch) 228 | embedding, output = self.model(tokenized_tensors) 229 | embeddings.append(embedding.cpu()) 230 | all_outputs.append(output.cpu()) 231 | all_labels.append(labels.cpu()) 232 | 233 | all_outputs = torch.cat(all_outputs, dim=0) 234 | all_labels = torch.cat(all_labels, dim=0) 235 | embeddings = torch.cat(embeddings, dim=0) 236 | soft_labels = torch.softmax(all_outputs / self.temperature, dim=1) 237 | soft_labels[self.train_idx] = all_labels[self.train_idx] 238 | 239 | test_predictions = torch.argmax(all_outputs[self.test_idx], dim=1).numpy() 240 | test_labels = torch.argmax(all_labels[self.test_idx], dim=1).numpy() 241 | torch.save(embeddings, self.intermediate_data_filepath / 'embeddings_iter_-1.pt') 242 | torch.save(soft_labels, self.intermediate_data_filepath / 'soft_labels_iter_-1.pt') 243 | 244 | test_acc = accuracy_score(test_predictions, test_labels) 245 | test_f1 = f1_score(test_predictions, test_labels) 246 | self.results['pretrain accuracy'] = test_acc 247 | self.results['pretrain f1'] = test_f1 248 | 249 | 250 | print(f'LM Pretrain Test Accuracy = {test_acc}') 251 | print(f'LM Pretrain Test F1 = {test_f1}') 252 | self.run.log({'LM Pretrain Test Accuracy': test_acc}) 253 | self.run.log({'LM Pretrain Test F1': test_f1}) 254 | 255 | 256 | 257 | 258 | def train(self, soft_labels): 259 | for param in self.model.classifier.parameters(): 260 | param.requires_grad = False 261 | parameters = filter(lambda p: p.requires_grad, self.model.parameters()) 262 | optimizer = self.get_optimizer(parameters) 263 | scheduler = self.get_scheduler(optimizer) 264 | 265 | early_stop_flag = True 266 | print('LM training start!') 267 | step = 0 268 | train_loader = build_LM_dataloader(self.dataloader_config, self.train_idx_all, self.user_seq, soft_labels, 'train', self.is_pl) 269 | 270 | 271 | for epoch in range(self.lm_epochs_per_iter): 272 | self.model.train() 273 | print(f'This is iter {self.iter} epoch {epoch}/{self.lm_epochs_per_iter-1}') 274 | 275 | for batch in tqdm(train_loader): 276 | step += 1 277 | 278 | tokenized_tensors, labels, is_pl = self.batch_to_tensor(batch) 279 | 280 | _, output = self.model(tokenized_tensors) 281 | 282 | pl_idx = torch.nonzero(is_pl == 1).squeeze() 283 | rl_idx = torch.nonzero(is_pl == 0).squeeze() 284 | 285 | if pl_idx.numel() == 0: 286 | loss = self.criterion(output[rl_idx], labels[rl_idx]) 287 | elif rl_idx.numel() == 0: 288 | loss = self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), labels[pl_idx]) 289 | else: 290 | loss_KD = self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), labels[pl_idx]) 291 | loss_H = self.criterion(output[rl_idx], labels[rl_idx]) 292 | self.run.log({'loss_KD': loss_KD.item()}) 293 | self.run.log({'loss_H': loss_H.item()}) 294 | loss = self.pl_weight * loss_KD + (1 - self.pl_weight) * loss_H 295 | 296 | loss /= self.grad_accumulation 297 | loss.backward() 298 | self.run.log({'LM Train Loss': loss.item()}) 299 | 300 | if step % self.grad_accumulation == 0: 301 | 302 | optimizer.step() 303 | optimizer.zero_grad() 304 | scheduler.step() 305 | if step % self.eval_patience == 0: 306 | valid_acc, valid_f1 = self.eval() 307 | 308 | print(f'LM Valid Accuracy = {valid_acc}') 309 | print(f'LM Valid F1 = {valid_f1}') 310 | self.run.log({'LM Valid Accuracy': valid_acc}) 311 | self.run.log({'LM Valid F1': valid_f1}) 312 | 313 | if valid_acc > self.best_valid_acc: 314 | early_stop_flag = False 315 | self.best_valid_acc = valid_acc 316 | self.best_iter = self.iter 317 | self.best_epoch = epoch 318 | torch.save({'model': self.model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()}, self.ckpt_filepath / 'best.pkl') 319 | 320 | print(f'The highest valid accuracy is {self.best_valid_acc}!') 321 | return early_stop_flag 322 | 323 | def infer(self, provide_embeddings_for_GNN_pretraining=False): 324 | self.model.eval() 325 | infer_loader = build_LM_dataloader(self.dataloader_config, None, self.user_seq, self.hard_labels, mode='infer') 326 | all_outputs = [] 327 | all_labels = [] 328 | embeddings = [] 329 | with torch.no_grad(): 330 | if provide_embeddings_for_GNN_pretraining: 331 | for batch in tqdm(infer_loader): 332 | tokenized_tensors, labels, _ = self.batch_to_tensor(batch) 333 | embedding, _ = self.model(tokenized_tensors) 334 | embeddings.append(embedding.cpu()) 335 | embeddings = torch.cat(embeddings, dim=0) 336 | torch.save(embeddings, self.raw_data_filepath / f'embeddings_{self.model_name.lower()}.pt') 337 | 338 | else: 339 | ckpt = torch.load(self.ckpt_filepath / 'best.pkl') 340 | self.model.load_state_dict(ckpt['model']) 341 | # self.optimizer.load_state_dict(ckpt['optimizer']) 342 | # self.scheduler.load_state_dict(ckpt['scheduler']) 343 | for batch in tqdm(infer_loader): 344 | tokenized_tensors, labels, _ = self.batch_to_tensor(batch) 345 | 346 | embedding, output = self.model(tokenized_tensors) 347 | embeddings.append(embedding.cpu()) 348 | all_outputs.append(output.cpu()) 349 | all_labels.append(labels.cpu()) 350 | 351 | all_outputs = torch.cat(all_outputs, dim=0) 352 | all_labels = torch.cat(all_labels, dim=0) 353 | embeddings = torch.cat(embeddings, dim=0) 354 | 355 | soft_labels = torch.softmax(all_outputs / self.temperature, dim=1) 356 | soft_labels[self.train_idx] = all_labels[self.train_idx] 357 | 358 | torch.save(soft_labels, self.intermediate_data_filepath / f'soft_labels_iter_{self.iter}.pt') 359 | 360 | torch.save(embeddings, self.intermediate_data_filepath / f'embeddings_iter_{self.iter}.pt') 361 | 362 | self.iter += 1 363 | 364 | def eval(self, mode='valid'): 365 | if mode == 'valid': 366 | eval_loader = build_LM_dataloader(self.dataloader_config, self.valid_idx, self.user_seq, self.hard_labels, mode='eval') 367 | elif mode == 'test': 368 | eval_loader = build_LM_dataloader(self.dataloader_config, self.test_idx, self.user_seq, self.hard_labels, mode='eval') 369 | self.model.eval() 370 | 371 | valid_predictions = [] 372 | valid_labels = [] 373 | 374 | with torch.no_grad(): 375 | for batch in tqdm(eval_loader): 376 | tokenized_tensors, labels, _ = self.batch_to_tensor(batch) 377 | 378 | _, output = self.model(tokenized_tensors) 379 | 380 | valid_predictions.append(torch.argmax(output, dim=1).cpu().numpy()) 381 | valid_labels.append(torch.argmax(labels, dim=1).cpu().numpy()) 382 | 383 | valid_predictions = np.concatenate(valid_predictions) 384 | valid_labels = np.concatenate(valid_labels) 385 | valid_acc = accuracy_score(valid_labels, valid_predictions) 386 | valid_f1 = f1_score(valid_labels, valid_predictions) 387 | 388 | return valid_acc, valid_f1 389 | 390 | 391 | def test(self): 392 | print('Computing test accuracy and f1 for LM...') 393 | ckpt = torch.load(self.ckpt_filepath / 'best.pkl') 394 | self.model.load_state_dict(ckpt['model']) 395 | test_acc, test_f1 = self.eval('test') 396 | print(f'LM Test Accuracy = {test_acc}') 397 | print(f'LM Test F1 = {test_f1}') 398 | self.run.log({'LM Test Accuracy': test_acc}) 399 | self.run.log({'LM Test F1': test_f1}) 400 | self.results['accuracy'] = test_acc 401 | self.results['f1'] = test_f1 402 | 403 | def batch_to_tensor(self, batch): 404 | 405 | tokenized_tensors = self.tokenizer(text=batch[0], return_tensors='pt', max_length=self.max_length, truncation=True, padding='longest', add_special_tokens=False) 406 | for key in tokenized_tensors.keys(): 407 | tokenized_tensors[key] = tokenized_tensors[key].to(self.device) 408 | labels = batch[1].to(self.device) 409 | 410 | if len(batch) == 3: 411 | is_pl = batch[2].to(self.device) 412 | return tokenized_tensors, labels, is_pl 413 | else: 414 | return tokenized_tensors, labels, None 415 | 416 | def load_embedding(self, iter): 417 | embeddings = torch.load(self.intermediate_data_filepath / f'embeddings_iter_{iter}.pt') 418 | return embeddings 419 | 420 | def save_results(self, path): 421 | json.dump(self.results, open(path, 'w'), indent=4) 422 | 423 | def get_train_idx_all(self): 424 | n_total = self.hard_labels.shape[0] 425 | all = set(np.arange(n_total)) 426 | exclude = set(self.train_idx.numpy()) 427 | n = self.train_idx.shape[0] 428 | pl_ratio_LM = min(self.pl_ratio, (n_total - n) / n) 429 | n_pl_LM = int(n * pl_ratio_LM) 430 | pl_idx = torch.LongTensor(np.random.choice(np.array(list(all - exclude)), n_pl_LM, replace=False)) 431 | self.train_idx_all = torch.cat((self.train_idx, pl_idx)) 432 | self.is_pl = torch.ones_like(self.train_idx_all, dtype=torch.int64) 433 | self.is_pl[0: n] = 0 434 | 435 | 436 | class GNN_Trainer: 437 | def __init__( 438 | self, 439 | model_name, 440 | device, 441 | optimizer_name, 442 | lr, 443 | weight_decay, 444 | dropout, 445 | pl_weight, 446 | batch_size, 447 | gnn_n_layers, 448 | n_relations, 449 | activation, 450 | gnn_epochs_per_iter, 451 | temperature, 452 | pl_ratio, 453 | intermediate_data_filepath, 454 | ckpt_filepath, 455 | pretrain_ckpt_filepath, 456 | train_idx, 457 | valid_idx, 458 | test_idx, 459 | hard_labels, 460 | edge_index, 461 | edge_type, 462 | run, 463 | SimpleHGN_att_res, 464 | att_heads, 465 | RGT_semantic_heads, 466 | gnn_hidden_dim, 467 | lm_name 468 | ): 469 | 470 | self.model_name = model_name 471 | self.device = device 472 | self.optimizer_name = optimizer_name 473 | self.lr = lr 474 | self.weight_decay = weight_decay 475 | self.pl_weight = pl_weight 476 | self.dropout = dropout 477 | self.batch_size = batch_size 478 | self.gnn_n_layers = gnn_n_layers 479 | self.n_relations = n_relations 480 | self.activation = activation 481 | self.gnn_epochs_per_iter = gnn_epochs_per_iter 482 | self.temperature = temperature 483 | self.pl_ratio = pl_ratio 484 | self.intermediate_data_filepath = intermediate_data_filepath 485 | self.ckpt_filepath = ckpt_filepath 486 | self.pretrain_ckpt_filepath = pretrain_ckpt_filepath 487 | self.train_idx = train_idx 488 | self.valid_idx = valid_idx 489 | self.test_idx = test_idx 490 | self.hard_labels = hard_labels 491 | self.edge_index = edge_index 492 | self.edge_type = edge_type 493 | self.run = run 494 | self.SimpleHGN_att_res = SimpleHGN_att_res 495 | self.att_heads = att_heads 496 | self.RGT_semantic_heads = RGT_semantic_heads 497 | self.gnn_hidden_dim = gnn_hidden_dim 498 | self.lm_input_dim = 1024 if lm_name.lower() in ['roberta-large'] else 768 499 | self.iter = 0 500 | self.best_iter = 0 501 | self.best_valid_acc = 0 502 | self.best_valid_epoch = 0 503 | self.criterion = CrossEntropyLoss() 504 | self.KD_criterion = KLDivLoss(log_target=False, reduction='batchmean') 505 | 506 | 507 | self.results = {} 508 | self.get_train_idx_all() 509 | self.optimizer_args = dict(lr=lr, weight_decay=weight_decay) 510 | 511 | self.model_config = { 512 | 'GNN_model': model_name, 513 | 'optimizer': optimizer_name, 514 | 'gnn_n_layers': gnn_n_layers, 515 | 'n_relations': n_relations, 516 | 'activation': activation, 517 | 'dropout': dropout, 518 | 'gnn_hidden_dim': gnn_hidden_dim, 519 | 'lm_input_dim': self.lm_input_dim, 520 | 'SimpleHGN_att_res': SimpleHGN_att_res, 521 | 'att_heads': att_heads, 522 | 'RGT_semantic_heads': RGT_semantic_heads, 523 | 'device': device 524 | } 525 | 526 | self.dataloader_config = { 527 | 'batch_size': batch_size, 528 | 'n_layers': gnn_n_layers 529 | } 530 | 531 | 532 | 533 | def build_model(self): 534 | self.model = build_GNN_model(self.model_config) 535 | 536 | def get_scheduler(self, optimizer): 537 | return CosineAnnealingLR(optimizer, T_max=self.gnn_epochs_per_iter, eta_min=0) 538 | 539 | 540 | def get_optimizer(self): 541 | 542 | if self.optimizer_name == "adam": 543 | optimizer = torch.optim.Adam(self.model.parameters(), **self.optimizer_args) 544 | elif self.optimizer_name == "adamw": 545 | optimizer = torch.optim.AdamW(self.model.parameters(), **self.optimizer_args) 546 | elif self.optimizer_name == "adadelta": 547 | optimizer = torch.optim.Adadelta(self.model.parameters(), **self.optimizer_args) 548 | elif self.optimizer_name == "radam": 549 | optimizer = torch.optim.RAdam(self.model.parameters(), **self.optimizer_args) 550 | else: 551 | return NotImplementedError 552 | 553 | return optimizer 554 | 555 | def train(self, embeddings_LM, soft_labels): 556 | early_stop_flag = True 557 | 558 | optimizer = self.get_optimizer() 559 | scheduler = self.get_scheduler(optimizer) 560 | print('GNN training start!') 561 | print(f'This is iter {self.iter}') 562 | 563 | train_loader = build_GNN_dataloader(self.dataloader_config, self.train_idx_all, embeddings_LM, soft_labels, self.edge_index, self.edge_type, mode='train', is_pl=self.is_pl) 564 | 565 | for epoch in tqdm(range(self.gnn_epochs_per_iter)): 566 | self.model.train() 567 | for batch in train_loader: 568 | optimizer.zero_grad() 569 | batch_size = batch.batch_size 570 | x_batch = batch.x.to(self.device) 571 | 572 | edge_index_batch = batch.edge_index.to(self.device) 573 | edge_type_batch = batch.edge_type.to(self.device) 574 | is_pl = batch.is_pl[0: batch_size].to(self.device) 575 | labels = batch.labels[0: batch_size].to(self.device) 576 | 577 | output = self.model(x_batch, edge_index_batch, edge_type_batch) 578 | output = output[0: batch_size] 579 | 580 | pl_idx = torch.nonzero(is_pl == 1).squeeze() 581 | rl_idx = torch.nonzero(is_pl == 0).squeeze() 582 | 583 | 584 | if pl_idx.numel() == 0: 585 | loss = self.criterion(output[rl_idx], labels[rl_idx]) 586 | elif rl_idx.numel() == 0: 587 | loss = self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), labels[pl_idx]) 588 | else: 589 | # loss = self.pl_weight * self.criterion(output[pl_idx], labels[pl_idx]) + (1 - self.pl_weight) * self.criterion(output[rl_idx], labels[rl_idx]) 590 | loss = self.pl_weight * self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), labels[pl_idx]) + (1 - self.pl_weight) * self.criterion(output[rl_idx], labels[rl_idx]) 591 | 592 | loss.backward() 593 | optimizer.step() 594 | scheduler.step() 595 | self.run.log({'GNN Train Loss': loss.item()}) 596 | 597 | 598 | valid_acc, valid_f1 = self.eval(embeddings_LM) 599 | 600 | self.run.log({'GNN Valid Accuracy': valid_acc}) 601 | self.run.log({'GNN Valid F1': valid_f1}) 602 | 603 | if valid_acc > self.best_valid_acc: 604 | early_stop_flag = False 605 | self.best_valid_acc = valid_acc 606 | self.best_epoch = epoch 607 | self.best_iter = self.iter 608 | torch.save({'model': self.model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()}, self.ckpt_filepath / 'best.pkl') 609 | print(f'The highest valid accuracy is {self.best_valid_acc}!') 610 | return early_stop_flag 611 | 612 | def infer(self, embeddings_LM): 613 | self.model.eval() 614 | infer_loader = build_GNN_dataloader(self.dataloader_config, None, embeddings_LM, self.hard_labels, self.edge_index, self.edge_type, mode='infer') 615 | 616 | all_outputs = [] 617 | all_labels = [] 618 | with torch.no_grad(): 619 | ckpt = torch.load(self.ckpt_filepath / 'best.pkl') 620 | self.model.load_state_dict(ckpt['model']) 621 | # self.optimizer.load_state_dict(ckpt['optimizer']) 622 | # self.scheduler.load_state_dict(ckpt['scheduler']) 623 | for batch in infer_loader: 624 | batch_size = batch.batch_size 625 | x_batch = batch.x.to(self.device) 626 | 627 | edge_index_batch = batch.edge_index.to(self.device) 628 | edge_type_batch = batch.edge_type.to(self.device) 629 | labels = batch.labels[0: batch_size].to(self.device) 630 | 631 | output = self.model(x_batch, edge_index_batch, edge_type_batch) 632 | output = output[0: batch_size] 633 | 634 | all_outputs.append(output.cpu()) 635 | all_labels.append(labels.cpu()) 636 | 637 | all_outputs = torch.cat(all_outputs, dim=0) 638 | all_labels = torch.cat(all_labels, dim=0) 639 | soft_labels = torch.softmax(all_outputs / self.temperature, dim=1) 640 | soft_labels[self.train_idx] = all_labels[self.train_idx] 641 | 642 | torch.save(soft_labels, self.intermediate_data_filepath / f'soft_labels_iter_{self.iter}.pt') 643 | 644 | self.iter += 1 645 | 646 | 647 | def eval(self, embeddings_LM, mode='valid'): 648 | if mode == 'valid': 649 | eval_loader = build_GNN_dataloader(self.dataloader_config, self.valid_idx, embeddings_LM, self.hard_labels, self.edge_index, self.edge_type, mode='eval') 650 | elif mode == 'test': 651 | eval_loader = build_GNN_dataloader(self.dataloader_config, self.test_idx, embeddings_LM, self.hard_labels, self.edge_index, self.edge_type, mode='eval') 652 | self.model.eval() 653 | 654 | valid_predictions = [] 655 | valid_labels = [] 656 | 657 | with torch.no_grad(): 658 | for batch in eval_loader: 659 | batch_size = batch.batch_size 660 | x_batch = batch.x.to(self.device) 661 | edge_index_batch = batch.edge_index.to(self.device) 662 | edge_type_batch = batch.edge_type.to(self.device) 663 | labels = batch.labels[0: batch_size].to(self.device) 664 | 665 | output = self.model(x_batch, edge_index_batch, edge_type_batch) 666 | output = output[0: batch_size] 667 | 668 | valid_predictions.append(torch.argmax(output, dim=1).cpu().numpy()) 669 | valid_labels.append(torch.argmax(labels, dim=1).cpu().numpy()) 670 | 671 | valid_predictions = np.concatenate(valid_predictions) 672 | valid_labels = np.concatenate(valid_labels) 673 | valid_acc = accuracy_score(valid_labels, valid_predictions) 674 | valid_f1 = f1_score(valid_labels, valid_predictions) 675 | 676 | return valid_acc, valid_f1 677 | 678 | 679 | 680 | def test(self, embeddings_LM): 681 | print('Computing test accuracy and f1 for GNN...') 682 | ckpt = torch.load(self.ckpt_filepath / 'best.pkl') 683 | self.model.load_state_dict(ckpt['model']) 684 | test_acc, test_f1 = self.eval(embeddings_LM, 'test') 685 | print(f'GNN Test Accuracy = {test_acc}') 686 | print(f'GNN Test F1 = {test_f1}') 687 | self.run.log({'GNN Test Accuracy': test_acc}) 688 | self.run.log({'GNN Test F1': test_f1}) 689 | self.results['accuracy'] = test_acc 690 | self.results['f1'] = test_f1 691 | 692 | def load_soft_labels(self, iter): 693 | soft_labels = torch.load(self.intermediate_data_filepath / f'soft_labels_iter_{iter}.pt') 694 | return soft_labels 695 | 696 | def save_results(self, path): 697 | json.dump(self.results, open(path, 'w'), indent=4) 698 | 699 | def get_train_idx_all(self): 700 | n_total = self.hard_labels.shape[0] 701 | all = set(np.arange(n_total)) 702 | exclude = set(self.train_idx.numpy()) 703 | n = self.train_idx.shape[0] 704 | pl_ratio_GNN = min(self.pl_ratio, (n_total - n) / n) 705 | n_pl_GNN = int(n * pl_ratio_GNN) 706 | self.pl_idx = torch.LongTensor(np.random.choice(np.array(list(all - exclude)), n_pl_GNN, replace=False)) 707 | self.train_idx_all = torch.cat((self.train_idx, self.pl_idx)) 708 | self.is_pl = torch.ones_like(self.train_idx_all, dtype=torch.int64) 709 | self.is_pl[0: n] = 0 710 | 711 | 712 | 713 | 714 | class MLP_Trainer: 715 | def __init__( 716 | self, 717 | device, 718 | optimizer_name, 719 | lr, 720 | weight_decay, 721 | dropout, 722 | pl_weight, 723 | batch_size, 724 | n_layers, 725 | hidden_dim, 726 | activation, 727 | glnn_epochs, 728 | mlp_epochs_per_iter, 729 | temperature, 730 | pl_ratio, 731 | intermediate_data_filepath, 732 | ckpt_filepath, 733 | KD_ckpt_filepath, 734 | train_idx, 735 | valid_idx, 736 | test_idx, 737 | hard_labels, 738 | run, 739 | seed, 740 | use_gnn): 741 | 742 | self.device = device 743 | self.optimizer_name = optimizer_name 744 | self.lr = lr 745 | self.weight_decay = weight_decay 746 | self.pl_weight = pl_weight 747 | self.dropout = dropout 748 | self.batch_size = batch_size 749 | self.n_layers = n_layers 750 | self.hidden_dim = hidden_dim 751 | self.activation = activation 752 | self.mlp_epochs_per_iter = mlp_epochs_per_iter 753 | self.glnn_epochs = glnn_epochs 754 | self.temperature = temperature 755 | self.pl_ratio = pl_ratio 756 | self.intermediate_data_filepath = intermediate_data_filepath 757 | self.ckpt_filepath = ckpt_filepath 758 | self.KD_ckpt_filepath = KD_ckpt_filepath 759 | self.train_idx = train_idx 760 | self.valid_idx = valid_idx 761 | self.test_idx = test_idx 762 | self.hard_labels = hard_labels 763 | self.run = run 764 | self.seed = seed 765 | self.use_gnn = use_gnn 766 | self.iter = 0 767 | self.best_iter = 0 768 | self.best_valid_acc = 0 769 | self.best_valid_epoch = 0 770 | self.criterion = CrossEntropyLoss() 771 | self.KD_criterion = KLDivLoss(log_target=False, reduction='batchmean') 772 | 773 | 774 | self.get_train_idx_all() 775 | self.results = {} 776 | 777 | self.dataloader_config = { 778 | 'batch_size': batch_size 779 | } 780 | 781 | self.optimizer_args = dict(lr=lr, weight_decay=weight_decay) 782 | 783 | def get_scheduler(self, optimizer, T_max): 784 | return CosineAnnealingLR(optimizer, T_max=T_max, eta_min=0) 785 | 786 | 787 | def get_optimizer(self): 788 | if self.optimizer_name == "adam": 789 | optimizer = torch.optim.Adam(self.model.parameters(), **self.optimizer_args) 790 | elif self.optimizer_name == "adamw": 791 | optimizer = torch.optim.AdamW(self.model.parameters(), **self.optimizer_args) 792 | elif self.optimizer_name == "adadelta": 793 | optimizer = torch.optim.Adadelta(self.model.parameters(), **self.optimizer_args) 794 | elif self.optimizer_name == "radam": 795 | optimizer = torch.optim.RAdam(self.model.parameters(), **self.optimizer_args) 796 | else: 797 | return NotImplementedError 798 | 799 | return optimizer 800 | 801 | def build_model(self): 802 | if self.use_gnn: 803 | self.model = MLP(in_channels=768, hidden_channels=self.hidden_dim, out_channels=2, dropout=self.dropout, act=self.activation, num_layers=self.n_layers).to(self.device) 804 | else: 805 | ckpt = torch.load(self.KD_ckpt_filepath / f'seed_{self.seed}_best.pkl') 806 | self.model = MLP(**ckpt['model_params']).to(self.device) 807 | # self.model.load_state_dict(ckpt['model']) 808 | 809 | 810 | def KD_GLNN(self, LM_embeddings, soft_labels): 811 | print('Distilling from GNN to GLNN') 812 | train_loader = build_MLP_dataloader(self.dataloader_config, self.train_idx_all, LM_embeddings, soft_labels, mode='train', is_pl=self.is_pl) 813 | 814 | optimizer = self.get_optimizer() 815 | scheduler = self.get_scheduler(optimizer, self.glnn_epochs) 816 | 817 | for epoch in tqdm(range(self.glnn_epochs)): 818 | self.model.train() 819 | for batch in train_loader: 820 | optimizer.zero_grad() 821 | LM_embedding, label, is_pl = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device) 822 | output = self.model(LM_embedding) 823 | 824 | pl_idx = torch.nonzero(is_pl == 1).squeeze() 825 | rl_idx = torch.nonzero(is_pl == 0).squeeze() 826 | 827 | if pl_idx.numel() == 0: 828 | loss = self.criterion(output[rl_idx], label[rl_idx]) 829 | elif rl_idx.numel() == 0: 830 | loss = self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), label[pl_idx]) 831 | else: 832 | # loss = self.pl_weight * self.criterion(output[pl_idx], labels[pl_idx]) + (1 - self.pl_weight) * self.criterion(output[rl_idx], labels[rl_idx]) 833 | loss = self.pl_weight * self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), label[pl_idx]) + (1 - self.pl_weight) * self.criterion(output[rl_idx], label[rl_idx]) 834 | 835 | loss.backward() 836 | optimizer.step() 837 | scheduler.step() 838 | self.run.log({'GLNN KD Train Loss': loss.item()}) 839 | 840 | 841 | valid_acc, valid_f1 = self.eval(LM_embeddings) 842 | 843 | self.run.log({'GLNN KD Valid Accuracy': valid_acc}) 844 | self.run.log({'GLNN KD Valid F1': valid_f1}) 845 | 846 | if valid_acc > self.best_valid_acc: 847 | self.best_valid_acc = valid_acc 848 | self.best_epoch = epoch 849 | torch.save({'model': self.model.state_dict(), 'model_params': {'num_layers': self.n_layers, 'hidden_channels': self.hidden_dim, 'dropout': self.dropout, 'act': self.activation, 'in_channels': 768, 'out_channels': 2}}, self.KD_ckpt_filepath / f'seed_{self.seed}_best.pkl') 850 | 851 | print(f'The highest valid accuracy is {self.best_valid_acc}!') 852 | print(f'Save model from epoch {self.best_epoch}') 853 | 854 | def train(self, LM_embeddings, soft_labels): 855 | print('MLP training start!') 856 | print(f'This is iter {self.iter}') 857 | train_loader = build_MLP_dataloader(self.dataloader_config, self.train_idx_all, LM_embeddings, soft_labels, mode='train', is_pl=self.is_pl) 858 | early_stop_flag = True 859 | optimizer = self.get_optimizer() 860 | scheduler = self.get_scheduler(optimizer, self.glnn_epochs) 861 | for epoch in tqdm(range(self.mlp_epochs_per_iter)): 862 | self.model.train() 863 | for batch in train_loader: 864 | optimizer.zero_grad() 865 | LM_embedding, label, is_pl = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device) 866 | output = self.model(LM_embedding) 867 | 868 | pl_idx = torch.nonzero(is_pl == 1).squeeze() 869 | rl_idx = torch.nonzero(is_pl == 0).squeeze() 870 | 871 | if pl_idx.numel() == 0: 872 | loss = self.criterion(output[rl_idx], label[rl_idx]) 873 | elif rl_idx.numel() == 0: 874 | loss = self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), label[pl_idx]) 875 | else: 876 | loss = self.pl_weight * self.KD_criterion(F.log_softmax(output[pl_idx] / self.temperature, dim=-1), label[pl_idx]) + (1 - self.pl_weight) * self.criterion(output[rl_idx], label[rl_idx]) 877 | 878 | loss.backward() 879 | optimizer.step() 880 | scheduler.step() 881 | self.run.log({'MLP Train Loss': loss.item()}) 882 | 883 | 884 | 885 | valid_acc, valid_f1 = self.eval(LM_embeddings) 886 | 887 | self.run.log({'MLP Valid Accuracy': valid_acc}) 888 | self.run.log({'MLP Valid F1': valid_f1}) 889 | 890 | if valid_acc > self.best_valid_acc: 891 | early_stop_flag = False 892 | self.best_valid_acc = valid_acc 893 | self.best_epoch = epoch 894 | self.best_iter = self.iter 895 | torch.save({'model': self.model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()}, self.ckpt_filepath / 'best.pkl') 896 | print(f'The highest valid accuracy is {self.best_valid_acc}!') 897 | return early_stop_flag 898 | 899 | def infer(self, LM_embeddings): 900 | self.model.eval() 901 | infer_loader = build_MLP_dataloader(self.dataloader_config, None, LM_embeddings, self.hard_labels, mode='infer') 902 | all_outputs = [] 903 | all_labels = [] 904 | with torch.no_grad(): 905 | ckpt = torch.load(self.ckpt_filepath / 'best.pkl') 906 | self.model.load_state_dict(ckpt['model']) 907 | for batch in infer_loader: 908 | LM_embedding, label = batch[0].to(self.device), batch[1].to(self.device) 909 | output = self.model(LM_embedding) 910 | all_outputs.append(output.cpu()) 911 | all_labels.append(label.cpu()) 912 | 913 | all_outputs = torch.cat(all_outputs, dim=0) 914 | all_labels = torch.cat(all_labels, dim=0) 915 | 916 | soft_labels = torch.softmax(all_outputs / self.temperature, dim=1) 917 | soft_labels[self.train_idx] = all_labels[self.train_idx] 918 | 919 | torch.save(soft_labels, self.intermediate_data_filepath / f'soft_labels_iter_{self.iter}.pt') 920 | 921 | self.iter += 1 922 | 923 | def eval(self, LM_embeddings, mode='valid'): 924 | if mode == 'valid': 925 | eval_loader = build_MLP_dataloader(self.dataloader_config, self.valid_idx, LM_embeddings, self.hard_labels, mode='eval') 926 | elif mode == 'test': 927 | eval_loader = build_MLP_dataloader(self.dataloader_config, self.test_idx, LM_embeddings, self.hard_labels, mode='eval') 928 | self.model.eval() 929 | 930 | valid_predictions = [] 931 | valid_labels = [] 932 | 933 | with torch.no_grad(): 934 | for batch in eval_loader: 935 | LM_embedding, label = batch[0].to(self.device), batch[1].to(self.device) 936 | output = self.model(LM_embedding) 937 | 938 | valid_predictions.append(torch.argmax(output, dim=1).cpu().numpy()) 939 | valid_labels.append(torch.argmax(label, dim=1).cpu().numpy()) 940 | 941 | valid_predictions = np.concatenate(valid_predictions) 942 | valid_labels = np.concatenate(valid_labels) 943 | valid_acc = accuracy_score(valid_labels, valid_predictions) 944 | valid_f1 = f1_score(valid_labels, valid_predictions) 945 | 946 | return valid_acc, valid_f1 947 | 948 | def test(self, LM_embeddings): 949 | print('Computing test accuracy and f1 for MLP...') 950 | ckpt = torch.load(self.ckpt_filepath / 'best.pkl') 951 | self.model.load_state_dict(ckpt['model']) 952 | test_acc, test_f1 = self.eval(LM_embeddings, 'test') 953 | print(f'MLP Test Accuracy = {test_acc}') 954 | print(f'MLP Test F1 = {test_f1}') 955 | self.run.log({'MLP Test Accuracy': test_acc}) 956 | self.run.log({'MLP Test F1': test_f1}) 957 | self.results['accuracy'] = test_acc 958 | self.results['f1'] = test_f1 959 | 960 | def save_results(self, path): 961 | json.dump(self.results, open(path, 'w'), indent=4) 962 | 963 | def get_train_idx_all(self): 964 | n_total = self.hard_labels.shape[0] 965 | all = set(np.arange(n_total)) 966 | exclude = set(self.train_idx.numpy()) 967 | n = self.train_idx.shape[0] 968 | pl_ratio_LM = min(self.pl_ratio, (n_total - n) / n) 969 | n_pl_LM = int(n * pl_ratio_LM) 970 | pl_idx = torch.LongTensor(np.random.choice(np.array(list(all - exclude)), n_pl_LM, replace=False)) 971 | self.train_idx_all = torch.cat((self.train_idx, pl_idx)) 972 | self.is_pl = torch.ones_like(self.train_idx_all, dtype=torch.int64) 973 | self.is_pl[0: n] = 0 974 | 975 | --------------------------------------------------------------------------------