├── README.md ├── configs ├── data │ ├── citeseer.yaml │ ├── citeseer_tag.yaml │ ├── cora.yaml │ ├── cora_tag.yaml │ ├── cornell.yaml │ ├── data_defaults.yaml │ ├── pubmed.yaml │ ├── texas.yaml │ └── wisconsin.yaml ├── exp │ ├── icl.yaml │ └── sft.yaml ├── llm │ ├── chatgpt.yaml │ ├── gpt4.yaml │ ├── llama_icl.yaml │ ├── llama_peft.yaml │ └── llm_meta_data.yaml ├── main.yaml ├── model │ └── graph_text.yaml └── prompt │ ├── graph_tree_meta_data.yaml │ └── prompts.yaml ├── requirements.txt └── src ├── graph_text ├── __init__.py ├── agent.py ├── conversation.py ├── graph_instruction_dataset.py ├── icl.py ├── llama_flash_attn_monkey_patch.py ├── model.py ├── prompts.py └── samplers.py ├── llm ├── __init__.py ├── fake_llm.py ├── gpt.py ├── llama_icl.py └── llm.py ├── scripts ├── run_icl.py └── run_sft.py └── utils ├── __init__.py ├── basics ├── __init__.py ├── cfg_utils.py ├── iterables.py ├── logging.py ├── np_utils.py └── os_utils.py ├── data ├── __init__.py ├── graph_tree.py ├── ppr.py ├── preprocess.py └── textual_graph.py ├── pkg ├── __init__.py ├── dict2xml.py ├── distributed.py ├── graph_utils.py └── hf_utils.py └── project ├── __init__.py └── exp.py /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | Source code for [GraphText: Graph Reasoning in Text Space](https://arxiv.org/abs/2310.01089). 3 | 4 | # Steps to reproduce 5 | 6 | ## Python Environment 7 | 8 | ```shell 9 | pip install -r requirements.txt 10 | ``` 11 | ## Key Hyper-parameters 12 | Given an ego-graph, GraphText extracts text information (attributes) and relation information to construct a tree. 13 | 14 | The node text attributes, denoted as `text_info`, is a **set** of attributes derived from the (ego)graph, the valid items to compose the set are: 15 | - choice: The label of the node, in terms of the choice; e.g. in Cora, "D" is for the class of "Neural Network". Note that, if the node is not in the training set, the choice will be "NA". 16 | - a{k}x_t: The K-Means clustering index of original feature propagated $k$ times $k>=0$. To illustrate: a0x_t means the K-Means clustering index of the raw feature, a2x_t means the K-Means clustering index of the feature propagated 2 times. 17 | - a{k}y_t: The choice of training labels propagated $k$ times $k>=1$. 18 | 19 | 20 | The relations, denoted as `rel_info`, is a set of attributes derived from the (ego)graph, the valid items to compose the set are: 21 | - choice: The label of the node, in terms of the choice; e.g. in Cora, "D" is for the class of "Neural Network". Note that, if the node is not in the training set, the choice will be "NA". 22 | - a{k}x_t: The K-Means clustering index of original feature propagated $k$ times $k>=0$. To illustrate: a0x_t means the K-Means clustering index of the raw feature, a2x_t means the K-Means clustering index of the feature propagated 2 times. 23 | - a{k}y_t: The choice of training labels propagated $k$ times $k>=1$. 24 | 25 | 26 | ## Commands 27 | 28 | ### Setup OPENAI-API-Key 29 | Make sure to **set the openai api key** to environment variable before running ICL experiments. You can set it up by 30 | `export OPENAI_API_KEY="YourOwnAPIKey"`, or changing the `configs/main.yaml` for convenience: 31 | 32 | ```yaml 33 | env: 34 | vars: 35 | openai_api_key: ${oc.env:OPENAI_API_KEY,YourAPIKey} # Overwrite this to your API key 36 | ``` 37 | 38 | ### In-context Learning 39 | #### Original Split 40 | ```shell 41 | export OPENAI_API_KEY="YourOwnAPIKey" 42 | cd src/scripts 43 | python run_icl.py data=cora text_info=a2y_t.a3y_t rel_info=spd0.ppr.a2x_sim.a3x_sim 44 | python run_icl.py data=citeseer text_info=a3y_t.a0x_t rel_info=spd0.spd2.ppr.a2x_sim 45 | python run_icl.py data=texas text_info=a2y_t.a3y_t rel_info=spd2 46 | python run_icl.py data=wisconsin text_info=choice.a0x_t rel_info=a0x_sim.spd3 47 | python run_icl.py data=cornell text_info=a1y_t.a4y_t rel_info=spd1.a3x_sim 48 | ``` 49 | #### Few-Shot Node Classification 50 | 51 | ```shell 52 | export OPENAI_API_KEY="YourOwnAPIKey" 53 | cd src/scripts 54 | python run_icl.py data=citeseer data.n_shots=1 text_info=a0x_t.a3y_t rel_info=spd0.spd3 55 | python run_icl.py data=citeseer data.n_shots=3 text_info=a0x_t.a3y_t rel_info=spd0.spd3.a2x_sim.a3x_sim 56 | python run_icl.py data=citeseer data.n_shots=5 text_info=a0x_t.a3y_t rel_info=spd0.spd3.ppr.a3x_sim 57 | python run_icl.py data=citeseer data.n_shots=10 text_info=a0x_t.a3y_t rel_info=spd0.a0x_sim.a1x_sim 58 | python run_icl.py data=citeseer data.n_shots=15 text_info=a0x_t.a3y_t rel_info=spd0.a0x_sim.a1x_sim 59 | python run_icl.py data=citeseer data.n_shots=20 text_info=a0x_t.a3y_t rel_info=spd0.spd3.a2x_sim.a3x_sim 60 | 61 | python run_icl.py data=texas data.n_shots=1 text_info=a2y_t rel_info=spd0.spd2 62 | python run_icl.py data=texas data.n_shots=3 text_info=choice rel_info=spd3 63 | python run_icl.py data=texas data.n_shots=5 text_info=a2y_t rel_info=spd0.spd2 64 | python run_icl.py data=texas data.n_shots=10 text_info=choice rel_info=spd2 65 | python run_icl.py data=texas data.n_shots=15 text_info=choice rel_info=spd2 66 | python run_icl.py data=texas data.n_shots=20 text_info=choice rel_info=spd2 67 | ``` 68 | 69 | ### Supervised Fine-tuning (SFT) 70 | GraphText supports instruction fine-tuning a LLM on graph. An MLP is used to map the continuous feature to text space (as tokens). We recommend to use BF16 for stable training. 71 | ```shell 72 | cd src/scripts 73 | python run_sft.py exp=sft lora.r=-1 run_sft.py data=citeseer_tag nb_padding=false add_label_name_output=false max_bsz_per_gpu=4 eq_batch_size=16 rel_info=spd0.a0x_sim.ppr text_info=x llm.base_model=llama2-7b node_dropout=0 subgraph_size=3 total_steps=1000 74 | 75 | python run_sft.py exp=sft lora.r=-1 run_sft.py data=cora_tag nb_padding=false add_label_name_output=false max_bsz_per_gpu=4 eq_batch_size=16 rel_info=spd0.a1x_sim text_info=x llm.base_model=llama2-7b node_dropout=0 subgraph_size=3 total_steps=1000 76 | ``` 77 | # Misc 78 | ## Analyze the Results 79 | We highly recommend using Wandb to track the metrics. All the results are saved to an Excel file "${out_dir}{split}-${alias}.csv" with prompt and the generated text. 80 | 81 | ## Other Useful Parameters 82 | - `data.n_shots`: Number of shots for few-shot settings. 83 | - `debug`: Specify `debug=true` for a fake/small LLM in ICL/SFT to debug (to save time and money when developing). 84 | - `data.max_train_samples`, `data.max_eval_samples`, `data.max_test_samples`: Number of samples for train/eval/test. 85 | - `use_wandb`: `use_wandb=true` `use_wandb=false` to turn on/off Wandb sync. 86 | - `lora.r`: Specifies the rank for LoRA (used in SFT experiments only), if `lora.r'<0, then, LoRA is turned off (only the projection layer is trained). 87 | 88 | ## FAQ 89 | ### GPT initialize failed 90 | Error message: Error locating target 'llm.gpt.GPT', set env var HYDRA_FULL_ERROR=1 to see chained exception. 91 | Checklist: 92 | - Check if openai is installed. 93 | - Check if OPENAI_API_KEY is in your environment variable. Make sure to `export OPENAI_API_KEY="YourOwnAPIKey` before running the code. 94 | 95 | 96 | ## Citation 97 | If you find our work useful, please consider citing our work: 98 | ``` 99 | @misc{zhao2023graphtext, 100 | title={GraphText: Graph Reasoning in Text Space}, 101 | author={Jianan Zhao and Le Zhuo and Yikang Shen and Meng Qu and Kai Liu and Michael Bronstein and Zhaocheng Zhu and Jian Tang}, 102 | year={2023}, 103 | eprint={2310.01089}, 104 | archivePrefix={arXiv}, 105 | primaryClass={cs.CL} 106 | } 107 | ``` -------------------------------------------------------------------------------- /configs/data/citeseer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - data_defaults 5 | 6 | data: 7 | name: citeseer # To be overwritten by dataset specific values. 8 | alias: Cite 9 | type: dgl 10 | # type: explore 11 | _init_args: 12 | _target_: dgl.data.CiteseerGraphDataset 13 | raw_dir: ${path.data_storage} 14 | dataset_path: temp/data/${data.name} 15 | 16 | text: 17 | mode: label_name # How to generate text for each node? 18 | label_name: 19 | '0': Agents # 351 20 | '1': Artificial Intelligence # 217 21 | '2': Database # 418 22 | '3': Information Retrieval # 818 23 | '4': Machine Learning # 426 24 | '5': Human computer interaction # 298 25 | label_text: name 26 | #tokenized_folder: ${path.data_cache}${.name}${.mode}_{model}/ 27 | #tokenized_flag: ${.tokenized_folder}processed.flag 28 | # * meta_info: 29 | n_labels: 6 30 | n_nodes: 3327 31 | feat_dim: 3703 32 | task_description: >- 33 | You are a helpful assistant that classifies the topic of an academic paper based on the labels of the cited papers. You are going to choose the correct answer from several choices of paper categories: ${data.label_description} -------------------------------------------------------------------------------- /configs/data/citeseer_tag.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /data/citeseer 5 | 6 | data: 7 | name: cite_tag # To be overwritten by dataset specific values. 8 | alias: CiteTAG 9 | type: tag 10 | dataset_path: data/${data.name}/citeseer_random_sbert.pt 11 | feat_dim: 384 12 | text_cutoff: 64 -------------------------------------------------------------------------------- /configs/data/cora.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - data_defaults 5 | 6 | data: 7 | name: cora # To be overwritten by dataset specific values. 8 | alias: Cora 9 | type: dgl 10 | _init_args: 11 | _target_: dgl.data.CoraGraphDataset 12 | raw_dir: ${path.data_storage} 13 | 14 | text: 15 | mode: label_name # How to generate text for each node? 16 | label_name: 17 | '0': Theory # 351 18 | '1': Reinforcement Learning # 217 19 | '2': Genetic Algorithm # 418 20 | '3': Neural Network # 818 21 | '4': Probabilistic Method # 426 22 | '5': Case Based # 298 23 | '6': Rule Learning # 180 24 | label_text: name 25 | #tokenized_folder: ${path.data_cache}${.name}${.mode}_{model}/ 26 | #tokenized_flag: ${.tokenized_folder}processed.flag 27 | # * meta_info: 28 | n_labels: 7 29 | n_nodes: 2708 30 | feat_dim: 1433 31 | task_description: >- 32 | You are a helpful assistant that classifies the topic of an academic paper based on the labels of the cited papers. You are going to choose the correct answer from several choices of paper categories: ${data.label_description} -------------------------------------------------------------------------------- /configs/data/cora_tag.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /data/cora 5 | 6 | data: 7 | name: cora_tag # To be overwritten by dataset specific values. 8 | alias: CoraTAG 9 | type: tag 10 | dataset_path: data/${data.name}/cora_random_sbert.pt 11 | feat_dim: 384 12 | text_cutoff: 64 -------------------------------------------------------------------------------- /configs/data/cornell.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - data_defaults 5 | 6 | data: 7 | name: cornell # To be overwritten by dataset specific values. 8 | alias: Cornell 9 | type: dgl 10 | # type: explore 11 | _init_args: 12 | _target_: dgl.data.CornellDataset 13 | raw_dir: ${path.data_storage} 14 | dataset_path: temp/data/${data.name} 15 | 16 | text: 17 | mode: label_name # How to generate text for each node? 18 | label_name: 19 | '0': student 20 | '1': project 21 | '2': course 22 | '3': staff 23 | '4': faculty 24 | label_text: name 25 | #tokenized_folder: ${path.data_cache}${.name}${.mode}_{model}/ 26 | #tokenized_flag: ${.tokenized_folder}processed.flag 27 | # * meta_info: 28 | n_labels: 5 29 | n_nodes: 183 30 | feat_dim: 1703 31 | task_description: >- 32 | You are a helpful assistant that classifies the role of a web page in the network, where nodes represent web pages, and edges are hyperlinks between them. The web pages are manually classified into the five categories: ${data.label_description} -------------------------------------------------------------------------------- /configs/data/data_defaults.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | prompt_templates: 4 | qa_instruction: '' 5 | task_context: '' 6 | 7 | data: 8 | name: ??? # To be overwritten by dataset specific values. 9 | alias: ??? # To be overwritten by dataset specific values. 10 | _target_: .datasets.SeqGraph 11 | type: ??? # To be overwritten by dataset specific values. 12 | n_labels: ??? 13 | n_nodes: ??? 14 | 15 | # * raw_data: 16 | raw_data_path: ${path.data_storage} 17 | info_file: ${working_dir}graph.info 18 | 19 | # * process: 20 | max_train_samples: 9999999 21 | max_eval_samples: 999999 22 | max_test_samples: 999999 23 | label_description: '' 24 | n_shots: -1 # -1 for default split 25 | #tokenized_folder: ${path.data_cache}${.name}${.mode}_{model}/ 26 | #tokenized_flag: ${.tokenized_folder}processed.flag 27 | 28 | -------------------------------------------------------------------------------- /configs/data/pubmed.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - data_defaults 5 | 6 | data: 7 | name: pubmed # To be overwritten by dataset specific values. 8 | alias: Pub 9 | type: dgl 10 | _init_args: 11 | _target_: dgl.data.PubmedGraphDataset 12 | raw_dir: ${path.data_storage} 13 | 14 | text: 15 | mode: label_name # How to generate text for each node? 16 | label_name: 17 | '0': Experimental # 18 | '1': Type 1 # 19 | '2': Type 2 # 20 | label_text: name 21 | #tokenized_folder: ${path.data_cache}${.name}${.mode}_{model}/ 22 | #tokenized_flag: ${.tokenized_folder}processed.flag 23 | # * meta_info: 24 | n_labels: 3 25 | n_nodes: 19717 26 | feat_dim: 500 27 | -------------------------------------------------------------------------------- /configs/data/texas.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - data_defaults 5 | 6 | data: 7 | name: texas # To be overwritten by dataset specific values. 8 | alias: Texas 9 | type: dgl 10 | # type: explore 11 | _init_args: 12 | _target_: dgl.data.TexasDataset 13 | raw_dir: ${path.data_storage} 14 | dataset_path: temp/data/${data.name} 15 | 16 | text: 17 | mode: label_name # How to generate text for each node? 18 | label_name: 19 | '0': student 20 | '1': project 21 | '2': course 22 | '3': staff 23 | '4': faculty 24 | label_text: name 25 | #tokenized_folder: ${path.data_cache}${.name}${.mode}_{model}/ 26 | #tokenized_flag: ${.tokenized_folder}processed.flag 27 | # * meta_info: 28 | n_labels: 5 29 | n_nodes: 183 30 | feat_dim: 1703 31 | task_description: >- 32 | You are a helpful assistant that classifies the role of a web page in the network, where nodes represent web pages, and edges are hyperlinks between them. The web pages are manually classified into the five categories: ${data.label_description} -------------------------------------------------------------------------------- /configs/data/wisconsin.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - data_defaults 5 | 6 | data: 7 | name: wisconsin # To be overwritten by dataset specific values. 8 | alias: Wisconsin 9 | type: dgl 10 | # type: explore 11 | _init_args: 12 | _target_: dgl.data.WisconsinDataset 13 | raw_dir: ${path.data_storage} 14 | dataset_path: temp/data/${data.name} 15 | 16 | text: 17 | mode: label_name # How to generate text for each node? 18 | label_name: 19 | '0': student 20 | '1': project 21 | '2': course 22 | '3': staff 23 | '4': faculty 24 | label_text: name 25 | #tokenized_folder: ${path.data_cache}${.name}${.mode}_{model}/ 26 | #tokenized_flag: ${.tokenized_folder}processed.flag 27 | # * meta_info: 28 | n_labels: 5 29 | n_nodes: 251 30 | feat_dim: 1703 31 | task_description: >- 32 | You are a helpful assistant that classifies the role of a web page in the network, where nodes represent web pages, and edges are hyperlinks between them. The web pages are manually classified into the five categories: ${data.label_description} -------------------------------------------------------------------------------- /configs/exp/icl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /llm: chatgpt # Disable Hydra logging 4 | 5 | mode: icl 6 | data: 7 | max_eval_samples: 999999 8 | node_dropout: 0 # NO NEED FOR DROPOUT NODE 9 | wandb_proj: GraphText-ICL-Release 10 | -------------------------------------------------------------------------------- /configs/exp/sft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /llm: llama_peft 4 | 5 | mode: sft 6 | data: 7 | max_eval_samples: 999999 8 | use_demo: false 9 | out_field: c 10 | wandb_proj: GraphText-SFT 11 | 12 | # Prompt 13 | 14 | human_prompt: base 15 | gpt_prompt: base 16 | instruct_prompt: sft 17 | demo_prompt: base 18 | demo_qa_prompt: base 19 | question_prompt: sft 20 | 21 | # @ Encoder 22 | encoder: 23 | dropout: 0.5 24 | input_norm: true 25 | output_norm: true 26 | norm: LN 27 | input_dropout: true 28 | output_dropout: false 29 | new_arg: true 30 | encoder_alias: ${encoder.dropout}do${encoder.dropout} 31 | log_freq: 500 32 | 33 | use_flash_attn: false 34 | dropout: 0.5 35 | eval_choice_only: ${add_class_token} 36 | alias: ${llm.name}-${data.alias}-${rel_info}-${text_info} 37 | 38 | # 39 | add_class_token: true 40 | add_label_name_output: true 41 | add_field_token: true 42 | add_info_token: true 43 | add_pad_token: true 44 | # 45 | eval_metric: val_acc 46 | 47 | 48 | # @ EVALUATE 49 | add_loop_inference: false 50 | use_fwd_eval: false # FIXME To be removed 51 | metrics: [ 'acc' ] 52 | eval_sets: [ 'train', 'val' , 'test' ] 53 | min_eval_step: 100 54 | choice_readout_pos: 0 55 | eval_freq: 100 56 | 57 | # @ LLM 58 | max_tgt_len: 2048 # the maximum sequence length to be generated 59 | max_gen_len: 5 # the maximum sequence length to be generated 60 | lora: 61 | r: -1 # skip LoRA if rank < 1 62 | alpha: ${.r} 63 | dropout: 0.1 64 | target_modules: [ q_proj, v_proj, k_proj, o_proj ] 65 | # modules_to_save: [ embed_tokens, lm_head ] 66 | 67 | # @ Trainer (Agent) 68 | stage: 2 69 | save_freq: 30000 70 | max_epochs: 9999 71 | total_steps: 700 #Number of train steps 72 | frozen_encoder: false 73 | frozen_ori_llm_parameters: true # Fixme 74 | 75 | agent_name: DeepSpeedAgent 76 | ds_config_path: configs/dsconfig/openllama_peft_stage_${stage}.json 77 | 78 | # @ Deepspeed related 79 | use_deepspeed: true # For debug only 80 | eq_batch_size: 4 81 | inf_batch_size: ${oc.select:model._meta_data.inf_bsz,12} 82 | max_bsz_per_gpu: ${oc.select:llm._meta_data.max_bsz_per_gpu,12} 83 | bsz_per_gpu: ${get_bsz_per_gpu:${eq_batch_size}, ${max_bsz_per_gpu}} 84 | grad_acc_steps: ${get_grad_acc_steps:${eq_batch_size}, ${max_bsz_per_gpu}} 85 | 86 | # ! Float 87 | use_fp16: true 88 | use_bf16: true 89 | optimizer_type: AdamW 90 | 91 | # ! Optimizer 92 | warmup_rate: 0.1 93 | lr: 5e-5 94 | 95 | ds: # Deepspeed config 96 | train_batch_size: ${eq_batch_size} 97 | train_micro_batch_size_per_gpu: ${bsz_per_gpu} 98 | gradient_accumulation_steps: ${grad_acc_steps} # ! To be overwritten 99 | steps_per_print: 2000 100 | gradient_clipping: 1.0 101 | zero_optimization: 102 | stage: 2 # ??? # Original 2 103 | offload_optimizer: 104 | device: cpu 105 | contiguous_gradients: true 106 | allgather_bucket_size: 500000000 107 | allgather_partitions: true 108 | 109 | fp16: 110 | enabled: ${use_fp16} 111 | opt_level: O2 112 | min_loss_scale: 1 113 | 114 | bf16: 115 | enable: ${use_bf16} 116 | 117 | optimizer: 118 | type: ${optimizer_type} 119 | params: 120 | lr: ${lr} 121 | betas: [ 0.9, 0.95 ] 122 | eps: 1e-8 123 | weight_decay: 0.001 124 | 125 | scheduler: 126 | type: WarmupDecayLR 127 | params: 128 | warmup_min_lr: 0 129 | warmup_max_lr: ${lr} 130 | warmup_num_steps: ${round_mult:${total_steps}, ${warmup_rate}} 131 | total_num_steps: ${total_steps} 132 | 133 | activation_checkpointing: 134 | partition_activations: true 135 | cpu_checkpointing: true 136 | contiguous_memory_optimization: false 137 | number_checkpoints: null 138 | synchronize_checkpoint_boundary: false 139 | profile: false -------------------------------------------------------------------------------- /configs/llm/chatgpt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | llm: 3 | _file_: 4 | name: ChatGPT 5 | openai_name: gpt-3.5-turbo 6 | temperature: 0 # Fix 7 | _target_: llm.gpt.GPT 8 | _type: OpenAI 9 | hidden_dim: 9999 10 | sleep_time: 1 -------------------------------------------------------------------------------- /configs/llm/gpt4.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | llm: 3 | _file_: 4 | name: GPT4 5 | openai_name: gpt-4 6 | temperature: 0 # Fix 7 | _target_: llm.gpt.GPT 8 | _type: OpenAI 9 | hidden_dim: 9999 10 | sleep_time: 5 -------------------------------------------------------------------------------- /configs/llm/llama_icl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - llm_meta_data 5 | # generation hyper-parameters 6 | # generation hyper-parameters 7 | llm: 8 | name: LLaMA_ICL 9 | _target_: llm.llama_icl.LLaMA_ICL 10 | _meta_data: ${_llm_md_lookup.${.base_model}} 11 | local_dir: ${path.hf_local}${.base_model}/ 12 | base_model: tinygpt 13 | hf_name: ${._meta_data.hf_name} 14 | hidden_dim: 4096 15 | max_seq_len: 4096 16 | max_tgt_len: 4096 -------------------------------------------------------------------------------- /configs/llm/llama_peft.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - llm_meta_data 5 | # generation hyper-parameters 6 | llm: 7 | _meta_data: ${_llm_md_lookup.${.base_model}} 8 | local_dir: ${path.hf_local}${.base_model}/ 9 | name: LLaMA-PEFT 10 | base_model: llama2-7b 11 | hf_name: ${._meta_data.hf_name} 12 | hidden_dim: 4096 13 | 14 | # Train Config 15 | max_length: 1024 16 | max_shard_size: 10GB 17 | 18 | max_len: 512 19 | penalty_alpha: 0.6 20 | top_k: 10 21 | top_p: 0.7 22 | random_prefix_len: 5 23 | sample_num: 2 24 | decoding_method: sampling 25 | generate_len: 512 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /configs/llm/llm_meta_data.yaml: -------------------------------------------------------------------------------- 1 | _llm_md_lookup: 2 | tinygpt: 3 | hf_name: HuggingFaceM4/tiny-random-LlamaForCausalLM 4 | max_bsz_per_gpu: 18 5 | inf_bsz: ${.max_bsz_per_gpu} 6 | llama2-7b: 7 | hf_name: meta-llama/Llama-2-7b-hf 8 | max_bsz_per_gpu: 12 # FP16, 40GB 9 | inf_bsz: 12 # FP16, 40GB 10 | llama2-7b-chat: 11 | hf_name: meta-llama/Llama-2-7b-chat-hf 12 | max_bsz_per_gpu: 12 # FP16, 40GB 13 | inf_bsz: 12 # FP16, 40GB 14 | llama2-13b-chat: 15 | hf_name: meta-llama/Llama-2-13b-chat-hf 16 | max_bsz_per_gpu: 4 # FP16, 40GB 17 | inf_bsz: 8 18 | llama2-70b-chat: 19 | hf_name: meta-llama/Llama-2-70b-chat-hf 20 | max_bsz_per_gpu: 4 # FP16, 40GB 21 | inf_bsz: ${.max_bsz_per_gpu} 22 | 23 | deberta-base: 24 | hf_name: microsoft/deberta-v3-base 25 | inf_bsz: 480 # FP16, 40GB 26 | deberta-large: 27 | hf_name: microsoft/deberta-v3-large 28 | inf_bsz: 280 # FP16, 40GB 29 | -------------------------------------------------------------------------------- /configs/main.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # ! Primary hydra config for ALL models 3 | defaults: 4 | - _self_ # To be overwritten by experimental settings 5 | - optional user: env # (optional environment settings to add 6 | - data: texas 7 | - llm: chatgpt 8 | - model: graph_text 9 | - optional template: ${model}/${data} 10 | - prompt: prompts 11 | - exp: icl 12 | - override /hydra/hydra_logging@_group_: none # Disable Hydra logging 13 | - override /hydra/job_logging@_group_: none # Disable Hydra logging 14 | 15 | debug: false 16 | 17 | # ! Path 18 | # @ Note that path end with /, file end without / 19 | path: 20 | data_cache: data_cache/ 21 | data_storage: data/ 22 | temp: temp/ # Removable 23 | out_dir: output/ # 24 | 25 | env: 26 | vars: 27 | openai_api_key: ${oc.env:OPENAI_API_KEY,YourAPIKey} # Overwrite this to your API key 28 | 29 | working_dir: ${path.temp}working_dir/${.uid}/ # For deletable temporary files. 30 | out_dir: ${path.out_dir}${oc.select:wandb.sweep_id,local}/${model.name}/${.uid}-${.alias}/ # For files to be saved, to be initialized 31 | uid: null # To be generated in the main program 32 | seed: 2023 33 | # 34 | eval_freq: 50 35 | use_wandb: true 36 | alias: ${llm.name}-${data.alias}-${data.n_shots}Shot-Text=${text_info}-Rel=${rel_info} 37 | wandb: 38 | id: null 39 | name: ${alias} 40 | 41 | slurm_id: ${oc.env:SLURM_JOB_ID,null} 42 | logging: 43 | level: info 44 | log_wandb_metric_to_stdout: False 45 | code_version: 14.2 46 | 47 | # @ ?? 48 | hydra: 49 | run: 50 | dir: ../temp/hydra/${now:%Y.%m.%d}/${now:%H.%M.%S} 51 | 52 | # ! _file_ related 53 | _unimportant_cfg: 54 | fields: [ gpus, debug, wandb, proj, env, uid, 55 | local_rank, cmd, label_name, logging, 56 | use_wandb, n_nodes, n_labels, alias 57 | ] 58 | postfix: [ _path, _file, _dir, _url ] 59 | prefix: [ _ ] -------------------------------------------------------------------------------- /configs/model/graph_text.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /prompt/graph_tree_meta_data 4 | model: 5 | name: GraphText 6 | 7 | # @ GraphText Settings 8 | #text_info: choice.a1y_t.a2y_t.a0x_t.a1x_t.a2x_t 9 | text_info: a3y 10 | rel_info: spd0.spd1 11 | subgraph_size: 5 12 | nb_padding: false 13 | nb_order: true 14 | node_dropout: 0.5 15 | 16 | tree_hierarchy: attr_type.graph_type # Default config 17 | out_field: choice 18 | 19 | # @ Graph (Relational) Information 20 | sim: 21 | topk: 20 22 | cache_template: ${path.data_cache}${data.name}/{pg_name}Top${.topk}.sim_proxy_graph 23 | spd: 24 | max_hops: ${oc.select:data.max_spd,3} 25 | cache_file: ${path.data_cache}${data.name}/Top${.max_hops}.spd_matrix 26 | 27 | ppr: 28 | # Value Encoding: None stands for no value encoding; Rank group stands for value grouping 29 | # value_encoding: None 30 | # value_encoding: Rank_10 # Rank to 10 levels 31 | # sort_mat_construct_order: [ [ 0, 1, 2 ],[ 1, 0, 2 ], [ 2, 1, 0 ] ] 32 | max_hops: 2 33 | default_alpha: 0.25 34 | cache_template: ${path.data_cache}${data.name}/Top{topk}_eps{eps}_{normalization}norm_alpha={alpha}.ppr_matrix 35 | topk: 32 # Following PPRGo 36 | eps: 1e-4 # Following PPRGo 37 | normalization: sym # sym or row or col 38 | rank: 39 | methods: [ ppr_0.25 ] 40 | # methods: [ ppr_0.5, ppr_0.01 ] 41 | top_k: 32 # Following PPRGo 42 | hidden_dim: 128 43 | 44 | # @ Demo for In-Context-Learning 45 | use_demo: true 46 | demo: 47 | # select_method: first # Fixed seed examples for every sample 48 | select_method: max_degree # Fixed seed examples for every sample 49 | template: '{graph_tree_info}The answer is {label}.' 50 | # select: class-prototype # Select center of each class cluster 51 | # select: BM25 # Use BM25 for dynamic retrieval 52 | # select: BM25 # Randomly select seed examples 53 | keep_label_description: False 54 | n_separators: 2 # Number of separators between examples 55 | n_samples: ${data.n_labels} # Number of demonstrations 56 | 57 | 58 | local_rank: 0 59 | save_path: ${out_dir}checkpoints/ 60 | 61 | # @ Text Settings 62 | add_pre_instruction: true 63 | pre_instruction_template: short 64 | #remove_quotation: true 65 | remove_quotation: true 66 | 67 | # @ EVALUATE 68 | eval_sets: ['val' , 'test' ] # Evaluate validation and test sets only 69 | #eval_sets: [ 'test' ] # Evaluate on test set only 70 | 71 | conv_template: no_conv 72 | save_file: ${out_dir}{split}-${alias}.csv 73 | -------------------------------------------------------------------------------- /configs/prompt/graph_tree_meta_data.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | tree_node_alias: 3 | # graph continuous feature names 4 | x: feature 5 | y: neighbor labels 6 | a0x: center node feature 7 | a1x: first-order smoothed feature 8 | a2x: second-order smoothed feature 9 | a3x: third-order smoothed feature 10 | a1y: first-order pseudo labels 11 | a2y: second-order pseudo labels 12 | a3y: third-order pseudo labels 13 | a4y: fourth-order pseudo labels 14 | # graph text names 15 | choice: neighbor labels 16 | a0x_t: center node feature 17 | a1x_t: first-order smoothed feature 18 | a2x_t: second-order smoothed feature 19 | a3x_t: third-order smoothed feature 20 | a1y_t: first-order pseudo labels 21 | a2y_t: second-order pseudo labels 22 | a3y_t: third-order pseudo labels 23 | a4y_t: fourth-order pseudo labels 24 | # proxy graph names 25 | a0x_sim: feature similarity graph 26 | a1x_sim: first-order feature similarity graph 27 | a2x_sim: second-order feature similarity graph 28 | a3x_sim: third-order feature similarity graph 29 | spd0: center node 30 | spd1: first-hop neighbor 31 | spd2: second-hop neighbor 32 | spd3: third-hop neighbor 33 | subgraph_size: 3 34 | 35 | attr_mask: 36 | title: CenterOnly 37 | tape: CenterOnly 38 | text: CenterOnly 39 | pg_size: # Size of proxy graph 40 | spd0: 1 41 | spd1: ${subgraph_size} 42 | spd2: ${subgraph_size} 43 | spd3: ${subgraph_size} 44 | pprtopk: ${subgraph_size} # Original subgraphsize 45 | a0x_sim: ${subgraph_size} 46 | a1x_sim: ${subgraph_size} 47 | a2x_sim: ${subgraph_size} 48 | a3x_sim: ${subgraph_size} 49 | a4x_sim: ${subgraph_size} 50 | ppr: ${subgraph_size} 51 | 52 | in_field_description: null # To be initialized afterward 53 | #_in_field_description_lookup: 54 | # choice: neighbor labels 55 | # a0x_t: ': The center node feature.' 56 | # a1x_t: ': ' 57 | # a2x_t: second-order smoothed feature 58 | # a3x_t: third-order smoothed feature 59 | # a1y_t: first-order pseudo labels 60 | # a2y_t: second-order pseudo labels 61 | # a3y_t: third-order pseudo labels -------------------------------------------------------------------------------- /configs/prompt/prompts.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # Shared Prompt Settings for ALL PROMPTS 3 | 4 | # @ Few-shot Demonstration 5 | human_prompt: base 6 | gpt_prompt: base 7 | instruct_prompt: base 8 | demo_prompt: base 9 | demo_qa_prompt: base 10 | question_prompt: icl 11 | prompt: 12 | style: xml 13 | human: 14 | _target_: graph_text.prompts.Prompt 15 | template: ${prompt_templates.human[${human_prompt}]} 16 | instruction: ${prompt_templates.instruction[${instruct_prompt}]} 17 | demo: # 18 | _target_: graph_text.prompts.Prompt 19 | template: ${prompt_templates.demo[${demo_prompt}]} 20 | demo_qa: # 21 | _target_: graph_text.prompts.Prompt 22 | template: ${prompt_templates.demo_qa.${demo_qa_prompt}.${..style}} 23 | question: # 24 | _target_: graph_text.prompts.Prompt 25 | template: ${prompt_templates.question[${question_prompt}]} 26 | gpt: # @ for SFT target 27 | _target_: graph_text.prompts.Prompt 28 | template: ${prompt_templates.gpt[${gpt_prompt}]} 29 | 30 | prompt_templates: 31 | human: 32 | # @ human 33 | base: >- # Base template for human input 34 | {instruction}{demo}{question} 35 | # @ 36 | instruction: 37 | # @ Text classification prompt 38 | na: '' 39 | # Role + Task Description + Instructions 40 | sft: >- 41 | Your goal is to perform node classification. You are given the information of each node in a xml format. Using the given information of a node, you need to classify the node to several choices: ${data.label_description}. Remember, your answer should be in the form of the class label. 42 | base: >- # Cora 43 | ${data.task_description}\n 44 | # I will give you the following information:\n${in_field_description}\n 45 | demo: 46 | na: '' 47 | base: >- 48 | \nHere are a few examples:\n 49 | {demonstration}\n\n 50 | Now let's answer the question below:\n 51 | demo_qa: 52 | base: 53 | xml: >- 54 | {graph_info}\nWhat's the topic of academic paper given the information above?\n{answer} 55 | json: >- 56 | {graph_info}\nWhat's the topic of academic paper?${data.label_description}\n 57 | {answer} 58 | flatten: >- 59 | {graph_info}\nWhat's the topic of academic paper?${data.label_description}\n 60 | {answer} 61 | random_flatten: >- 62 | {graph_info}\nWhat's the topic of academic paper?${data.label_description}\n 63 | {answer} 64 | xml_wo_text: >- 65 | {graph_info}\nWhat's the topic of academic paper?${data.label_description}\n 66 | {answer} 67 | question: 68 | cla_fstring: >- 69 | {instruction}{demonstration}\n 70 | Question: What's the topic of academic paper [{query_msg}]?${data.label_description}\n 71 | Answer: 72 | label_prop_xml: >- 73 | {information}\n 74 | What's the topic of academic paper?${data.label_description}\n 75 | 76 | label_prop_xml_cot: >- 77 | {information}\n 78 | What's the topic of academic paper?${data.label_description}\n 79 | Generate your answer with around it. Let's think step by step and solve it with the message passing! 80 | label_prop_xml_demo: >- 81 | {information}\n 82 | What's the topic of the target academic paper with the neighborhood label information above?${data.label_description}\n 83 | {answer} 84 | icl: >- 85 | {graph_info}\nWhat's the topic of the paper given the information above? Valid choices are ${data.label_description}.\nRemember, your answer should be in the form of the class choice wrapped by . 86 | icl_general: >- 87 | {graph_info}\nWhat's the target class given the information above? Valid choices are ${data.label_description}.\nRemember, your answer should be in the form of the class choice wrapped by . 88 | icl_new: >- 89 | {graph_info}\nWhat's the topic of the paper given the information above? Valid choices are ${data.label_description}.\nRemember, you have to give a choice as the answer, and your answer should be in the form of the class choice wrapped by . 90 | icl_tag: >- 91 | {graph_info}\nWhat's the topic of the paper given the information above? Valid choices are ${data.label_description}.\nRemember, your answer should be in the form of a capital letter of choice index wrapped by , e.g., A, B, etc. 92 | icl_zscot: >- 93 | {graph_info}\nWhat's the topic of the paper given the information above? Valid choices are ${data.label_description}.\nLet's think step by step first and then answer the question. Remember, you have to give a choice as the answer, and your answer should be in the form of the class choice wrapped by . 94 | sft: >- 95 | \n{graph_info}\n 96 | gpt: 97 | base: >- # Researcher Arxiv 98 | The answer is: {answer}. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging 2 | torch>=2.0.1 3 | fschat 4 | peft 5 | transformers 6 | torch_geometric 7 | pytz 8 | omegaconf 9 | hydra-core 10 | packaging 11 | peft 12 | deepspeed 13 | sentencepiece 14 | wandb 15 | rich 16 | pandas 17 | easydict 18 | bidict 19 | scikit-learn 20 | dgl 21 | ogb 22 | chunkdot 23 | numba 24 | openai -------------------------------------------------------------------------------- /src/graph_text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndyJZhao/GraphText/73bde25b2fa9bf89b37b041062a4adfe42363652/src/graph_text/__init__.py -------------------------------------------------------------------------------- /src/graph_text/agent.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import types 4 | from collections import OrderedDict 5 | from collections import defaultdict 6 | from collections.abc import Iterable 7 | 8 | import deepspeed 9 | import numpy as np 10 | import torch as th 11 | import torch.optim as optim 12 | from omegaconf import OmegaConf 13 | 14 | from utils.basics import init_path, lot_to_tol, time_logger 15 | from utils.pkg.distributed import master_process_only 16 | from .model import IGNORE_INDEX 17 | 18 | logging.getLogger("transformers").setLevel(logging.WARNING) 19 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 20 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 21 | 22 | 23 | def indices_to_binary(n_classes, indices): 24 | binary_list = [0] * n_classes 25 | for index in indices: 26 | binary_list[index] = 1 27 | return binary_list 28 | 29 | 30 | def compare_eval(split, results, pred1, pred2, logger): 31 | results[f'{split}_inf_agreement_rate'] = np.mean(_ := np.array(pred2) == np.array(pred1)) 32 | logger.warning(f"Inference agreement rate: {results[f'{split}_inf_agreement_rate']}") 33 | if results[f'{split}_inf_agreement_rate'] < 1: 34 | logger.warning(f'Different prediction samples:\n{list(zip(pred1[~_], pred2[~_]))[:50]}') 35 | 36 | 37 | class Agent: 38 | 39 | def __init__(self, model, cfg, data, logger): 40 | self.cfg = cfg 41 | self.model = model 42 | self.data = data 43 | self.logger = logger 44 | self.optimizer = optim.Adam(model.parameters(), **cfg.ds.optimizer.params) 45 | self.total_batch_steps = cfg.total_steps * cfg.get('grad_acc_steps', 1) 46 | 47 | def forward(self, batch): 48 | return self.model(batch) 49 | 50 | def backward(self, loss): 51 | # Backward pass and optimization 52 | self.optimizer.zero_grad() # Clear the gradients 53 | loss.backward() # Calculate gradients 54 | self.optimizer.step() # Update weights 55 | 56 | def torch_distributed_barrier(self): 57 | if self.cfg.get('world_size', 1) > 1: 58 | th.distributed.barrier() 59 | 60 | @time_logger() 61 | def evaluate(self, eval_iter_dict, logger): 62 | results = {} 63 | for split, eval_iter in eval_iter_dict.items(): 64 | eval_res = defaultdict(list) 65 | for eval_batch in eval_iter: 66 | if self.cfg.use_fwd_eval: 67 | output = self.forward_eval(eval_batch, res_prefix='') 68 | else: 69 | output = self.predict(eval_batch, self.cfg.eval_choice_only) 70 | for item, value in output.items(): 71 | eval_res[item].append(value) 72 | eval_res = {k: np.concatenate(v) if isinstance(v[0], Iterable) else np.array(v) 73 | for k, v in eval_res.items()} 74 | results.update({f'{split}/{k}': np.array(eval_res[k]).mean() 75 | for k in ['loss', 'token_acc'] if k in eval_res}) 76 | logger.info(f'Example generated output in {split}: {eval_res["dialog"][:3]}\n\n\n') 77 | label, pred = eval_res['label'], eval_res['pred'] 78 | if not self.cfg.add_class_token: 79 | results[f'{split}_valid_choice_rate'] = valid_choice_rate = np.mean(eval_res['is_valid']) 80 | if valid_choice_rate < 1: 81 | logger.warning(f'Failed gold and prediction samples:\n' 82 | f"{list(zip(label[~eval_res['is_valid']], pred[~eval_res['is_valid']]))[:50]}") 83 | if 'acc' in self.cfg.metrics: 84 | results[f'{split}_acc'] = np.mean(label == pred) if len(label) == len(pred) else -1 85 | if 'loop_pred' in eval_res: 86 | results[f'{split}_loop_inf_acc'] = np.mean(label == eval_res['loop_pred']) if len(label) == len( 87 | pred) else -1 88 | compare_eval(split, results, pred, eval_res['loop_pred'], logger) 89 | if 'fwd_pred' in eval_res: 90 | results[f'{split}_fwd_inf_acc'] = np.mean(label == eval_res['fwd_pred']) if len(label) == len( 91 | pred) else -1 92 | compare_eval(split, results, pred, eval_res['fwd_pred'], logger) 93 | # print(f'label {label[:5]}\n\nPred {pred[:5]}\n\nForward Pred:{eval_res["fwd_pred"][:5]}') 94 | # results[f'{split}_fwd_acc'] = np.mean(label == eval_res['fwd_pred']) if len(label) == len(pred) 95 | # else -1 96 | logger.warning(results) 97 | return results 98 | 99 | @th.no_grad() 100 | def predict(self, batch, choice_only=False): 101 | self.model.eval() 102 | node_ids, graph_tree_lol, encode_seq, node_id_to_encode_id, conversation_list = batch 103 | gold_text = [conv[1]['value'] for conv in conversation_list] 104 | inputs = { 105 | 'batch': batch, 106 | 'max_tgt_len': self.cfg.max_gen_len, 107 | 'temperature': 0.2, 108 | 'gold_text': gold_text, 109 | } 110 | batch_output = self.model.generate(inputs, choice_only) 111 | # print(batch_output) # For debug only 112 | batch_output['label'], is_valid = lot_to_tol([self.model.match_label_from_text(text) for text in gold_text]) 113 | assert sum(is_valid) == len(is_valid), 'Incorrect gold text generation' 114 | # assert 'Not Found' not in label 115 | batch_output['pred'], batch_output['is_valid'] = lot_to_tol( 116 | [self.model.match_label_from_text(text) for text in batch_output['generated_text']]) 117 | if 'loop_generated_text' in batch_output: 118 | batch_output['loop_pred'], _ = lot_to_tol( 119 | [self.model.match_label_from_text(text) for text in batch_output['loop_generated_text']]) 120 | return batch_output 121 | 122 | @th.no_grad() 123 | def forward_eval(self, batch, res_prefix=''): 124 | self.model.eval() 125 | outputs, targets = self.forward(batch) 126 | batch_size, seq_len = targets.shape 127 | loss = outputs.loss 128 | gpt_gen_target_mask = (targets != IGNORE_INDEX).reshape(-1) # B* (S-1) 129 | cls_token_mask = th.tensor([x.item() in self.model.cls_tokens for x in targets.reshape(-1)]).to( 130 | gpt_gen_target_mask.device) 131 | target_cls_mask = gpt_gen_target_mask & cls_token_mask # [B*S] 132 | # Lookup by target_cls_mask 133 | gt_cls_tokens = targets.reshape(-1)[target_cls_mask] 134 | pred_cls_logits = outputs.logits.reshape(batch_size * seq_len, -1)[target_cls_mask] 135 | # Readout max logits 136 | class_pred = pred_cls_logits[:, self.model.choice_ids].argmax(-1).tolist() 137 | generated_cls_text = [self.model.cls_token_names[c] for c in class_pred] 138 | gold_labels = self.model.tokenizer.convert_ids_to_tokens(gt_cls_tokens) 139 | output = {} 140 | output[f'{res_prefix}loss'] = loss.item() 141 | output[f'{res_prefix}label'] = gold_labels 142 | output[f'{res_prefix}pred'] = generated_cls_text 143 | output[f'{res_prefix}dialog'] = [f"{prompt[0]['value']}GPT:{generated_cls_text[i]}" 144 | for i, prompt in enumerate(batch[4])] 145 | return output 146 | 147 | def train_model_batch(self, batch, current_step=0): 148 | self.model.train() 149 | outputs, targets = self.forward(batch) # Model forward 150 | 151 | loss = outputs.loss 152 | # calculate the token accuracy; 153 | # [B, S-1], BOS and one token shift next token prediction 154 | chosen_tokens = th.max(outputs.logits, dim=-1)[1][:, 1:-1] 155 | labels = targets[:, 2:] # BOS + space 156 | gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(th.long) # [B*S] 157 | valid_mask = (labels != IGNORE_INDEX).reshape(-1) 158 | valid_tokens = gen_acc & valid_mask # [B*S] 159 | gen_acc = valid_tokens.sum().item() / valid_mask.sum().item() if valid_mask.sum() > 0 else 0 160 | 161 | self.backward(loss) 162 | # self.progress.update(self.train_task, advance=1) 163 | return {'train/step': current_step, 'train/loss': round(loss.item(), 4), 'train/token_acc': round(gen_acc, 2), } 164 | 165 | @master_process_only 166 | def save_model(self, save_path, step, is_final=False): 167 | checkpoint_name = f"checkpoint_{step}" if not is_final else "final_model" 168 | # create save directory if not exists 169 | path = init_path(f"{save_path}{checkpoint_name}/") 170 | # only save trainable model parameters 171 | checkpoint = OrderedDict() 172 | for k, v in self.model.named_parameters(): 173 | if v.requires_grad: 174 | checkpoint[k] = v.detach().cpu() 175 | th.save(checkpoint, f'{path}/pytorch_model.pt') 176 | # save tokenizer 177 | self.model.tokenizer.save_pretrained(path) 178 | # save configuration 179 | self.model.llm.config.use_cache = True 180 | self.model.llm.config.save_pretrained(path) 181 | self.model.llm.config.use_cache = False 182 | self.torch_distributed_barrier() 183 | self.logger.info(f"Saved model into {path}") 184 | 185 | def load_stage_1_parameters(self, path): 186 | if path is None or not os.path.exists(path): 187 | self.logger.critical(f'Load {path} failed!!!, skipped loading') 188 | ckpt = th.load(path + 'pytorch_model.pt', map_location=th.device('cpu')) 189 | self.model.load_state_dict(ckpt, strict=False) 190 | 191 | def load_stage_1_parameters_prev(self, path): 192 | # Assuming `model` is your new model and `checkpoint` is the loaded checkpoint dictionary 193 | checkpoint = th.load(path, map_location=th.device('cpu')) 194 | 195 | model_dict = self.model.state_dict() 196 | 197 | # Filter out the embedding weights from the checkpoint 198 | filter_list = ['llama_model.base_model.model.model.embed_tokens.weight', 199 | 'llama_model.base_model.model.lm_head.weight'] 200 | checkpoint_filtered = {k: v for k, v in checkpoint.items() if k in model_dict and k not in filter_list} 201 | # Update the existing model parameters from the checkpoint. 202 | model_dict.update(checkpoint_filtered) 203 | 204 | # Set the updated weights to the model 205 | self.model.load_state_dict(model_dict, strict=False) 206 | 207 | if 'llama_model.base_model.model.model.embed_tokens.weight' in checkpoint: 208 | # Now handle the embedding weights separately 209 | old_embedding_weight = checkpoint['llama_model.base_model.model.model.embed_tokens.weight'] 210 | new_embedding_weight = self.model.llm.base_model.model.model.embed_tokens.weight 211 | 212 | # Copy the old weights to the new weights. 213 | new_embedding_weight.data[:old_embedding_weight.size(0)] = old_embedding_weight.data 214 | 215 | # Initialize the new token (you can use any initialization method here) 216 | # new_embedding_weight.data[old_embedding_weight.size(0):].normal_(mean=0.0, std=0.02) 217 | 218 | # Assign the new weights to the embedding layer 219 | self.model.llm.base_model.model.model.embed_tokens.weight = th.nn.Parameter( 220 | new_embedding_weight.clone()) 221 | 222 | if 'llama_model.base_model.model.lm_head.weight' in checkpoint: 223 | old_lm_head_weight = checkpoint['llama_model.base_model.model.lm_head.weight'] 224 | new_lm_head_weight = self.model.llm.base_model.model.lm_head.weight 225 | new_lm_head_weight.data[:old_lm_head_weight.size(0)] = old_lm_head_weight.data 226 | # new_lm_head_weight[old_lm_head_weight.size(0):].normal_(mean=0.0, std=0.02) 227 | self.model.llm.base_model.model.lm_head.weight = th.nn.Parameter(new_lm_head_weight.clone()) 228 | 229 | 230 | class DeepSpeedAgent(Agent): 231 | 232 | def __init__(self, model, cfg, data, logger): 233 | super(DeepSpeedAgent, self).__init__(model, cfg, data, logger) 234 | # load config parameters of deepspeed 235 | ds_params = OmegaConf.to_object(cfg.ds) 236 | self.ds_engine, self.optimizer, _, _ = deepspeed.initialize( 237 | model=model, 238 | model_parameters=model.parameters(), 239 | config_params=ds_params, 240 | dist_init_required=True, 241 | args=types.SimpleNamespace(**cfg) 242 | ) 243 | self.model = self.ds_engine.module # Overwrite with deepspeed module 244 | self.torch_distributed_barrier() 245 | 246 | def forward(self, batch): 247 | return self.ds_engine(batch) 248 | 249 | def backward(self, loss): 250 | self.ds_engine.backward(loss) 251 | self.ds_engine.step() 252 | -------------------------------------------------------------------------------- /src/graph_text/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List 4 | 5 | 6 | class SeparatorStyle(Enum): 7 | """Different separator style.""" 8 | SINGLE = auto() 9 | TWO = auto() 10 | MPT = auto() 11 | 12 | 13 | @dataclasses.dataclass 14 | class Conversation: 15 | """A class that keeps all conversation history.""" 16 | system: str 17 | roles: List[str] 18 | messages: List[List[str]] 19 | offset: int 20 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 21 | sep: str = "###" 22 | sep2: str = None 23 | version: str = "Unknown" 24 | 25 | skip_next: bool = False 26 | 27 | def get_prompt(self): 28 | if self.sep_style == SeparatorStyle.SINGLE: 29 | ret = self.system + self.sep 30 | for role, message in self.messages: 31 | if message: 32 | if type(message) is tuple: 33 | message, _, _ = message 34 | ret += role + ": " + message + self.sep 35 | else: 36 | ret += role + ":" 37 | return ret 38 | elif self.sep_style == SeparatorStyle.TWO: 39 | seps = [self.sep, self.sep2] 40 | ret = self.system + seps[0] 41 | for i, (role, message) in enumerate(self.messages): 42 | if message: 43 | if type(message) is tuple: 44 | message, _, _ = message 45 | ret += role + ": " + message + seps[i % 2] 46 | else: 47 | ret += role + ":" 48 | return ret 49 | if self.sep_style == SeparatorStyle.MPT: 50 | ret = self.system + self.sep 51 | for role, message in self.messages: 52 | if message: 53 | if type(message) is tuple: 54 | message, _, _ = message 55 | ret += role + message + self.sep 56 | else: 57 | ret += role 58 | return ret 59 | else: 60 | raise ValueError(f"Invalid style: {self.sep_style}") 61 | 62 | def append_message(self, role, message): 63 | self.messages.append([role, message]) 64 | 65 | def get_protein(self): 66 | prot_seqs = [] 67 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 68 | if i % 2 == 0: 69 | if type(msg) is tuple: 70 | msg, protein = msg 71 | prot_seqs.append(protein) 72 | return prot_seqs 73 | 74 | def to_gradio_chatbot(self): 75 | ret = [] 76 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 77 | if i % 2 == 0: 78 | if type(msg) is tuple: 79 | msg, protein = msg 80 | ret.append([msg, None]) 81 | else: 82 | ret[-1][-1] = msg 83 | return ret 84 | # Hack to make the demo work 85 | try: 86 | if '' in ret[0][0]: 87 | ret[0][0] = ret[0][0].replace("", "") 88 | except Exception as e: 89 | pass 90 | 91 | return ret 92 | 93 | def copy(self): 94 | return Conversation( 95 | system=self.system, 96 | roles=self.roles, 97 | messages=[[x, y] for x, y in self.messages], 98 | offset=self.offset, 99 | sep_style=self.sep_style, 100 | sep=self.sep, 101 | sep2=self.sep2) 102 | 103 | def dict(self): 104 | if len(self.get_images()) > 0: 105 | return { 106 | "system": self.system, 107 | "roles": self.roles, 108 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 109 | "offset": self.offset, 110 | "sep": self.sep, 111 | "sep2": self.sep2, 112 | } 113 | return { 114 | "system": self.system, 115 | "roles": self.roles, 116 | "messages": self.messages, 117 | "offset": self.offset, 118 | "sep": self.sep, 119 | "sep2": self.sep2, 120 | } 121 | 122 | 123 | conv_graph_text_v1 = Conversation( 124 | system="You are GraphText, a large language model for graph machine learning. You are able to " 125 | "understand the graph feature in both continuous and discrete form that the user provides, " 126 | "and assist the user with graph machine learning tasks. Follow the instructions carefully " 127 | "and answer the questions.", 128 | # system="", 129 | roles=("USER", "ASSISTANT"), 130 | version="v1", 131 | messages=(), 132 | offset=0, 133 | sep_style=SeparatorStyle.TWO, 134 | sep=" ", 135 | sep2="", # Empty string? 136 | ) 137 | 138 | no_conv = Conversation( 139 | system="", 140 | # system="", 141 | roles=("USER", "ASSISTANT"), 142 | version="v1", 143 | messages=(), 144 | offset=0, 145 | sep_style=SeparatorStyle.TWO, 146 | sep=" ", 147 | sep2="", # Empty string? 148 | ) 149 | 150 | default_conversation = conv_graph_text_v1 151 | conv_templates = { 152 | "graph_text_v1": conv_graph_text_v1, 153 | "no_conv": no_conv, 154 | } 155 | 156 | if __name__ == "__main__": 157 | print(default_conversation.get_prompt()) 158 | -------------------------------------------------------------------------------- /src/graph_text/graph_instruction_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from collections import defaultdict 15 | from functools import partial 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import torch as th 20 | from torch.utils.data import Dataset, DataLoader, Subset 21 | from itertools import chain 22 | 23 | from utils.data.textual_graph import TextualGraph 24 | from .samplers import DistributedBatchSampler 25 | 26 | 27 | def load_graph_sft_dataset(cfg, full_dataset, split, split_ids, batch_size, world_size=1, rank=0): 28 | dataset = Subset(full_dataset, split_ids) 29 | if split == "train": 30 | sampler = th.utils.data.RandomSampler(dataset) 31 | else: 32 | sampler = th.utils.data.SequentialSampler(dataset) 33 | if split == "train" and world_size > 1: 34 | batch_sampler = DistributedBatchSampler( 35 | sampler, batch_size, True, rank, world_size 36 | ) 37 | iter_ = DataLoader( 38 | dataset, 39 | batch_sampler=batch_sampler, 40 | num_workers=0, 41 | collate_fn=partial(full_dataset.collate), 42 | pin_memory=True, 43 | ) 44 | else: 45 | iter_ = DataLoader( 46 | dataset, 47 | sampler=sampler, 48 | batch_size=batch_size, 49 | num_workers=0, 50 | collate_fn=partial(full_dataset.collate), 51 | pin_memory=True, 52 | ) 53 | return dataset, iter_, sampler 54 | 55 | 56 | class GraphInstructionDataset(Dataset): 57 | """Dataset for supervised fine-tuning.""" 58 | 59 | def __init__(self, data: TextualGraph, cfg, mode): 60 | super(GraphInstructionDataset, self).__init__() 61 | self.data = data 62 | self.cfg = cfg 63 | self.mode = mode 64 | self.g = data.g 65 | 66 | def __len__(self): # number of instances 67 | return len(self.data) 68 | 69 | def get_link_pred_info(self, f, i): 70 | is_positive_edge = np.random.choice([True, False]) 71 | hop = 1 if len(f) == 1 else int(f[-1]) 72 | if is_positive_edge: 73 | return f"Yes" 74 | else: 75 | return f"No" 76 | 77 | def __getitem__(self, node_id): 78 | # ! Build Graph Trees 79 | support_tree_list = [] 80 | if self.cfg.use_demo: 81 | demo_center_nodes = self.data.select_demo(self.cfg.demo.select_method, node_id) 82 | support_tree_list = [ # No node drop out for demo nodes 83 | self.data.build_graph_tree(center_node, self.cfg.attr_mask, supervised=True) 84 | for center_node in demo_center_nodes] 85 | query_tree = self.data.build_graph_tree(node_id, self.cfg.attr_mask, supervised=False) 86 | graph_tree_list = support_tree_list + [query_tree] 87 | 88 | # ! Build Prompt 89 | demo = self.data.build_demo_prompt(support_tree_list) 90 | question = self.data.prompt.question(graph_info=query_tree.prompt) 91 | in_text = self.data.prompt.human(demo=demo, question=question) 92 | if self.mode == 'sft': 93 | out_text = self.data.prompt.gpt(answer=self.data.text.iloc[node_id][self.cfg.out_field]) 94 | else: 95 | out_text = None 96 | 97 | conversation = [ 98 | {"from": "human", "value": in_text}, 99 | {"from": "gpt", "value": out_text}, 100 | ] 101 | 102 | return node_id, graph_tree_list, in_text, out_text, demo, question, conversation 103 | 104 | def get_node_subgraph_info(self, node_id, subg_nodes, node_id_to_encode_id, encode_seq): 105 | subg_info = {} 106 | for f in self.data.in_text_fields: 107 | subg_info[f] = self.data.get_node_info(node_id, field=f) 108 | for f in self.data.in_cont_fields: 109 | # Add empty string to the continuous field, to be encoded in the model forward part 110 | subg_info[f] = "" 111 | # update cont-field to enable unique seq name: seq_name 112 | seq_names = [f"{f}-{_}" for _ in subg_nodes] 113 | node_id_to_encode_id[f].extend(seq_names) 114 | encode_seq[f].extend( 115 | self.data.get_node_info(n, field=f) for n in subg_nodes 116 | ) 117 | 118 | # subg_info = defaultdict(dict) 119 | # # Center node 120 | # for f in self.data.in_text_fields: 121 | # subg_info['center node'][f] = self.data.get_node_info(node_id, field=f) 122 | # # Neighborhood Subgraph Information 123 | # for f in self.data.in_cont_fields: 124 | # # Add empty string to the continuous field, to be encoded in the model forward part 125 | # if self.cfg.rel_info == 'ByOrder': 126 | # order_lookup = {1: 'first', 2: 'second', 3: 'third'} 127 | # subg_info['first order neighbor information'] = {f: ''} 128 | # subg_info['second order neighbor information'] = {f: ''} 129 | # else: 130 | # subg_info['neighbor graph information'][f] = '' 131 | # # update cont-field to enable unique seq name: seq_name 132 | # seq_names = [f'{f}-{_}' for _ in subg_nodes] 133 | # node_id_to_encode_id[f].extend(seq_names) 134 | # encode_seq[f].extend(self.data.get_node_info(n, field=f) for n in subg_nodes) 135 | return subg_info 136 | 137 | def collate(self, batch): 138 | # Key: field, Value: The list of continuous sequence to encode 139 | node_ids, graph_tree_lol, in_text_list, out_text_list, demo_list, question_list, conversation_list = zip(*batch) 140 | # ! Get continuous batch dataframe to be encoded 141 | batch_encode_cont_df = pd.concat([tree.encode_df for tree in chain.from_iterable(graph_tree_lol)]) 142 | if len(batch_encode_cont_df) > 0: 143 | grouped = batch_encode_cont_df.groupby('attr_type').agg({'nodes': list}) 144 | # encode_id: key: attr_type, value: node_id 145 | encode_ids = {f: list(set(chain.from_iterable(row.nodes))) for f, row in grouped.iterrows()} 146 | node_id_to_encode_id = { 147 | f: {node_id: encode_id for encode_id, node_id in enumerate(nodes)} 148 | for f, nodes in encode_ids.items() 149 | } 150 | encode_dict = {f: self.g.ndata[f][nodes] for f, nodes in encode_ids.items()} 151 | else: # No continuous attribute 152 | encode_dict, node_id_to_encode_id = None, None 153 | return node_ids, graph_tree_lol, encode_dict, node_id_to_encode_id, conversation_list 154 | -------------------------------------------------------------------------------- /src/graph_text/icl.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | 3 | from llm.llm import LLM 4 | from utils.data.textual_graph import TextualGraph 5 | 6 | 7 | class LLMForInContextLearning(object): 8 | def __init__(self, cfg: DictConfig, data: TextualGraph, llm: LLM, _logger, max_new_tokens=20, gen_mode="text", 9 | **kwargs, ): 10 | self.cfg = cfg 11 | self.gen_mode = gen_mode 12 | self.data = data 13 | self.text = data.text 14 | self.logger = _logger 15 | self.llm = llm 16 | self.max_new_tokens = max_new_tokens 17 | # ! Classification prompt 18 | 19 | self.text["dialog"] = "NA" 20 | self.text["demo"] = "NA" 21 | self.text["question"] = "NA" 22 | self.text["generated_text"] = "NA" 23 | 24 | def eval_and_save(self, step, sample_node_id, split): 25 | res_df = self.text.dropna() 26 | res_df["correctness"] = res_df.apply(lambda x: x["gold_choice"] in x["pred_choice"], axis=1) 27 | res_df.sort_values('correctness', inplace=True) 28 | save_file = self.cfg.save_file.format(split=split) 29 | res_df.to_csv(save_file) 30 | acc = res_df["correctness"].mean() 31 | self.logger.save_file_to_wandb(save_file, base_path=self.cfg.out_dir) 32 | valid_choice_rate = (res_df["pred_choice"].isin(self.data.choice_to_label_id.keys()).mean()) 33 | acc_in_valid_choice = acc / valid_choice_rate if valid_choice_rate > 0 else 0 34 | result = { 35 | "out_file": save_file, 36 | f"{split}_acc": acc, 37 | f"{split}_valid_choice_rate": valid_choice_rate, 38 | f"{split}_acc_in_valid_choice": acc_in_valid_choice, 39 | } 40 | if valid_choice_rate > 0: 41 | valid_df = res_df[res_df["pred_choice"].isin(self.data.choice_to_label_id.keys())] 42 | valid_df["true_choices"] = valid_df.apply(lambda x: self.data.label_info.choice[x["label_id"]], axis=1) 43 | result.update({f"PD/{choice}.{self.data.choice_to_label_name[choice]}": cnt / len(valid_df) 44 | for choice, cnt in valid_df.pred_choice.value_counts().to_dict().items()}) 45 | sample = {f"sample_{k}": v 46 | for k, v in self.data.text.iloc[sample_node_id].to_dict().items()} 47 | self.logger.info(sample) 48 | self.logger.wandb_metric_log({**result, "step": step}) 49 | 50 | # ! Save statistics to results 51 | # y_true, y_pred = [valid_df.apply(lambda x: self.data.l_choice_to_id[x[col]], axis=1).tolist() for col in 52 | # ('true_choices', 'pred_choice')] 53 | # result['cla_report'] = classification_report( 54 | # y_true, y_pred, output_dict=True, 55 | # target_names=self.data.label_info.label_name.tolist()) 56 | self.logger.info(result) 57 | self.logger.critical(f"Saved results to {save_file}") 58 | self.logger.wandb_summary_update({**result, **sample}) 59 | return result 60 | 61 | def __call__(self, node_id, prompt, demo, question, log_sample=False): 62 | # ! Classification 63 | prompt = prompt + " " if prompt.endswith(":") else prompt # ! Critical 64 | if self.gen_mode == "choice": 65 | generated = self.llm.generate_text(prompt, max_new_tokens=1, choice_only=True) 66 | pred_choice = generated[-1] if len(generated) > 0 else "NULL" 67 | else: 68 | generated = self.llm.generate_text(prompt, self.max_new_tokens) 69 | try: 70 | pred_choice = generated.split("")[-1][0] # Could be improved 71 | except: 72 | pred_choice = "" 73 | 74 | self.text.loc[node_id, "dialog"] = prompt + generated 75 | self.text.loc[node_id, "demo"] = demo 76 | self.text.loc[node_id, "question"] = question 77 | self.text.loc[node_id, "pred_choice"] = pred_choice 78 | self.text.loc[node_id, "generated_text"] = generated 79 | -------------------------------------------------------------------------------- /src/graph_text/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import transformers 6 | from einops import rearrange 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | """Input shape: Batch x Time x Channel 26 | 27 | attention_mask: [bsz, q_len] 28 | """ 29 | bsz, q_len, _ = hidden_states.size() 30 | 31 | query_states = ( 32 | self.q_proj(hidden_states) 33 | .view(bsz, q_len, self.num_heads, self.head_dim) 34 | .transpose(1, 2) 35 | ) 36 | key_states = ( 37 | self.k_proj(hidden_states) 38 | .view(bsz, q_len, self.num_heads, self.head_dim) 39 | .transpose(1, 2) 40 | ) 41 | value_states = ( 42 | self.v_proj(hidden_states) 43 | .view(bsz, q_len, self.num_heads, self.head_dim) 44 | .transpose(1, 2) 45 | ) 46 | # [bsz, q_len, nh, hd] 47 | # [bsz, nh, q_len, hd] 48 | 49 | kv_seq_len = key_states.shape[-2] 50 | assert past_key_value is None, "past_key_value is not supported" 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | # [bsz, nh, t, hd] 57 | assert not output_attentions, "output_attentions is not supported" 58 | assert not use_cache, "use_cache is not supported" 59 | 60 | # Flash attention codes from 61 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 62 | 63 | # transform the data into the format required by flash attention 64 | qkv = torch.stack( 65 | [query_states, key_states, value_states], dim=2 66 | ) # [bsz, nh, 3, q_len, hd] 67 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 68 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 69 | # the attention_mask should be the same as the key_padding_mask 70 | key_padding_mask = attention_mask 71 | 72 | if key_padding_mask is None: 73 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 74 | max_s = q_len 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | output = flash_attn_unpadded_qkvpacked_func( 79 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 80 | ) 81 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 82 | else: 83 | nheads = qkv.shape[-2] 84 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 85 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 86 | x_unpad = rearrange( 87 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 88 | ) 89 | output_unpad = flash_attn_unpadded_qkvpacked_func( 90 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 91 | ) 92 | output = rearrange( 93 | pad_input( 94 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 95 | ), 96 | "b s (h d) -> b s h d", 97 | h=nheads, 98 | ) 99 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 100 | 101 | 102 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 103 | # requires the attention mask to be the same as the key_padding_mask 104 | def _prepare_decoder_attention_mask( 105 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 106 | ): 107 | # [bsz, seq_len] 108 | return attention_mask 109 | 110 | 111 | def replace_llama_attn_with_flash_attn(): 112 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 113 | if cuda_major < 8: 114 | logging.warning( 115 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 116 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 117 | ) 118 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 119 | _prepare_decoder_attention_mask 120 | ) 121 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 122 | -------------------------------------------------------------------------------- /src/graph_text/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | 5 | import torch.nn as nn 6 | import transformers 7 | from peft import LoraConfig, TaskType, get_peft_model 8 | from transformers import AutoTokenizer 9 | import pandas as pd 10 | 11 | logging.getLogger("transformers").setLevel(logging.WARNING) 12 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 13 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 14 | from transformers import StoppingCriteria, LlamaForCausalLM 15 | from typing import Dict 16 | from graph_text import conversation as conversation_lib 17 | from utils.basics.os_utils import time_logger 18 | from utils.pkg.hf_utils import download_hf_ckpt_to_local 19 | import torch as th 20 | from torch.nn.utils import rnn 21 | from bidict import bidict 22 | 23 | IGNORE_INDEX = -100 24 | import time 25 | 26 | 27 | def find_consecutive_subarrays(arr): 28 | if not arr: 29 | return [] 30 | 31 | subarrays = [] 32 | current_subarray = [arr[0]] 33 | 34 | for i in range(1, len(arr)): 35 | if arr[i] == arr[i - 1] + 1: 36 | current_subarray.append(arr[i]) 37 | else: 38 | subarrays.append(current_subarray) 39 | current_subarray = [arr[i]] 40 | 41 | subarrays.append(current_subarray) 42 | return subarrays 43 | 44 | 45 | def smart_tokenizer_and_embedding_resize( 46 | special_tokens_dict: Dict, 47 | tokenizer: transformers.PreTrainedTokenizer, 48 | model: transformers.PreTrainedModel, 49 | ): 50 | """Resize tokenizer and embedding. 51 | 52 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 53 | """ 54 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 55 | model.resize_token_embeddings(len(tokenizer)) 56 | 57 | if num_new_tokens > 0: 58 | input_embeddings = model.get_input_embeddings().weight.data 59 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 60 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 61 | if model.get_output_embeddings() is not None: 62 | output_embeddings = model.get_output_embeddings().weight.data 63 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 64 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 65 | 66 | 67 | class KeywordsStoppingCriteria(StoppingCriteria): 68 | def __init__(self, keywords, tokenizer, input_ids): 69 | self.keywords = keywords 70 | self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] 71 | self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if 72 | type(keyword_id) is list and len(keyword_id) == 1] 73 | self.tokenizer = tokenizer 74 | self.start_len = None 75 | self.input_ids = input_ids 76 | 77 | def __call__(self, output_ids: th.LongTensor, scores: th.FloatTensor, **kwargs) -> bool: 78 | if self.start_len is None: 79 | self.start_len = self.input_ids.shape[1] 80 | else: 81 | for keyword_id in self.keyword_ids: 82 | if output_ids[0, -1] == keyword_id: 83 | return True 84 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 85 | for keyword in self.keywords: 86 | if keyword in outputs: 87 | return True 88 | return False 89 | 90 | 91 | def build_one_instance_supervised(tokenizer, sources, conv_template): 92 | # ! The code is modified from LLaVA's code 93 | conv = conv_template.copy() 94 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 95 | 96 | # Apply prompt templates 97 | conversations = [] 98 | for i, source in enumerate(sources): 99 | if roles[source[0]["from"]] != conv.roles[0]: 100 | # Skip the first one if it is not from human 101 | source = source[1:] 102 | 103 | conv.messages = [] 104 | for j, sentence in enumerate(source): 105 | role = roles[sentence["from"]] 106 | assert role == conv.roles[j % 2], f"{i}" 107 | conv.append_message(role, sentence["value"]) 108 | conversations.append(conv.get_prompt()) 109 | 110 | # Tokenize conversations 111 | input_ids = tokenizer( 112 | conversations, 113 | return_tensors="pt", 114 | padding="longest", 115 | max_length=tokenizer.model_max_length, 116 | truncation=True, 117 | ).input_ids 118 | targets = input_ids.clone() 119 | 120 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 121 | # Mask targets 122 | role_sep = conv.sep + conv.roles[1] + ": " 123 | for conversation, target in zip(conversations, targets): 124 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 125 | rounds = conversation.split(conv.sep2) # 126 | cur_len = 1 # Currently processed length, start from masking BOS token 127 | target[:cur_len] = IGNORE_INDEX 128 | for i, round_text in enumerate(rounds): 129 | if round_text == "": 130 | break 131 | # ! Mask human instructions 132 | parts = round_text.split(role_sep) 133 | if len(parts) != 2: 134 | break 135 | parts[0] += role_sep 136 | round_len = len(tokenizer(round_text).input_ids) 137 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # BOS + space 138 | target[cur_len: cur_len + instruction_len] = IGNORE_INDEX 139 | 140 | cur_len += round_len 141 | target[cur_len:] = IGNORE_INDEX # The rest are masked 142 | # if cur_len < tokenizer.model_max_length: 143 | # if cur_len != total_len: 144 | # target[:] = IGNORE_INDEX 145 | # logger.debug(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored)") 146 | assert sum(target != -100) > 0 147 | 148 | return [], input_ids, targets 149 | 150 | 151 | def process_batch_instance(tokenizer, conversation_list, max_tgt_len, conv_template): 152 | _, batch_input_ids, batch_target_ids = build_one_instance_supervised(tokenizer, conversation_list, 153 | conv_template) 154 | input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) 155 | target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=IGNORE_INDEX) 156 | assert input_ids.size() == target_ids.size() 157 | input_ids = input_ids[:, :max_tgt_len] 158 | target_ids = target_ids[:, :max_tgt_len] 159 | attention_mask = input_ids.ne(tokenizer.pad_token_id) 160 | assert attention_mask.size() == input_ids.size() 161 | return input_ids, target_ids, attention_mask.long() 162 | 163 | 164 | def process_batch_instance_for_inference(left_tokenizer, batch_input_text): 165 | input_ids = left_tokenizer( 166 | batch_input_text, 167 | return_tensors="pt", 168 | padding="longest", 169 | max_length=left_tokenizer.model_max_length, 170 | truncation=True, 171 | add_special_tokens=True, 172 | ).input_ids 173 | attention_mask = input_ids.ne(left_tokenizer.pad_token_id) 174 | assert attention_mask.size() == input_ids.size() 175 | return input_ids, attention_mask.long() 176 | 177 | 178 | class LinearSeqEncoder(nn.Module): 179 | def __init__(self, in_dim, in_len, out_dim, out_len, dropout=0.3, norm='LN', input_norm=True, output_norm=True, 180 | output_dropout=True, input_dropout=True, **kwargs): 181 | super(LinearSeqEncoder, self).__init__() 182 | self.in_dim, self.in_len, self.out_dim, self.out_len = in_dim, in_len, out_dim, out_len 183 | self.proj = nn.Linear(input_seq_dim := in_dim * in_len, output_seq_dim := out_dim * out_len) 184 | norm_layer = nn.BatchNorm1d if norm == 'BN' else nn.LayerNorm 185 | if input_norm: 186 | self.input_norm = norm_layer(input_seq_dim) 187 | if output_norm: 188 | self.output_norm = norm_layer(output_seq_dim) 189 | if input_dropout: 190 | self.input_dropout = nn.Dropout(dropout) 191 | if output_dropout: 192 | self.output_dropout = nn.Dropout(dropout) 193 | 194 | def forward(self, input): 195 | # Encode input of [bsz, in_seq_len, in_dim] to [bsz] 196 | batch_size, input_seq_length, hidden_dim = input.shape 197 | input = input.view(batch_size, -1) 198 | if hasattr(self, 'input_norm'): 199 | input = self.input_norm(input) 200 | if hasattr(self, 'input_drop'): 201 | input = self.input_drop(input) 202 | if self.proj.weight.dtype != input.dtype: 203 | logging.error(f'weight {self.proj.weight.dtype}, input {input.dtype}') 204 | output = self.proj(input) 205 | if hasattr(self, 'output_norm'): 206 | output = self.output_norm(output) 207 | output = output.view((batch_size, self.out_len, self.out_dim)) 208 | if hasattr(self, 'output_drop'): 209 | output = self.output_drop(output) 210 | return output 211 | 212 | 213 | class MLPEncoder(nn.Module): 214 | """ An MLP Encoder with input/output dropout and input/output norm 215 | Since the output layer of projection layers is the input space of LLM, we need to add input and output layers norm 216 | and dropout too. 217 | """ 218 | 219 | def __init__(self, in_dim, out_dim, n_layers=1, hidden_dim=None, dropout=0.3, norm='LN', 220 | input_norm=True, output_norm=True, input_dropout=True, output_dropout=True, **kwargs): 221 | super(MLPEncoder, self).__init__() 222 | 223 | self.in_dim = in_dim 224 | self.out_dim = out_dim 225 | 226 | norm_layer = nn.BatchNorm1d if norm == 'BN' else nn.LayerNorm 227 | 228 | # Input normalization and dropout 229 | if input_norm: 230 | self.input_norm = norm_layer(in_dim) 231 | if input_dropout: 232 | self.input_dropout = nn.Dropout(dropout) 233 | 234 | # Initialize layers 235 | self.layers = nn.ModuleList() 236 | if n_layers > 1: 237 | self.layers.append(nn.Linear(in_dim, hidden_dim)) 238 | for _ in range(n_layers - 2): 239 | self.layers.append(nn.Linear(hidden_dim, hidden_dim)) 240 | self.layers.append(nn.Linear(hidden_dim, out_dim)) 241 | else: # Just a single layer from input to output (acts like LinearEncoder) 242 | self.layers.append(nn.Linear(in_dim, out_dim)) 243 | 244 | # Output normalization and dropout 245 | if output_norm: 246 | self.output_norm = norm_layer(out_dim) 247 | if output_dropout: 248 | self.output_dropout = nn.Dropout(dropout) 249 | 250 | # Activation function 251 | self.relu = nn.ReLU() 252 | 253 | def forward(self, input): 254 | # Input normalization and dropout 255 | if hasattr(self, 'input_norm'): 256 | input = self.input_norm(input) 257 | if hasattr(self, 'input_dropout'): 258 | input = self.input_dropout(input) 259 | 260 | # Hidden layers 261 | for i, layer in enumerate(self.layers[:-1]): 262 | input = layer(input) 263 | input = self.relu(input) 264 | 265 | # Output layer (no activation) 266 | output = self.layers[-1](input) 267 | 268 | # Output normalization and dropout 269 | if hasattr(self, 'output_norm'): 270 | output = self.output_norm(output) 271 | if hasattr(self, 'output_dropout'): 272 | output = self.output_dropout(output) 273 | 274 | return output 275 | 276 | 277 | class GraphText(nn.Module): 278 | '''LoRA for LLaMa model''' 279 | 280 | def __init__(self, cfg, data, logger): 281 | super(GraphText, self).__init__() 282 | self.cfg = cfg 283 | self.data = data 284 | self.logger = logger 285 | self.device = th.cuda.current_device() if th.cuda.is_available() else th.device('cpu') 286 | 287 | if self.cfg.ds.bf16.enable: 288 | self.float_type = th.bfloat16 289 | else: 290 | self.float_type = th.float32 291 | if self.cfg.ds.fp16.enabled: 292 | self.float_type = th.float16 293 | self.conv_template = conversation_lib.conv_templates[cfg.conv_template] 294 | max_tgt_len = cfg['max_tgt_len'] 295 | self.gpt_response_prompt = data.prompt.gpt.template.split('{answer}')[0] 296 | 297 | # # Load checkpoint 298 | download_hf_ckpt_to_local(cfg.llm.hf_name, cfg.llm.local_dir) 299 | self.tokenizer = AutoTokenizer.from_pretrained( 300 | cfg.llm.local_dir, 301 | use_fast=False, 302 | model_max_length=max_tgt_len, 303 | padding_side="right", 304 | ) 305 | # ! UNK and EOS token leads to error 306 | # self.tokenizer.pad_token = self.tokenizer.unk_token # Leads to error 307 | self.tokenizer.pad_token = '' # Deal with empty unk token bug 308 | with time_logger(f'initialization of LLM decoder from {cfg.llm.local_dir}'): 309 | self.llm = LlamaForCausalLM.from_pretrained(cfg.llm.local_dir) 310 | self.llm.config.use_cache = False 311 | self.cls_token_names = class_tokens = [f'' for l in range(data.n_labels)] 312 | field_tokens = [f'<{f} emb>' for f in data.in_cont_fields] 313 | fields_to_add = [pg for pg in cfg.rel_info.split('.')] + data.in_cont_fields + data.in_text_fields 314 | field_names = [cfg.tree_node_alias.get(f, f) for f in fields_to_add] 315 | field_tokens += sum([[f'<{f}>', f''] for f in field_names], []) 316 | special_tokens = [] 317 | if cfg.get('add_class_token', True): 318 | special_tokens += class_tokens 319 | if cfg.get('add_field_token', True): 320 | special_tokens += field_tokens 321 | if cfg.get('add_pad_token', True): 322 | special_tokens += [''] 323 | if cfg.get('add_info_token', True): 324 | special_tokens += ['', ''] 325 | if len(special_tokens) > 0: 326 | smart_tokenizer_and_embedding_resize( 327 | special_tokens_dict={'additional_special_tokens': special_tokens}, 328 | tokenizer=self.tokenizer, 329 | model=self.llm, 330 | ) 331 | self.choice_ids = [self.tokenizer([_]).input_ids[0][1] for _ in class_tokens] 332 | self.tok_to_id = bidict({t: self.tokenizer.convert_tokens_to_ids(t) for t in special_tokens}) 333 | self.id_to_tok = self.tok_to_id.inverse 334 | self.cls_tokens = self.tokenizer.convert_tokens_to_ids(class_tokens) 335 | 336 | self.left_tokenizer = copy.deepcopy(self.tokenizer) 337 | self.left_tokenizer.padding_side = 'left' 338 | 339 | # Data related 340 | for id, _ in data.label_info.iterrows(): 341 | data.label_info.loc[id]['label_name'] = self.tokenizer.decode(self.tokenizer(_.label_name).input_ids[1:]) 342 | 343 | self.lid_to_lname = bidict({_.label_id: _.label_name 344 | for id, _ in data.label_info.iterrows()}) 345 | self.lname_to_lid = self.lid_to_lname.inverse 346 | 347 | if self.cfg.lora.r > 0: 348 | # add the lora module 349 | peft_config = LoraConfig( 350 | task_type=TaskType.CAUSAL_LM, 351 | inference_mode=False, 352 | r=self.cfg.lora.r, 353 | lora_alpha=self.cfg.lora.alpha, 354 | lora_dropout=self.cfg.lora.dropout, 355 | target_modules=self.cfg.lora.target_modules, 356 | ) 357 | self.llm = get_peft_model(self.llm, peft_config) 358 | self.llm.print_trainable_parameters() 359 | 360 | # Graph Encoder 361 | 362 | self.encoder = nn.ModuleDict() # Token Encoder 363 | for f in data.in_cont_fields: 364 | self.encoder[f] = MLPEncoder( 365 | in_dim=cfg.hidden_dim[f.lower()], 366 | out_dim=self.llm.config.hidden_size, 367 | **cfg.encoder, 368 | ) 369 | if cfg.frozen_encoder: 370 | for name, param in self.encoder.named_parameters(): 371 | param.requires_grad = False 372 | logging.info('LLAMA proj is frozen.') 373 | else: 374 | for name, param in self.encoder.named_parameters(): 375 | param.requires_grad = True 376 | logging.info('LLAMA proj is not frozen.') 377 | logger.info('LLAMA proj initialized.') 378 | 379 | if cfg.frozen_ori_llm_parameters: 380 | for name, param in self.llm.named_parameters(): 381 | param.requires_grad = False 382 | 383 | # ! Since new tokens are added, it is vital to train them 384 | for p in self.llm.get_input_embeddings().parameters(): 385 | p.requires_grad = True 386 | for p in self.llm.get_output_embeddings().parameters(): 387 | p.requires_grad = True 388 | logging.info('The LLM LLAMA is frozen except input and output embeddings.') 389 | self.max_tgt_len = max_tgt_len 390 | 391 | def build_continuous_fields(self, token_ids, cont_fields, graph_tree_list, node_id_to_encode_id): 392 | # build up continuous field information, e.g. , 393 | # Returns cont_fields: List of tuple of (field, text_position, encode_ids) 394 | encode_df = pd.concat([tree.encode_df for tree in graph_tree_list]).reset_index() 395 | field_tokens = self.tokenizer.convert_tokens_to_ids([f'<{f} emb>' for f in cont_fields]) 396 | cont_text_locations = th.where(th.isin(token_ids.cpu(), th.tensor(field_tokens)))[0].numpy() 397 | cont_fields_positions = find_consecutive_subarrays(cont_text_locations.tolist()) 398 | assert len(encode_df) == len(cont_fields_positions), 'Error in processing continuous feature.' 399 | 400 | cont_fields = [] # Field, text_pos, encdoe_ids 401 | for i, text_position in enumerate(cont_fields_positions): 402 | f = encode_df.iloc[i].attr_type 403 | encode_nodes = encode_df.iloc[i].nodes 404 | assert len(text_position) == len(encode_nodes), 'Error in processing continuous feature.' 405 | encode_ids = [node_id_to_encode_id[f][n] for n in encode_nodes] 406 | start, end = text_position[0], text_position[-1] + 1 407 | cont_fields.append((f, range(start, end), encode_ids)) 408 | 409 | return cont_fields 410 | 411 | def prompt_wrap(self, graph_emb, node_ids, graph_tree_lol, input_tok_ids, node_id_to_encode_id): 412 | input_tok_ids = input_tok_ids.to(self.device) # bsz x s2 413 | batch_size = input_tok_ids.shape[0] 414 | # Lookup text embeddings 415 | if self.llm.base_model.__class__.__name__ == 'LlamaModel': 416 | inputs_embeds = self.llm.model.embed_tokens( 417 | input_tok_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim 418 | else: 419 | inputs_embeds = self.llm.model.model.embed_tokens( 420 | input_tok_ids).expand(batch_size, -1, -1) # bsz x s2 x embed_dim 421 | if graph_emb is not None: 422 | # Construct graph embeddings to override text embeddings 423 | new_input_embeds = [] 424 | for node_id, graph_tree_list, cur_input_ids, _cur_input_embeds in zip( 425 | node_ids, graph_tree_lol, input_tok_ids, inputs_embeds): 426 | cur_input_embeds = _cur_input_embeds.clone() # Clone the old embedding 427 | continuous_fields = self.build_continuous_fields(cur_input_ids, graph_emb.keys(), graph_tree_list, 428 | node_id_to_encode_id) 429 | for field, text_pos, encdoe_ids in continuous_fields: 430 | # lookup batch encoded node embeddings 431 | g_emb = graph_emb[field][encdoe_ids] 432 | cur_input_embeds[text_pos] = g_emb 433 | new_input_embeds.append(cur_input_embeds) 434 | inputs_embeds = th.stack(new_input_embeds, dim=0) 435 | return inputs_embeds 436 | 437 | def forward(self, inputs): 438 | node_ids, graph_tree_lol, encode_dict, node_id_to_encode_id, conversation_list = inputs 439 | # ! Get Graph Language 440 | # ! Tokenization: batch instance to input and target IDs. 441 | input_ids, target_ids, attention_mask = process_batch_instance(self.tokenizer, conversation_list, 442 | self.max_tgt_len, self.conv_template) 443 | if encode_dict is not None: 444 | graph_emb = {f: self.encoder[f](seq.to(self.float_type).to(self.device)) for f, seq in encode_dict.items()} 445 | else: 446 | graph_emb = None 447 | inputs_embeds = self.prompt_wrap(graph_emb, node_ids, graph_tree_lol, input_ids, node_id_to_encode_id) 448 | target_ids = target_ids.to(self.device) 449 | attention_mask = attention_mask.to(self.device) 450 | 451 | outputs = self.llm( 452 | inputs_embeds=inputs_embeds, 453 | attention_mask=attention_mask, 454 | return_dict=True, 455 | labels=target_ids, 456 | ) 457 | 458 | return outputs, target_ids 459 | 460 | def match_label_from_text(self, text): 461 | if self.cfg.add_class_token: 462 | splited = text.replace(':', '').rstrip('.').split(' ') 463 | matched = [cls for cls in self.cls_token_names if cls in splited] 464 | else: 465 | text = text.replace('', '') 466 | matched = [label_name for label_id, label_name in self.lid_to_lname.items() if label_name in text] 467 | if len(matched) == 0: 468 | return text, False 469 | elif len(matched) == 1: 470 | return matched[0], True 471 | else: 472 | return f'Multiple labels matched {matched}', False 473 | 474 | def generate(self, inputs, choice_only=False): 475 | # ! Prepare input 476 | node_ids, graph_tree_lol, encode_dict, node_id_to_encode_id, conversation_list = inputs['batch'] 477 | # [1286, 72, 19] -> [3, 768] emb 478 | if encode_dict is not None: 479 | graph_emb = {f: self.encoder[f](seq.to(self.float_type).to(self.device)) for f, seq in encode_dict.items()} 480 | else: 481 | graph_emb = None 482 | batch_input_text = [] 483 | for c in conversation_list: 484 | conv = self.conv_template.copy() 485 | conv.append_message(conv.roles[0], c[0]['value']) 486 | conv.append_message(conv.roles[1], self.gpt_response_prompt) # ASSISTANT: The answer is: 487 | # conv.append_message(conv.roles[1], None) # ASSISTANT: 488 | # Remove Gold response 489 | _prompt = conv.get_prompt().strip(conv.sep2) 490 | batch_input_text.append(_prompt) 491 | readout_pos = self.cfg.get('choice_readout_pos', 0) 492 | 493 | start_time = time.time() 494 | batch_input_ids, attention_mask = process_batch_instance_for_inference( 495 | self.left_tokenizer, batch_input_text) 496 | batch_inputs_embeds = self.prompt_wrap(graph_emb, node_ids, graph_tree_lol, batch_input_ids, 497 | node_id_to_encode_id) 498 | attention_mask = attention_mask.to(self.device) 499 | # Mask embedding attn_mask=0 to zeros 500 | masked_batch_embedding = batch_inputs_embeds * attention_mask.unsqueeze(-1).to(batch_inputs_embeds.dtype) 501 | # Run model inference 502 | with th.inference_mode(): 503 | batch_output = self.llm.generate( 504 | inputs_embeds=masked_batch_embedding, 505 | attention_mask=attention_mask, 506 | max_new_tokens=inputs['max_tgt_len'] if not choice_only else 3, 507 | temperature=max(float(inputs['temperature']), 0.01), 508 | # Too low temp leads to inf prob error. 509 | output_scores=choice_only, 510 | use_cache=True, 511 | return_dict_in_generate=choice_only, 512 | ) 513 | if choice_only: # The answer is: 514 | batch_preds = batch_output.scores[readout_pos][:, self.choice_ids].argmax(-1).cpu().tolist() 515 | batch_out_text = [self.cls_token_names[_] for _ in batch_preds] 516 | else: 517 | batch_out_text = self.tokenizer.batch_decode(batch_output, skip_special_tokens=False) 518 | outputs = {'dialog': [p + o for p, o in zip(batch_input_text, batch_out_text)], 519 | 'generated_text': batch_out_text} 520 | if self.cfg.add_loop_inference: 521 | self.logger.info(f"BATCH inference time: {time.time() - start_time:.2f} seconds") 522 | input_id_list = self.tokenizer(batch_input_text).input_ids 523 | loop_outputs = [] 524 | # ! Generate one by one as batch generation requires adding tokens to prompt and leads to confusion 525 | start_time = time.time() 526 | for i, (node_id, input_ids) in enumerate(zip(node_ids, input_id_list)): 527 | input_ids = th.as_tensor(input_ids).view(1, -1).to(self.device) 528 | input_embeds = self.prompt_wrap(graph_emb, [node_id], input_ids, node_id_to_encode_id) 529 | # Run model inference 530 | with th.inference_mode(): 531 | output = self.llm.generate( 532 | inputs_embeds=input_embeds, 533 | max_new_tokens=inputs['max_tgt_len'] if not choice_only else 3, 534 | temperature=max(float(inputs['temperature']), 0.01), # Too low temp leads to inf prob error. 535 | output_scores=choice_only, 536 | use_cache=True, 537 | return_dict_in_generate=choice_only, 538 | ) 539 | 540 | # Decode output tokens 541 | if not choice_only: 542 | out_text = self.tokenizer.decode(output[0], skip_special_tokens=False) 543 | # out_text = out_text.strip().rstrip(stop_str).strip() 544 | loop_outputs.append(out_text) 545 | else: 546 | # out_topk_choices = [self.tokenizer.convert_ids_to_tokens(s.topk(3).indices.squeeze()) 547 | # for s in output.scores] 548 | # logger.debug(f"Gold {inputs['gold_text'][i]}. Generated: {out_topk_choices}") 549 | class_logits = output.scores[readout_pos].squeeze()[self.choice_ids] 550 | out_text = self.cls_token_names[class_logits.argmax().item()] 551 | loop_outputs.append(out_text) 552 | outputs['loop_generated_text'] = loop_outputs 553 | self.logger.info(f"LOOP inference time: {time.time() - start_time:.2f} seconds") 554 | return outputs 555 | 556 | def generate_prob(self, inputs): 557 | # ! Prepare input 558 | node_ids, graph_tree_lol, encode_seq, node_id_to_encode_id, conversation_list = inputs['batch'] 559 | # [1286, 72, 19] -> [3, 768] emb 560 | emb = {f: self.encoder[f](seq.to(self.float_type).to(self.device)) for f, seq in encode_seq.items()} 561 | prompt = [] 562 | for c in conversation_list: 563 | conv = self.conv_template.copy() 564 | conv.append_message(conv.roles[0], c[0]['value']) 565 | # conv.append_message(conv.roles[1], self.gpt_response_prompt) # ASSISTANT: The answer is: 566 | # conv.append_message(conv.roles[1], None) # ASSISTANT: 567 | # Remove Gold response 568 | _prompt = conv.get_prompt().strip(conv.sep2) 569 | prompt.append(_prompt) 570 | 571 | input_id_list = self.tokenizer(prompt).input_ids 572 | outputs = [] 573 | 574 | # ! Generate one by one as batch generation requires adding tokens to prompt and leads to confusion 575 | for i, (node_id, input_ids) in enumerate(zip(node_ids, input_id_list)): 576 | input_ids = th.as_tensor(input_ids).view(1, -1).to(self.device) 577 | input_embeds = self.prompt_wrap(emb, [node_id], input_ids, node_id_to_encode_id) 578 | # Define stopping criteria for generation 579 | conv = self.conv_template.copy() 580 | stop_str = conv.sep if conv.sep_style != conversation_lib.SeparatorStyle.TWO else conv.sep2 581 | stopping_criteria = KeywordsStoppingCriteria([stop_str], self.tokenizer, input_ids) 582 | # Run model inference 583 | with th.inference_mode(): 584 | output = self.llm.generate( 585 | inputs_embeds=input_embeds, 586 | max_new_tokens=inputs['max_tgt_len'], 587 | temperature=max(float(inputs['temperature']), 0.01), # Too low temp leads to inf prob error. 588 | do_sample=True, 589 | use_cache=True, 590 | stopping_criteria=[stopping_criteria], 591 | ) 592 | 593 | # Decode output tokens 594 | out_text = self.tokenizer.decode(output[0], skip_special_tokens=False) 595 | out_text = out_text.strip().rstrip(stop_str).strip() 596 | outputs.append(out_text) 597 | 598 | return {'dialog': [p + o for p, o in zip(prompt, outputs)], 599 | 'generated_text': outputs, 600 | } 601 | -------------------------------------------------------------------------------- /src/graph_text/prompts.py: -------------------------------------------------------------------------------- 1 | from string import Formatter 2 | 3 | from omegaconf import OmegaConf, DictConfig 4 | from pandas import DataFrame 5 | 6 | 7 | def get_string_args(s): 8 | return [fn for _, fn, _, _ in Formatter().parse(s) if fn is not None] 9 | 10 | 11 | def preprocess_prompt_config(prompt_cfg, lookup_dict): 12 | cfg_dict = OmegaConf.to_object(prompt_cfg) 13 | processed_dict = {} 14 | for k, v in cfg_dict.items(): 15 | if not k.startswith('_') and k in lookup_dict: # Is a prompt template 16 | retrieved_str = lookup_dict[k][v] 17 | processed_dict[k] = preprocess_yaml_fstring(retrieved_str) 18 | return processed_dict 19 | 20 | 21 | def preprocess_yaml_fstring(s): 22 | s = s.replace('\n', '') # Remove original \n 23 | s = s.replace('\\n ', '\n') # Note that there is a SPACE after \\n 24 | s = s.replace('\\n', '\n') 25 | return s 26 | 27 | 28 | def init_prompt_from_cfg(prompt_cfg: DictConfig, template_lookup_dict, **kwargs): 29 | cfg_dict = preprocess_prompt_config(prompt_cfg, template_lookup_dict) 30 | template = cfg_dict.pop('prompt') 31 | cfg_dict.update({k: v for k, v in kwargs.items() if not k.startswith('_')}) 32 | return Prompt(template, **cfg_dict) 33 | 34 | 35 | class Prompt: 36 | # A simpler class than Langchain.PromptTemplate 37 | # With some tweaks 38 | def __init__(self, template: str, **prompt_args): 39 | self.template = preprocess_yaml_fstring(template) 40 | self._variables = {k: preprocess_yaml_fstring(v) for k, v in prompt_args.items()} 41 | self._var_set = set(get_string_args(self.template)) 42 | 43 | def update(self, **kwargs): 44 | self._variables.update(kwargs) 45 | 46 | @property 47 | def filled_template(self, **kwargs): 48 | return self.__call__(**kwargs, assert_vars=False) 49 | 50 | @property 51 | def unfilled_fields(self): 52 | return list(set(self._variables.keys()) - self._var_set) 53 | 54 | def __call__(self, assert_vars=True, **kwargs): 55 | args = {**self._variables, **kwargs} 56 | if assert_vars: 57 | assert len(set(args.keys()) - self._var_set) >= 0, f'{self._var_set - set(args.keys())} not given.' 58 | else: # For unknown args, keep the {arg} instead. 59 | args = {k: args.get(k, f'{{{k}}}') for k in self._var_set} 60 | return self.template.format(**args) 61 | 62 | def __str__(self): 63 | return self.filled_template 64 | 65 | def __repr__(self): 66 | return f'PromptTemplate: <<{self.__str__()}>>' 67 | 68 | -------------------------------------------------------------------------------- /src/graph_text/samplers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """batch samplers that work with either random or sequential data samplers""" 16 | 17 | import torch 18 | from torch.utils import data 19 | 20 | 21 | class RandomSampler(data.sampler.Sampler): 22 | r""" 23 | Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, 24 | but this class lets the user set an epoch like DistributedSampler 25 | Samples elements randomly. If without replacement, then sample from a shuffled dataset. 26 | If with replacement, then user can specify ``num_samples`` to draw. 27 | Arguments: 28 | data_source (Dataset): dataset to sample from 29 | num_samples (int): number of samples to draw, default=len(dataset) 30 | replacement (bool): samples are drawn with replacement if ``True``, default=False 31 | """ 32 | 33 | def __init__(self, data_source, replacement=False, num_samples=None): 34 | super(RandomSampler, self).__init__(data_source) 35 | self.data_source = data_source 36 | self.replacement = replacement 37 | self._num_samples = num_samples 38 | self.epoch = -1 39 | 40 | if self._num_samples is not None and replacement is False: 41 | raise ValueError("With replacement=False, num_samples should not be specified, " 42 | "since a random permute will be performed.") 43 | 44 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 45 | raise ValueError("num_samples should be a positive integer " 46 | "value, but got num_samples={}".format(self.num_samples)) 47 | if not isinstance(self.replacement, bool): 48 | raise ValueError("replacement should be a boolean value, but got " 49 | "replacement={}".format(self.replacement)) 50 | 51 | @property 52 | def num_samples(self): 53 | # dataset size might change at runtime 54 | if self._num_samples is None: 55 | return len(self.data_source) 56 | return self._num_samples 57 | 58 | def __iter__(self): 59 | n = len(self.data_source) 60 | g = torch.Generator() 61 | if self.epoch >= 0: 62 | g.manual_seed(self.epoch) 63 | if self.replacement: 64 | for _ in range(self.num_samples // 32): 65 | yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist() 66 | yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, 67 | generator=g).tolist() 68 | else: 69 | yield from torch.randperm(n, generator=self.generator).tolist() 70 | 71 | def __len__(self): 72 | return self.num_samples 73 | 74 | def set_epoch(self, epoch): 75 | self.epoch = epoch 76 | 77 | 78 | class DistributedSequentialSampler(data.sampler.Sampler): 79 | def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2): 80 | super().__init__(num_samples) 81 | if rank == -1: 82 | rank = 0 83 | world_size = 1 84 | self.num_samples = num_samples 85 | self.rank = rank 86 | self.world_size = world_size 87 | self.start_iter = 0 88 | self.train_iters = train_iters 89 | self.batch_size = batch_size 90 | self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)] 91 | 92 | def __iter__(self): 93 | for idx in range(self.start_iter, self.train_iters * 10): 94 | batch = [(idx + bias) % self.num_samples for bias in self.batch_bias] 95 | tbatch = self._batch(batch) 96 | yield tbatch 97 | 98 | def __len__(self): 99 | return self.train_iters 100 | 101 | def _batch(self, batch): 102 | """extracts samples only pertaining to this worker's batch""" 103 | start = self.rank * self.batch_size // self.world_size 104 | end = (self.rank + 1) * self.batch_size // self.world_size 105 | return batch[start:end] 106 | 107 | 108 | class DistributedBatchSampler(data.sampler.BatchSampler): 109 | """ 110 | similar to normal implementation of distributed sampler, except implementation is at the 111 | batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary 112 | data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. 113 | """ 114 | 115 | def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, 116 | gradient_accumulation_steps=None): 117 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) 118 | if rank == -1: 119 | assert False, 'should not be here' 120 | self.rank = rank 121 | self.world_size = world_size 122 | self.sampler.wrap_around = 0 123 | self.wrap_around = 0 124 | self.wrap_last = wrap_last 125 | self.start_iter = 0 126 | self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * \ 127 | gradient_accumulation_steps 128 | 129 | def __iter__(self): 130 | batch = [] 131 | i = 0 132 | for idx in self.data_iterator(self.sampler, wrap_around=False): 133 | batch.append(idx) 134 | if len(batch) == self.batch_size: 135 | tbatch = self._batch(batch) 136 | if i >= self.start_iter * self.effective_batch_size: 137 | yield tbatch 138 | self.start_iter = 0 139 | i += len(batch) 140 | batch = [] 141 | batch_len = len(batch) 142 | if batch_len > 0 and not self.drop_last: 143 | if self.wrap_last: 144 | self.sampler.wrap_around -= (self.batch_size) 145 | self.wrap_around += (len(batch)) 146 | self.wrap_around %= self.batch_size 147 | yield self._batch(batch) 148 | if self.wrap_last: 149 | self.sampler.wrap_around += self.batch_size 150 | 151 | def data_iterator(self, _iter, wrap_around=False): 152 | """iterates through data and handles wrap around""" 153 | for i, idx in enumerate(_iter): 154 | if i < self.wrap_around % self.batch_size: 155 | continue 156 | if wrap_around: 157 | self.wrap_around += 1 158 | self.wrap_around %= self.batch_size 159 | yield idx 160 | 161 | def _batch(self, batch): 162 | """extracts samples only pertaining to this worker's batch""" 163 | start = self.rank * self.batch_size // self.world_size 164 | end = (self.rank + 1) * self.batch_size // self.world_size 165 | return batch[start:end] 166 | -------------------------------------------------------------------------------- /src/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .fake_llm import CpuFakeDebugraph_text 2 | -------------------------------------------------------------------------------- /src/llm/fake_llm.py: -------------------------------------------------------------------------------- 1 | from langchain.llms.fake import FakeListLLM 2 | 3 | seq_default_list = ['C' for _ in range(20)] + ['?' for _ in range(200)] + ['C' for _ in range(20000)] 4 | 5 | 6 | class CpuFakeDebugraph_text: 7 | fake_llm = FakeListLLM(responses=seq_default_list) # Choose C as Default 8 | 9 | def generate_text(self, prompt, max_new_tokens=1, choice_only=False): 10 | return self.fake_llm(prompt)[:max_new_tokens] 11 | -------------------------------------------------------------------------------- /src/llm/gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import sleep 3 | 4 | import openai 5 | from tenacity import retry, stop_after_attempt, wait_random_exponential 6 | 7 | from utils.basics import logger 8 | from .llm import LLM 9 | 10 | 11 | class GPT(LLM): 12 | def __init__(self, openai_name="gpt-3.5-turbo", temperature=0, top_p=1, max_tokens=200, sleep_time=0, **kwargs): 13 | assert 'OPENAI_API_KEY' in os.environ, 'Please set OPENAI_API_KEY as an environment variable.' 14 | openai.api_key = os.environ["OPENAI_API_KEY"] 15 | self.model = openai_name 16 | self.temperature = temperature 17 | self.top_p = top_p 18 | self.max_tokens = max_tokens 19 | self.sleep_time = sleep_time 20 | logger.critical(f'Using OPENAI {openai_name.upper()}') 21 | # logger.critical(f'OPENAI-API-Key= {os.environ["OPENAI_API_KEY"]}') 22 | 23 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) 24 | def generate_text(self, prompt, max_new_tokens=10, choice_only=False): 25 | response = openai.ChatCompletion.create( 26 | model=self.model, 27 | messages=[{"role": "user", "content": prompt}], 28 | temperature=0 if choice_only else self.temperature, 29 | top_p=self.top_p, 30 | max_tokens=1 if choice_only else self.max_tokens 31 | ) 32 | sleep(self.sleep_time) 33 | return response["choices"][0]["message"]["content"] 34 | -------------------------------------------------------------------------------- /src/llm/llama_icl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from transformers import AutoTokenizer 5 | import logging 6 | import os 7 | 8 | from transformers import AutoTokenizer 9 | 10 | logging.getLogger("transformers").setLevel(logging.WARNING) 11 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 12 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 13 | from transformers import LlamaForCausalLM 14 | from utils.basics.os_utils import time_logger 15 | from utils.pkg.hf_utils import download_hf_ckpt_to_local 16 | 17 | from .llm import LLM 18 | 19 | 20 | class LLaMA_ICL(LLM): 21 | def __init__(self, hf_name, local_dir, max_tgt_len, **kwargs): 22 | # # Load checkpoint 23 | download_hf_ckpt_to_local(hf_name, local_dir) 24 | self.tokenizer = AutoTokenizer.from_pretrained( 25 | local_dir, 26 | use_fast=False, 27 | model_max_length=max_tgt_len, 28 | padding_side="right", 29 | ) 30 | # ! UNK and EOS token leads to error 31 | # self.tokenizer.pad_token = self.tokenizer.unk_token # Leads to error 32 | with time_logger(f'initialization of LLM decoder from {local_dir}'): 33 | self.llm = LlamaForCausalLM.from_pretrained(local_dir) 34 | self.llm.config.use_cache = False 35 | self.llm.half().cuda() 36 | 37 | def generate_text(self, prompt, max_new_tokens, **kwargs): 38 | """Generates text from the model. 39 | Parameters: 40 | prompt: The prompt to use. This can be a string or a list of strings. 41 | Returns: 42 | A list of strings. 43 | :param **kwargs: 44 | """ 45 | inputs = self.tokenizer(prompt, return_tensors="pt") 46 | generated_ids = self.llm.generate(inputs.input_ids.cuda(), attention_mask=inputs.attention_mask.cuda(), 47 | max_new_tokens=max_new_tokens) 48 | conversation = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, 49 | clean_up_tokenization_spaces=False)[0] 50 | out_text = conversation.split(prompt[-5:])[-1] 51 | return out_text 52 | -------------------------------------------------------------------------------- /src/llm/llm.py: -------------------------------------------------------------------------------- 1 | """Contains classes for querying large language models. Modified from APE repo""" 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class LLM(ABC): 6 | """Abstract base class for large language models.""" 7 | 8 | @abstractmethod 9 | def generate_text(self, prompt, max_new_tokens=1, choice_only=False): 10 | """Generates text from the model. 11 | Parameters: 12 | prompt: The prompt to use. This can be a string or a list of strings. 13 | Returns: 14 | A list of strings. 15 | """ 16 | pass 17 | 18 | # @abstractmethod 19 | # def log_probs(self, text, log_prob_range): 20 | # """Returns the log probs of the text. 21 | # Parameters: 22 | # text: The text to get the log probs of. This can be a string or a list of strings. 23 | # log_prob_range: The range of characters within each string to get the log_probs of. 24 | # This is a list of tuples of the form (start, end). 25 | # Returns: 26 | # A list of log probs. 27 | # """ 28 | # pass 29 | -------------------------------------------------------------------------------- /src/scripts/run_icl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | root_path = os.path.abspath(os.path.dirname(__file__)).split("src")[0] 5 | os.chdir(root_path) 6 | sys.path.append(root_path + "src") 7 | 8 | from utils.basics import init_env_variables, time_logger, wandb_finish 9 | from tqdm import tqdm 10 | 11 | init_env_variables() 12 | 13 | from utils.project.exp import init_experiment 14 | import logging 15 | import hydra 16 | 17 | logging.getLogger("transformers").setLevel(logging.WARNING) 18 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | from graph_text.icl import LLMForInContextLearning 22 | from utils.data.textual_graph import TextualGraph 23 | from llm import CpuFakeDebugraph_text 24 | from graph_text.graph_instruction_dataset import GraphInstructionDataset 25 | from torch.utils.data import Subset 26 | import numpy as np 27 | 28 | 29 | @time_logger() 30 | @hydra.main(config_path=f"{root_path}/configs", config_name="main", version_base=None) 31 | def run_inference(cfg): 32 | cfg, logger = init_experiment(cfg) 33 | data = TextualGraph(cfg=cfg) 34 | full_dataset = GraphInstructionDataset(data, cfg, cfg.mode) 35 | eval_splits = cfg.get('eval_sets', ['val', 'test']) 36 | results = {} 37 | for split in eval_splits: 38 | data.text["pred_choice"] = np.nan 39 | dataset = Subset(full_dataset, data.split_ids[split][:cfg.data.max_test_samples]) 40 | if cfg.get("debug", False): 41 | llm = CpuFakeDebugraph_text() # Use local CPU for faster debugging 42 | else: 43 | llm = hydra.utils.instantiate(cfg.llm) 44 | 45 | model = LLMForInContextLearning(cfg, data, llm, logger, **cfg.model) 46 | for i, item in tqdm(enumerate(dataset), "Evaluating..."): 47 | node_id, graph_tree_list, in_text, out_text, demo, question, _ = item 48 | is_evaluate = i % cfg.eval_freq == 0 and i != 0 49 | model(node_id, in_text, demo, question, log_sample=is_evaluate) 50 | if is_evaluate: 51 | model.eval_and_save(step=i, sample_node_id=node_id, split=split) 52 | 53 | results.update(model.eval_and_save(step=i, sample_node_id=node_id, split=split)) 54 | logger.info("Evaluation finished") 55 | wandb_finish(results) 56 | 57 | 58 | if __name__ == "__main__": 59 | run_inference() 60 | -------------------------------------------------------------------------------- /src/scripts/run_sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0] 5 | os.chdir(root_path) 6 | sys.path.append(root_path + 'src') 7 | 8 | from utils.basics import init_env_variables, print_important_cfg, time_logger 9 | from tqdm import tqdm 10 | from math import ceil 11 | 12 | init_env_variables() 13 | 14 | from utils.pkg.distributed import initialize_deepspeed, initialize_distributed 15 | from utils.project.exp import init_experiment 16 | import logging 17 | import hydra 18 | 19 | logging.getLogger("transformers").setLevel(logging.WARNING) 20 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 21 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 22 | 23 | from graph_text.agent import DeepSpeedAgent, Agent 24 | from graph_text.graph_instruction_dataset import GraphInstructionDataset, load_graph_sft_dataset 25 | from graph_text.model import GraphText 26 | from utils.data.textual_graph import TextualGraph 27 | import torch as th 28 | 29 | 30 | @time_logger() 31 | @hydra.main(config_path=f'{root_path}/configs', config_name='main', version_base=None) 32 | def train_graph_text_sft(cfg): 33 | cfg, logger = init_experiment(cfg) 34 | data = TextualGraph(cfg=cfg) 35 | 36 | cfg.hidden_dim = {f: data.g.ndata[f].shape[-1] for f in data.g.ndata.keys()} 37 | is_cpu_debug = not th.cuda.is_available() 38 | if is_cpu_debug: 39 | cfg.llm.base_model = 'tinygpt' 40 | cfg.use_bf16 = False 41 | else: 42 | cfg.use_bf16 = th.cuda.is_bf16_supported() and cfg.use_bf16 43 | 44 | initialize_distributed(cfg) 45 | initialize_deepspeed(cfg) 46 | if cfg.get('use_flash_attn', False): # and ( == 'Ampere': # CPU Debug only 47 | logger.critical('Using FlashAttn2 for training') 48 | from graph_text.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 49 | replace_llama_attn_with_flash_attn() 50 | else: 51 | logger.critical('FlashAttn2 disabled for training') 52 | logger.critical(f'eq_batch_size={cfg.eq_batch_size}, bsz_per_gpu={cfg.bsz_per_gpu}, ' 53 | f'grad_acc_steps={cfg.grad_acc_steps}') 54 | model = GraphText(cfg, data, logger) 55 | if cfg.use_deepspeed: 56 | logger.critical('Using DeepSpeed agent for training') 57 | agent = DeepSpeedAgent(model, cfg, data, logger) 58 | else: 59 | model = model.to(model.device) 60 | logger.critical(f'Using normal agent for training.') 61 | agent = Agent(model, cfg, data, logger) 62 | 63 | print_important_cfg(cfg, logger.debug) 64 | # Initialize DataLoaders 65 | batch_size = cfg.world_size * cfg.ds['train_micro_batch_size_per_gpu'] 66 | full_dataset = GraphInstructionDataset(data, cfg, cfg.mode) 67 | # ! Full data for link prediction 68 | train_ids = data.split_ids['train'][:cfg.data.max_train_samples] 69 | train_data, train_iter, sampler = load_graph_sft_dataset( 70 | cfg, 71 | full_dataset=full_dataset, split_ids=train_ids, 72 | batch_size=batch_size, 73 | split='train', world_size=cfg.world_size, rank=cfg.local_rank 74 | ) 75 | 76 | eval_iter_dict = {split: load_graph_sft_dataset( 77 | cfg, 78 | full_dataset=full_dataset, 79 | batch_size=cfg.inf_batch_size, # Full test evaluate 80 | split_ids=data.split_ids[split][:cfg.data.max_eval_samples if split != 'test' else cfg.data.max_test_samples], 81 | split=split, world_size=cfg.world_size, rank=cfg.local_rank 82 | )[1] for split in cfg.get('eval_sets', ['train', 'val', 'test'])} 83 | 84 | epochs = min(cfg.get('max_epochs', 1000), ceil(ceil(cfg.total_steps / (len(train_data) / cfg.eq_batch_size)))) 85 | logger.warning(f'Begin training {cfg.total_steps} steps ({epochs} epochs).') 86 | current_step = 0 87 | is_eval = cfg.local_rank == 0 and 'c' in cfg.out_field 88 | pbar_refresh_freq = max(agent.total_batch_steps // 100, 10) 89 | pbar = tqdm(total=agent.total_batch_steps, desc="Training", dynamic_ncols=True, disable=cfg.local_rank > 0) 90 | for epoch_i in range(epochs): 91 | logger.critical(f'Started epoch {epoch_i}.') 92 | for batch in train_iter: 93 | results = agent.train_model_batch(batch, current_step=current_step) 94 | if is_eval and current_step % cfg.eval_freq == 0 and current_step >= cfg.min_eval_step: 95 | eval_results = agent.evaluate(eval_iter_dict, logger) 96 | results.update(eval_results) 97 | logger.wandb_metric_log({**results, **{'train/epoch': epoch_i}}) 98 | agent.torch_distributed_barrier() 99 | 100 | if current_step % cfg.save_freq == 0 and epoch_i > 0 and not is_cpu_debug: 101 | agent.save_model(cfg.save_path, current_step) 102 | if current_step % pbar_refresh_freq == 0: 103 | pbar.update(pbar_refresh_freq) 104 | 105 | current_step += 1 # Every gradient update or every batch forward 106 | if current_step >= agent.total_batch_steps: 107 | break 108 | pbar.close() 109 | # save at the end of the training 110 | agent.save_model(cfg.save_path, current_step, is_final=True) 111 | # update final valid and test acc 112 | final_results = logger.lookup_metric_checkpoint_by_best_eval('val_acc', out_metrics=['val_acc', 'test_acc']) 113 | logger.wandb_summary_update(final_results, finish_wandb=True) 114 | 115 | 116 | if __name__ == "__main__": 117 | import cProfile 118 | import pstats 119 | 120 | with cProfile.Profile() as pr: 121 | train_graph_text_sft() 122 | stats = pstats.Stats(pr) 123 | stats.sort_stats(pstats.SortKey.TIME) 124 | stats.dump_stats(filename='profiling.prof') 125 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.basics.os_utils import * 2 | -------------------------------------------------------------------------------- /src/utils/basics/__init__.py: -------------------------------------------------------------------------------- 1 | from .cfg_utils import * 2 | from .iterables import * 3 | from .logging import * 4 | from .np_utils import * 5 | from .os_utils import * 6 | 7 | PROJ_CFG = {} 8 | -------------------------------------------------------------------------------- /src/utils/basics/cfg_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | from easydict import EasyDict 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from .logging import logger 9 | from .os_utils import append_header_to_file, subset_dict_by_condition 10 | 11 | UNIMPORTANT_CFG = EasyDict( 12 | fields=['gpus', 'debug', 'wandb', 'env', 'uid', 13 | 'local_rank', 'cmd', 'file_prefix'], 14 | prefix=['_'], 15 | postfix=['_path', '_file', '_dir'] 16 | ) 17 | 18 | 19 | def cfg_dict(cfg): 20 | if isinstance(cfg, DictConfig): 21 | return EasyDict(OmegaConf.to_object(cfg)) 22 | elif isinstance(cfg, dict): 23 | return cfg 24 | else: 25 | raise ValueError(f'Unsupported config type for {type(cfg)}') 26 | 27 | 28 | def cfg_to_file_name(cfg, compress=False): 29 | """To avoid conflict while launching multiple parallel runs together, 30 | here we map the config to a unique string file name. 31 | """ 32 | MAX_FNAME = 255 33 | 34 | def _cfg_to_str(cfg): 35 | if isinstance(cfg, dict): 36 | if '_file_' in cfg: 37 | return cfg['_file_'] 38 | return ''.join([_cfg_to_str(_) for _ in cfg.values()]) 39 | else: 40 | return str(cfg) 41 | 42 | _cfg = cfg_dict(cfg) if isinstance(cfg, DictConfig) else cfg 43 | imp_cfg = get_important_cfg(cfg) 44 | s = f"{_cfg_to_str(imp_cfg)}" 45 | 46 | for _ in ['/', ' ', ':', 'class', 'None', '~', 'yaml', 47 | '\'', '[', ']', '(', ')', '{', '}', '.', ',']: 48 | s = s.replace(_, '') 49 | 50 | for k, v in {'True': 'T', 'False': 'F'}.items(): 51 | s = s.replace(k, v) 52 | 53 | if compress: 54 | _map = lambda list_of_ascii: 61 + sum(list_of_ascii) % 29 55 | if len(s) > MAX_FNAME: 56 | # Post-processing to map to length <255 57 | # map to devision compression method, the prime is selected as 29 58 | ascii_code = np.array(list(s.encode('ascii'))) 59 | compressed_ascii = [_map(_) for _ in np.array_split(ascii_code, MAX_FNAME)] 60 | s = ''.join(chr(_) for _ in compressed_ascii) 61 | else: 62 | s = '/'.join([s[i:i + 255] for i in range(0, len(s), 255)]) 63 | return s 64 | 65 | 66 | # ! Get config 67 | 68 | def add_cmd_line_args_to_hydra_cfg(cfg: DictConfig): 69 | # Previously, we need to access choices, e.g. exp, in default list 70 | # Abandoned. Using ${hydra:runtime.choices.exp} instead 71 | OmegaConf.set_struct(cfg, False) 72 | cmd_arg = {} 73 | for _ in sys.argv: 74 | if '.py' not in _: 75 | sep = '=' if '=' in _ else ' ' 76 | k, v = _.split(sep) 77 | cmd_arg[k] = v 78 | cfg.cmd = cmd_arg 79 | 80 | 81 | def save_cfg(cfg: DictConfig, path, as_global=True): 82 | processed_cfg = get_important_cfg(cfg) 83 | OmegaConf.save(config=DictConfig(processed_cfg), f=path) 84 | if as_global: 85 | append_header_to_file(path, header='# @package _global_\n') 86 | return cfg 87 | 88 | 89 | def get_important_cfg(cfg: DictConfig, reserve_file_cfg=True, unimportant_cfg=UNIMPORTANT_CFG): 90 | uimp_cfg = cfg.get('_unimportant_cfg', unimportant_cfg) 91 | imp_cfg = OmegaConf.to_object(cfg) 92 | 93 | def is_preserve(k: str): 94 | judge_file_setting = k == '_file_' and reserve_file_cfg 95 | prefix_allowed = (not any([k.startswith(_) for _ in uimp_cfg.prefix])) or judge_file_setting 96 | postfix_allowed = not any([k.endswith(_) for _ in uimp_cfg.postfix]) 97 | field_allowed = k not in uimp_cfg.fields 98 | return prefix_allowed and postfix_allowed and field_allowed 99 | 100 | imp_cfg = subset_dict_by_condition(imp_cfg, is_preserve) 101 | return imp_cfg 102 | 103 | 104 | def print_important_cfg(cfg, log_func=logger.info): 105 | log_func(OmegaConf.to_yaml(get_important_cfg(cfg, reserve_file_cfg=False))) 106 | 107 | 108 | # ! Custom OmegaConf Resolvers 109 | 110 | def replace_dot_by_underscore(input_str): 111 | return input_str.replace('.', '_') 112 | 113 | 114 | def calc_bsz_and_grad_acc_steps(eq_batch_size, max_bsz_per_gpu, min_bsz=2): 115 | if (gpus := os.environ.get('CUDA_VISIBLE_DEVICES', '')) is not None: 116 | n_gpus = len(gpus.split(',')) if gpus != '' else 1 117 | else: # CPU Running 118 | return eq_batch_size, 1 119 | 120 | def _get_bsz_per_gpu(bsz_per_gpu): 121 | # Find batch_size and grad_acc_steps combination that are DIVISIBLE! 122 | grad_acc_steps = eq_batch_size / bsz_per_gpu / n_gpus 123 | if grad_acc_steps.is_integer(): 124 | return bsz_per_gpu, int(grad_acc_steps) 125 | elif grad_acc_steps: 126 | if bsz_per_gpu >= min_bsz: 127 | return _get_bsz_per_gpu(bsz_per_gpu - 1) 128 | else: 129 | raise ValueError( 130 | f'Cannot find grad_acc_step with integer batch_size greater than {min_bsz}, ' 131 | f'eq_bsz={eq_batch_size}, n_gpus={n_gpus}') 132 | 133 | batch_size, grad_acc_steps = _get_bsz_per_gpu(max_bsz_per_gpu) 134 | # print(f'Eq_batch_size = {eq_batch_size}, bsz={batch_size}, grad_acc_steps={grad_acc_steps}, ngpus={n_gpus}') 135 | return batch_size, grad_acc_steps 136 | 137 | 138 | def get_bsz_per_gpu(eq_batch_size, max_bsz_per_gpu, min_bsz=2): 139 | return calc_bsz_and_grad_acc_steps(eq_batch_size, max_bsz_per_gpu, min_bsz)[0] 140 | 141 | 142 | def get_grad_acc_steps(eq_batch_size, max_bsz_per_gpu, min_bsz=2): 143 | return calc_bsz_and_grad_acc_steps(eq_batch_size, max_bsz_per_gpu, min_bsz)[1] 144 | 145 | 146 | def int_divide(a, b): 147 | # python integer division 148 | return a // b 149 | 150 | 151 | def ternary_operator(condition, val_if_true, val_if_false): 152 | return val_if_true if condition else val_if_false 153 | 154 | 155 | def round_mult(a, b): 156 | return round(a * b) 157 | 158 | 159 | # Register resolvers 160 | OmegaConf.register_new_resolver('get_bsz_per_gpu', get_bsz_per_gpu) 161 | OmegaConf.register_new_resolver('get_grad_acc_steps', get_grad_acc_steps) 162 | OmegaConf.register_new_resolver('condition', ternary_operator) 163 | OmegaConf.register_new_resolver('round_mult', round_mult) 164 | OmegaConf.register_new_resolver('replace_dot_by_underscore', replace_dot_by_underscore) 165 | -------------------------------------------------------------------------------- /src/utils/basics/iterables.py: -------------------------------------------------------------------------------- 1 | def has_intersection(iterable1, iterable2): 2 | return len(set(iterable1), set(iterable2)) > 0 3 | 4 | 5 | # * ============================= Itertool Related ============================= 6 | 7 | def lot_to_tol(list_of_tuple): 8 | # list of tuple to tuple lists 9 | # Note: zip(* zipped_file) is an unzip operation 10 | return list(map(list, zip(*list_of_tuple))) 11 | -------------------------------------------------------------------------------- /src/utils/basics/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | 4 | import hydra 5 | import wandb 6 | from rich.console import Console 7 | from rich.logging import RichHandler 8 | from rich.traceback import install 9 | 10 | install() 11 | logging.basicConfig( 12 | level="INFO", format="%(message)s", datefmt="[%X]", 13 | handlers=[RichHandler( 14 | rich_tracebacks=False, tracebacks_suppress=[hydra], 15 | console=Console(width=165), 16 | enable_link_path=False 17 | )], 18 | ) 19 | # Default logger 20 | logger = rich_logger = logging.getLogger("rich") 21 | # from rich.traceback import install 22 | # install(show_locals=True, width=150, suppress=[hydra]) 23 | logger.info("Rich Logger initialized.") 24 | 25 | NonPercentageFloatMetrics = ['loss', 'time'] 26 | 27 | 28 | def get_best_by_val_perf(res_list, prefix, metric): 29 | results = max(res_list, key=lambda x: x[f'val_{metric}']) 30 | return {f'{prefix}_{k}': v for k, v in results.items()} 31 | 32 | 33 | def judge_by_partial_match(k, match_dict, case_sensitive=False): 34 | k = k if case_sensitive else k.lower() 35 | return len([m for m in match_dict if m in k]) > 0 36 | 37 | 38 | def metric_processing(log_dict): 39 | # Round floats and process percentage 40 | for k, v in log_dict.items(): 41 | if isinstance(v, float): 42 | is_percentage = not judge_by_partial_match(k, NonPercentageFloatMetrics) 43 | if is_percentage: 44 | log_dict[k] *= 100 45 | log_dict[k] = round(log_dict[k], 4) 46 | return log_dict 47 | 48 | 49 | def get_split(metric): 50 | split = 'train' 51 | if 'val' in metric: 52 | split = 'val' 53 | elif 'test' in metric: 54 | split = 'test' 55 | return split 56 | 57 | 58 | class WandbExpLogger: 59 | '''Wandb Logger with experimental metric saving logics''' 60 | 61 | def __init__(self, cfg): 62 | self.wandb = cfg.wandb 63 | self.wandb_on = cfg.wandb.id is not None 64 | self.local_rank = cfg.local_rank 65 | self.logger = rich_logger # Rich logger 66 | self.logger.setLevel(getattr(logging, cfg.logging.level.upper())) 67 | self.info = self.logger.info 68 | self.critical = self.logger.critical 69 | self.warning = self.logger.warning 70 | self.debug = self.logger.debug 71 | self.info = self.logger.info 72 | self.error = self.logger.error 73 | self.log_metric_to_stdout = (not self.wandb_on and cfg.local_rank <= 0) or \ 74 | cfg.logging.log_wandb_metric_to_stdout 75 | # ! Experiment Metrics 76 | self.results = defaultdict(list) 77 | 78 | # ! Log functions 79 | def log(self, *args, level='', **kwargs): 80 | if self.local_rank <= 0: 81 | self.logger.log(getattr(logging, level.upper()), *args, **kwargs) 82 | 83 | def log_fig(self, fig_name, fig_file): 84 | if wandb.run is not None and self.local_rank <= 0: 85 | wandb.log({fig_name: wandb.Image(fig_file)}) 86 | else: 87 | self.error('Figure not logged to Wandb since Wandb is off.', 'ERROR') 88 | 89 | def wandb_metric_log(self, metric_dict, level='info'): 90 | # Preprocess metric 91 | metric_dict = metric_processing(metric_dict) 92 | for metric, value in metric_dict.items(): 93 | self.results[metric].append(value) 94 | 95 | if wandb.run is not None and self.local_rank <= 0: 96 | wandb.log(metric_dict) 97 | if self.log_metric_to_stdout: 98 | self.log(metric_dict, level=level) 99 | 100 | def lookup_metric_checkpoint_by_best_eval(self, eval_metric, out_metrics=None): 101 | if len(self.results[eval_metric]) == 0: 102 | return {} 103 | best_val_ind = self.results[eval_metric].index(max(self.results[eval_metric])) 104 | out_metrics = out_metrics or self.results.keys() 105 | return {m: self.results[m][best_val_ind] for m in out_metrics} 106 | 107 | # ! Experiment metrics functions 108 | def wandb_summary_update(self, result): 109 | # ! Finish wandb 110 | if wandb.run is not None and self.local_rank <= 0: 111 | wandb.summary.update(result) 112 | 113 | def save_file_to_wandb(self, file, base_path, policy='now', **kwargs): 114 | if wandb.run is not None and self.local_rank <= 0: 115 | wandb.save(file, base_path=base_path, policy=policy, **kwargs) 116 | 117 | 118 | def wandb_finish(result=None): 119 | if wandb.run is not None: 120 | wandb.summary.update(result or {}) 121 | wandb.finish() 122 | -------------------------------------------------------------------------------- /src/utils/basics/np_utils.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import numpy as np 4 | from tqdm import tqdm as tqdm 5 | 6 | 7 | def _judge_type(data): 8 | min_val, max_val = data.min(), data.max() 9 | _dtype = type(min_val) 10 | if np.issubdtype(_dtype, np.integer): 11 | if max_val <= 1 and min_val >= 0: 12 | _dtype = np._bool 13 | if max_val <= 255 and min_val >= 0: 14 | _dtype = np.uint8 15 | elif max_val <= 65535 and min_val >= 0: 16 | _dtype = np.uint16 17 | elif max_val <= 2147483647 and min_val >= -2147483647: 18 | _dtype = np.int32 19 | elif np.issubdtype(_dtype, np.float): 20 | _dtype = np.float16 21 | return _dtype 22 | 23 | 24 | def save_memmap(data: np.ndarray, path, dtype=None, node_chunk_size=200000, log=print): 25 | # ! Determine the least memory cost type 26 | 27 | dtype = _judge_type(data) if dtype is None else dtype 28 | 29 | # ! Store memory map 30 | x = np.memmap(path, dtype=dtype, mode='w+', 31 | shape=data.shape) 32 | 33 | for i in tqdm(range(0, data.shape[0], node_chunk_size)): 34 | j = min(i + node_chunk_size, data.shape[0]) 35 | x[i:j] = data[i:j] 36 | log(f'Saved {path} as {dtype}...') 37 | del x 38 | gc.collect() 39 | log('releas x') 40 | return # EasyDict(type=dtype, path=path, shape=data.shape) 41 | -------------------------------------------------------------------------------- /src/utils/basics/os_utils.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import json 3 | import logging 4 | import os 5 | import pickle 6 | import socket 7 | import subprocess 8 | import sys 9 | import time 10 | from collections import OrderedDict 11 | from contextlib import ContextDecorator 12 | from datetime import datetime 13 | from functools import wraps 14 | from pprint import pformat 15 | 16 | import numpy as np 17 | import pytz 18 | from omegaconf import OmegaConf 19 | 20 | from utils.basics.logging import logger 21 | 22 | cur_path = os.path.abspath(os.path.dirname(__file__)) 23 | root_path = cur_path.split('src')[0] 24 | 25 | 26 | def init_env_variables(cfg=None, env_cfg_file=f'{root_path}configs/user/env.yaml'): 27 | if cfg is None and os.path.exists(env_cfg_file): 28 | cfg = OmegaConf.load(env_cfg_file) 29 | if 'env' in cfg and 'vars' in cfg.env: 30 | for k, v in cfg.env.vars.items(): 31 | k = k.upper() 32 | os.environ[k] = v 33 | # ! Insert conda path to the first place 34 | if (conda_path := os.environ.get('CONDA_EXE')) is not None: 35 | conda_bin_dir = conda_path.rstrip('conda') 36 | os.environ['PATH'] = f"{conda_bin_dir}:{os.environ['PATH']}" 37 | 38 | return cfg 39 | 40 | 41 | def run_command(cmd, gpus=[], log_func=print, parallel=True): 42 | if parallel and len(gpus) > 1: 43 | # ! Generate parallel commands with torchrun 44 | _ = cmd.split('python ') 45 | env_path, variables = _[0], _[1] 46 | gpus_ = ",".join([str(_) for _ in gpus]) 47 | cmd = f'CUDA_VISIBLE_DEVICES={gpus_} {env_path}' \ 48 | f'torchrun ' \ 49 | f'--master_port={find_free_port()} --nproc_per_node={len(gpus)} {variables}' 50 | 51 | log_func(f'Running command:\n{cmd}') 52 | ret_value = os.system(cmd) 53 | cmd_to_print = 'python' + cmd.split("python")[-1] 54 | if ret_value != 0: 55 | raise ValueError(f'Failed to operate {cmd_to_print}') 56 | 57 | 58 | def mkdir_p(path, enable_log=True): 59 | """Create a directory for the specified path. 60 | Parameters 61 | ---------- 62 | path : str 63 | Path name 64 | enable_log : bool 65 | Whether to print result for directory creation 66 | """ 67 | import errno 68 | if os.path.exists(path): return 69 | # logger.info(path) 70 | # path = path.replace('\ ',' ') 71 | # logger.info(path) 72 | try: 73 | os.makedirs(path) 74 | if enable_log: 75 | logger.info('Created directory {}'.format(path)) 76 | except OSError as exc: 77 | if exc.errno == errno.EEXIST and os.path.isdir(path) and logger: 78 | logger.info('Directory {} already exists.'.format(path)) 79 | else: 80 | raise 81 | 82 | 83 | def mkdir_list(p_list, use_relative_path=True, enable_log=True): 84 | """Create directories for the specified path lists. 85 | Parameters 86 | ---------- 87 | p_list :Path lists or a single path 88 | 89 | """ 90 | # ! Note that the paths MUST END WITH '/' !!! 91 | root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0] 92 | p_list = p_list if isinstance(p_list, list) else [p_list] 93 | for p in p_list: 94 | p = os.path.join(root_path, p) if use_relative_path else p 95 | p = get_dir_of_file(p) 96 | mkdir_p(p, enable_log) 97 | 98 | 99 | def find_free_port(): 100 | from contextlib import closing 101 | import socket 102 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 103 | s.bind(('', 0)) 104 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 105 | return s.getsockname()[1] 106 | 107 | 108 | def check_path_dict_exists(path_dict): 109 | # Check if all paths in path_dict already exists. 110 | try: 111 | for k, p in path_dict.items(): 112 | assert os.path.exists(p), f'{k} not found.' 113 | return True 114 | except: 115 | return False 116 | 117 | 118 | def init_path(dir_or_file_list): 119 | if isinstance(dir_or_file_list, list): 120 | return [_init_path(_) for _ in dir_or_file_list] 121 | else: # single file 122 | return _init_path(dir_or_file_list) 123 | 124 | 125 | def _init_path(dir_or_file): 126 | if dir_or_file.startswith('~'): 127 | dir_or_file = os.path.expanduser(dir_or_file) 128 | path = get_dir_of_file(dir_or_file) 129 | if not os.path.exists(path): 130 | mkdir_p(path) 131 | return dir_or_file.replace('//', '/') 132 | 133 | 134 | def list_dir(dir_name, error_msg=None): 135 | try: 136 | f_list = os.listdir(dir_name) 137 | return f_list 138 | except FileNotFoundError: 139 | if error_msg is not None: 140 | logger.info(f'{error_msg}') 141 | return [] 142 | 143 | 144 | def remove_file_or_path(file_or_path, enable_log=True): 145 | # Modified from 'https://stackoverflow.com/questions/10840533/most-pythonic-way-to-delete-a-file-which-may-not 146 | # -exist' 147 | import shutil 148 | try: 149 | if file_or_path[-1] == '/': 150 | shutil.rmtree(file_or_path) 151 | else: 152 | os.remove(file_or_path) 153 | if enable_log: 154 | logger.warning(f'{file_or_path} removed!') 155 | except OSError as e: # this would be "except OSError, e:" before Python 2.6 156 | if e.errno != errno.ENOENT: # errno.ENOENT = no such file or directory 157 | raise # re-raise exception if a different error occurred 158 | 159 | 160 | def remove_file(f_list): 161 | 'Remove file or file list' 162 | f_list = f_list if isinstance(f_list, list) else [f_list] 163 | for f_name in f_list: 164 | remove_file_or_path(f_name) 165 | 166 | 167 | def get_dir_of_file(f_name): 168 | return os.path.dirname(f_name) + '/' 169 | 170 | 171 | def get_grand_parent_dir(f_name): 172 | from pathlib import Path 173 | if '.' in f_name.split('/')[-1]: # File 174 | return get_grand_parent_dir(get_dir_of_file(f_name)) 175 | else: # Path 176 | return f'{Path(f_name).parent}/' 177 | 178 | 179 | def get_abs_path(f_name, style='command_line'): 180 | # python 中的文件目录对空格的处理为空格,命令行对空格的处理为'\ '所以命令行相关需 replace(' ','\ ') 181 | if style == 'python': 182 | cur_path = os.path.abspath(os.path.dirname(__file__)) 183 | else: # style == 'command_line': 184 | cur_path = os.path.abspath(os.path.dirname(__file__)).replace(' ', '\ ') 185 | 186 | root_path = cur_path.split('src')[0] 187 | return os.path.join(root_path, f_name) 188 | 189 | 190 | def pickle_save(var, f_name): 191 | init_path(f_name) 192 | pickle.dump(var, open(f_name, 'wb')) 193 | logger.info(f'Saved {f_name}') 194 | 195 | 196 | def pickle_load(f_name): 197 | return pickle.load(open(f_name, 'rb')) 198 | 199 | 200 | def strtobool(val): 201 | """Convert a string representation of truth to true (1) or false (0). 202 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 203 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 204 | 'val' is anything else. 205 | """ 206 | if isinstance(val, bool): 207 | return val 208 | val = val.lower() 209 | if val in ('y', 'yes', 't', 'true', 'on', '1'): 210 | return True 211 | elif val in ('n', 'no', 'f', 'false', 'off', '0'): 212 | return False 213 | else: 214 | raise ValueError("invalid truth value %r" % (val,)) 215 | 216 | 217 | # * <<<<<<<<<<<<<<<<<<<< GIT >>>>>>>>>>>>>>>>>>>> 218 | 219 | def get_git_hash(): 220 | return subprocess.run(['git', 'rev-parse', '--short', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8').strip( 221 | '\n') 222 | 223 | 224 | # * <<<<<<<<<<<<<<<<<<<< PROJ SHARED UTILS >>>>>>>>>>>>>>>>>>>> 225 | def floor_quantize(val, to_values): 226 | """Quantize a value with regard to a set of allowed values. 227 | 228 | Examples: 229 | quantize(49.513, [0, 45, 90]) -> 45 230 | quantize(17, [0, 10, 20, 30]) -> 10 # FLOORED 231 | 232 | Note: function doesn't assume to_values to be sorted and 233 | iterates over all values (i.e. is rather slow). 234 | 235 | Args: 236 | val The value to quantize 237 | to_values The allowed values 238 | Returns: 239 | Closest value among allowed values. 240 | """ 241 | best_match = None 242 | best_match_diff = None 243 | assert min(to_values) <= val 244 | for other_val in to_values: 245 | if other_val <= val: # Floored (only smaller values are matched) 246 | diff = abs(other_val - val) 247 | if best_match is None or diff < best_match_diff: 248 | best_match = other_val 249 | best_match_diff = diff 250 | return best_match 251 | 252 | 253 | def json_save(data, file_name, log_func=print): 254 | with open(init_path(file_name), 'w', encoding='utf-8') as f: 255 | try: 256 | json.dumps(data) 257 | except: 258 | log_func(f"{data['Static logs']} failed to save in json format.") 259 | json.dump(data, f, ensure_ascii=False, indent=4) 260 | # log_func(f'Successfully saved to {file_name}') 261 | 262 | 263 | def json_load(file_name): 264 | with open(file_name) as data_file: 265 | return json.load(data_file) 266 | 267 | 268 | # * ============================= Init ============================= 269 | 270 | def exp_init(args): 271 | """ 272 | Functions: 273 | - Set GPU 274 | - Initialize Seeds 275 | - Set log level 276 | """ 277 | from warnings import simplefilter 278 | simplefilter(action='ignore', category=DeprecationWarning) 279 | # if not hasattr(args, 'local_rank'): 280 | if args.gpus is not None and args.gpus != '-1': 281 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 282 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 283 | if hasattr(args, 'local_rank') and args.local_rank > 1: block_log() 284 | # Torch related packages should be imported afterward setting 285 | init_random_state(args.seed) 286 | os.chdir(root_path) 287 | 288 | 289 | def init_random_state(seed=0): 290 | # Libraries using GPU should be imported after specifying GPU-ID 291 | import torch 292 | import random 293 | # import dgl 294 | # dgl.seed(seed) 295 | # dgl.random.seed(seed) 296 | random.seed(seed) 297 | np.random.seed(seed) 298 | torch.manual_seed(seed) 299 | torch.cuda.manual_seed_all(seed) 300 | 301 | 302 | def is_runing_on_local(): 303 | try: 304 | host_name = socket.gethostname() 305 | if 'MacBook' in host_name: 306 | return True 307 | except: 308 | logger.info("Unable to get Hostname and IP") 309 | return False 310 | 311 | 312 | def set_kwargs_default(kwargs, default_dict): 313 | for k, v in default_dict.items(): 314 | kwargs.setdefault(k, v) 315 | 316 | 317 | # * ============================= Print Related ============================= 318 | def subset_dict(d, sub_keys): 319 | return {k: d[k] for k in sub_keys if k in d} 320 | 321 | 322 | def subset_dict_by_condition(d, is_preserve=lambda x: True): 323 | # Filter keys in current dictionary 324 | if isinstance(d, dict): 325 | d = {k: v for k, v in d.items() if is_preserve(k)} 326 | # Filter keys in sub dictionary 327 | for key in d.keys(): 328 | if isinstance(d[key], dict) and is_preserve(key): 329 | d[key] = subset_dict_by_condition(d[key], is_preserve) 330 | return d 331 | 332 | 333 | def ordered_dict_str(d): 334 | return '\n'.join(f'{k}: {pformat(v)}' for k, v in OrderedDict(d).items()) 335 | 336 | 337 | def block_log(): 338 | sys.stdout = open(os.devnull, 'w') 339 | logger = logging.getLogger() 340 | logger.disabled = True 341 | 342 | 343 | def enable_logs(): 344 | # Restore 345 | sys.stdout = sys.__stdout__ 346 | logger = logging.getLogger() 347 | logger.disabled = False 348 | 349 | 350 | # def print_log(log_dict): 351 | # log_ = lambda log: f'{log:.4f}' if isinstance(log, float) else f'{log:04d}' 352 | # logger.info(' | '.join([f'{k} {log_(v)}' for k, v in log_dict.items()])) 353 | 354 | 355 | def mp_list_str(mp_list): 356 | return '_'.join(mp_list) 357 | 358 | 359 | # * ============================= Time Related ============================= 360 | 361 | def time2str(t): 362 | if t > 86400: 363 | return '{:.2f}day'.format(t / 86400) 364 | if t > 3600: 365 | return '{:.2f}h'.format(t / 3600) 366 | elif t > 60: 367 | return '{:.2f}min'.format(t / 60) 368 | else: 369 | return '{:.2f}s'.format(t) 370 | 371 | 372 | def get_cur_time(timezone='Asia/Shanghai', t_format='%m-%d %H:%M:%S'): 373 | return datetime.fromtimestamp(int(time.time()), pytz.timezone(timezone)).strftime(t_format) 374 | 375 | 376 | class time_logger(ContextDecorator): 377 | def __init__(self, name=None, log_func=logger.info): 378 | self.name = name 379 | self.log_func = log_func 380 | 381 | def __enter__(self): 382 | self.start_time = time.time() 383 | self.log_func(f'Started {self.name} at {get_cur_time()}') 384 | return self 385 | 386 | def __exit__(self, *exc): 387 | self.log_func(f'Finished {self.name} at {get_cur_time()}, running time = ' 388 | f'{time2str(time.time() - self.start_time)}.') 389 | return False 390 | 391 | def __call__(self, func): 392 | self.name = self.name or func.__name__ 393 | self.start_time = None 394 | 395 | @wraps(func) 396 | def decorator(*args, **kwargs): 397 | with self: 398 | return func(*args, **kwargs) 399 | 400 | return decorator 401 | 402 | 403 | # * ============================= Parser Related ============================= 404 | def parse_conf(parser, input): 405 | """Update parser by input (Dictionary or namespace)""" 406 | # Get default parser and update 407 | args = parser.parse_args([]) 408 | d = input if type(input) == dict else input.__dict__ 409 | args.__dict__.update({k: v for k, v in d.items() if k in args.__dict__}) 410 | return args 411 | 412 | 413 | def args_to_cmd(parser, input, allow_unknown_args=False, to_str=True): 414 | """Convert parser and input to args""" 415 | default = vars(parser.parse_args([])) 416 | d = input if type(input) == dict else input.__dict__ 417 | type_spec_parse_func = { 418 | **{_: lambda k, v: f'--{k}={v}' for _ in (int, float, str)}, 419 | bool: lambda k, v: f'--{k}' if default[k] != v else '', 420 | list: lambda k, v: f'--{k}={" ".join([str(_) for _ in v])}', 421 | } 422 | 423 | is_parse = lambda k: True if allow_unknown_args else lambda k: k in default 424 | parse_func = lambda k, v: type_spec_parse_func[type(v)](k, v) if is_parse(k) else '' 425 | rm_empty = lambda input_list: [_ for _ in input_list if len(_) > 0] 426 | cmd_list = rm_empty([parse_func(k, v) for k, v in d.items()]) 427 | if to_str: 428 | return ' '.join(cmd_list) 429 | else: 430 | return cmd_list 431 | 432 | 433 | def append_header_to_file(file_path, header): 434 | with open(file_path, 'r') as input_file: 435 | original_content = input_file.read() 436 | 437 | with open(file_path, 'w') as output_file: 438 | output_file.write(header + original_content) 439 | # def assert_folder_size_limit(folder='temp/', limit=15): 440 | # ''' 441 | # Parameters 442 | # ---------- 443 | # folder : The folder to limit 444 | # limit: The maximum size of a folder (in Gigabytes) 445 | # ------- 446 | # 447 | # ''' 448 | # now = time.time() 449 | # # !LINUX COMMAND: du -lh --max-depth=1 450 | # while os.path.getsize(folder): 451 | # files = [os.path.join(folder, filename) for filename in os.listdir(folder)] 452 | # for filename in files: 453 | # if (now - os.stat(filename).st_mtime) > 1800: 454 | # command = "rm {0}".format(filename) 455 | # subprocess.call(command, shell=True) 456 | -------------------------------------------------------------------------------- /src/utils/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndyJZhao/GraphText/73bde25b2fa9bf89b37b041062a4adfe42363652/src/utils/data/__init__.py -------------------------------------------------------------------------------- /src/utils/data/graph_tree.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | from omegaconf import DictConfig 3 | import pandas as pd 4 | import json 5 | from copy import deepcopy 6 | from utils.pkg.dict2xml import dict2xml 7 | import re 8 | 9 | 10 | class GraphTree: 11 | def __init__(self, data, df, center_node, subg_nodes, hierarchy, name_alias, style='xml', label=None): 12 | self.df = df 13 | self.text = data.text 14 | self.style = style 15 | self.subg_nodes = subg_nodes 16 | self.center_node = center_node 17 | self.hierarchy = hierarchy 18 | self.label = label 19 | self.tree_dict = {} 20 | self.encode_df = pd.DataFrame() 21 | if len(self.df): 22 | grouped_tree_df = self.df.groupby(hierarchy).agg({'nodes': list}).reset_index() 23 | grouped_tree_df['center_node'] = center_node 24 | self.prompt = self.traverse(data, grouped_tree_df, hierarchy, name_alias) 25 | else: 26 | self.prompt = '' 27 | return 28 | 29 | def traverse(self, data, grouped_tree_df, hierarchy, name_alias): 30 | self.continuous_row_dict = [] 31 | 32 | # ! Convert the grouped DataFrame to a nested dictionary (to be traversed) 33 | 34 | def extract_indices(s): 35 | # Find all occurrences of the pattern "" 36 | matches = re.findall(r'', s) 37 | 38 | # Convert the extracted indices to integers and return as a list 39 | return [int(match) for match in matches] 40 | 41 | cont_field_str_template = '' 42 | for index, row in grouped_tree_df.iterrows(): 43 | current_dict = self.tree_dict # Pointer to the current dictionary 44 | 45 | # Traverse through hierarchy levels 46 | for level in hierarchy[:-1]: 47 | level_key = name_alias.get(row[level], row[level]) 48 | 49 | if level_key not in current_dict: 50 | current_dict[level_key] = {} 51 | 52 | current_dict = current_dict[level_key] # Move pointer down 53 | 54 | # Final hierarchy level for leaf nodes 55 | field = row[hierarchy[-1]] 56 | 57 | if row.attr_type in data.in_cont_fields: 58 | if len(row.nodes) > 0: 59 | current_dict[name_alias.get(field, field)] = cont_field_str_template.format(index=index) 60 | else: # Text 61 | content = [data.text.iloc[_][row.attr_type] for _ in row['nodes'] if _ != -1] 62 | if isinstance(content, list): 63 | content = [_ for _ in content if _ != 'NA'] 64 | if len(content) > 0: 65 | current_dict[name_alias.get(field, field)] = str(content) 66 | 67 | if self.style == 'xml': 68 | graph_str = dict2xml(self.tree_dict, wrap="information", indent="\t") 69 | # Post-processing continuous feature 70 | cont_indices = extract_indices(graph_str) 71 | if len(cont_indices) > 0: # Process continuous feature to encode 72 | # ! Store df to encode (only continuous methods needs further encoding) 73 | self.encode_df = grouped_tree_df.loc[cont_indices] 74 | for index, row in self.encode_df.iterrows(): 75 | placeholder = "".join([f"<{row.attr_type} emb>" for _ in row['nodes']]) 76 | placeholder = f"<{row.attr_type}> {placeholder} " 77 | graph_str = graph_str.replace(cont_field_str_template.format(index=index), placeholder) 78 | assert len(extract_indices(graph_str)) == 0 79 | else: 80 | raise ValueError(f'Unsupported prompt style {self.style}') 81 | return graph_str 82 | 83 | def __str__(self): 84 | return self.prompt 85 | 86 | def __repr__(self): 87 | return self.prompt 88 | 89 | 90 | if __name__ == "__main__": 91 | # Create a sample DataFrame 92 | # df = pd.DataFrame({ 93 | # 'node_id': [0, 1, 2, 3, 4, 5] + [1, 3, 5], 94 | # 'SPD': [0, 1, 1, 2, 2, 2] + [1, 2, 2], 95 | # 'feature_type': ['x'] * 6 + ['y'] * 3 96 | # }) 97 | # 98 | # # Display the nested dictionary 99 | # print(f"Nested Dictionary: {tree_as_dict}\nFlattened graph:") 100 | # print() 101 | # print(json.dumps(tree_as_dict, indent=4)) 102 | 103 | df = pd.DataFrame({ 104 | 'node_id': [1, 2, 3, 4, 5, 6, 7, 8], 105 | 'SPD': [10, 20, 10, 30, 20, 10, 30, 30], 106 | 'feature_type': ['A', 'B', 'A', 'B', 'A', 'A', 'B', 'C'], 107 | 'attribute1': ['X', 'X', 'Y', 'Z', 'W', 'Y', 'Z', 'V'], 108 | 'attribute2': [100, 200, 100, 300, 200, 100, 300, 300] 109 | }) 110 | 111 | # Define your hierarchy and aggregation 112 | hierarchy = ['SPD', 'feature_type'] 113 | 114 | agg_dict = {col: 'first' for col in df.columns if col not in hierarchy} 115 | agg_dict['node_id'] = list # Assuming you want to group 'node_id' 116 | 117 | # Group the DataFrame 118 | grouped_df = df.groupby(hierarchy).agg(agg_dict).reset_index() 119 | 120 | print(grouped_df) 121 | -------------------------------------------------------------------------------- /src/utils/data/ppr.py: -------------------------------------------------------------------------------- 1 | # Code modified from https://github.com/TUM-DAML/pprgo_pytorch/blob/master/pprgo/ppr.py 2 | import dgl 3 | import numba 4 | import numpy as np 5 | import scipy.sparse as sp 6 | from tqdm import tqdm 7 | 8 | import utils.basics as uf 9 | from utils.pkg.graph_utils import k_hop_nb_graph 10 | from utils.pkg.distributed import process_on_master_and_sync_by_pickle 11 | 12 | 13 | def get_row_rank_from_sparse_matrix(A, k): 14 | """ 15 | Row rank for each row in a sparse matrix A 16 | Note that: 17 | 1. only the top k ranks are computed for each row 18 | 2. The rank starts at 1, 0 means no rank (rank > k) 19 | :param A: Sparse matrix to rank 20 | :param k: Top K rank to be computed. 21 | :return: R: Sparse matrix with ranks 22 | 23 | Example: 24 | A = sp.csr_matrix([[5, 3, 2, 10, 1], [0, 8, 0, 4, 0], [3, 0, 0, 6, 7]]) 25 | R = get_row_rank_from_sparse_matrix(A, 3) 26 | print("Ranking matrix in CSR format: {R.toarray()}") 27 | [[2 3 0 1 0] 28 | [0 1 0 2 3] 29 | [3 0 0 2 1]] 30 | """ 31 | rows_list = [] # List to collect the sparse rows before stacking them to form R 32 | 33 | # Loop through the rows using the indptr array to find the start and end of each row in the indices and data arrays 34 | for start, end in zip(A.indptr[:-1], A.indptr[1:]): 35 | # Extract the non-zero data and their column indices for this row 36 | row_data = A.data[start:end] 37 | col_indices = A.indices[start:end] 38 | 39 | # Sort the non-zero elements and get their sorted indices 40 | sorted_indices = np.argsort(row_data)[::-1] 41 | 42 | # Get the top k indices among the non-zero elements 43 | top_k_indices = col_indices[sorted_indices[:k]] 44 | 45 | # Only set ranks for the top k non-zero elements, and add 1 to each rank 46 | top_k_ranks = np.arange(len(top_k_indices), dtype=int) + 1 47 | 48 | # Construct a sparse row using top_k_indices and top_k_ranks 49 | sparse_row = sp.csr_matrix( 50 | (top_k_ranks, (np.zeros(len(top_k_indices)), top_k_indices)), 51 | shape=(1, A.shape[1]), 52 | ) 53 | rows_list.append(sparse_row) 54 | 55 | # Stack sparse rows to create the final sparse ranking matrix 56 | R = sp.vstack(rows_list) 57 | return R 58 | 59 | 60 | @numba.njit( 61 | cache=True, 62 | locals={"_val": numba.float32, "res": numba.float32, "res_vnode": numba.float32}, 63 | ) 64 | def _calc_ppr_node(inode, indptr, indices, deg, alpha, epsilon): 65 | alpha_eps = alpha * epsilon 66 | f32_0 = numba.float32(0) 67 | p = {inode: f32_0} 68 | r = {} 69 | r[inode] = alpha 70 | q = [inode] 71 | while len(q) > 0: 72 | unode = q.pop() 73 | 74 | res = r[unode] if unode in r else f32_0 75 | if unode in p: 76 | p[unode] += res 77 | else: 78 | p[unode] = res 79 | r[unode] = f32_0 80 | for vnode in indices[indptr[unode] : indptr[unode + 1]]: 81 | _val = (1 - alpha) * res / deg[unode] 82 | if vnode in r: 83 | r[vnode] += _val 84 | else: 85 | r[vnode] = _val 86 | 87 | res_vnode = r[vnode] if vnode in r else f32_0 88 | if res_vnode >= alpha_eps * deg[vnode]: 89 | if vnode not in q: 90 | q.append(vnode) 91 | 92 | return list(p.keys()), list(p.values()) 93 | 94 | 95 | @numba.njit(cache=True) 96 | def calc_ppr(indptr, indices, deg, alpha, epsilon, nodes): 97 | js = [] 98 | vals = [] 99 | for i, node in enumerate(nodes): 100 | j, val = _calc_ppr_node(node, indptr, indices, deg, alpha, epsilon) 101 | js.append(j) 102 | vals.append(val) 103 | return js, vals 104 | 105 | 106 | @numba.njit(cache=True, parallel=True) 107 | def calc_ppr_topk_parallel(indptr, indices, deg, alpha, epsilon, nodes, topk): 108 | js = [np.zeros(0, dtype=np.int64)] * len(nodes) 109 | vals = [np.zeros(0, dtype=np.float32)] * len(nodes) 110 | for i in numba.prange(len(nodes)): 111 | j, val = _calc_ppr_node(nodes[i], indptr, indices, deg, alpha, epsilon) 112 | j_np, val_np = np.array(j), np.array(val) 113 | idx_topk = np.argsort(val_np)[-topk:] 114 | js[i] = j_np[idx_topk] 115 | vals[i] = val_np[idx_topk] 116 | return js, vals 117 | 118 | 119 | def ppr_topk(adj_matrix, alpha, epsilon, nodes, topk): 120 | """Calculate the PPR matrix approximately using Anderson.""" 121 | 122 | out_degree = np.sum(adj_matrix > 0, axis=1).A1 123 | nnodes = adj_matrix.shape[0] 124 | 125 | neighbors, weights = calc_ppr_topk_parallel( 126 | adj_matrix.indptr, 127 | adj_matrix.indices, 128 | out_degree, 129 | numba.float32(alpha), 130 | numba.float32(epsilon), 131 | nodes, 132 | topk, 133 | ) 134 | 135 | ppr_mat = construct_sparse(neighbors, weights, (len(nodes), nnodes)) 136 | return ppr_mat 137 | 138 | 139 | def ppr_topk_batch(adj_matrix, alpha, epsilon, nodes, topk, batch_size=1000): 140 | out_degree = np.sum(adj_matrix > 0, axis=1).A1 141 | nnodes = adj_matrix.shape[0] 142 | 143 | with tqdm(total=len(nodes), desc="Calculating PPR TopK") as pbar: 144 | 145 | def batch_calc(nodes_batch): 146 | neighbors, weights = calc_ppr_topk_parallel( 147 | adj_matrix.indptr, 148 | adj_matrix.indices, 149 | out_degree, 150 | numba.float32(alpha), 151 | numba.float32(epsilon), 152 | nodes_batch, 153 | topk, 154 | ) 155 | pbar.update(len(nodes_batch)) 156 | return neighbors, weights 157 | 158 | neighbors, weights = [], [] 159 | for i in range(0, len(nodes), batch_size): 160 | nodes_batch = nodes[i : i + batch_size] 161 | n, w = batch_calc(nodes_batch) 162 | neighbors.extend(n) 163 | weights.extend(w) 164 | 165 | ppr_mat = construct_sparse(neighbors, weights, (len(nodes), nnodes)) 166 | return ppr_mat 167 | 168 | 169 | def construct_sparse(neighbors, weights, shape): 170 | i = np.repeat( 171 | np.arange(len(neighbors)), np.fromiter(map(len, neighbors), dtype=np.int32) 172 | ) 173 | j = np.concatenate(neighbors) 174 | return sp.coo_matrix((np.concatenate(weights), (i, j)), shape) 175 | 176 | 177 | def calc_approximate_ppr_rank( 178 | g: dgl.DGLGraph, alpha, n_rank, cache_template, topk, eps=1e-4, **kwargs 179 | ): 180 | ppr_mat = topk_approximate_ppr_matrix( 181 | g, alpha=alpha, eps=eps, topk=topk, cache_template=cache_template 182 | ) 183 | ppr_rank = get_row_rank_from_sparse_matrix(ppr_mat, n_rank) 184 | return ppr_rank 185 | 186 | 187 | def topk_approximate_ppr_matrix( 188 | g: dgl.DGLGraph, alpha, eps, topk, cache_template, normalization="row" 189 | ): 190 | @process_on_master_and_sync_by_pickle(cache_arg=0) 191 | def _topk_approximate_ppr_matrix(cache_file): 192 | row, col = dgl.to_bidirected(g).edges() 193 | # build sparse csr matrix from row and col 194 | adj = sp.coo_matrix( 195 | (np.ones(len(row)), (row.numpy(), col.numpy())), 196 | shape=(g.num_nodes(), g.num_nodes()), 197 | ).tocsr() 198 | 199 | """Create a sparse matrix where each node has up to the topk PPR neighbors and their weights.""" 200 | idx = g.nodes().cpu().numpy() # All node IDs 201 | # topk_matrix = ppr_topk(adj, alpha, eps, idx, topk).tocsr() 202 | topk_matrix = ppr_topk_batch(adj, alpha, eps, idx, topk).tocsr() 203 | if normalization == "sym": 204 | # Assume undirected (symmetric) adjacency matrix 205 | deg = adj.sum(1).A1 206 | deg_sqrt = np.sqrt(np.maximum(deg, 1e-12)) 207 | deg_inv_sqrt = 1.0 / deg_sqrt 208 | 209 | row, col = topk_matrix.nonzero() 210 | # assert np.all(deg[idx[row]] > 0) 211 | # assert np.all(deg[col] > 0) 212 | topk_matrix.data = deg_sqrt[idx[row]] * topk_matrix.data * deg_inv_sqrt[col] 213 | elif normalization == "col": 214 | # Assume undirected (symmetric) adjacency matrix 215 | deg = adj.sum(1).A1 216 | deg_inv = 1.0 / np.maximum(deg, 1e-12) 217 | 218 | row, col = topk_matrix.nonzero() 219 | # assert np.all(deg[idx[row]] > 0) 220 | # assert np.all(deg[col] > 0) 221 | topk_matrix.data = deg[idx[row]] * topk_matrix.data * deg_inv[col] 222 | elif normalization == "row": 223 | pass 224 | else: 225 | raise ValueError(f"Unknown PPR normalization: {normalization}") 226 | uf.pickle_save(topk_matrix, cache_file) 227 | 228 | _cache_file = cache_template.format( 229 | alpha=alpha, eps=eps, topk=topk, normalization=normalization 230 | ) 231 | return _topk_approximate_ppr_matrix(_cache_file) 232 | 233 | 234 | @uf.time_logger() 235 | def find_top_k_neighbors_within_khop_ego_subgraph_iter( 236 | g, importance_mat, max_hops, k, padding, ordered=True 237 | ): 238 | """ 239 | Function to find n-hop neighbors and sort them 240 | :param g: DGL Graph 241 | :param importance_mat: [i, j] stands for the importance of j for node i 242 | :param max_hops: max hops to construct 243 | :param k: At most k neighbors are selected 244 | :param padding: Whether pad neighbors to k (by adding -1) 245 | :return: 246 | """ 247 | nb_list = [[] for _ in range(g.number_of_nodes())] 248 | n_neighbors = 0 249 | for i in range(g.number_of_nodes()): 250 | # Initialize with self 251 | current_neighbors = set([i]) 252 | 253 | # Find n-hop neighbors 254 | for h in range(max_hops): 255 | new_neighbors = set() 256 | for neighbor in current_neighbors: 257 | new_neighbors = new_neighbors.union(g.successors(neighbor).tolist()) 258 | current_neighbors = current_neighbors.union(new_neighbors) 259 | 260 | # Remove self-loop if exists 261 | current_neighbors.discard(i) 262 | if len(nb_list[i]) > 0: 263 | uf.logger.warning(f"Node {i} has no neighbors") 264 | # Sort the n-hop neighbors based on importance from CSR matrix A 265 | nb_list[i] = sorted(current_neighbors, key=lambda x: importance_mat[i, x])[:k] 266 | if not ordered: # Permute 267 | nb_list[i] = list(np.random.permutation(nb_list[i])) 268 | n_neighbors += len(nb_list[i]) 269 | if padding: # Padding with -1 270 | nb_list[i] = nb_list[i] + [-1] * (k - len(nb_list[i])) 271 | 272 | uf.logger.info( 273 | f"Average number of subgraph neighbors = {n_neighbors / g.number_of_nodes()}" 274 | ) 275 | return nb_list 276 | 277 | 278 | @uf.time_logger() 279 | @process_on_master_and_sync_by_pickle(cache_kwarg="cache_file") 280 | def find_top_k_neighbors_within_khop_ego_subgraph( 281 | g, score_mat, max_hops, k, padding, cache_file=None, ordered=True 282 | ): 283 | """ 284 | Parameters 285 | ---------- 286 | g : DGL Graph 287 | score_mat : csr_matrix 288 | [i, j] stands for the importance of j for node i 289 | max_hops : int 290 | max hops to construct 291 | k : int 292 | At most k neighbors are selected 293 | padding : bool 294 | Whether pad neighbors to k (by adding -1) 295 | ordered : bool 296 | Whether to keep the neighbors sorted by importance 297 | cache_file: str 298 | The temp cache file for multi-agent to save and load 299 | Returns 300 | ------- 301 | nb_list : list of list 302 | Sorted neighbors for each node 303 | """ 304 | if k == 0: # Save empty list if no neighbors 305 | uf.pickle_save([[] for _ in g.nodes()], cache_file) 306 | 307 | # Step 1: Use dgl.khop_graph to find max_hops neighbors for all nodes 308 | uf.logger.info(f"Start sorting top building PPR sorted neighbors within {max_hops} hop") 309 | src, dst = [_.numpy() for _ in g.edges()] # Init first hop 310 | for hop in range(2, max_hops + 1): # Start from 2 hop 311 | k_hop_g = k_hop_nb_graph(g, hop) 312 | new_src, new_dst = [_.numpy() for _ in k_hop_g.edges()] 313 | src = np.concatenate((src, new_src)) 314 | dst = np.concatenate((dst, new_dst)) 315 | 316 | # Step 2: Create a masked sparse importance matrix based on k_hop_g 317 | valid_indices = np.vstack((src, dst)) 318 | 319 | min_connect_prob = 1e-6 * np.ones( 320 | len(src) 321 | ) # Add minimum prob to neighbors within khop 322 | khop_connectedness = sp.csr_matrix( 323 | (min_connect_prob, valid_indices), shape=score_mat.shape 324 | ) 325 | nb_score = khop_connectedness + score_mat 326 | 327 | # Only neighbors within k-hop should be selected (Mask out PPR distant neighbors) 328 | data = np.array(nb_score[valid_indices[0], valid_indices[1]]) 329 | nb_score = sp.csr_matrix((data.reshape(-1), valid_indices), shape=score_mat.shape) 330 | 331 | # Remove self-loop from neighbor 332 | nb_score.setdiag(0) 333 | nb_score.eliminate_zeros() 334 | uf.logger.info(f"Created score matrix for ranking") 335 | 336 | # Step 3: Iterate through masked sparse matrix to get sorted neighbors 337 | nb_list = [] 338 | n_neighbors = 0 339 | for i in tqdm(range(nb_score.shape[0]), "Ranking neighbors"): 340 | row = nb_score.getrow(i) 341 | # Get top-k neighbors based on importance 342 | sorted_nb_indices = np.argsort(row.data)[:k] 343 | sorted_neighbors = row.indices[sorted_nb_indices].tolist() 344 | 345 | if not ordered: # Permute 346 | sorted_neighbors = list(np.random.permutation(sorted_neighbors)) 347 | n_neighbors += len(sorted_neighbors) 348 | 349 | if padding: # Padding with -1 350 | sorted_neighbors = sorted_neighbors + [-1] * (k - len(sorted_neighbors)) 351 | 352 | nb_list.append(sorted_neighbors) 353 | 354 | print(f"Average number of subgraph neighbors = {n_neighbors / g.number_of_nodes()}") 355 | uf.pickle_save(nb_list, cache_file) 356 | -------------------------------------------------------------------------------- /src/utils/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import dgl 5 | import numpy as np 6 | import pandas as pd 7 | import torch as th 8 | from dgl import save_graphs, load_graphs 9 | from dgl.data.utils import save_info, load_info 10 | from tqdm import tqdm 11 | 12 | import utils.basics as uf 13 | from utils.pkg.graph_utils import sample_nodes 14 | 15 | 16 | def plot_length_distribution(node_text, tokenizer, g): 17 | sampled_ids = np.random.permutation(g.nodes())[:10000] 18 | get_text = lambda n: node_text.iloc[n]['text'].tolist() 19 | tokenized = tokenizer(get_text(sampled_ids), padding='do_not_pad').data['input_ids'] 20 | node_text['text_length'] = node_text.apply(lambda x: len(x['text'].split(' ')), axis=1) 21 | pd.Series([len(_) for _ in tokenized]).hist(bins=20) 22 | import matplotlib.pyplot as plt 23 | plt.show() 24 | 25 | 26 | def load_ogb_graph_structure_only(ogb_name, raw_data_path, save_path='NA'): 27 | graph_path = os.path.join(save_path, 'dgl_graph.bin') 28 | info_path = os.path.join(save_path, 'graph_info.pkl') 29 | if save_path == 'NA' or not os.path.exists(save_path): 30 | from ogb.nodeproppred import DglNodePropPredDataset 31 | data = DglNodePropPredDataset(ogb_name, root=uf.init_path(raw_data_path)) 32 | g, labels = data[0] 33 | split_idx = data.get_idx_split() 34 | labels = labels.squeeze().numpy() 35 | if save_path is not None: 36 | # Save 37 | save_graphs(graph_path, g) 38 | info_dict = {'split_idx': split_idx, 'labels': labels, 'meta_info': data.meta_info} 39 | save_info(info_path, info_dict) 40 | else: 41 | g, info_dict = load_graphs(graph_path)[0][0], load_info(info_path) 42 | split_idx, labels = info_dict['split_idx'], info_dict['labels'] 43 | 44 | g.ndata['label'] = th.tensor(labels).to(int) 45 | return g, labels, split_idx 46 | 47 | 48 | def process_raw_arxiv(labels, mode, ogb_name, raw_data_path, raw_text_url, max_seq_len, 49 | processed_text_file, chunk_size=50000, _label_info=None, **kwargs): 50 | def merge_by_ids(meta_data, node_ids, label_info): 51 | meta_data.columns = ['node_id', "Title", "Abstract"] 52 | # meta_data.drop([0, meta_data.shape[0] - 1], axis=0, inplace=True) # Drop first and last in Arxiv full 53 | # dataset processing 54 | meta_data['node_id'] = meta_data['node_id'].astype(np.int64) 55 | meta_data.columns = ["mag_id", "title", "abstract"] 56 | data = pd.merge(node_ids, meta_data, how="left", on="mag_id") 57 | data = pd.merge(data, label_info, how="left", on="label_id") 58 | return data 59 | 60 | def read_ids_and_labels(): 61 | _ = f'{raw_data_path}{ogb_name.replace("-", "_")}/mapping/' 62 | category_path_csv = f"{_}labelidx2arxivcategeory.csv.gz" 63 | paper_id_path_csv = f"{_}nodeidx2paperid.csv.gz" # 64 | paper_ids = pd.read_csv(paper_id_path_csv) 65 | label_info = pd.read_csv(category_path_csv) 66 | paper_ids.columns = ['node_id', "mag_id"] 67 | label_info.columns = ["label_id", "label_raw_name"] 68 | paper_ids["label_id"] = labels[paper_ids['node_id']] 69 | label_info['label_raw_name'] = label_info.apply(lambda x: x['label_raw_name'].split('arxiv cs ')[1].upper(), 70 | axis=1) 71 | label_info['label_name'] = label_info.apply(lambda x: _label_info[x['label_raw_name']].split(' - ')[0], axis=1) 72 | label_info['label_alias'] = label_info.apply(lambda x: f"cs.{x['label_raw_name']}", axis=1) 73 | label_info['label_alias+name'] = label_info.apply(lambda x: f"{x['label_alias']} ({x['label_name']})", axis=1) 74 | label_info['label_description'] = label_info.apply(lambda x: _label_info[x['label_raw_name']], axis=1) 75 | return label_info, paper_ids # 返回类别和论文ID 76 | 77 | def process_raw_text_df(meta_data, node_ids, label_info): 78 | data = merge_by_ids(meta_data.dropna(), node_ids, label_info) 79 | data = data[~data['title'].isnull()] 80 | text_func = { 81 | 'TA': lambda x: f"Title: {x['title']}. Abstract: {x['abstract']}", 82 | 'T': lambda x: x['title'], 83 | } 84 | # Merge title and abstract 85 | data['text'] = data.apply(text_func[mode], axis=1) 86 | data['text'] = data.apply(lambda x: ' '.join(x['text'].split(' ')[:max_seq_len]), axis=1) 87 | return data 88 | 89 | from ogb.utils.url import download_url 90 | # Get Raw text path 91 | print(f'Loading raw text for {ogb_name}') 92 | raw_text_path = download_url(raw_text_url, raw_data_path) 93 | 94 | label_info, node_ids = read_ids_and_labels() 95 | df_list = [] 96 | for meta_data in tqdm(pd.read_table(raw_text_path, header=None, chunksize=chunk_size, skiprows=[0])): 97 | # Load part of the dataframe to prevent OOM. 98 | df_list.append(process_raw_text_df(meta_data, node_ids, label_info)) 99 | processed_df = pd.concat(df_list).sort_index() 100 | assert sum(processed_df.node_id == np.arange(len(labels))) == len(labels) 101 | uf.pickle_save((processed_df, label_info), processed_text_file) 102 | return processed_df 103 | -------------------------------------------------------------------------------- /src/utils/pkg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AndyJZhao/GraphText/73bde25b2fa9bf89b37b041062a4adfe42363652/src/utils/pkg/__init__.py -------------------------------------------------------------------------------- /src/utils/pkg/dict2xml.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import collections.abc 3 | import re 4 | 5 | 6 | def dict2xml(data, *args, **kwargs): 7 | """Return an xml string of a Python dict object.""" 8 | return Converter(*args, **kwargs).build(data) 9 | 10 | 11 | start_ranges = "|".join( 12 | "[{0}]".format(r) 13 | for r in [ 14 | "\xC0-\xD6", 15 | "\xD8-\xF6", 16 | "\xF8-\u02FF", 17 | "\u0370-\u037D", 18 | "\u037F-\u1FFF", 19 | "\u200C-\u200D", 20 | "\u2070-\u218F", 21 | "\u2C00-\u2FEF", 22 | "\u3001-\uD7FF", 23 | "\uF900-\uFDCF", 24 | "\uFDF0-\uFFFD", 25 | ] 26 | ) 27 | 28 | NameStartChar = re.compile(r"(:|[A-Z]|_|[a-z]|{0})".format(start_ranges)) 29 | NameChar = re.compile(r"(\-|\.|[0-9]|\xB7|[\u0300-\u036F]|[\u203F-\u2040])") 30 | 31 | 32 | ######################## 33 | ### NODE 34 | ######################## 35 | 36 | 37 | class Node(object): 38 | """ 39 | Represents each tag in the tree 40 | 41 | Each node has _either_ a single value or one or more children 42 | If it has a value: 43 | The serialized result is <%(tag)s>%(value)s 44 | 45 | If it has children: 46 | The serialized result is 47 | <%(wrap)s> 48 | %(children)s 49 | 50 | 51 | Which one it is depends on the implementation of self.convert 52 | """ 53 | 54 | # A mapping of characters to treat as escapable entities and their replacements 55 | # ! We don't want them to replace entity 56 | entities = [("&", "&"), ("<", "<"), (">", ">")] 57 | 58 | def __init__( 59 | self, wrap="", tag="", data=None, iterables_repeat_wrap=True, closed_tags_for=None 60 | ): 61 | self.tag = self.sanitize_element(tag) 62 | self.wrap = self.sanitize_element(wrap) 63 | self.data = data 64 | self.type = self.determine_type() 65 | self.closed_tags_for = closed_tags_for 66 | self.iterables_repeat_wrap = iterables_repeat_wrap 67 | 68 | # if self.type == "flat" and isinstance(self.data, str): 69 | # # Make sure we deal with entities 70 | # for entity, replacement in self.entities: 71 | # self.data = self.data.replace(entity, replacement) 72 | 73 | def serialize(self, indenter): 74 | """Returns the Node serialized as an xml string""" 75 | # Determine the start and end of this node 76 | wrap = self.wrap 77 | end, start = "", "" 78 | if wrap: 79 | end = "".format(wrap) 80 | start = "<{0}>".format(wrap) 81 | 82 | if self.closed_tags_for and self.data in self.closed_tags_for: 83 | return "<{0}/>".format(self.wrap) 84 | 85 | # Convert the data attached in this node into a value and children 86 | value, children = self.convert() 87 | 88 | # Determine the content of the node (essentially the children as a string value) 89 | content = "" 90 | if children: 91 | if self.type != "iterable": 92 | # Non-iterable wraps all it's children in the same tag 93 | content = indenter((c.serialize(indenter) for c in children), wrap) 94 | else: 95 | if self.iterables_repeat_wrap: 96 | # Iterables repeat the wrap for each child 97 | result = [] 98 | for c in children: 99 | content = c.serialize(indenter) 100 | if c.type == "flat": 101 | # Child with value, it already is surrounded by the tag 102 | result.append(content) 103 | else: 104 | # Child with children of it's own, they need to be wrapped by start and end 105 | content = indenter([content], True) 106 | result.append("".join((start, content, end))) 107 | 108 | # We already have what we want, return the indented result 109 | return indenter(result, False) 110 | else: 111 | result = [] 112 | for c in children: 113 | result.append(c.serialize(indenter)) 114 | return "".join([start, indenter(result, True), end]) 115 | 116 | # If here, either: 117 | # * Have a value 118 | # * Or this node is not an iterable 119 | return "".join((start, value, content, end)) 120 | 121 | def determine_type(self): 122 | """ 123 | Return the type of the data on this node as an identifying string 124 | 125 | * Iterable : Supports "for item in data" 126 | * Mapping : Supports "for key in data: value = data[key]" 127 | * flat : A string or something that isn't iterable or a mapping 128 | """ 129 | data = self.data 130 | if isinstance(data, str): 131 | return "flat" 132 | elif isinstance(data, collections.abc.Mapping): 133 | return "mapping" 134 | elif isinstance(data, collections.abc.Iterable): 135 | return "iterable" 136 | else: 137 | return "flat" 138 | 139 | def convert(self): 140 | """ 141 | Convert data on this node into a (value, children) tuple depending on the type of the data 142 | If the type is : 143 | * flat : Use self.tag to surround the value. value 144 | * mapping : Return a list of tags where the key for each child is the wrap for that node 145 | * iterable : Return a list of Nodes where self.wrap is the tag for that node 146 | """ 147 | val = "" 148 | typ = self.type 149 | data = self.data 150 | children = [] 151 | 152 | if typ == "mapping": 153 | sorted_data = data 154 | if not isinstance(data, collections.OrderedDict): 155 | sorted_data = sorted(data) 156 | 157 | for key in sorted_data: 158 | item = data[key] 159 | children.append( 160 | Node( 161 | key, 162 | "", 163 | item, 164 | iterables_repeat_wrap=self.iterables_repeat_wrap, 165 | closed_tags_for=self.closed_tags_for, 166 | ) 167 | ) 168 | 169 | elif typ == "iterable": 170 | for item in data: 171 | children.append( 172 | Node( 173 | "", 174 | self.wrap, 175 | item, 176 | iterables_repeat_wrap=self.iterables_repeat_wrap, 177 | closed_tags_for=self.closed_tags_for, 178 | ) 179 | ) 180 | 181 | else: 182 | val = str(data) 183 | if self.tag: 184 | val = "<{0}>{1}".format(self.tag, val, self.tag) 185 | 186 | return val, children 187 | 188 | @staticmethod 189 | def sanitize_element(wrap): 190 | """ 191 | Convert `wrap` into a valid tag name applying the xml Naming Rules. 192 | 193 | * Names can contain letters, numbers, and other characters 194 | * Names cannot start with a number or punctuation character 195 | * Names cannot start with the letters xml (or xml, or Xml, etc) 196 | * Names cannot contain spaces 197 | * Any name can be used, no words are reserved. 198 | 199 | :ref: http://www.w3.org/TR/REC-xml/#NT-NameChar 200 | """ 201 | if wrap and isinstance(wrap, str): 202 | if wrap.lower().startswith("xml"): 203 | wrap = "_" + wrap 204 | return "".join( 205 | ["_" if not NameStartChar.match(wrap) else ""] 206 | + ["_" if not (NameStartChar.match(c) or NameChar.match(c)) else c for c in wrap] 207 | ) 208 | else: 209 | return wrap 210 | 211 | 212 | ######################## 213 | ### CONVERTER 214 | ######################## 215 | 216 | 217 | class Converter(object): 218 | """Logic for creating a Node tree and serialising that tree into a string""" 219 | 220 | def __init__(self, wrap=None, indent=" ", newlines=True): 221 | """ 222 | wrap: The tag that the everything else will be contained within 223 | indent: The string that is multiplied at the start of each new line, to represent each level of nesting 224 | newlines: A boolean specifying whether we want each tag on a new line. 225 | 226 | Note that indent only works if newlines is True 227 | """ 228 | self.wrap = wrap 229 | self.indent = indent 230 | self.newlines = newlines 231 | 232 | def _make_indenter(self): 233 | """Returns a function that given a list of strings, will return that list as a single, indented, string""" 234 | indent = self.indent 235 | newlines = self.newlines 236 | 237 | if not newlines: 238 | # No newlines, don't care about indentation 239 | ret = lambda nodes, wrapped: "".join(nodes) 240 | else: 241 | if not indent: 242 | indent = "" 243 | 244 | def eachline(nodes): 245 | """Yield each line in each node""" 246 | for node in nodes: 247 | for line in node.split("\n"): 248 | yield line 249 | 250 | def ret(nodes, wrapped): 251 | """ 252 | Indent nodes depending on value of wrapped and indent 253 | If not wrapped, then don't indent 254 | Otherwise, 255 | Seperate each child by a newline 256 | and indent each line in the child by one indent unit 257 | """ 258 | if wrapped: 259 | seperator = "\n{0}".format(indent) 260 | surrounding = "\n{0}{{0}}\n".format(indent) 261 | else: 262 | seperator = "\n" 263 | surrounding = "{0}" 264 | return surrounding.format(seperator.join(eachline(nodes))) 265 | 266 | return ret 267 | 268 | def build(self, data, iterables_repeat_wrap=True, closed_tags_for=None): 269 | """Create a Node tree from the data and return it as a serialized xml string""" 270 | indenter = self._make_indenter() 271 | return Node( 272 | wrap=self.wrap, 273 | data=data, 274 | iterables_repeat_wrap=iterables_repeat_wrap, 275 | closed_tags_for=closed_tags_for, 276 | ).serialize(indenter) 277 | -------------------------------------------------------------------------------- /src/utils/pkg/distributed.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from collections import defaultdict 4 | from functools import wraps 5 | 6 | import torch as th 7 | from torch import distributed as dist 8 | 9 | from utils.basics import pickle_load, logger 10 | 11 | cpu_group = None 12 | gpu_group = None 13 | 14 | 15 | def initialize_deepspeed(cfg): 16 | cfg.use_deepspeed = cfg.get('use_deepspeed', False) and th.cuda.is_available() 17 | cfg.use_fp16 = cfg.use_fp16 if cfg.use_deepspeed else False 18 | cfg.use_bf16 = cfg.use_bf16 if cfg.use_deepspeed else False 19 | if cfg.use_deepspeed: 20 | import deepspeed 21 | deepspeed.init_distributed(dist_backend='nccl') 22 | 23 | 24 | def initialize_distributed(cfg): 25 | cfg.is_distributed = False 26 | cfg.world_size = int(os.getenv('WORLD_SIZE', '1')) 27 | if th.cuda.is_available(): 28 | cfg.master_ip = os.getenv('MASTER_ADDR', 'localhost') 29 | cfg.master_port = os.getenv('MASTER_PORT', '6000') 30 | cfg.local_rank = int(os.getenv('RANK', '0')) % th.cuda.device_count() 31 | device = cfg.local_rank % th.cuda.device_count() 32 | th.cuda.set_device(device) 33 | cfg.is_distributed = th.cuda.device_count() > 1 34 | else: 35 | cfg.local_rank = 0 36 | 37 | 38 | def process_on_master_and_sync_by_pickle(cache_arg=None, cache_kwarg=None, log_func=logger.info): 39 | def decorator(func): 40 | @wraps(func) 41 | def wrapper(*args, **kwargs): 42 | if cache_kwarg is not None: 43 | filename = kwargs[cache_kwarg] 44 | elif cache_arg is not None: 45 | filename = args[cache_arg] 46 | else: 47 | log_func('No cache file specified') 48 | skip_cache = kwargs.pop('skip_cache', False) 49 | if not os.path.exists(filename) or skip_cache: 50 | if get_rank() == 0: # Master process 51 | func(*args, **kwargs) 52 | else: 53 | if get_rank() == 0: # Master process 54 | log_func(f'Loaded cache {filename}, skipped {func.__name__}') 55 | return pickle_load(filename) 56 | synchronize() 57 | assert os.path.exists(filename), f'The {filename} must be saved in the {func.__name__}' 58 | return pickle_load(filename) 59 | 60 | return wrapper 61 | 62 | return decorator 63 | 64 | 65 | def master_process_only(func): 66 | @wraps(func) 67 | def wrapper(*args, **kwargs): 68 | if get_rank() == 0: 69 | result = func(*args, **kwargs) 70 | return result 71 | synchronize() 72 | 73 | return wrapper 74 | 75 | 76 | def get_rank(): 77 | """ 78 | Get the rank of this process in distributed processes. 79 | 80 | Return 0 for single process case. 81 | """ 82 | if dist.is_initialized(): 83 | return dist.get_rank() 84 | if "RANK" in os.environ: 85 | return int(os.environ["RANK"]) 86 | return 0 87 | 88 | 89 | def get_world_size(): 90 | """ 91 | Get the total number of distributed processes. 92 | 93 | Return 1 for single process case. 94 | """ 95 | if dist.is_initialized(): 96 | return dist.get_world_size() 97 | if "WORLD_SIZE" in os.environ: 98 | return int(os.environ["WORLD_SIZE"]) 99 | return 1 100 | 101 | 102 | def get_group(device): 103 | """ 104 | Get the process group corresponding to the given device. 105 | 106 | Parameters: 107 | device (th.device): query device 108 | """ 109 | group = cpu_group if device.type == "cpu" else gpu_group 110 | if group is None: 111 | raise ValueError("%s group is not initialized. Use comm.init_process_group() to initialize it" 112 | % device.type.upper()) 113 | return group 114 | 115 | 116 | def init_process_group(backend, init_method=None, **kwargs): 117 | """ 118 | Initialize CPU and/or GPU process groups. 119 | 120 | Parameters: 121 | backend (str): Communication backend. Use ``nccl`` for GPUs and ``gloo`` for CPUs. 122 | init_method (str, optional): URL specifying how to initialize the process group 123 | """ 124 | global cpu_group 125 | global gpu_group 126 | 127 | dist.init_process_group(backend, init_method, **kwargs) 128 | gpu_group = dist.group.WORLD 129 | if backend == "nccl": 130 | cpu_group = dist.new_group(backend="gloo") 131 | else: 132 | cpu_group = gpu_group 133 | 134 | 135 | def get_cpu_count(): 136 | """ 137 | Get the number of CPUs on this node. 138 | """ 139 | return multiprocessing.cpu_count() 140 | 141 | 142 | def synchronize(): 143 | """ 144 | Synchronize among all distributed processes. 145 | """ 146 | if get_world_size() > 1: 147 | dist.barrier() 148 | 149 | 150 | def _recursive_read(obj): 151 | values = defaultdict(list) 152 | sizes = defaultdict(list) 153 | if isinstance(obj, th.Tensor): 154 | values[obj.dtype] += [obj.flatten()] 155 | sizes[obj.dtype] += [th.tensor([obj.numel()], device=obj.device)] 156 | elif isinstance(obj, dict): 157 | for v in obj.values(): 158 | child_values, child_sizes = _recursive_read(v) 159 | for k, v in child_values.items(): 160 | values[k] += v 161 | for k, v in child_sizes.items(): 162 | sizes[k] += v 163 | elif isinstance(obj, list) or isinstance(obj, tuple): 164 | for v in obj: 165 | child_values, child_sizes = _recursive_read(v) 166 | for k, v in child_values.items(): 167 | values[k] += v 168 | for k, v in child_sizes.items(): 169 | sizes[k] += v 170 | else: 171 | raise ValueError("Unknown type `%s`" % type(obj)) 172 | return values, sizes 173 | 174 | 175 | def _recursive_write(obj, values, sizes=None): 176 | if isinstance(obj, th.Tensor): 177 | if sizes is None: 178 | size = th.tensor([obj.numel()], device=obj.device) 179 | else: 180 | s = sizes[obj.dtype] 181 | size, s = s.split([1, len(s) - 1]) 182 | sizes[obj.dtype] = s 183 | v = values[obj.dtype] 184 | new_obj, v = v.split([size, v.shape[-1] - size], dim=-1) 185 | # compatible with reduce / stack / cat 186 | new_obj = new_obj.view(new_obj.shape[:-1] + (-1,) + obj.shape[1:]) 187 | values[obj.dtype] = v 188 | return new_obj, values 189 | elif isinstance(obj, dict): 190 | new_obj = {} 191 | for k, v in obj.items(): 192 | new_obj[k], values = _recursive_write(v, values, sizes) 193 | elif isinstance(obj, list) or isinstance(obj, tuple): 194 | new_obj = [] 195 | for v in obj: 196 | new_v, values = _recursive_write(v, values, sizes) 197 | new_obj.append(new_v) 198 | else: 199 | raise ValueError("Unknown type `%s`" % type(obj)) 200 | return new_obj, values 201 | 202 | 203 | def reduce(obj, op="sum", dst=None): 204 | """ 205 | Reduce any nested container of tensors. 206 | 207 | Parameters: 208 | obj (Object): any container object. Can be nested list, tuple or dict. 209 | op (str, optional): element-wise reduction operator. 210 | Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. 211 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 212 | 213 | Example:: 214 | 215 | >>> # assume 4 workers 216 | >>> rank = comm.get_rank() 217 | >>> x = th.rand(5) 218 | >>> obj = {"polynomial": x ** rank} 219 | >>> obj = comm.reduce(obj) 220 | >>> assert th.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1) 221 | """ 222 | values = _recursive_read(obj)[0] 223 | values = {k: th.cat(v) for k, v in values.items()} 224 | 225 | is_mean = op == "mean" 226 | if is_mean: 227 | op = "sum" 228 | op = getattr(dist.ReduceOp, op.upper()) 229 | 230 | reduced = {} 231 | for k, v in values.items(): 232 | dtype = v.dtype 233 | # NCCL can't solve bool. Cast them to byte 234 | if dtype == th.bool: 235 | v = v.byte() 236 | group = get_group(v.device) 237 | if dst is None: 238 | dist.all_reduce(v, op=op, group=group) 239 | else: 240 | dist.reduce(v, op=op, dst=dst, group=group) 241 | if is_mean: 242 | v = v / get_world_size() 243 | reduced[k] = v.type(dtype) 244 | 245 | return _recursive_write(obj, reduced)[0] 246 | 247 | 248 | def stack(obj, dst=None): 249 | """ 250 | Stack any nested container of tensors. The new dimension will be added at the 0-th axis. 251 | 252 | Parameters: 253 | obj (Object): any container object. Can be nested list, tuple or dict. 254 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 255 | 256 | Example:: 257 | 258 | >>> # assume 4 workers 259 | >>> rank = comm.get_rank() 260 | >>> x = th.rand(5) 261 | >>> obj = {"exponent": x ** rank} 262 | >>> obj = comm.stack(obj) 263 | >>> truth = th.stack([th.ones_like(x), x, x ** 2, x ** 3] 264 | >>> assert th.allclose(obj["exponent"], truth)) 265 | """ 266 | values = _recursive_read(obj)[0] 267 | values = {k: th.cat(v) for k, v in values.items()} 268 | 269 | stacked = {} 270 | for k, v in values.items(): 271 | dtype = v.dtype 272 | # NCCL can't solve bool. Cast them to byte 273 | if dtype == th.bool: 274 | dtype = th.uint8 275 | s = th.zeros(get_world_size(), *v.shape, dtype=dtype, device=v.device) 276 | s[get_rank()] = v 277 | group = get_group(s.device) 278 | if dst is None: 279 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) 280 | else: 281 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) 282 | stacked[k] = s.type(v.dtype) 283 | 284 | return _recursive_write(obj, stacked)[0] 285 | 286 | 287 | def cat(obj, dst=None): 288 | """ 289 | Concatenate any nested container of tensors along the 0-th axis. 290 | 291 | Parameters: 292 | obj (Object): any container object. Can be nested list, tuple or dict. 293 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 294 | 295 | Example:: 296 | 297 | >>> # assume 4 workers 298 | >>> rank = comm.get_rank() 299 | >>> rng = th.arange(10) 300 | >>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]} 301 | >>> obj = comm.cat(obj) 302 | >>> assert th.allclose(obj["range"], rng) 303 | """ 304 | values, sizes = _recursive_read(obj) 305 | sizes = {k: th.cat(v) for k, v in sizes.items()} 306 | 307 | sizes = stack(sizes) 308 | cated = {} 309 | for k, value in values.items(): 310 | size = sizes[k].t().flatten() # sizes[k]: (num_worker, num_obj) 311 | dtype = value[0].dtype 312 | # NCCL can't solve bool. Cast them to byte 313 | if dtype == th.bool: 314 | dtype = th.uint8 315 | s = th.zeros(size.sum(), dtype=dtype, device=value[0].device) 316 | obj_id = get_rank() 317 | world_size = get_world_size() 318 | offset = size[:obj_id].sum() 319 | for v in value: 320 | assert offset + v.numel() <= len(s) 321 | s[offset: offset + v.numel()] = v 322 | offset += size[obj_id: obj_id + world_size].sum() 323 | obj_id += world_size 324 | group = get_group(s.device) 325 | if dst is None: 326 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) 327 | else: 328 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) 329 | cated[k] = s.type(value[0].dtype) 330 | sizes = {k: v.sum(dim=0) for k, v in sizes.items()} 331 | 332 | return _recursive_write(obj, cated, sizes)[0] 333 | -------------------------------------------------------------------------------- /src/utils/pkg/graph_utils.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import numpy 3 | import numpy as np 4 | import torch as th 5 | 6 | from utils.basics import init_random_state, time_logger, logger, pickle_save 7 | from utils.pkg.distributed import process_on_master_and_sync_by_pickle 8 | import scipy.sparse as sp 9 | from tqdm import tqdm 10 | from sklearn.metrics.pairwise import cosine_similarity 11 | from collections import defaultdict 12 | 13 | 14 | def k_hop_nb_graph(g, k): 15 | return dgl.to_simple(dgl.remove_self_loop(dgl.khop_graph(g, k))) 16 | 17 | 18 | def sample_nodes(g, subset_nodes, fanout_list, to_numpy=True): 19 | # E.g. fanout_list = [2, 2, 1] ->> 2 first-hop 20 | # subset_nodes = th.tensor(subset_nodes).to(g.device) if isinstance(subset_nodes, int) else subset_nodes 21 | subset_nodes = th.tensor(subset_nodes) 22 | induced_nodes = {0: (cur_nodes := subset_nodes.view(-1))} 23 | init_random_state(0) 24 | for l, fanout in enumerate(fanout_list): 25 | frontier = dgl.sampling.sample_neighbors(g, cur_nodes, fanout) 26 | cur_nodes = frontier.edges()[0].unique() 27 | induced_nodes[l + 1] = cur_nodes 28 | sampled_nodes = th.cat(list(induced_nodes.values())).unique() 29 | if to_numpy: 30 | sampled_nodes = sampled_nodes.cpu().numpy() 31 | induced_nodes = {hop: hop_nodes.cpu().numpy() for hop, hop_nodes in induced_nodes.items()} 32 | return sampled_nodes, induced_nodes 33 | 34 | 35 | def get_neighbors_within_k_hop(graph, node_id, k, remove_center_node=False): 36 | """ 37 | Function to get the neighbors within k-hop for a given node in the graph. 38 | 39 | Parameters: 40 | graph (dgl.DGLGraph): The input graph. 41 | node_id (int): The ID of the target node. 42 | k (int): The number of hops to consider. 43 | 44 | Returns: 45 | set: A set of node IDs representing neighbors within k-hop. 46 | """ 47 | # Use dgl.k_hop_subgraph to get the subgraph within k-hop 48 | neighbors_khop_in = dgl.khop_in_subgraph(graph, node_id, k)[0] 49 | neighbors_khop_out = dgl.khop_in_subgraph(graph, node_id, k)[0] 50 | 51 | # Get the nodes in the subgraph and add them to the set 52 | neighbors_within_k_in = set(neighbors_khop_in.ndata[dgl.NID].numpy().tolist()) 53 | neighbors_within_k_out = set(neighbors_khop_out.ndata[dgl.NID].numpy().tolist()) 54 | neighbors_within_k = neighbors_within_k_in | neighbors_within_k_out 55 | 56 | # Remove the target node from the set as it is not considered a neighbor 57 | if remove_center_node: 58 | neighbors_within_k.remove(node_id) 59 | 60 | return np.array(list(neighbors_within_k)) 61 | 62 | 63 | def get_edge_set(g: dgl.DGLGraph): 64 | """graph_edge_to list of (row_id, col_id) tuple 65 | """ 66 | 67 | return set(map(tuple, np.column_stack([_.cpu().numpy() for _ in g.edges()]).tolist())) 68 | 69 | 70 | def edge_set_to_inds(edge_set): 71 | """ Unpack edge set to row_ids, col_ids""" 72 | return list(map(list, zip(*edge_set))) 73 | 74 | 75 | def get_spd_by_sp_matrix(spd_sp_mat, i, j): 76 | # ! Note that the default value of a sp matrix is always zero 77 | # ! which is conflict with the self-loop spd (0) 78 | if i == j: # Self loop 79 | return 0 80 | elif spd_sp_mat[i, j] == 0: # Out of max hop 81 | return - 1 82 | else: 83 | return spd_sp_mat[i, j] 84 | 85 | 86 | @time_logger() 87 | @process_on_master_and_sync_by_pickle(cache_kwarg="cache_file") 88 | def get_spd_matrices(g: dgl.DGLGraph, max_hops, cache_file=None): 89 | # ! Calculate SPD at scale (supports OGB data) 90 | # Initialize the CSR sparse matrix with zeros 91 | sp_mat_shape = (g.number_of_nodes(), g.number_of_nodes()) 92 | # Residue matrix stores the residue to unreachable, i.e. RESIDUE = MAX_HOPS + 1 - SPD (hop) 93 | residue_mat = sp.csr_matrix(([], ([], [])), shape=sp_mat_shape, dtype=np.int64) 94 | 95 | for hop in tqdm(range(max_hops, 0, -1), 'building SPD matrices'): 96 | new_src, new_dst = k_hop_nb_graph(g, hop).edges() 97 | new_indices = np.vstack((new_src.numpy(), new_dst.numpy())) 98 | new_residue = sp.csr_matrix((np.full(new_src.shape, 1, dtype=np.int64), new_indices), shape=sp_mat_shape) 99 | new_residue.data.fill(max_hops + 1 - hop) 100 | 101 | # Add the new CSR matrix to the final CSR matrix 102 | residue_mat = residue_mat.maximum(new_residue) 103 | # SPD = MAX_HOPS + 1 - RESIDUE 104 | spd_mat = residue_mat.copy() 105 | spd_mat.data = max_hops + 1 - residue_mat.data 106 | 107 | # ! Convert to SPD neighbor list dictionary 108 | spd_nb_list = defaultdict(list) 109 | spd_nb_list[0] = [[n] for n in range(g.num_nodes())] 110 | # Iterate through the rows 111 | for row in tqdm(range(spd_mat.shape[0]), 'building SPD neighbors'): 112 | start_idx = spd_mat.indptr[row] 113 | end_idx = spd_mat.indptr[row + 1] 114 | 115 | row_cols = spd_mat.indices[start_idx:end_idx] 116 | row_data = spd_mat.data[start_idx:end_idx] 117 | 118 | row_dict = {k: [] for k in range(1, max_hops + 1)} 119 | 120 | for col, value in zip(row_cols, row_data): 121 | row_dict[value].append(col) 122 | 123 | for value, positions in row_dict.items(): 124 | spd_nb_list[value].append(positions) 125 | 126 | pickle_save((spd_mat, spd_nb_list), cache_file) 127 | 128 | 129 | def k_hop_nb_graph(g, k): 130 | return dgl.remove_self_loop(dgl.khop_graph(g, k)) 131 | 132 | 133 | def get_sparse_numpy_adj(g): 134 | row, col = dgl.to_bidirected(g).edges() 135 | return sp.coo_matrix( 136 | (np.ones(len(row)), (row.numpy(), col.numpy())), 137 | shape=(g.num_nodes(), g.num_nodes()), 138 | ) 139 | 140 | 141 | def get_propagated_feature(g, x, k): 142 | # Compute the cosine similarity matrix 143 | if isinstance(x, th.Tensor): 144 | x = x.cpu().numpy() 145 | adj = get_sparse_numpy_adj(g).toarray() 146 | for _ in range(1, k + 1): 147 | x = adj @ x 148 | return x 149 | 150 | 151 | @process_on_master_and_sync_by_pickle(cache_kwarg="cache_file") 152 | @time_logger() 153 | def get_pairwise_topk_sim_mat_scipy(x, k=20, cache_file=None): # Preserve at most 20 neighbors 154 | # Set diagonal and zero-values to a very negative number 155 | sim_mat = cosine_similarity(x) 156 | np.fill_diagonal(sim_mat, -float('inf')) 157 | # Find the top-k similar graph 158 | nb_list = [] 159 | for i in tqdm(range(sim_mat.shape[0]), desc=f'Building top-{k} similarity graph'): 160 | nonzero_indices = np.where(sim_mat[i] > 0)[0] 161 | nonzero_values = sim_mat[i][nonzero_indices] 162 | # Sort the non-zero values in descending order and get the top-k 163 | sorted_nonzero_indices = np.argsort(-nonzero_values)[:k] 164 | 165 | # Map it back to the original indices 166 | selected = nonzero_indices[sorted_nonzero_indices].tolist() 167 | nb_list.append(selected) 168 | pickle_save(nb_list, cache_file) 169 | 170 | 171 | @process_on_master_and_sync_by_pickle(cache_kwarg="cache_file") 172 | @time_logger() 173 | def get_pairwise_topk_sim_mat_chunkdot(x, k=20, max_mem_in_gb=5, cache_file=None): # Preserve at most 20 neighbors 174 | from chunkdot import cosine_similarity_top_k 175 | # Set diagonal and zero-values to a very negative number 176 | sim_mat = cosine_similarity_top_k(x, top_k=k + 1, max_memory=max_mem_in_gb * 1e9, show_progress=True) 177 | sim_mat.setdiag(-float('inf')) 178 | nb_list = [] 179 | for row in tqdm(range(sim_mat.shape[0]), f'building similarity to {cache_file}'): 180 | start_idx = sim_mat.indptr[row] 181 | end_idx = sim_mat.indptr[row + 1] 182 | 183 | row_cols = sim_mat.indices[start_idx:end_idx] 184 | row_data = sim_mat.data[start_idx:end_idx] 185 | 186 | # Sort the non-zero values in descending order and get the top-k 187 | sorted_nonzero_indices = np.argsort(-row_data)[:k + 1].tolist() 188 | # Map it back to the original indices 189 | selected = row_cols[sorted_nonzero_indices[:k]] 190 | nb_list.append(selected.tolist()) 191 | 192 | pickle_save(nb_list, cache_file) 193 | -------------------------------------------------------------------------------- /src/utils/pkg/hf_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from huggingface_hub import snapshot_download 4 | from transformers import AutoModel, AutoTokenizer, AutoConfig 5 | 6 | from utils.basics import logger, init_path, time_logger 7 | from utils.pkg.distributed import master_process_only 8 | 9 | @time_logger() 10 | @master_process_only 11 | def download_hf_ckpt_to_local(hf_name, local_dir): 12 | hf_token = os.environ.get('HF_ACCESS_TOKEN', False) 13 | if not os.path.exists(f'{local_dir}config.json'): 14 | logger.critical(f'Downloading {hf_name} ckpt to {local_dir}') 15 | # Resolves Proxy error: https://github.com/huggingface/transformers/issues/17611 16 | os.environ['CURL_CA_BUNDLE'] = '' 17 | snapshot_download(repo_id=hf_name, local_dir=init_path(local_dir), token=hf_token) 18 | 19 | 20 | def load_hf_auto_model_and_tokenizer(hf_name, local_dir): 21 | download_hf_ckpt_to_local(hf_name, local_dir) 22 | bert = AutoModel.from_pretrained(local_dir) 23 | tokenizer = AutoTokenizer.from_pretrained(local_dir) 24 | model_cfg = AutoConfig.from_pretrained(local_dir) 25 | return bert, tokenizer, model_cfg 26 | -------------------------------------------------------------------------------- /src/utils/project/__init__.py: -------------------------------------------------------------------------------- 1 | from .exp import * 2 | -------------------------------------------------------------------------------- /src/utils/project/exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from uuid import uuid4 4 | 5 | import numpy as np 6 | import torch 7 | import wandb 8 | from omegaconf import OmegaConf 9 | from torch import distributed as dist 10 | 11 | from utils.basics import init_env_variables, save_cfg, WandbExpLogger, print_important_cfg, init_path, \ 12 | get_important_cfg, logger 13 | from utils.pkg.distributed import get_rank, get_world_size, init_process_group 14 | 15 | proj_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0] 16 | PROJ_CONFIG_FILE = 'config/proj.yaml' 17 | 18 | 19 | def set_seed(seed): 20 | # dgl.seed(seed) 21 | # dgl.random.seed(seed) 22 | np.random.seed(seed) 23 | os.environ['PYTHONHASHSEED'] = str(seed) 24 | torch.manual_seed(seed + get_rank()) 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | 32 | def device_init(gpus): 33 | import torch as th 34 | device = th.device('cpu') 35 | if gpus != '-1' and th.cuda.is_available(): # GPU 36 | if get_rank() >= 0: # DDP 37 | th.cuda.set_device(get_rank()) 38 | device = th.device(get_rank()) 39 | else: # Single GPU 40 | device = th.device("cuda:0") 41 | return device 42 | 43 | 44 | def generate_unique_id(cfg): 45 | """Generate a Unique ID (UID) for (1) File system (2) Communication between submodules 46 | By default, we use time and UUID4 as UID. UIDs could be overwritten by wandb or UID specification. 47 | """ 48 | # 49 | if cfg.get('uid') is not None and cfg.wandb.id is not None: 50 | assert cfg.get('uid') == cfg.wandb.id, 'Confliction: Wandb and uid mismatch!' 51 | cur_time = datetime.now().strftime("%b%-d-%-H:%M-") 52 | given_uid = cfg.wandb.id or cfg.get('uid') 53 | uid = given_uid if given_uid else cur_time + str(uuid4()).split('-')[0] 54 | return uid 55 | 56 | 57 | def init_experiment(cfg): 58 | OmegaConf.set_struct(cfg, False) # Prevent ConfigKeyError when accessing non-existing keys 59 | cfg = init_env_variables(cfg) # Update environment args defined in cfg 60 | wandb_init(cfg) 61 | set_seed(cfg.seed) 62 | world_size = get_world_size() 63 | if world_size > 1 and not dist.is_initialized(): 64 | # init_process_group("nccl", init_method="proj://") 65 | init_process_group("nccl", init_method="env://") 66 | 67 | # In mplm working directory is initialized by mplm and shared by LM and GNN submodules. 68 | cfg.uid = generate_unique_id(cfg) 69 | init_path([cfg.out_dir, cfg.working_dir]) 70 | cfg_out_file = cfg.out_dir + 'hydra_cfg.yaml' 71 | save_cfg(cfg, cfg_out_file, as_global=True) 72 | # Add global attribute to reproduce hydra configs at ease. 73 | cfg.local_rank = get_rank() 74 | _logger = WandbExpLogger(cfg) 75 | _logger.save_file_to_wandb(cfg_out_file, base_path=cfg.out_dir, policy='now') 76 | _logger.info(f'Local_rank={cfg.local_rank}, working_dir = {cfg.working_dir}') 77 | print_important_cfg(cfg, _logger.debug) 78 | return cfg, _logger 79 | 80 | 81 | def wandb_init(cfg) -> None: 82 | os.environ["WANDB_WATCH"] = "false" 83 | if cfg.get('use_wandb', False) and get_rank() <= 0: 84 | try: 85 | WANDB_DIR, WANDB_PROJ, WANDB_ENTITY = ( 86 | cfg.env.vars[k.lower()] for k in ['WANDB_DIR', 'WANDB_PROJ', 'WANDB_ENTITY']) 87 | wandb_dir = os.path.join(proj_path, WANDB_DIR) 88 | 89 | # ! Create wandb session 90 | if cfg.wandb.id is None: 91 | # First time running, create new wandb 92 | init_path([wandb_dir, cfg.get('wandb_cache_dir', '')]) 93 | wandb.init(project=WANDB_PROJ, entity=WANDB_ENTITY, dir=wandb_dir, 94 | reinit=True, config=get_important_cfg(cfg), name=cfg.wandb.name) 95 | else: # Resume from previous 96 | logger.critical(f'Resume from previous wandb run {cfg.wandb.id}') 97 | wandb.init(project=WANDB_PROJ, entity=WANDB_ENTITY, reinit=True, 98 | resume='must', id=cfg.wandb.id) 99 | cfg.wandb.id, cfg.wandb.name, cfg.wandb.sweep_id = wandb.run.id, wandb.run.name, wandb.run.sweep_id 100 | cfg.wandb_on = True 101 | return 102 | except Exception as e: 103 | # Code to run if an exception is raised 104 | logger.critical(f"An error occurred during wandb initialization: {e}\n'WANDB NOT INITIALIZED.'") 105 | os.environ["WANDB_DISABLED"] = "true" 106 | cfg.wandb_on = False 107 | return 108 | --------------------------------------------------------------------------------