├── data ├── __init__.py ├── filter.py ├── graph_dataset.py └── preprocess_data.py ├── modeling ├── __init__.py ├── tokenizer.py ├── duplex.py ├── model_mlp.py ├── model.py ├── model_magnet.py ├── gatconv.py └── magnet.py ├── utils ├── __init__.py ├── loss.py ├── common_utils.py └── train_utils.py ├── requirements.txt ├── preprocess.py ├── LEGAL.md ├── accelerate_ds_config.yaml ├── data_sample ├── instruction_downstream.jsonl ├── instruction_graph.jsonl └── pretrain.jsonl ├── configs ├── config_pretrain.json └── config_instruction.json ├── README.md ├── arguments.py ├── run.py └── LICENSE.md /data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | from .tokenizer import build_tokenizer -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common_utils import * 2 | from .loss import loss_func 3 | from .train_utils import accelerate_train -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.28.0 2 | deepspeed==0.9.3 3 | dgl==1.1.3+cu117 4 | flash-attn==2.3.6 5 | networkx==3.1 6 | numpy==1.23.5 7 | pandas==1.5.3 8 | peft==0.7.0 9 | tensorboard==2.11.0 10 | torch==2.0.1 11 | transformers==4.37.0 -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import set_seed 3 | from arguments import prepare_args 4 | from data.filter import filter 5 | 6 | # get args 7 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 8 | args = prepare_args() 9 | 10 | 11 | def main(): 12 | 13 | set_seed(args.seed) 14 | 15 | filter(args) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /modeling/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | def build_tokenizer(args): 4 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, trust_remote_code=True) 5 | tokenizer.add_special_tokens({"additional_special_tokens": [args.graph_pad_token]}) 6 | # there is extra embedddings in deepseek coder. no need to resize the model 7 | args.graph_pad_id = tokenizer.convert_tokens_to_ids(args.graph_pad_token) 8 | tokenizer.pad_token = tokenizer.eos_token 9 | tokenizer.pad_token_id = tokenizer.eos_token_id 10 | return tokenizer 11 | -------------------------------------------------------------------------------- /accelerate_ds_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | debug: false 5 | deepspeed_config: 6 | gradient_accumulation_steps: 1 7 | gradient_clipping: 1.0 8 | offload_optimizer_device: none 9 | offload_param_device: none 10 | zero3_init_flag: false 11 | zero_stage: 2 12 | distributed_type: DEEPSPEED 13 | downcast_bf16: 'no' 14 | machine_rank: 0 15 | main_training_function: main 16 | mixed_precision: 'bf16' 17 | num_machines: 1 18 | num_processes: 2 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false -------------------------------------------------------------------------------- /data_sample/instruction_downstream.jsonl: -------------------------------------------------------------------------------- 1 | {"human":"# Fix the bug in the following Java program\n\n# Buggy\npublic static TYPE_1 init ( java.lang.String name , java.util.Date date ) {\n TYPE_1 VAR_1 = new TYPE_1 ( );\n VAR_1.METHOD_1( name );\n java.util.Calendar VAR_2 = java.util.Calendar.getInstance ( );\n VAR_2.METHOD_2( date );\n VAR_1.METHOD_3( VAR_2 );\n return VAR_1;\n}\n\n# Fixed\n","bot":"public static TYPE_1 init ( java.lang.String name , java.util.Date date ) {\n TYPE_1 VAR_1 = new TYPE_1 ( );\n VAR_1.METHOD_1( name );\n java.util.Calendar VAR_2 = null;\n if ( date != null ) {\n VAR_2 = java.util.Calendar.getInstance ( );\n VAR_2.METHOD_2( date );\n }\n VAR_1.METHOD_3( VAR_2 );\n return VAR_1;\n}"} 2 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import CrossEntropyLoss 3 | 4 | 5 | def loss_func(outputs, labels, loss_mask): 6 | 7 | lm_logits = outputs["logits"].contiguous() 8 | labels = labels.to(device=lm_logits.device).contiguous() 9 | loss_mask = loss_mask.to(device=lm_logits.device) 10 | # logits: (bs, l, v); labels, loss_mask: (bs, l) 11 | 12 | # lm loss 13 | bsz = labels.shape[0] 14 | loss_func = CrossEntropyLoss(reduction='none') 15 | losses = loss_func(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) # logits: (bs * l, v); labels: (bs * l,) 16 | # losses -> (bs, l) 17 | losses = losses.contiguous().view(bsz, -1) 18 | 19 | loss_mask = loss_mask.view(-1) 20 | losses = losses.view(-1) 21 | loss_lm = torch.sum(losses * loss_mask) / loss_mask.sum() 22 | 23 | return loss_lm 24 | -------------------------------------------------------------------------------- /configs/config_pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "path_to_pretrain_data", 3 | "node_type_embedding": "node_type_embedding.pth", 4 | "output_dir": "output/qwen2.5-coder-1.5b", 5 | "tb_dir": "output/tb/qwen2.5-coder-1.5b", 6 | "pretrained_model_path": "qwen2.5-coder-1.5b", 7 | "model_type": "qwen2.5", 8 | "mode": "pt", 9 | "graph_token_num": 256, 10 | "graph_hidden_dim": 1024, 11 | "data_split": "99.5,0.5", 12 | "per_device_train_batch_size": 1, 13 | "per_device_eval_batch_size": 1, 14 | "learning_rate": 1e-5, 15 | "min_lr": 1e-6, 16 | "weight_decay": 0.1, 17 | "gradient_accumulation_steps": 1, 18 | "lr_scheduler_type": "cosine", 19 | "num_warmup_steps": 1000, 20 | "num_train_epochs": 50, 21 | "seed": 42, 22 | "seq_length": 4096, 23 | "log_interval": 10, 24 | "checkpointing_steps": 1000, 25 | "evaluation_steps": 1000, 26 | "epoch_checkpointing": true, 27 | "early_stopping": false, 28 | "early_stopping_stall_num": 15 29 | } 30 | -------------------------------------------------------------------------------- /data_sample/instruction_graph.jsonl: -------------------------------------------------------------------------------- 1 | {"node_ids":[7,15,13,17,15,13,17,6,13,12,13,30,41,13,20,41,13,20,17,11,1,41,13,17,2,13,17,1,40,13,30,30,0,18,13,13,4,18,13,4,18,13,13,11,1,41,13,17,2,13,17,1,40,13,30,30,4,18,18,13,13,18,13,13,23,13,10,6,13],"edge_index":[[14,13],[17,16],[23,22],[22,25],[29,25],[29,29],[22,29],[36,33],[16,34],[22,35],[29,35],[13,38],[16,42],[47,46],[46,49],[53,49],[53,53],[46,53],[33,61],[16,62],[46,63],[53,63],[65,65]],"question":"```\nimport java.util.Scanner;\nimport java.util.Arrays;\n\npublic class Main{\n\tpublic static void main(String[] args){\n\t\tScanner scan = new Scanner(System.in);\n\t\tint[] heights = new int[10];\n\t\tfor(int i = 0; i < 10; i++){\n\t\t\theights[i] = scan.nextInt();\n\t\t}\n\n\t\tArrays.sort(heights);\n\t\tfor(int i = 9; i >= 7; i--){\n\t\t\tSystem.out.println(heights[i]);\n\t\t}\n\t}\n}\n```\n\nIs there an edge from identifier expression `args` to identifier expression `args` in this Java data flow graph?","bot":"Yes, that is the case. `args` is directly connected to `args`."} 2 | -------------------------------------------------------------------------------- /configs/config_instruction.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "path_to_instruction_data", 3 | "node_type_embedding": "node_type_embedding.pth", 4 | "output_dir": "output/instruction/qwen2.5-coder-1.5b", 5 | "tb_dir": "output/instruction/tb/qwen2.5-coder-1.5b", 6 | "pretrained_model_path": "qwen2.5-coder-1.5b", 7 | "checkpoint": "output/qwen2.5-coder-1.5b/epoch_10", 8 | "model_type": "qwen2.5", 9 | "mode": "ft", 10 | "lora": false, 11 | "lora_rank": 64, 12 | "graph_token_num": 256, 13 | "graph_hidden_dim": 1024, 14 | "data_split": "99.5,0.5", 15 | "per_device_train_batch_size": 3, 16 | "per_device_eval_batch_size": 3, 17 | "learning_rate": 5e-6, 18 | "min_lr": 1e-7, 19 | "weight_decay": 0.1, 20 | "gradient_accumulation_steps": 1, 21 | "lr_scheduler_type": "cosine", 22 | "num_warmup_steps": 1000, 23 | "num_train_epochs": 10, 24 | "seed": 42, 25 | "seq_length": 4096, 26 | "log_interval": 10, 27 | "checkpointing_steps": 500, 28 | "evaluation_steps": 500, 29 | "epoch_checkpointing": true, 30 | "early_stopping": false, 31 | "early_stopping_stall_num": 25 32 | } 33 | -------------------------------------------------------------------------------- /data/filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm.auto import tqdm 3 | import pandas as pd 4 | import numpy as np 5 | from data.preprocess_data import UniformEncoder 6 | 7 | 8 | def filter(args): 9 | 10 | encoder = UniformEncoder(args) 11 | encoder.initializer() 12 | 13 | 14 | files = os.listdir(args.data_dir) 15 | jsonl_files = [f for f in files if f.endswith('.jsonl')] 16 | 17 | for file in jsonl_files: 18 | file_name = f"{args.data_dir}/{file}" 19 | df = pd.read_json(file_name, lines=True) 20 | valid_idx = [] 21 | for i in tqdm(range(len(df))): 22 | if 'node_ids' in df.keys(): 23 | features = encoder.encode_graph(df.loc[i]) 24 | else: 25 | features = encoder.encode_text(df.loc[i]) 26 | # returns None if too long 27 | if features is not None: 28 | valid_idx.append(i) 29 | valid_idx = np.array(valid_idx) 30 | len_orig = len(df) 31 | print(f"{file}: {len_orig} -> ", end='') 32 | df = df.loc[valid_idx] 33 | print(f"{len(df)}") 34 | if len(df) != len_orig: 35 | print(f'updating {args.data_dir}/{file}') 36 | df.to_json(f"{args.data_dir}/{file}", orient='records', lines=True) 37 | -------------------------------------------------------------------------------- /data_sample/pretrain.jsonl: -------------------------------------------------------------------------------- 1 | {"node_ids":[7,0,13,17,0,13,17,12,13,28,13,4,13,17,2,13,17,17,28,13,4,13,17,2,13,17,17,0,13,2,13,13,4,13,2,2,2,2,4,13,13,17,4,13,13,17,4,13,2,13,13,4,13,10,17,13,10,17,13,10,12,13],"edge_index":[[0,1],[1,2],[1,3],[0,4],[4,5],[4,6],[0,7],[60,8],[9,10],[9,11],[11,12],[11,13],[11,14],[14,15],[14,16],[11,17],[18,19],[18,20],[20,21],[20,22],[20,23],[23,24],[23,25],[20,26],[27,28],[27,29],[29,30],[29,31],[32,33],[32,34],[34,35],[35,36],[36,37],[37,38],[38,39],[38,40],[37,41],[36,42],[42,43],[42,44],[35,45],[34,46],[46,47],[46,48],[48,49],[48,50],[0,51],[51,52],[0,53],[53,54],[53,55],[0,56],[56,57],[56,58],[0,59],[59,60],[59,61]],"text":["M = 9\nN = 9\n\ndef main():\n for i in range(1,M+1,1):\n for j in range(1,N+1,1):\n mult = i * j\n print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))\nmain()","M = 9","M","9","N = 9","N","9","def main():\n for i in range(1,M+1,1):\n for j in range(1,N+1,1):\n mult = i * j\n print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))","main","for i in range(1,M+1,1):\n for j in range(1,N+1,1):\n mult = i * j\n print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))","i","range(1,M+1,1)","range","1","M+1","M","1","1","for j in range(1,N+1,1):\n mult = i * j\n print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))","j","range(1,N+1,1)","range","1","N+1","N","1","1","mult = i * j","mult","i * j","i","j","print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))","print","str(i) + \"x\" + str(j) + \"=\" + str(i * j)","str(i) + \"x\" + str(j) + \"=\"","str(i) + \"x\" + str(j)","str(i) + \"x\"","str(i)","str","i","\"x\"","str(j)","str","j","\"=\"","str(i * j)","str","i * j","i","j","main()","main","M = 9","9","M","N = 9","9","N","def main():\n for i in range(1,M+1,1):\n for j in range(1,N+1,1):\n mult = i * j\n print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))","def main():\n for i in range(1,M+1,1):\n for j in range(1,N+1,1):\n mult = i * j\n print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))","main"],"source":"M = 9\nN = 9\n\ndef main():\n for i in range(1,M+1,1):\n for j in range(1,N+1,1):\n mult = i * j\n print(str(i) + \"x\" + str(j) + \"=\" + str(i * j))\nmain()\n"} 2 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # print out arguments in a nice way 5 | def print_args(args, accelerator): 6 | # 计算所有键的最大字符串长度 7 | max_key_length = max(len(str(key)) for key in vars(args).keys()) 8 | 9 | message = "" 10 | message += "====" * 40 + "\n" 11 | message += '\n'.join([f'{k:<{max_key_length}} : {v}' for k, v in vars(args).items()]) + "\n" 12 | message += "====" * 40 + "\n" 13 | accelerator.print(message) 14 | 15 | 16 | def count_parameters(model): 17 | return sum(p.numel() for p in model.parameters()) 18 | 19 | 20 | def print_with_rank(accelerator, msg): 21 | print(accelerator.process_index, msg) 22 | 23 | 24 | def print_rank_0(*message): 25 | """If distributed is initialized print only on rank 0.""" 26 | if torch.distributed.is_initialized(): 27 | if torch.distributed.get_rank() == 0: 28 | print(*message, flush=True) 29 | else: 30 | print(*message, flush=True) 31 | 32 | 33 | def print_rank_0_highlight(*message): 34 | """If distributed is initialized print only on rank 0.""" 35 | if torch.distributed.is_initialized(): 36 | if torch.distributed.get_rank() == 0: 37 | print('=='*100) 38 | print(*message, flush=True) 39 | print('=='*100) 40 | else: 41 | print('=='*100) 42 | print(*message, flush=True) 43 | print('=='*100) 44 | 45 | 46 | def print_highlight(*message): 47 | print('=='*100) 48 | print(*message) 49 | print('=='*100) 50 | 51 | 52 | def get_computation_speed(batch_size_per_device, seq_len, step_time): 53 | 54 | return batch_size_per_device * seq_len / (step_time + 1e-12) 55 | 56 | 57 | def touch_print(accelerator, batch, num_tokens=10): 58 | """touch first and last tokens and labels for debugging usage""" 59 | accelerator.print(f"step 1 batch shape: {batch['input_ids'].shape},\n" 60 | f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}" 61 | f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}") 62 | accelerator.print(f"first {num_tokens} input_ids and loss_mask") 63 | for pt in range(1): 64 | accelerator.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") 65 | accelerator.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") -------------------------------------------------------------------------------- /data/graph_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm.auto import tqdm 3 | import pandas as pd 4 | import numpy as np 5 | from datasets import Dataset, concatenate_datasets 6 | from data.preprocess_data import UniformEncoder 7 | 8 | 9 | def load_dataset(args, accelerator): 10 | all_data_fields = ['node_ids', 'edge_index', 'source'] if args.mode == 'pt' else ['node_ids', 'edge_index', 'question', 'human', 'bot'] 11 | 12 | encoder = UniformEncoder(args) 13 | encoder.initializer() 14 | 15 | splits = [] 16 | splits_string = args.data_split 17 | if splits_string.find(",") != -1: 18 | splits = [float(s) for s in splits_string.split(",")] 19 | elif splits_string.find("/") != -1: 20 | splits = [float(s) for s in splits_string.split("/")] 21 | else: 22 | splits = [float(splits_string)] 23 | while len(splits) < 2: 24 | splits.append(0.0) 25 | splits = splits[:2] 26 | accelerator.print(f'data splits: {splits}') 27 | 28 | files = os.listdir(args.data_dir) 29 | jsonl_files = [f for f in files if f.endswith('.jsonl')] 30 | 31 | dfs = [] 32 | if args.mode == 'ft': 33 | dfs_ft = [] 34 | for file in jsonl_files: 35 | file_name = f"{args.data_dir}/{file}" 36 | df = pd.read_json(file_name, lines=True) 37 | if args.mode == 'ft' and 'node_ids' not in df.keys(): 38 | dfs_ft.append(df) 39 | else: 40 | dfs.append(df) 41 | 42 | df = pd.concat(dfs) 43 | dataset = Dataset.from_dict({k: df[k].to_list() for k in df.keys() if k in all_data_fields}) 44 | # shuffle and split 45 | dataset_split = dataset.train_test_split(train_size=splits[0]/100.0, shuffle=True, seed=args.seed) 46 | 47 | if args.mode == 'pt': 48 | accelerator.print(dataset_split) 49 | return dataset_split['train'], dataset_split['test'] 50 | 51 | else: 52 | # shuffle each finetune dataset (Java-Python, Python-Java) separately to avoid data leak 53 | datasets_ft = [Dataset.from_dict({k: df_ft[k].to_list() for k in df_ft.keys() if k in all_data_fields}) for df_ft in dfs_ft] 54 | dataset_splits_ft = [ds_ft.train_test_split(train_size=99/100.0, shuffle=True, seed=42) for ds_ft in datasets_ft] 55 | dataset_ft_train = concatenate_datasets([ds_ft['train'] for ds_ft in dataset_splits_ft]) 56 | dataset_ft_valid = concatenate_datasets([ds_ft['test'] for ds_ft in dataset_splits_ft]) 57 | 58 | accelerator.print('Graph dataset:') 59 | accelerator.print(dataset_split) 60 | accelerator.print('Finetune dataset (train):') 61 | accelerator.print(dataset_ft_train) 62 | accelerator.print('Finetune dataset (valid):') 63 | accelerator.print(dataset_ft_valid) 64 | 65 | return dataset_split['train'], dataset_split['test'], dataset_ft_train, dataset_ft_valid -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GALLa 2 | 3 | ## Introduction 4 | 5 | Repo for the ACL 2025 paper [GALLa: Graph Aligned Large Language Models for Improved Source Code Understanding](https://arxiv.org/abs/2409.04183), which aligns LLMs to code structural graphs (e.g. AST, DFG) to enhance their understanding and semantic representation of code. 6 | 7 | ## Setup 8 | 9 | ### Environment 10 | 11 | GALLa uses a GNN (Graph Neural Network) and an adapter to bridge graphs and LLMs, similar in essence to vision language models such as LLaVA or Qwen-VL. This codebase uses DGL to implement GNN. See `requirements.txt` for more details. 12 | 13 | ### Models 14 | 15 | The training of GALLa consists of three modules: GNN, adapter, and LLM. We use a [DUPLEX](https://arxiv.org/abs/2406.05391) as GNN and a single-layer cross-attention as the adapter by default. We also implement two alternatives: [MagNet](https://arxiv.org/abs/2102.11391) as GNN or an MLP as adapter. To use these alternatives, rename `model_magnet.py` or `model_mlp.py` to `model.py` in the `modeling` folder. 16 | 17 | At inference time, only the LLM is used (which means you can use models trained with GALLa just the way you use the base LLM). Currently we support LLaMA-2, LLaMA-3, Phi-1, StarCoder, CodeGen, and Qwen2.5. 18 | 19 | ### Data 20 | 21 | For the first stage (pretraining), only graph data is required (see `data_sample/pretrain.jsonl` for an example). The data should be one or multiple jsonl files stored in a folder, and three fields are required: 22 | - `node_ids`: ids of the nodes in DGL format 23 | - `edge_index`: edge indicies in DGL format 24 | - `source`: the program's source code 25 | 26 | In the first stage, the model is trained to recover source code from the graph. You will also need to pass a node embedding matrix (see `configs/config_pretrain.json`), which is a .pth file containing an $N\times d$ tensor, where $N$ is the total number of node types (43 in our case), and $d$ is the node embedding dimension. In our experiments, we used codet5p-embedding to generate these node type embeddings. 27 | 28 | For the second stage (instruction finetuning), two types of data are required: graph data and downstream task data, both stored in one or more jsonl files. 29 | 30 | Graph data in the second stage should include four fields: 31 | - `node_ids`: same as stage 1 32 | - `edge_index`: same as stage 1 33 | - `question`: a question about the graph 34 | - `bot`: the answer to the question 35 | 36 | Downstream task data should include two fields: 37 | - `human`: the question, or prompt 38 | - `bot`: the answer, or response 39 | 40 | Examples of these two types of data are provided in `data_sample` folder. You should place these two types of data in the same folder for training. 41 | 42 | ## Run 43 | 44 | First stage: 45 | 46 | ``` 47 | accelerate launch --config_file accelerate_ds_config.yaml run.py --train_config configs/config_pretrain.json 48 | ``` 49 | 50 | Second stage: 51 | ``` 52 | accelerate launch --config_file accelerate_ds_config.yaml run.py --train_config configs/config_instruction.json 53 | ``` 54 | 55 | Notes: 56 | - `checkpoint` in the second stage's training config should be the path to the checkpoint saved in first stage's training 57 | - LoRA is supported for the second stage. Simply set `lora` to `true` in the config. 58 | - The models saved during second stage training have exactly the same architecture as the base LLM you are using, and can be used in the same way (e.g. with Hugging Face Transformers or VLLM). 59 | - When evaluating the trained models, please format the problems consistently with downstream task training data in stage 2 for best performance. 60 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | import argparse, json 3 | from typing import List, Union 4 | import torch 5 | 6 | @dataclass 7 | class TrainArgs: 8 | 9 | # train data paths on shared FS 10 | data_dir: Union[str, List[str]] 11 | node_type_embedding: str 12 | 13 | # output dir for saving adaptors in peft or full ckpts in full-parameter training 14 | output_dir: str 15 | 16 | # tensorboard dir for saving tensorboard logs 17 | tb_dir: str 18 | 19 | # pretrained_model_path, on which is the model you want to train 20 | pretrained_model_path: str 21 | 22 | # whether to load pretrained checkpoint for finetuning 23 | checkpoint: Union[str, None] = None 24 | 25 | # model type 26 | model_type: str = 'phi' 27 | 28 | # training mode: "pt" for pretraining, "ft" for instruction finetuning 29 | mode: str = "ft" 30 | 31 | # graph embedding dimension 32 | graph_embedding_dim: int = 256 33 | graph_hidden_dim: int = 1024 34 | 35 | # number of graph node types 36 | graph_node_types: int = 43 37 | 38 | # graph token placeholder 39 | graph_pad_token: str = "<|graph_pad|>" 40 | graph_token_num: int = 128 41 | 42 | # train/valid/test split 43 | data_split: str = "99,1,0" 44 | 45 | # mircro train batch size 46 | per_device_train_batch_size: int = 8 47 | 48 | # micro eval batch size, always same as micro train batch size 49 | per_device_eval_batch_size: int = 8 50 | 51 | # lora (for stage 2 only) 52 | lora: bool = False 53 | lora_rank: int = 64 54 | lora_alpha: int = 16 55 | 56 | # initial lr 57 | learning_rate: float = 5e-5 58 | 59 | # minimum lr 60 | min_lr: float = 5e-6 61 | 62 | # weight decay 63 | weight_decay: float = 0.1 64 | 65 | # gradient_accumulation_steps 66 | gradient_accumulation_steps: int = 1 67 | 68 | # lr_scheduler_type 69 | lr_scheduler_type: str = "cosine" 70 | 71 | # num_warmup_steps 72 | num_warmup_steps: int = 300 73 | 74 | # num_train_epochs 75 | num_train_epochs: int = 4 76 | 77 | # seed for reproducing 78 | seed: int = 42 79 | 80 | # seq_length, context length 81 | seq_length: int = 4096 82 | 83 | # num of steps for logging training loss 84 | log_interval: int = 10 85 | 86 | # num of steps for saving ckpt 87 | checkpointing_steps: int = 100 88 | 89 | # num of steps for evaluation(eval_loss), better same as checkpointing steps 90 | evaluation_steps: int = 100 91 | 92 | # max train steps, if None, depends on num_train_epochs 93 | max_train_steps: Union[None, int] = None 94 | 95 | # if checkpointing every epoch, maybe True in sst 96 | epoch_checkpointing: bool = False 97 | 98 | # if early stop when eval loss is not converging in the past early_stopping_stall_num evaluation point 99 | early_stopping: bool = True 100 | early_stopping_stall_num: int = 5 101 | 102 | #ATTENTION_CLASSES = { "eager": Normal Attention, "flash_attention_2": FlashAttention2} 103 | attn_implementation: str = "flash_attention_2" 104 | 105 | def dict(self): 106 | return {k: str(v) for k, v in asdict(self).items()} 107 | 108 | def prepare_args(): 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--train_config", type=str, default=None) 111 | parsed = parser.parse_args() 112 | with open(parsed.train_config, 'r') as f: 113 | train_config = json.load(f) 114 | 115 | args = TrainArgs(**train_config) 116 | if not torch.cuda.is_available(): 117 | args.attn_implementation = 'eager' 118 | if args.model_type in ['codegen']: 119 | args.attn_implementation = 'eager' 120 | 121 | return args -------------------------------------------------------------------------------- /modeling/duplex.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | import torch 4 | from .gatconv import GATConv 5 | 6 | 7 | class DUPLEX(nn.Module): 8 | def __init__(self, args): 9 | super(DUPLEX, self).__init__() 10 | self.args = args 11 | self.activation = F.relu 12 | self.am_layers = nn.ModuleList() 13 | self.ph_layers = nn.ModuleList() 14 | self.dropout = nn.Dropout(args.dr_rate) 15 | self.n_layers = args.n_layers 16 | self.fusion_layer = args.fusion_layer 17 | # ----- am layer ----- 18 | self.am_layers.append(GATConv(args.input_dim, args.hidden_dim//args.head, num_heads=args.head)) 19 | self.ph_layers.append(GATConv(args.input_dim, args.hidden_dim//args.head, num_heads=args.head)) 20 | if args.fusion == 'add': 21 | for i in range(1, args.n_layers-1): 22 | self.am_layers.append(GATConv(args.hidden_dim, args.hidden_dim//args.head, num_heads=args.head)) 23 | self.ph_layers.append(GATConv(args.hidden_dim, args.hidden_dim//args.head, num_heads=args.head)) 24 | if self.n_layers==1: 25 | self.am_agg_layer = GATConv(args.input_dim, args.hidden_dim//args.head, num_heads=args.head) 26 | self.ph_agg_layer = GATConv(args.input_dim, args.hidden_dim//args.head, num_heads=args.head) 27 | else: 28 | self.am_agg_layer = GATConv(args.hidden_dim, args.hidden_dim//args.head, num_heads=args.head) 29 | self.ph_agg_layer = GATConv(args.hidden_dim, args.hidden_dim//args.head, num_heads=args.head) 30 | self.am_layers.append(GATConv(args.hidden_dim, args.output_dim, num_heads=args.head)) 31 | self.ph_layers.append(GATConv(args.hidden_dim, args.output_dim, num_heads=args.head)) 32 | 33 | else: 34 | for i in range(0, args.n_layers-2): 35 | self.am_layers.append(GATConv(args.hidden_dim, args.hidden_dim//args.head, num_heads=args.head)) 36 | self.ph_layers.append(GATConv(args.hidden_dim, args.hidden_dim//args.head, num_heads=args.head)) 37 | self.ph_layers.append(GATConv(args.hidden_dim, args.output_dim, num_heads=args.head)) 38 | self.am_layers.append(GATConv(args.hidden_dim, args.output_dim, num_heads=args.head)) 39 | self.projector = nn.Linear(args.output_dim*2, args.output_dim) 40 | 41 | def forward(self, g, input_am, input_ph): 42 | h_am = input_am 43 | h_ph = input_ph 44 | for i in range(self.args.n_layers): 45 | am_layer = self.am_layers[i] 46 | ph_layer = self.ph_layers[i] 47 | 48 | if self.args.fusion == 'add': 49 | if i == self.fusion_layer: 50 | h_am_agg = self.am_agg_layer(g, h_ph) # agg new 51 | h_ph_agg = self.ph_agg_layer(g, h_am) # agg new 52 | 53 | h_am = am_layer(g, h_am) # am_new 54 | h_ph = ph_layer(g, h_ph) # ph_new 55 | 56 | h_am = h_am + h_am_agg 57 | h_ph = h_ph + h_ph_agg 58 | else: 59 | h_am = am_layer(g, h_am) # am_new 60 | h_ph = ph_layer(g, h_ph) # ph_new 61 | else: 62 | h_am = am_layer(g, h_am) # am_new 63 | h_ph = ph_layer(g, h_ph) 64 | 65 | if i < self.n_layers-1: 66 | h_am = h_am.flatten(1) 67 | h_am = self.activation(h_am) 68 | h_am = self.dropout(h_am) 69 | h_ph = h_ph.flatten(1) 70 | h_ph = self.activation(h_ph) 71 | h_ph = self.dropout(h_ph) 72 | else: 73 | h_am = h_am.mean(1) 74 | h_ph = h_ph.mean(1) 75 | if i < self.n_layers-1: 76 | h_am = self.activation(h_am) 77 | h_am = self.dropout(h_am) 78 | h_ph = self.activation(h_ph) 79 | h_ph = self.dropout(h_ph) 80 | else: 81 | output = self.projector(torch.cat((h_am, h_ph), dim=-1)) 82 | # pred_score = self.classifier(torch.cat((h_am, h_ph), dim=-1)) 83 | # continue 84 | return output 85 | -------------------------------------------------------------------------------- /modeling/model_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .duplex import DUPLEX 4 | from transformers import AutoModelForCausalLM, AutoModel 5 | from utils import count_parameters, print_rank_0, print_highlight, print_rank_0_highlight 6 | from peft import ( 7 | LoraConfig, 8 | TaskType, 9 | get_peft_model, 10 | PeftModel, 11 | ) 12 | 13 | 14 | class Adapter(nn.Module): 15 | def __init__(self, args): 16 | super(Adapter, self).__init__() 17 | self.args = args 18 | 19 | # the last layer in the GNN is a linear (without activation) 20 | self.net = nn.Sequential( 21 | nn.SiLU(), 22 | nn.Linear(args.graph_hidden_dim, args.lm_hidden_size), 23 | nn.SiLU(), 24 | nn.Linear(args.lm_hidden_size, args.lm_hidden_size), 25 | nn.SiLU(), 26 | nn.Linear(args.lm_hidden_size, args.lm_hidden_size), 27 | ) 28 | print_rank_0(f"Parameters of cross attention: {count_parameters(self.net) / 1e6:.1f}M") 29 | 30 | def forward(self, features, batch): 31 | # features: (sum(num_node), d_embed) 32 | embeddings = self.net(features) 33 | return embeddings 34 | 35 | 36 | class Model(nn.Module): 37 | def __init__(self, args, vocab): 38 | super(Model, self).__init__() 39 | self.num_heads = 8 40 | 41 | # language model 42 | self.lm = AutoModelForCausalLM.from_pretrained( 43 | args.pretrained_model_path, 44 | attn_implementation=args.attn_implementation, 45 | torch_dtype="auto", 46 | trust_remote_code=True, 47 | ) 48 | self.lm.gradient_checkpointing_enable() 49 | self.lm.config.use_cache = False # silence the warnings. Please re-enable for inference! 50 | if args.model_type in ['starcoder', 'llama3']: 51 | self.lm.resize_token_embeddings(vocab) 52 | # lora 53 | if args.lora: 54 | peft_config = LoraConfig( 55 | task_type=TaskType.CAUSAL_LM, 56 | inference_mode=False, 57 | r=args.lora_rank, 58 | lora_alpha=args.lora_alpha, 59 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"] 60 | ) 61 | self.lm = get_peft_model(self.lm, peft_config) 62 | print_rank_0(f"Parameters of language model: {count_parameters(self.lm) / 1e9:.2f}B") 63 | 64 | # graph model 65 | self.embed_dim = args.graph_embedding_dim 66 | class Args: 67 | dr_rate: float = 0.1 68 | n_layers: int = 3 69 | fusion_layer: int = 1 70 | input_dim: int = args.graph_embedding_dim 71 | hidden_dim: int = args.graph_hidden_dim 72 | output_dim: int = args.graph_hidden_dim 73 | head: int = 1 74 | fusion = None 75 | 76 | args_gnn = Args() 77 | self.gnn = DUPLEX(args_gnn) 78 | print_rank_0(f"Parameters of GNN: {count_parameters(self.gnn) / 1000:.1f}K") 79 | 80 | # update args 81 | args.num_heads = self.num_heads 82 | args.lm_hidden_size = self.lm.config.hidden_size 83 | self.args = args 84 | 85 | # adapter 86 | self.adapter = Adapter(args) 87 | print_rank_0(f"Parameters of adapter (attention + query): {count_parameters(self.adapter) / 1e6:.1f}M") 88 | 89 | if args.checkpoint: 90 | print_rank_0_highlight(f"Loading exising checkpoint: {args.checkpoint}") 91 | self.gnn.load_state_dict(torch.load(f"{args.checkpoint}/GNN.pth")) 92 | self.adapter.load_state_dict(torch.load(f"{args.checkpoint}/adapter.pth")) 93 | 94 | def forward(self, x): 95 | bs = x['input_ids'].shape[0] 96 | 97 | if 'graph_embedding' in x.keys(): 98 | # embedding -> (bs, l, d_lm), bf16 99 | if self.args.model_type in ['llama3', 'phi', 'qwen2.5']: 100 | if not self.args.lora: 101 | inputs_embeds = self.lm.model.embed_tokens(x['input_ids']) 102 | else: 103 | inputs_embeds = self.lm.base_model.model.model.embed_tokens(x['input_ids']) 104 | elif self.args.model_type in ['starcoder', 'codegen']: 105 | inputs_embeds = self.lm.transformer.wte(x['input_ids']) 106 | else: 107 | raise NotImplementedError() 108 | 109 | # x['graph_embedding']: (sum(num_nodes), d_embed) 110 | # x.g.edges(): (2, sum(num_edges)) 111 | embeddings = x['graph_embedding'].to(self.gnn.am_layers[0].attn_l.dtype) 112 | 113 | # GNN -> (sum(num_node), d_embed), bf16 114 | # features = self.magnet(real=embeddings, imag=embeddings, edge_index=x['edge_index']) 115 | features = self.gnn(x['g'], embeddings, embeddings) 116 | 117 | # adapter -> (sum(num_node), d_embed) 118 | embeddings = self.adapter(features, x['batch_num_nodes']) 119 | 120 | inputs_embeds = inputs_embeds.reshape(-1, inputs_embeds.shape[-1]) 121 | idx = torch.nonzero(x['input_ids'].reshape(-1) == self.args.graph_pad_id).squeeze() 122 | inputs_embeds[idx] = embeddings 123 | inputs_embeds = inputs_embeds.reshape(bs, -1, inputs_embeds.shape[-1]) 124 | 125 | # lm 126 | # print_rank_0('start lm forward') 127 | outputs = self.lm(inputs_embeds=inputs_embeds, 128 | return_dict=True) 129 | return outputs 130 | 131 | else: 132 | return self.lm(input_ids=x['input_ids'], return_dict=True) 133 | -------------------------------------------------------------------------------- /data/preprocess_data.py: -------------------------------------------------------------------------------- 1 | from modeling import build_tokenizer 2 | import random 3 | 4 | table = {ord(f): ord(t) for f, t in zip( 5 | u',。!?:【】()%#@&1234567890', 6 | u',.!?:[]()%#@&1234567890')} 7 | 8 | 9 | def punctuation_format(text: str): 10 | # Replace non-breaking space with space 11 | # text = text.strip() + '\n' 12 | text = text.replace('\u202f', ' ').replace('\xa0', ' ') 13 | # change chinese punctuation to english ones 14 | # text = text.translate(table) 15 | if not text.endswith("\n"): 16 | text += "\n" 17 | return text 18 | 19 | 20 | def format_eol(text): 21 | if not text.endswith("\n"): 22 | text += "\n" 23 | return text 24 | 25 | 26 | def get_white_space(): 27 | r = random.random() 28 | return '' if r < 0.33 else (' ' if r < 0.66 else '\n') 29 | 30 | 31 | def gen_prompt(tokenizer, data, graph_pad_id, graph_token_num): 32 | # randomly select [graph tokens, question] or [question, graph tokens] 33 | if random.random() < 0.5: 34 | return tokenizer.encode(f"{data['question']}{get_white_space()}", add_special_tokens=False) + [graph_pad_id] * graph_token_num + tokenizer.encode('\n', add_special_tokens=False) 35 | else: 36 | return [graph_pad_id] * graph_token_num + tokenizer.encode(f"{get_white_space()}{format_eol(data['question'])}", add_special_tokens=False) 37 | 38 | 39 | class Encoder(object): 40 | def __init__(self, args): 41 | self.args = args 42 | # seq_length - 1 for shifting 43 | self.seq_length = args.seq_length - 1 44 | 45 | def initializer(self): 46 | self.tokenizer = build_tokenizer(self.args) 47 | 48 | self.HUMAN = 'human' 49 | self.BOT = 'bot' 50 | self.SYSTEM = 'system' 51 | self.ROLE_START_MARKER = '' 52 | self.ROLE_END_MARKER = '\n' 53 | 54 | self.human_marker_ids = self.tokenizer.encode(f"{self.ROLE_START_MARKER}{self.HUMAN}{self.ROLE_END_MARKER}", add_special_tokens=False) 55 | self.bot_marker_ids = self.tokenizer.encode(f"{self.ROLE_START_MARKER}{self.BOT}{self.ROLE_END_MARKER}", add_special_tokens=False) 56 | self.system_marker_ids = self.tokenizer.encode(f"{self.ROLE_START_MARKER}{self.SYSTEM}{self.ROLE_END_MARKER}", add_special_tokens=False) 57 | self.sft_end_marker_ids = [self.tokenizer.eos_token_id] 58 | self.role_to_markerid = {self.HUMAN: self.human_marker_ids, self.BOT: self.bot_marker_ids, self.SYSTEM: self.system_marker_ids} 59 | 60 | self.default_system_ids = self.system_marker_ids + self.tokenizer.encode('You are an AI code assistant. You will be given a task. You must provide an accurate answer according to the requirements.\n', add_special_tokens=False) 61 | 62 | def padding(self, input_ids, loss_mask): 63 | pad_id = self.tokenizer.pad_token_id 64 | assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}" 65 | input_ids += [pad_id] * (self.seq_length - len(input_ids)) 66 | loss_mask += [0] * (self.seq_length - len(loss_mask)) 67 | return { 68 | "input_ids": input_ids, 69 | "loss_mask": loss_mask 70 | } 71 | 72 | 73 | class UniformEncoder(Encoder): 74 | def __init__(self, args): 75 | super().__init__(args) 76 | 77 | def encode_graph(self, data): 78 | input_ids, loss_mask = [], [] 79 | 80 | if self.args.mode == 'ft': 81 | # system 82 | # input_ids += self.default_system_ids 83 | # loss_mask += [0] * len(self.default_system_ids) 84 | # human 85 | content_ids = gen_prompt(self.tokenizer, data, self.args.graph_pad_id, self.args.graph_token_num) 86 | input_ids += self.human_marker_ids + content_ids 87 | loss_mask += [0] * (len(self.human_marker_ids) + len(content_ids)) 88 | # bot 89 | content_ids = self.tokenizer.encode(data['bot'], add_special_tokens=False) + self.sft_end_marker_ids 90 | input_ids += self.bot_marker_ids + content_ids 91 | loss_mask += [0] * len(self.bot_marker_ids) + [1] * len(content_ids) 92 | 93 | elif self.args.mode == 'pt': 94 | # graph 95 | content_ids = [self.args.graph_pad_id] * self.args.graph_token_num 96 | input_ids += content_ids 97 | loss_mask += [0] * len(content_ids) 98 | # source code 99 | content_ids = self.tokenizer.encode(data['source'], add_special_tokens=False) + self.sft_end_marker_ids 100 | input_ids += content_ids 101 | loss_mask += [1] * len(content_ids) 102 | 103 | else: 104 | raise NotImplementedError() 105 | 106 | assert len(input_ids) == len(loss_mask) 107 | if len(input_ids) <= self.seq_length: 108 | features = self.padding(input_ids, loss_mask) 109 | features['node_ids'] = data['node_ids'] 110 | features['edge_index'] = data['edge_index'] 111 | return features 112 | 113 | # drop if too long 114 | else: 115 | return None 116 | 117 | def encode_text(self, data): 118 | input_ids, loss_mask = [], [] 119 | 120 | # system 121 | # input_ids += self.default_system_ids 122 | # loss_mask += [0] * len(self.default_system_ids) 123 | # human 124 | content_ids = self.tokenizer.encode(data['human'], add_special_tokens=False) 125 | input_ids += self.human_marker_ids + content_ids 126 | loss_mask += [0] * (len(self.human_marker_ids) + len(content_ids)) 127 | # bot 128 | content_ids = self.tokenizer.encode(data['bot'], add_special_tokens=False) + self.sft_end_marker_ids 129 | input_ids += self.bot_marker_ids + content_ids 130 | loss_mask += [0] * len(self.bot_marker_ids) + [1] * len(content_ids) 131 | 132 | assert len(input_ids) == len(loss_mask) 133 | if len(input_ids) <= self.seq_length: 134 | features = self.padding(input_ids, loss_mask) 135 | return features 136 | 137 | # drop if too long 138 | else: 139 | return None -------------------------------------------------------------------------------- /modeling/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from .magnet import MagNet 4 | from .duplex import DUPLEX 5 | from transformers import AutoModelForCausalLM, AutoModel 6 | from utils import count_parameters, print_rank_0, print_highlight, print_rank_0_highlight 7 | from peft import ( 8 | LoraConfig, 9 | TaskType, 10 | get_peft_model, 11 | PeftModel, 12 | ) 13 | 14 | 15 | class Adapter(nn.Module): 16 | def __init__(self, args): 17 | super(Adapter, self).__init__() 18 | self.args = args 19 | 20 | # learnable query: (graph_token_num, lm_d) 21 | self.q = nn.Parameter(torch.randn(args.graph_token_num, args.lm_hidden_size)) 22 | # cross attention 23 | self.attn = nn.MultiheadAttention( 24 | embed_dim=args.lm_hidden_size, 25 | num_heads=args.num_heads, 26 | kdim=args.graph_hidden_dim, 27 | vdim=args.graph_hidden_dim, 28 | batch_first=True 29 | ) 30 | print_rank_0(f"Parameters of learnable query: {self.q.numel() / 1e6:.1f}M") 31 | print_rank_0(f"Parameters of cross attention: {count_parameters(self.attn) / 1e6:.1f}M") 32 | 33 | def forward(self, features, batch): 34 | # reshape features from (sum(num_node), d_embed) to (bs, max(num_node), d_embed) 35 | features_2d = features.to(self.q.dtype) 36 | bs = len(batch) 37 | max_n = batch.max().item() 38 | features = torch.zeros((bs, max_n, features_2d.shape[-1]), dtype=features_2d.dtype, device=features_2d.device) 39 | start_idx = 0 40 | for i in range(bs): 41 | end_idx = start_idx + batch[i] 42 | features[i, :batch[i]] = features_2d[start_idx: end_idx] 43 | start_idx = end_idx 44 | 45 | # adapter -> (bs, num_graph_tokens, d_lm) 46 | # expand querys to (bs, n_query, d_lm) 47 | queries = self.q.expand(bs, -1, -1) 48 | 49 | # mask should be shape (bs, S), where S is source seq length 50 | # note that Pytorch documentation refers to query as "target", and key/value as "source" 51 | mask = torch.arange(max_n, device=features.device).expand(bs, max_n) < batch.unsqueeze(1) 52 | mask = ~mask # positions set to True are not allowed to attend 53 | 54 | embeddings = self.attn(queries, features, features, key_padding_mask=mask, need_weights=False)[0] 55 | return embeddings 56 | 57 | 58 | class Model(nn.Module): 59 | def __init__(self, args, vocab): 60 | super(Model, self).__init__() 61 | self.num_heads = 8 62 | 63 | # language model 64 | self.lm = AutoModelForCausalLM.from_pretrained( 65 | args.pretrained_model_path, 66 | attn_implementation=args.attn_implementation, 67 | torch_dtype="auto", 68 | trust_remote_code=True, 69 | ) 70 | self.lm.gradient_checkpointing_enable() 71 | self.lm.config.use_cache = False # silence the warnings. Please re-enable for inference! 72 | if args.model_type in ['starcoder', 'llama3', 'llama2']: 73 | self.lm.resize_token_embeddings(vocab) 74 | # lora 75 | if args.lora: 76 | peft_config = LoraConfig( 77 | task_type=TaskType.CAUSAL_LM, 78 | inference_mode=False, 79 | r=args.lora_rank, 80 | lora_alpha=args.lora_alpha, 81 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"] 82 | ) 83 | self.lm = get_peft_model(self.lm, peft_config) 84 | print_rank_0(f"Parameters of language model: {count_parameters(self.lm) / 1e9:.2f}B") 85 | 86 | # graph model 87 | self.embed_dim = args.graph_embedding_dim 88 | class Args: 89 | dr_rate: float = 0.1 90 | n_layers: int = 3 91 | fusion_layer: int = 1 92 | input_dim: int = args.graph_embedding_dim 93 | hidden_dim: int = args.graph_hidden_dim 94 | output_dim: int = args.graph_hidden_dim 95 | head: int = 1 96 | fusion = None 97 | 98 | args_gnn = Args() 99 | self.gnn = DUPLEX(args_gnn) 100 | print_rank_0(f"Parameters of GNN: {count_parameters(self.gnn) / 1000:.1f}K") 101 | 102 | # update args 103 | args.num_heads = self.num_heads 104 | args.lm_hidden_size = self.lm.config.hidden_size 105 | self.args = args 106 | 107 | # adapter 108 | self.adapter = Adapter(args) 109 | print_rank_0(f"Parameters of adapter (attention + query): {count_parameters(self.adapter) / 1e6:.1f}M") 110 | 111 | if args.checkpoint: 112 | print_rank_0_highlight(f"Loading exising checkpoint: {args.checkpoint}") 113 | self.gnn.load_state_dict(torch.load(f"{args.checkpoint}/GNN.pth")) 114 | self.adapter.load_state_dict(torch.load(f"{args.checkpoint}/adapter.pth")) 115 | 116 | def forward(self, x): 117 | bs = x['input_ids'].shape[0] 118 | 119 | if 'graph_embedding' in x.keys(): 120 | # embedding -> (bs, l, d_lm), bf16 121 | if self.args.model_type in ['llama3', 'phi', 'llama2', 'qwen2.5']: 122 | if not self.args.lora: 123 | inputs_embeds = self.lm.model.embed_tokens(x['input_ids']) 124 | else: 125 | inputs_embeds = self.lm.base_model.model.model.embed_tokens(x['input_ids']) 126 | elif self.args.model_type in ['starcoder', 'codegen']: 127 | inputs_embeds = self.lm.transformer.wte(x['input_ids']) 128 | else: 129 | raise NotImplementedError() 130 | 131 | # x['graph_embedding']: (sum(num_nodes), d_embed) 132 | # x.g.edges(): (2, sum(num_edges)) 133 | embeddings = x['graph_embedding'].to(self.gnn.am_layers[0].attn_l.dtype) 134 | 135 | # GNN -> (sum(num_node), d_embed), bf16 136 | # features = self.magnet(real=embeddings, imag=embeddings, edge_index=x['edge_index']) 137 | features = self.gnn(x['g'], embeddings, embeddings) 138 | 139 | # adapter -> (bs, num_graph_tokens, d_lm) 140 | embeddings = self.adapter(features, x['batch_num_nodes']) 141 | 142 | # if lora, inputs_embeds will have no grad func and are thus leaft tensors - can't apply in-place operation 143 | if self.args.lora: 144 | inputs_embeds = inputs_embeds.clone() 145 | 146 | for i in range(bs): 147 | # replace graph embedding 148 | graph_token_positions = (x['input_ids'][i] == self.args.graph_pad_id).long().nonzero().squeeze() 149 | pos_start = graph_token_positions.min() 150 | pos_end = graph_token_positions.max() 151 | inputs_embeds[i, pos_start:pos_end+1] = embeddings[i] 152 | 153 | # lm 154 | # print_rank_0('start lm forward') 155 | outputs = self.lm(inputs_embeds=inputs_embeds, 156 | return_dict=True) 157 | return outputs 158 | 159 | else: 160 | return self.lm(input_ids=x['input_ids'], return_dict=True) 161 | -------------------------------------------------------------------------------- /modeling/model_magnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .magnet import MagNet 4 | from transformers import AutoModelForCausalLM, AutoModel 5 | from utils import count_parameters, print_rank_0, print_highlight, print_rank_0_highlight 6 | from peft import ( 7 | LoraConfig, 8 | TaskType, 9 | get_peft_model, 10 | PeftModel, 11 | ) 12 | 13 | 14 | class Adapter(nn.Module): 15 | def __init__(self, args): 16 | super(Adapter, self).__init__() 17 | self.args = args 18 | 19 | # learnable query: (graph_token_num, lm_d) 20 | self.q = nn.Parameter(torch.randn(args.graph_token_num, args.lm_hidden_size)) 21 | # cross attention 22 | self.attn = nn.MultiheadAttention( 23 | embed_dim=args.lm_hidden_size, 24 | num_heads=args.num_heads, 25 | kdim=args.graph_hidden_dim, 26 | vdim=args.graph_hidden_dim, 27 | batch_first=True 28 | ) 29 | print_rank_0(f"Parameters of learnable query: {self.q.numel() / 1e6:.1f}M") 30 | print_rank_0(f"Parameters of cross attention: {count_parameters(self.attn) / 1e6:.1f}M") 31 | 32 | def forward(self, features, batch): 33 | # # reshape features from (sum(num_node), d_embed) to (bs, max(num_node), d_embed) 34 | # features_2d = features.to(self.q.dtype) 35 | # bs = len(batch) 36 | # max_n = batch.max().item() 37 | # features = torch.zeros((bs, max_n, features_2d.shape[-1]), dtype=features_2d.dtype, device=features_2d.device) 38 | # start_idx = 0 39 | # for i in range(bs): 40 | # end_idx = start_idx + batch[i] 41 | # features[i, :batch[i]] = features_2d[start_idx: end_idx] 42 | # start_idx = end_idx 43 | 44 | # # adapter -> (bs, num_graph_tokens, d_lm) 45 | # # expand querys to (bs, n_query, d_lm) 46 | # queries = self.q.expand(bs, -1, -1) 47 | 48 | # # mask should be shape (bs, S), where S is source seq length 49 | # # note that Pytorch documentation refers to query as "target", and key/value as "source" 50 | # mask = torch.arange(max_n, device=features.device).expand(bs, max_n) < batch.unsqueeze(1) 51 | # mask = ~mask # positions set to True are not allowed to attend 52 | 53 | # embeddings = self.attn(queries, features, features, key_padding_mask=mask, need_weights=False)[0] 54 | # return embeddings 55 | 56 | # reshape features from (sum(num_node), d_embed) to (bs, max(num_node), d_embed) 57 | features_2d = features.to(self.q.dtype) 58 | bincount = batch.bincount() # bincount: (bs,) 59 | bs = len(bincount) 60 | max_n = bincount.max().item() 61 | features = torch.zeros((bs, max_n, features_2d.shape[-1]), dtype=features_2d.dtype, device=features_2d.device) 62 | for i in range(bs): 63 | features[i, :bincount[i]] = features_2d[batch == i] 64 | 65 | # adapter -> (bs, num_graph_tokens, d_lm) 66 | queries = self.q.expand(bs, -1, -1) 67 | # key padding mask should be shape (L, S), where L is target seq length, and S is source seq length 68 | # note that Pytorch documentation refers to query as "target", and key/value as "source" 69 | 70 | mask = torch.arange(max_n, device=features.device).expand(bs, max_n) < bincount.unsqueeze(1) 71 | mask = ~mask # positions set to True are not allowed to attend 72 | 73 | embeddings = self.attn(queries, features, features, key_padding_mask=mask, need_weights=False)[0] 74 | return embeddings 75 | 76 | 77 | class Model(nn.Module): 78 | def __init__(self, args, vocab): 79 | super(Model, self).__init__() 80 | self.num_heads = 8 81 | 82 | # language model 83 | self.lm = AutoModelForCausalLM.from_pretrained( 84 | args.pretrained_model_path, 85 | attn_implementation=args.attn_implementation, 86 | torch_dtype=torch.bfloat16, 87 | trust_remote_code=True, 88 | ) 89 | self.lm.gradient_checkpointing_enable() 90 | self.lm.config.use_cache = False # silence the warnings. Please re-enable for inference! 91 | if args.model_type in ['starcoder', 'llama3']: 92 | self.lm.resize_token_embeddings(vocab) 93 | # lora 94 | if args.lora: 95 | peft_config = LoraConfig( 96 | task_type=TaskType.CAUSAL_LM, 97 | inference_mode=False, 98 | r=args.lora_rank, 99 | lora_alpha=args.lora_alpha, 100 | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"] 101 | ) 102 | self.lm = get_peft_model(self.lm, peft_config) 103 | print_rank_0(f"Parameters of language model: {count_parameters(self.lm) / 1e9:.2f}B") 104 | 105 | # graph model 106 | self.embed_dim = args.graph_embedding_dim 107 | self.gnn = MagNet(self.embed_dim, output_dim=args.graph_hidden_dim, hidden=args.graph_hidden_dim) 108 | print_rank_0(f"Parameters of GNN: {count_parameters(self.gnn) / 1000:.1f}K") 109 | 110 | # update args 111 | args.num_heads = self.num_heads 112 | args.lm_hidden_size = self.lm.config.hidden_size 113 | self.args = args 114 | 115 | # adapter 116 | self.adapter = Adapter(args) 117 | print_rank_0(f"Parameters of adapter (attention + query): {count_parameters(self.adapter) / 1e6:.1f}M") 118 | 119 | if args.checkpoint: 120 | print_rank_0_highlight(f"Loading exising checkpoint: {args.checkpoint}") 121 | self.gnn.load_state_dict(torch.load(f"{args.checkpoint}/GNN.pth")) 122 | self.adapter.load_state_dict(torch.load(f"{args.checkpoint}/adapter.pth")) 123 | 124 | def forward(self, x): 125 | bs = x['input_ids'].shape[0] 126 | 127 | if 'graph_embedding' in x.keys(): 128 | # embedding -> (bs, l, d_lm), bf16 129 | if self.args.model_type in ['llama3', 'phi', 'qwen2.5']: 130 | if not self.args.lora: 131 | inputs_embeds = self.lm.model.embed_tokens(x['input_ids']) 132 | else: 133 | inputs_embeds = self.lm.base_model.model.model.embed_tokens(x['input_ids']) 134 | elif self.args.model_type in ['starcoder', 'codegen']: 135 | inputs_embeds = self.lm.transformer.wte(x['input_ids']) 136 | else: 137 | raise NotImplementedError() 138 | 139 | # x['graph_embedding']: (sum(num_nodes), d_embed) 140 | # x.g.edges(): (2, sum(num_edges)) 141 | embeddings = x['graph_embedding'].to(self.gnn.Chebs[0].weight.dtype) 142 | 143 | # GNN -> (sum(num_node), d_embed), bf16 144 | features = self.gnn(real=embeddings, imag=embeddings, edge_index=x['edge_index']) 145 | # features = self.gnn(x['g'], embeddings, embeddings) 146 | 147 | # adapter -> (bs, num_graph_tokens, d_lm) 148 | embeddings = self.adapter(features, x['batch']) 149 | 150 | # if lora, inputs_embeds will have no grad func and are thus leaft tensors - can't apply in-place operation 151 | if self.args.lora: 152 | inputs_embeds = inputs_embeds.clone() 153 | 154 | for i in range(bs): 155 | # replace graph embedding 156 | graph_token_positions = (x['input_ids'][i] == self.args.graph_pad_id).long().nonzero().squeeze() 157 | pos_start = graph_token_positions.min() 158 | pos_end = graph_token_positions.max() 159 | inputs_embeds[i, pos_start:pos_end+1] = embeddings[i] 160 | 161 | # lm 162 | # print_rank_0('start lm forward') 163 | outputs = self.lm(inputs_embeds=inputs_embeds, 164 | return_dict=True) 165 | return outputs 166 | 167 | else: 168 | return self.lm(input_ids=x['input_ids'], return_dict=True) 169 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import time, os, json, math, logging 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import dgl 6 | # from deepspeed.ops.adam import FusedAdam as AdamW 7 | from torch.optim import AdamW 8 | from accelerate import Accelerator 9 | from accelerate.logging import get_logger 10 | from transformers import ( 11 | set_seed, 12 | get_scheduler, 13 | ) 14 | from arguments import prepare_args 15 | from data.graph_dataset import load_dataset 16 | from data.preprocess_data import UniformEncoder 17 | from modeling import Model, build_tokenizer 18 | from utils import print_args, accelerate_train, print_with_rank 19 | 20 | # get args 21 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 22 | args = prepare_args() 23 | 24 | # start accelerator 25 | set_seed(args.seed) 26 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) 27 | 28 | # print and save args in the main process 29 | print_args(args, accelerator) 30 | if accelerator.is_main_process: 31 | if not os.path.exists(args.output_dir): 32 | os.makedirs(args.output_dir) 33 | with open(os.path.join(args.output_dir, "args.json"), "w") as f: 34 | json.dump(args.dict(), f, indent=2) 35 | 36 | # prepare logger 37 | logger = get_logger(__name__) 38 | logging.basicConfig( 39 | format="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", 40 | datefmt="%Y-%m-%d %H:%M:%S", 41 | level=logging.INFO, 42 | ) 43 | logger.info(accelerator.state, main_process_only=False) 44 | 45 | # used in collate function 46 | node_type_embedding = torch.load(args.node_type_embedding) 47 | encoder = UniformEncoder(args) 48 | encoder.initializer() 49 | 50 | 51 | def collate_fn(instances): 52 | input_ids, loss_mask, node_ids, edge_index = [], [], [], [] 53 | for instance in instances: 54 | if 'node_ids' in instance.keys(): 55 | features = encoder.encode_graph(instance) 56 | else: 57 | features = encoder.encode_text(instance) 58 | if features is not None: 59 | input_ids.append(features['input_ids']) 60 | loss_mask.append(features['loss_mask']) 61 | node_ids.append(features.get('node_ids', None)) 62 | edge_index.append(features.get('edge_index', None)) 63 | 64 | n = len(input_ids) 65 | 66 | # batch graphs 67 | if 'node_ids' in instances[0].keys(): 68 | graphs = [] 69 | for i in range(n): 70 | edges = torch.LongTensor(edge_index[i]).t().contiguous() 71 | g = dgl.graph((edges[0,:], edges[1,:]), num_nodes=len(node_ids[i])) 72 | g.ndata['x'] = node_type_embedding[torch.tensor(node_ids[i])] 73 | graphs.append(g) 74 | batch = dgl.batch(graphs) 75 | # edge_index is changed, features remain the same 76 | # batch.ndata['x']: (sum(num_nodes), d_embed) 77 | # batch.edges(): (2, sum(num_edges)) 78 | else: 79 | batch = None 80 | 81 | result_batch = { 82 | 'g': batch, 83 | 'graph_embedding': batch.ndata['x'], 84 | 'batch_num_nodes': batch.batch_num_nodes() 85 | } if batch else {} 86 | 87 | loss_mask = torch.tensor(loss_mask).long() 88 | # dynamic padding 89 | last_one_pos = (loss_mask == 1).long().cumsum(dim=1).argmax(dim=1) 90 | # get last non-padding position 91 | max_pos = last_one_pos.max().item() + 1 92 | 93 | result_batch['loss_mask'] = loss_mask.float()[:, 1:max_pos].contiguous() 94 | input_ids = torch.tensor(input_ids).long() 95 | result_batch['input_ids'] = input_ids[:, :max_pos - 1].contiguous() 96 | result_batch['labels'] = input_ids[:, 1:max_pos].contiguous() 97 | 98 | return result_batch 99 | 100 | 101 | def main(): 102 | t0 = time.time() 103 | 104 | # set seed 105 | set_seed(args.seed) 106 | 107 | # load dataset 108 | if args.mode == 'pt': 109 | train_dataset, valid_dataset = load_dataset(args, accelerator) 110 | else: 111 | train_dataset, valid_dataset, train_dataset_ft, valid_dataset_ft = load_dataset(args, accelerator) 112 | 113 | t1 = time.time() 114 | logger.info(f"Dataset loading time: {t1 - t0:.2f}s") 115 | 116 | # load model 117 | tokenizer = build_tokenizer(args) 118 | model = Model(args, len(tokenizer)) 119 | 120 | # print(model.lm.device) 121 | t2 = time.time() 122 | logger.info(f"model loading time: {t2 - t1:.2f}s") 123 | 124 | # dataloader 125 | train_dataloader = DataLoader( 126 | train_dataset, shuffle=True, collate_fn=collate_fn, 127 | batch_size=args.per_device_train_batch_size, pin_memory=True 128 | ) 129 | valid_dataloader = DataLoader( 130 | valid_dataset, collate_fn=collate_fn, 131 | batch_size=args.per_device_eval_batch_size, pin_memory=True 132 | ) 133 | if args.mode == 'ft': 134 | train_dataloader_ft = DataLoader( 135 | train_dataset_ft, shuffle=True, collate_fn=collate_fn, 136 | batch_size=args.per_device_train_batch_size, pin_memory=True 137 | ) 138 | valid_dataloader_ft = DataLoader( 139 | valid_dataset_ft, collate_fn=collate_fn, 140 | batch_size=args.per_device_eval_batch_size, pin_memory=True 141 | ) 142 | else: 143 | train_dataloader_ft, valid_dataloader_ft = None, None 144 | 145 | # if finetuning, train all params, else only pretrain GNN and adapter 146 | if args.mode == 'ft': 147 | trained_params = model.parameters() 148 | elif args.mode == 'pt': 149 | trained_params = [p for p in model.gnn.parameters()] + [p for p in model.adapter.parameters()] 150 | else: 151 | raise NotImplementedError() 152 | optimizer = AdamW( 153 | trained_params, 154 | weight_decay=args.weight_decay, 155 | lr=args.learning_rate, 156 | betas=(0.9, 0.95), 157 | ) 158 | 159 | # Scheduler and math around the number of training steps. 160 | overrode_max_train_steps = False 161 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) + (len(train_dataloader_ft) if args.mode == 'ft' else 0) / args.gradient_accumulation_steps) 162 | if args.max_train_steps is None: 163 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 164 | overrode_max_train_steps = True 165 | 166 | lr_scheduler = get_scheduler( 167 | name=args.lr_scheduler_type, 168 | optimizer=optimizer, 169 | num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, 170 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 171 | ) 172 | 173 | logger.info(f"{'=='*100}\nbefore accelerator preparation: [dataloader: {len(train_dataloader)}][epochs: {args.num_train_epochs}][total steps: {args.max_train_steps}]\n{'=='*100}") 174 | if torch.cuda.is_available(): 175 | model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare( 176 | model, train_dataloader, valid_dataloader, optimizer, lr_scheduler 177 | ) 178 | if args.mode == 'ft': 179 | train_dataloader_ft, valid_dataloader_ft = accelerator.prepare(train_dataloader_ft, valid_dataloader_ft) 180 | 181 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 182 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) + (len(train_dataloader_ft) if args.mode == 'ft' else 0) / args.gradient_accumulation_steps) 183 | if overrode_max_train_steps: 184 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 185 | # Afterward we recalculate our number of training epochs 186 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 187 | logger.info(f"{'=='*100}\nafter accelerator preparation: [dataloader: {len(train_dataloader)}][epochs: {args.num_train_epochs}][total steps: {args.max_train_steps}]\n{'=='*100}") 188 | 189 | # Train! 190 | accelerate_train(accelerator, 191 | model, 192 | train_dataloader, 193 | valid_dataloader, 194 | train_dataloader_ft, 195 | valid_dataloader_ft, 196 | optimizer, 197 | lr_scheduler, 198 | tokenizer, 199 | len(train_dataset), 200 | args) 201 | 202 | 203 | if __name__ == "__main__": 204 | main() -------------------------------------------------------------------------------- /modeling/gatconv.py: -------------------------------------------------------------------------------- 1 | # borrow from dgl 2 | """Torch modules for graph attention networks(GAT).""" 3 | # pylint: disable= no-member, arguments-differ, invalid-name 4 | import torch as th 5 | from torch import nn 6 | 7 | #dgl.nn.pytorch.conv.gatconv 8 | from dgl import function as fn 9 | from dgl.base import DGLError 10 | from dgl.utils import expand_as_pair 11 | from dgl.nn.functional import edge_softmax 12 | from dgl.nn.pytorch.utils import Identity 13 | 14 | class GATConv(nn.Module): 15 | def __init__( 16 | self, 17 | in_feats, 18 | out_feats, 19 | num_heads, 20 | feat_drop=0.0, 21 | attn_drop=0.0, 22 | negative_slope=0.2, 23 | residual=False, 24 | activation=None, 25 | allow_zero_in_degree=True, 26 | bias=True, 27 | ): 28 | super(GATConv, self).__init__() 29 | self._num_heads = num_heads 30 | self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 31 | self._out_feats = out_feats 32 | self._allow_zero_in_degree = allow_zero_in_degree 33 | if isinstance(in_feats, tuple): 34 | self.fc_src = nn.Linear( 35 | self._in_src_feats, out_feats * num_heads, bias=False 36 | ) 37 | self.fc_dst = nn.Linear( 38 | self._in_dst_feats, out_feats * num_heads, bias=False 39 | ) 40 | else: 41 | self.fc = nn.Linear( 42 | self._in_src_feats, out_feats * num_heads, bias=False 43 | ) 44 | self.attn_l = nn.Parameter( 45 | th.FloatTensor(size=(1, num_heads, out_feats)) 46 | ) 47 | self.attn_r = nn.Parameter( 48 | th.FloatTensor(size=(1, num_heads, out_feats)) 49 | ) 50 | self.feat_drop = nn.Dropout(feat_drop) 51 | self.attn_drop = nn.Dropout(attn_drop) 52 | self.leaky_relu = nn.LeakyReLU(negative_slope) 53 | 54 | self.has_linear_res = False 55 | self.has_explicit_bias = False 56 | if residual: 57 | if self._in_dst_feats != out_feats * num_heads: 58 | self.res_fc = nn.Linear( 59 | self._in_dst_feats, num_heads * out_feats, bias=bias 60 | ) 61 | self.has_linear_res = True 62 | else: 63 | self.res_fc = Identity() 64 | else: 65 | self.register_buffer("res_fc", None) 66 | 67 | if bias and not self.has_linear_res: 68 | self.bias = nn.Parameter( 69 | th.FloatTensor(size=(num_heads * out_feats,)) 70 | ) 71 | self.has_explicit_bias = True 72 | else: 73 | self.register_buffer("bias", None) 74 | 75 | self.reset_parameters() 76 | self.activation = activation 77 | 78 | def reset_parameters(self): 79 | """ 80 | 81 | Description 82 | ----------- 83 | Reinitialize learnable parameters. 84 | 85 | Note 86 | ---- 87 | The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. 88 | The attention weights are using xavier initialization method. 89 | """ 90 | gain = nn.init.calculate_gain("relu") 91 | if hasattr(self, "fc"): 92 | nn.init.xavier_normal_(self.fc.weight, gain=gain) 93 | else: 94 | nn.init.xavier_normal_(self.fc_src.weight, gain=gain) 95 | nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) 96 | nn.init.xavier_normal_(self.attn_l, gain=gain) 97 | nn.init.xavier_normal_(self.attn_r, gain=gain) 98 | if self.has_explicit_bias: 99 | nn.init.constant_(self.bias, 0) 100 | if isinstance(self.res_fc, nn.Linear): 101 | nn.init.xavier_normal_(self.res_fc.weight, gain=gain) 102 | if self.res_fc.bias is not None: 103 | nn.init.constant_(self.res_fc.bias, 0) 104 | 105 | 106 | def set_allow_zero_in_degree(self, set_value): 107 | r""" 108 | 109 | Description 110 | ----------- 111 | Set allow_zero_in_degree flag. 112 | 113 | Parameters 114 | ---------- 115 | set_value : bool 116 | The value to be set to the flag. 117 | """ 118 | self._allow_zero_in_degree = set_value 119 | 120 | def forward(self, graph, feat, edge_weight=None, get_attention=False): 121 | r""" 122 | 123 | Description 124 | ----------- 125 | Compute graph attention network layer. 126 | 127 | Parameters 128 | ---------- 129 | graph : DGLGraph 130 | The graph. 131 | feat : torch.Tensor or pair of torch.Tensor 132 | If a torch.Tensor is given, the input feature of shape :math:`(N, *, D_{in})` where 133 | :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. 134 | If a pair of torch.Tensor is given, the pair must contain two tensors of shape 135 | :math:`(N_{in}, *, D_{in_{src}})` and :math:`(N_{out}, *, D_{in_{dst}})`. 136 | edge_weight : torch.Tensor, optional 137 | A 1D tensor of edge weight values. Shape: :math:`(|E|,)`. 138 | get_attention : bool, optional 139 | Whether to return the attention values. Default to False. 140 | 141 | Returns 142 | ------- 143 | torch.Tensor 144 | The output feature of shape :math:`(N, *, H, D_{out})` where :math:`H` 145 | is the number of heads, and :math:`D_{out}` is size of output feature. 146 | torch.Tensor, optional 147 | The attention values of shape :math:`(E, *, H, 1)`, where :math:`E` is the number of 148 | edges. This is returned only when :attr:`get_attention` is ``True``. 149 | 150 | Raises 151 | ------ 152 | DGLError 153 | If there are 0-in-degree nodes in the input graph, it will raise DGLError 154 | since no message will be passed to those nodes. This will cause invalid output. 155 | The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``. 156 | """ 157 | with graph.local_scope(): 158 | if not self._allow_zero_in_degree: 159 | if (graph.in_degrees() == 0).any(): 160 | raise DGLError( 161 | "There are 0-in-degree nodes in the graph, " 162 | "output for those nodes will be invalid. " 163 | "This is harmful for some applications, " 164 | "causing silent performance regression. " 165 | "Adding self-loop on the input graph by " 166 | "calling `g = dgl.add_self_loop(g)` will resolve " 167 | "the issue. Setting ``allow_zero_in_degree`` " 168 | "to be `True` when constructing this module will " 169 | "suppress the check and let the code run." 170 | ) 171 | 172 | if isinstance(feat, tuple): 173 | src_prefix_shape = feat[0].shape[:-1] 174 | dst_prefix_shape = feat[1].shape[:-1] 175 | h_src = self.feat_drop(feat[0]) 176 | h_dst = self.feat_drop(feat[1]) 177 | if not hasattr(self, "fc_src"): 178 | feat_src = self.fc(h_src).view( 179 | *src_prefix_shape, self._num_heads, self._out_feats 180 | ) 181 | feat_dst = self.fc(h_dst).view( 182 | *dst_prefix_shape, self._num_heads, self._out_feats 183 | ) 184 | else: 185 | feat_src = self.fc_src(h_src).view( 186 | *src_prefix_shape, self._num_heads, self._out_feats 187 | ) 188 | feat_dst = self.fc_dst(h_dst).view( 189 | *dst_prefix_shape, self._num_heads, self._out_feats 190 | ) 191 | else: 192 | src_prefix_shape = dst_prefix_shape = feat.shape[:-1] 193 | h_src = h_dst = self.feat_drop(feat) 194 | feat_src = feat_dst = self.fc(h_src).view( 195 | *src_prefix_shape, self._num_heads, self._out_feats 196 | ) 197 | if graph.is_block: 198 | feat_dst = feat_src[: graph.number_of_dst_nodes()] 199 | h_dst = h_dst[: graph.number_of_dst_nodes()] 200 | dst_prefix_shape = ( 201 | graph.number_of_dst_nodes(), 202 | ) + dst_prefix_shape[1:] 203 | # NOTE: GAT paper uses "first concatenation then linear projection" 204 | # to compute attention scores, while ours is "first projection then 205 | # addition", the two approaches are mathematically equivalent: 206 | # We decompose the weight vector a mentioned in the paper into 207 | # [a_l || a_r], then 208 | # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j 209 | # Our implementation is much efficient because we do not need to 210 | # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus, 211 | # addition could be optimized with DGL's built-in function u_add_v, 212 | # which further speeds up computation and saves memory footprint. 213 | el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) 214 | er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) 215 | graph.srcdata.update({"ft": feat_src, "el": el}) 216 | graph.dstdata.update({"er": er}) 217 | # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. 218 | graph.apply_edges(fn.u_add_v("el", "er", "e")) 219 | e = self.leaky_relu(graph.edata.pop("e")) 220 | # compute softmax 221 | graph.edata["a"] = self.attn_drop(edge_softmax(graph, e)) 222 | if edge_weight is not None: 223 | graph.edata["a"] = graph.edata["a"] * edge_weight.tile( 224 | 1, self._num_heads, 1 225 | ).transpose(0, 2) 226 | # message passing 227 | graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft")) 228 | rst = graph.dstdata["ft"] 229 | # residual 230 | if self.res_fc is not None: 231 | # Use -1 rather than self._num_heads to handle broadcasting 232 | resval = self.res_fc(h_dst).view( 233 | *dst_prefix_shape, -1, self._out_feats 234 | ) 235 | rst = rst + resval 236 | # bias 237 | if self.has_explicit_bias: 238 | rst = rst + self.bias.view( 239 | *((1,) * len(dst_prefix_shape)), 240 | self._num_heads, 241 | self._out_feats 242 | ) 243 | # activation 244 | if self.activation: 245 | rst = self.activation(rst) 246 | 247 | if get_attention: 248 | return rst, graph.edata["a"] 249 | else: 250 | return rst 251 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math, random 4 | import torch 5 | from tqdm.auto import tqdm 6 | 7 | sys.path.append("..") 8 | from utils import loss_func, touch_print 9 | from torch.utils.tensorboard import SummaryWriter 10 | from accelerate.logging import get_logger 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): 16 | for key, value in log_dict.items(): 17 | summary_writer.add_scalar(f'{key}', value, completed_steps) 18 | 19 | 20 | def accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir: str, completed_steps: int, args): 21 | accelerator.wait_for_everyone() 22 | 23 | accelerator.print(f"[CHECKPOINT] Saving checkpoint") 24 | 25 | if accelerator.is_main_process: 26 | tokenizer.save_pretrained(output_dir) 27 | torch.save(accelerator.get_state_dict(model.gnn), f"{output_dir}/GNN.pth") 28 | torch.save(accelerator.get_state_dict(model.adapter), f"{output_dir}/adapter.pth") 29 | 30 | if args.mode != 'pt': 31 | unwrapped_model = accelerator.unwrap_model(model) 32 | unwrapped_model.lm.save_pretrained( 33 | output_dir, 34 | is_main_process=accelerator.is_main_process, 35 | save_function=accelerator.save, 36 | state_dict=accelerator.get_state_dict(model.lm) 37 | ) 38 | 39 | accelerator.print( 40 | f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved" 41 | ) 42 | accelerator.wait_for_everyone() 43 | 44 | 45 | def accelerate_monitor(accelerator, reduce_loss, reduce_loss_ft, args, completed_steps, 46 | lr_scheduler, optimizer, summary_writer, reduce_step, reduce_step_ft): 47 | """ 48 | gather reduce_loss from all N devices. 49 | train logging and tensorboarding. 50 | """ 51 | if type(reduce_loss) != int: 52 | reduce_losses = accelerator.gather(reduce_loss) 53 | 54 | train_loss = torch.mean(reduce_losses) / reduce_step 55 | 56 | # logging and tensorboard 57 | logger.info( 58 | f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]" 59 | f"[gather shape={reduce_losses.shape}][lr={lr_scheduler.get_lr()[0]:.4e}, {optimizer.param_groups[0]['lr']:.4e}]", 60 | ) 61 | 62 | train_log_dict = {"training_loss": train_loss, "lr": lr_scheduler.get_lr()[0]} 63 | else: 64 | train_log_dict = {"lr": lr_scheduler.get_lr()[0]} 65 | 66 | if args.mode == 'ft' and type(reduce_loss_ft) != int: 67 | reduce_loss_ft = accelerator.gather(reduce_loss_ft) 68 | train_loss_ft = torch.mean(reduce_loss_ft) / reduce_step_ft 69 | train_log_dict["training_loss_ft"] = train_loss_ft 70 | 71 | if accelerator.is_main_process: 72 | write_tensorboard(summary_writer, train_log_dict, completed_steps) 73 | 74 | 75 | def accelerate_evaluate(accelerator, model, valid_dataloader, valid_dataloader_ft, args, completed_steps, step, min_eval_loss, stall_num, 76 | best_step, summary_writer): 77 | """ 78 | evaluate the model at current completed_steps on valid_dataloader and gather eval_loss on all devices. 79 | eval logging and tensorboarding. 80 | """ 81 | losses = [] 82 | for batch in valid_dataloader: 83 | with torch.no_grad(): 84 | outputs = model(batch) 85 | 86 | loss = loss_func( 87 | outputs=outputs, 88 | labels=batch['labels'], 89 | loss_mask=batch['loss_mask'], 90 | ) 91 | 92 | losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) 93 | # print(losses[-1].shape) 94 | 95 | accelerator.wait_for_everyone() 96 | valid_batch_num = len(losses) 97 | gathered_size = losses[0].shape 98 | losses = torch.cat(losses) 99 | 100 | try: 101 | eval_loss = torch.mean(losses) 102 | if eval_loss <= min_eval_loss: 103 | min_eval_loss = eval_loss 104 | stall_num = 0 105 | best_step = completed_steps 106 | else: 107 | stall_num += 1 108 | perplexity = math.exp(eval_loss) 109 | except OverflowError: 110 | perplexity = float("inf") 111 | 112 | logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]" 113 | f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]" 114 | f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}]") 115 | eval_log_dict = {"valid_loss": eval_loss.float(), 116 | "perplexity": perplexity} 117 | 118 | if args.mode == 'ft': 119 | losses = [] 120 | for batch in valid_dataloader_ft: 121 | with torch.no_grad(): 122 | outputs = model(batch) 123 | 124 | loss = loss_func( 125 | outputs=outputs, 126 | labels=batch['labels'], 127 | loss_mask=batch['loss_mask'], 128 | ) 129 | 130 | losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) 131 | 132 | accelerator.wait_for_everyone() 133 | valid_batch_num_ft = len(losses) 134 | gathered_size_ft = losses[0].shape 135 | losses = torch.cat(losses) 136 | 137 | try: 138 | eval_loss_ft = torch.mean(losses) 139 | perplexity_ft = math.exp(eval_loss_ft) 140 | except OverflowError: 141 | perplexity_ft = float("inf") 142 | 143 | logger.info(f"[valid_batch_num_ft={valid_batch_num_ft}], [gather_size_ft={gathered_size_ft}]" 144 | f"[perplexity_ft={perplexity_ft:.4f}][eval_loss_ft={eval_loss_ft:.6f}]") 145 | eval_log_dict["valid_loss_ft"] = eval_loss_ft.float() 146 | eval_log_dict["perplexity_ft"] = perplexity_ft 147 | 148 | if accelerator.is_main_process: 149 | write_tensorboard(summary_writer, eval_log_dict, completed_steps) 150 | 151 | return eval_loss, min_eval_loss, stall_num, best_step 152 | 153 | 154 | def accelerate_train(accelerator, model, train_dataloader, valid_dataloader, train_dataloader_ft, valid_dataloader_ft, 155 | optimizer, lr_scheduler, tokenizer, total_train_dataset_size, args): 156 | # tensorboard writer 157 | summary_writer = SummaryWriter(log_dir=args.tb_dir) if accelerator.is_main_process else None 158 | # Train! 159 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 160 | logger.info("**************************************** Running training ****************************************") 161 | logger.info(f" Num examples = {total_train_dataset_size}") 162 | logger.info(f" Num Epochs = {args.num_train_epochs}") 163 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 164 | logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 165 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 166 | logger.info(f" Total optimization(update/completed) steps = {args.max_train_steps}") 167 | logger.info(f" Complete/Optimization steps per Epoch = {args.max_train_steps // args.num_train_epochs}") 168 | logger.info("***************************************************************************************************") 169 | 170 | # Only show the progress bar once on each machine. 171 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 172 | 173 | # set starting_epoch, completed_steps and resume_step of train_dataloader 174 | completed_steps = 0 175 | starting_epoch = 0 176 | 177 | # monitor minimum eval_loss, stalling num, and best_step 178 | min_eval_loss = float('inf') 179 | stall_num = 0 180 | best_step = None 181 | 182 | # monitor train loss 183 | reduce_loss, reduce_step = 0, 0 184 | reduce_loss_ft, reduce_step_ft = 0, 0 185 | 186 | # Training Loop! 187 | for epoch in range(starting_epoch, args.num_train_epochs): 188 | if args.early_stopping and stall_num == args.early_stopping_stall_num: 189 | break 190 | 191 | # prepare dataloaders 192 | train_dataloader_iter = iter(train_dataloader) 193 | next_idx, next_ft_idx = 0, 0 194 | if args.mode == 'ft': 195 | train_dataloader_ft_iter = iter(train_dataloader_ft) 196 | ratio = len(train_dataloader_ft) / (len(train_dataloader) + len(train_dataloader_ft)) 197 | loss_ratio = min(len(train_dataloader_ft) / len(train_dataloader), 1) 198 | 199 | def get_batch(next_idx, next_ft_idx): 200 | if args.mode == 'pt' or next_ft_idx == len(train_dataloader_ft): 201 | # pt mode, or ft mode but ft dataloader is over 202 | next_idx += 1 203 | return next(train_dataloader_iter), next_idx, next_ft_idx 204 | elif next_idx == len(train_dataloader): 205 | # graph dataloader is over 206 | next_ft_idx += 1 207 | return next(train_dataloader_ft_iter), next_idx, next_ft_idx 208 | else: 209 | assert next_idx < len(train_dataloader) and next_ft_idx < len(train_dataloader_ft) 210 | if random.random() < ratio: 211 | next_ft_idx += 1 212 | return next(train_dataloader_ft_iter), next_idx, next_ft_idx 213 | else: 214 | next_idx += 1 215 | return next(train_dataloader_iter), next_idx, next_ft_idx 216 | 217 | print(f"length of dataloader: {len(train_dataloader)}") 218 | if args.mode == 'ft': 219 | print(f"length of dataloader: {len(train_dataloader_ft)}, ratio: {ratio}") 220 | 221 | model.train() 222 | # Inner Loop! 223 | for step in range(len(train_dataloader) + len(train_dataloader_ft) if args.mode == 'ft' else len(train_dataloader)): 224 | 225 | with accelerator.accumulate(model): 226 | batch, next_idx, next_ft_idx = get_batch(next_idx, next_ft_idx) 227 | # if step == 0: 228 | # touch_print(accelerator, batch, num_tokens=10) 229 | # forward 230 | outputs = model(batch) 231 | 232 | # loss 233 | loss = loss_func( 234 | outputs=outputs, 235 | labels=batch['labels'], 236 | loss_mask=batch['loss_mask'], 237 | ) * (loss_ratio if 'graph_embedding' in batch.keys() and args.mode == 'ft' else 1) 238 | 239 | # backward 240 | accelerator.backward(loss) 241 | 242 | # update(sync_gradients) 243 | optimizer.step() 244 | lr_scheduler.step() 245 | optimizer.zero_grad() 246 | # support args.min_lr 247 | if optimizer.param_groups[0]['lr'] <= args.min_lr: 248 | optimizer.param_groups[0]['lr'] = args.min_lr 249 | 250 | # accumulate resuce_loss in a log_interval 251 | if not torch.isnan(loss): 252 | if 'graph_embedding' in batch.keys(): 253 | reduce_loss += loss.detach().float() / (loss_ratio if 'graph_embedding' in batch.keys() and args.mode == 'ft' else 1) 254 | reduce_step += 1 255 | else: 256 | assert args.mode == 'ft' 257 | reduce_loss_ft += loss.detach().float() 258 | reduce_step_ft += 1 259 | 260 | # If the accelerator has performed an optimization step behind the scenes, thus a completed_step done. 261 | if accelerator.sync_gradients: 262 | 263 | completed_steps += 1 264 | # monitoring training process and logging and tensorboarding 265 | if completed_steps % args.log_interval == 0: 266 | progress_bar.update(args.log_interval) 267 | accelerate_monitor( 268 | accelerator, reduce_loss, reduce_loss_ft, args, completed_steps, 269 | lr_scheduler, optimizer, summary_writer, reduce_step, reduce_step_ft 270 | ) 271 | reduce_loss, reduce_loss_ft = 0, 0 272 | reduce_step, reduce_step_ft = 0, 0 273 | 274 | # steps checkpointing 275 | if args.checkpointing_steps and completed_steps % args.checkpointing_steps == 0: 276 | output_dir = f"step_{completed_steps}" 277 | if args.output_dir is not None: 278 | output_dir = os.path.join(args.output_dir, output_dir) 279 | accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args) 280 | 281 | # steps evaluation 282 | if completed_steps % args.evaluation_steps == 0: 283 | model.eval() 284 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate( 285 | accelerator, model, valid_dataloader, valid_dataloader_ft, args, completed_steps, step, 286 | min_eval_loss, stall_num, best_step, summary_writer 287 | ) 288 | model.train() 289 | 290 | # early stoppin when stalling more than args.early_stopping_stall_num 291 | if args.early_stopping and stall_num == args.early_stopping_stall_num: 292 | accelerator.print(f"[WARNING] Early stopping at {completed_steps}") 293 | break 294 | 295 | if completed_steps >= args.max_train_steps: 296 | break 297 | accelerator.wait_for_everyone() 298 | 299 | # epoch checkpointing 300 | if args.epoch_checkpointing: 301 | output_dir = f"epoch_{epoch}" 302 | if args.output_dir is not None: 303 | output_dir = os.path.join(args.output_dir, output_dir) 304 | accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args) 305 | 306 | if summary_writer: 307 | summary_writer.close() 308 | 309 | # final save 310 | output_dir = f"final_step_{completed_steps}" 311 | if args.output_dir is not None: 312 | output_dir = os.path.join(args.output_dir, output_dir) 313 | accelerate_saving_checkpoint(accelerator, model, tokenizer, output_dir, completed_steps, args) 314 | -------------------------------------------------------------------------------- /modeling/magnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import print_rank_0, print_highlight, print_rank_0_highlight 4 | from typing import Optional 5 | import torch.nn.functional as F 6 | from torch_geometric_signed_directed.nn import complex_relu_layer 7 | 8 | 9 | from torch.nn import Parameter 10 | from torch_geometric.nn.inits import zeros, glorot 11 | from torch_geometric.typing import OptTensor 12 | from torch_geometric.nn.conv import MessagePassing 13 | from torch_geometric.utils import remove_self_loops, add_self_loops 14 | from torch_geometric_signed_directed.utils.directed.get_magnetic_Laplacian import get_magnetic_Laplacian 15 | 16 | 17 | class MagNetConv(MessagePassing): 18 | r"""The magnetic graph convolutional operator from the 19 | `MagNet: A Neural Network for Directed Graphs. `_ paper 20 | :math:`\mathbf{\hat{L}}` denotes the scaled and normalized magnetic Laplacian 21 | :math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}`. 22 | 23 | Args: 24 | in_channels (int): Size of each input sample. 25 | out_channels (int): Size of each output sample. 26 | K (int): Chebyshev filter size :math:`K`. 27 | q (float, optional): Initial value of the phase parameter, 0 <= q <= 0.25. Default: 0.25. 28 | trainable_q (bool, optional): whether to set q to be trainable or not. (default: :obj:`False`) 29 | normalization (str, optional): The normalization scheme for the magnetic 30 | Laplacian (default: :obj:`sym`): 31 | 1. :obj:`None`: No normalization 32 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A} \odot \exp(i \Theta^{(q)})` 33 | 2. :obj:`"sym"`: Symmetric normalization 34 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} 35 | \mathbf{D}^{-1/2} \odot \exp(i \Theta^{(q)})` 36 | `\odot` denotes the element-wise multiplication. 37 | cached (bool, optional): If set to :obj:`True`, the layer will cache 38 | the __norm__ matrix on first execution, and will use the 39 | cached version for further executions. 40 | This parameter should only be set to :obj:`True` in transductive 41 | learning scenarios. (default: :obj:`False`) 42 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 43 | an additive bias. (default: :obj:`True`) 44 | **kwargs (optional): Additional arguments of 45 | :class:`torch_geometric.nn.conv.MessagePassing`. 46 | """ 47 | 48 | def __init__(self, in_channels: int, out_channels: int, K: int, q: float, trainable_q: bool, 49 | normalization: str = 'sym', cached: bool = False, bias: bool = True, **kwargs): 50 | kwargs.setdefault('aggr', 'add') 51 | super(MagNetConv, self).__init__(**kwargs) 52 | 53 | assert K > 0 54 | assert normalization in [None, 'sym'], 'Invalid normalization' 55 | kwargs.setdefault('flow', 'target_to_source') 56 | 57 | self.in_channels = in_channels 58 | self.out_channels = out_channels 59 | self.normalization = normalization 60 | self.cached = cached 61 | self.trainable_q = trainable_q 62 | if trainable_q: 63 | self.q = Parameter(torch.Tensor(1).fill_(q)) 64 | else: 65 | self.q = q 66 | self.weight = Parameter(torch.Tensor(K, in_channels, out_channels)) 67 | 68 | if bias: 69 | self.bias = Parameter(torch.Tensor(out_channels)) 70 | else: 71 | self.register_parameter('bias', None) 72 | 73 | self.reset_parameters() 74 | 75 | def reset_parameters(self): 76 | glorot(self.weight) 77 | zeros(self.bias) 78 | self.cached_result = None 79 | self.cached_num_edges = None 80 | self.cached_q = None 81 | 82 | def __norm__( 83 | self, 84 | edge_index, 85 | num_nodes: Optional[int], 86 | edge_weight: OptTensor, 87 | q: float, 88 | normalization: Optional[str], 89 | lambda_max, 90 | dtype: Optional[int] = None 91 | ): 92 | """ 93 | Get magnetic laplacian. 94 | 95 | Arg types: 96 | * edge_index (PyTorch Long Tensor) - Edge indices. 97 | * num_nodes (int, Optional) - Node features. 98 | * edge_weight (PyTorch Float Tensor, optional) - Edge weights corresponding to edge indices. 99 | * lambda_max (optional, but mandatory if normalization is None) - Largest eigenvalue of Laplacian. 100 | 101 | Return types: 102 | * edge_index_real, edge_index_imag, edge_weight_real, edge_weight_imag (PyTorch Float Tensor) - Magnetic laplacian tensor: real and imaginary edge indices and weights. 103 | """ 104 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 105 | 106 | edge_index, edge_weight_real, edge_weight_imag = get_magnetic_Laplacian( 107 | edge_index, edge_weight, normalization, dtype, num_nodes, q 108 | ) 109 | 110 | edge_weight_real = (2.0 * edge_weight_real) / lambda_max 111 | edge_weight_real.masked_fill_(edge_weight_real == float("inf"), 0) 112 | edge_index_imag = edge_index.clone() 113 | 114 | edge_index_real, edge_weight_real = add_self_loops( 115 | edge_index, edge_weight_real, fill_value=-1.0, num_nodes=num_nodes 116 | ) 117 | assert edge_weight_real is not None 118 | 119 | edge_weight_imag = (2.0 * edge_weight_imag) / lambda_max 120 | edge_weight_imag.masked_fill_(edge_weight_imag == float("inf"), 0) 121 | 122 | assert edge_weight_imag is not None 123 | 124 | return edge_index_real, edge_index_imag, edge_weight_real, edge_weight_imag 125 | 126 | def forward( 127 | self, 128 | x_real: torch.FloatTensor, 129 | x_imag: torch.FloatTensor, 130 | edge_index: torch.LongTensor, 131 | edge_weight: OptTensor = None, 132 | lambda_max: OptTensor = None, 133 | ) -> torch.FloatTensor: 134 | """ 135 | Making a forward pass of the MagNet Convolution layer. 136 | 137 | Arg types: 138 | * x_real, x_imag (PyTorch Float Tensor) - Node features. 139 | * edge_index (PyTorch Long Tensor) - Edge indices. 140 | * edge_weight (PyTorch Float Tensor, optional) - Edge weights corresponding to edge indices. 141 | * lambda_max (optional, but mandatory if normalization is None) - Largest eigenvalue of Laplacian. 142 | Return types: 143 | * out_real, out_imag (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (N_nodes, F_out). 144 | """ 145 | if self.trainable_q: 146 | self.q = Parameter(torch.clamp(self.q, 0, 0.25)) 147 | 148 | if self.cached and self.cached_result is not None: 149 | if edge_index.size(1) != self.cached_num_edges: 150 | raise RuntimeError( 151 | 'Cached {} number of edges, but found {}. Please ' 152 | 'disable the caching behavior of this layer by removing ' 153 | 'the `cached=True` argument in its constructor.'.format( 154 | self.cached_num_edges, edge_index.size(1))) 155 | if self.q != self.cached_q: 156 | raise RuntimeError( 157 | 'Cached q is {}, but found {} in input. Please ' 158 | 'disable the caching behavior of this layer by removing ' 159 | 'the `cached=True` argument in its constructor.'.format( 160 | self.cached_q, self.q)) 161 | if not self.cached or self.cached_result is None: 162 | self.cached_num_edges = edge_index.size(1) 163 | if self.trainable_q: 164 | self.cached_q = self.q.detach().item() 165 | else: 166 | self.cached_q = self.q 167 | if self.normalization != 'sym' and lambda_max is None: 168 | if self.trainable_q: 169 | raise RuntimeError( 170 | 'Cannot train q while not calculating maximum eigenvalue of Laplacian!') 171 | _, _, _, lambda_max = get_magnetic_Laplacian( 172 | edge_index, edge_weight, None, q=self.q, return_lambda_max=True 173 | ) 174 | 175 | if lambda_max is None: 176 | lambda_max = torch.tensor( 177 | 2.0, dtype=x_real.dtype, device=x_real.device) 178 | if not isinstance(lambda_max, torch.Tensor): 179 | lambda_max = torch.tensor(lambda_max, dtype=x_real.dtype, 180 | device=x_real.device) 181 | assert lambda_max is not None 182 | edge_index_real, edge_index_imag, norm_real, norm_imag = self.__norm__(edge_index, x_real.size(self.node_dim), 183 | edge_weight, self.q, self.normalization, 184 | lambda_max, dtype=x_real.dtype) 185 | self.cached_result = edge_index_real, edge_index_imag, norm_real, norm_imag 186 | 187 | edge_index_real, edge_index_imag, norm_real, norm_imag = self.cached_result 188 | 189 | Tx_0_real_real = x_real 190 | Tx_0_imag_imag = x_imag 191 | Tx_0_imag_real = x_real 192 | Tx_0_real_imag = x_imag 193 | out_real_real = torch.matmul(Tx_0_real_real, self.weight[0]) 194 | out_imag_imag = torch.matmul(Tx_0_imag_imag, self.weight[0]) 195 | out_imag_real = torch.matmul(Tx_0_imag_real, self.weight[0]) 196 | out_real_imag = torch.matmul(Tx_0_real_imag, self.weight[0]) 197 | 198 | # propagate_type: (x: Tensor, norm: Tensor) 199 | if self.weight.size(0) > 1: 200 | Tx_1_real_real = self.propagate( 201 | edge_index_real, x=x_real, norm=norm_real, size=None).to(self.weight[1].dtype) 202 | out_real_real = out_real_real + \ 203 | torch.matmul(Tx_1_real_real, self.weight[1]) 204 | Tx_1_imag_imag = self.propagate( 205 | edge_index_imag, x=x_imag, norm=norm_imag, size=None).to(self.weight[1].dtype) 206 | out_imag_imag = out_imag_imag + \ 207 | torch.matmul(Tx_1_imag_imag, self.weight[1]) 208 | Tx_1_imag_real = self.propagate( 209 | edge_index_real, x=x_real, norm=norm_real, size=None).to(self.weight[1].dtype) 210 | out_imag_real = out_imag_real + \ 211 | torch.matmul(Tx_1_imag_real, self.weight[1]) 212 | Tx_1_real_imag = self.propagate( 213 | edge_index_imag, x=x_imag, norm=norm_imag, size=None).to(self.weight[1].dtype) 214 | out_real_imag = out_real_imag + \ 215 | torch.matmul(Tx_1_real_imag, self.weight[1]) 216 | 217 | for k in range(2, self.weight.size(0)): 218 | Tx_2_real_real = self.propagate( 219 | edge_index_real, x=Tx_1_real_real, norm=norm_real, size=None) 220 | Tx_2_real_real = 2. * Tx_2_real_real - Tx_0_real_real 221 | out_real_real = out_real_real + \ 222 | torch.matmul(Tx_2_real_real, self.weight[k]) 223 | Tx_0_real_real, Tx_1_real_real = Tx_1_real_real, Tx_2_real_real 224 | 225 | Tx_2_imag_imag = self.propagate( 226 | edge_index_imag, x=Tx_1_imag_imag, norm=norm_imag, size=None) 227 | Tx_2_imag_imag = 2. * Tx_2_imag_imag - Tx_0_imag_imag 228 | out_imag_imag = out_imag_imag + \ 229 | torch.matmul(Tx_2_imag_imag, self.weight[k]) 230 | Tx_0_imag_imag, Tx_1_imag_imag = Tx_1_imag_imag, Tx_2_imag_imag 231 | 232 | Tx_2_imag_real = self.propagate( 233 | edge_index_real, x=Tx_1_imag_real, norm=norm_real, size=None) 234 | Tx_2_imag_real = 2. * Tx_2_imag_real - Tx_0_imag_real 235 | out_imag_real = out_imag_real + \ 236 | torch.matmul(Tx_2_imag_real, self.weight[k]) 237 | Tx_0_imag_real, Tx_1_imag_real = Tx_1_imag_real, Tx_2_imag_real 238 | 239 | Tx_2_real_imag = self.propagate( 240 | edge_index_imag, x=Tx_1_real_imag, norm=norm_imag, size=None) 241 | Tx_2_real_imag = 2. * Tx_2_real_imag - Tx_0_real_imag 242 | out_real_imag = out_real_imag + \ 243 | torch.matmul(Tx_2_real_imag, self.weight[k]) 244 | Tx_0_real_imag, Tx_1_real_imag = Tx_1_real_imag, Tx_2_real_imag 245 | 246 | out_real = out_real_real - out_imag_imag 247 | out_imag = out_imag_real + out_real_imag 248 | 249 | if self.bias is not None: 250 | out_real += self.bias 251 | out_imag += self.bias 252 | 253 | return out_real, out_imag 254 | 255 | def message(self, x_j, norm): 256 | return norm.view(-1, 1) * x_j 257 | 258 | def __repr__(self): 259 | return '{}({}, {}, K={}, normalization={})'.format( 260 | self.__class__.__name__, self.in_channels, self.out_channels, 261 | self.weight.size(0), self.normalization) 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | class MagNet(nn.Module): 270 | r""" 271 | Args: 272 | num_features (int): Size of each input sample. 273 | hidden (int, optional): Number of hidden channels. Default: 2. 274 | K (int, optional): Order of the Chebyshev polynomial. Default: 2. 275 | q (float, optional): Initial value of the phase parameter, 0 <= q <= 0.25. Default: 0.25. 276 | output_dim (int, optional): output dimension. Default: 2. 277 | activation (bool, optional): whether to use activation function or not. (default: :obj:`False`) 278 | trainable_q (bool, optional): whether to set q to be trainable or not. (default: :obj:`False`) 279 | layer (int, optional): Number of MagNetConv layers. Deafult: 2. 280 | dropout (float, optional): Dropout value. (default: :obj:`False`) 281 | normalization (str, optional): The normalization scheme for the magnetic 282 | Laplacian (default: :obj:`sym`): 283 | 1. :obj:`None`: No normalization 284 | :math:`\mathbf{L} = \mathbf{D} - \mathbf{A} \odot \exp(i \Theta^{(q)})` 285 | 2. :obj:`"sym"`: Symmetric normalization 286 | :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} 287 | \mathbf{D}^{-1/2} \odot \exp(i \Theta^{(q)})` 288 | `\odot` denotes the element-wise multiplication. 289 | cached (bool, optional): If set to :obj:`True`, the layer will cache 290 | the __norm__ matrix on first execution, and will use the 291 | cached version for further executions. 292 | This parameter should only be set to :obj:`True` in transductive 293 | learning scenarios. (default: :obj:`False`) 294 | """ 295 | 296 | def __init__(self, num_features: int, hidden: int = 2, q: float = 0.25, K: int = 2, output_dim: int = 2, 297 | activation: bool = False, trainable_q: bool = False, layer: int = 2, dropout: float = False, normalization: str = 'sym', cached: bool = False): 298 | super(MagNet, self).__init__() 299 | 300 | chebs = nn.ModuleList() 301 | chebs.append(MagNetConv(in_channels=num_features, out_channels=hidden, K=K, 302 | q=q, trainable_q=trainable_q, normalization=normalization, cached=cached)) 303 | self.normalization = normalization 304 | self.activation = activation 305 | if self.activation: 306 | self.complex_relu = complex_relu_layer() 307 | 308 | for _ in range(1, layer): 309 | chebs.append(MagNetConv(in_channels=hidden, out_channels=hidden, K=K, 310 | q=q, trainable_q=trainable_q, normalization=normalization, cached=cached)) 311 | 312 | self.Chebs = chebs 313 | 314 | self.Conv = nn.Conv1d(2*hidden, output_dim, kernel_size=1) 315 | self.dropout = dropout 316 | 317 | def reset_parameters(self): 318 | for cheb in self.Chebs: 319 | cheb.reset_parameters() 320 | self.Conv.reset_parameters() 321 | 322 | def forward(self, real: torch.FloatTensor, imag: torch.FloatTensor, edge_index: torch.LongTensor, 323 | edge_weight: Optional[torch.LongTensor] = None) -> torch.FloatTensor: 324 | """ 325 | Making a forward pass of the MagNet node classification model. 326 | 327 | Arg types: 328 | * real, imag (PyTorch Float Tensor) - Node features. 329 | * edge_index (PyTorch Long Tensor) - Edge indices. 330 | * edge_weight (PyTorch Float Tensor, optional) - Edge weights corresponding to edge indices. 331 | Return types: 332 | * log_prob (PyTorch Float Tensor) - Logarithmic class probabilities for all nodes, with shape (num_nodes, num_classes). 333 | """ 334 | for cheb in self.Chebs: 335 | real, imag = cheb(real, imag, edge_index, edge_weight) 336 | if self.activation: 337 | real, imag = self.complex_relu(real, imag) 338 | 339 | x = torch.cat((real, imag), dim=-1) 340 | 341 | if self.dropout > 0: 342 | x = F.dropout(x, self.dropout, training=self.training) 343 | 344 | x = x.unsqueeze(0) 345 | x = x.permute((0, 2, 1)) 346 | x = self.Conv(x) 347 | # x = F.log_softmax(x, dim=1) 348 | return torch.transpose(x[0], 0, 1) --------------------------------------------------------------------------------