├── cgm ├── utils │ ├── __init__.py │ ├── metrics.py │ ├── common_utils.py │ ├── loss.py │ ├── arguments.py │ └── train_utils.py ├── launch │ ├── zero3.sh │ └── zero2.sh ├── config │ └── template.json ├── inference │ └── layer.py ├── data │ ├── encode.py │ └── preprocess.py ├── train │ └── train.py ├── modeling │ └── cgm.py └── models │ └── qwen2 │ └── _4_46_1 │ └── modeling_attn_mask_utils.py ├── assets ├── pipeline.png ├── cgm_method_v3.png ├── SWE-Bench-Lite.png ├── framework_1126.jpeg ├── cgm_framework_0123.png ├── swe-bench-lite-ow.png ├── swe-bench-modified.png └── github-codefuse-logo-update.jpg ├── reranker ├── codegraph_parser │ ├── java │ │ └── __pycache__ │ │ │ └── codegraph_java_local.cpython-38.pyc │ └── python │ │ └── __pycache__ │ │ └── codegraph_python_local.cpython-38.pyc ├── qwen_api.py ├── prompt.py └── reranker.py ├── LEGAL.md ├── preprocess_embedding ├── generate_rewriter_embedding.py ├── generate_code_content.py └── generate_code_embedding.py ├── retriever ├── utils.py ├── serialize_subgraph.py ├── locate_anchor_node.py └── subgraph.py ├── rewriter ├── inference_rewriter.py ├── prompt.py ├── rewriter_output_post_process.py └── generate_rewriter_prompt.py └── README.md /cgm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /assets/cgm_method_v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/cgm_method_v3.png -------------------------------------------------------------------------------- /assets/SWE-Bench-Lite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/SWE-Bench-Lite.png -------------------------------------------------------------------------------- /assets/framework_1126.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/framework_1126.jpeg -------------------------------------------------------------------------------- /assets/cgm_framework_0123.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/cgm_framework_0123.png -------------------------------------------------------------------------------- /assets/swe-bench-lite-ow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/swe-bench-lite-ow.png -------------------------------------------------------------------------------- /assets/swe-bench-modified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/swe-bench-modified.png -------------------------------------------------------------------------------- /assets/github-codefuse-logo-update.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/github-codefuse-logo-update.jpg -------------------------------------------------------------------------------- /reranker/codegraph_parser/java/__pycache__/codegraph_java_local.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/reranker/codegraph_parser/java/__pycache__/codegraph_java_local.cpython-38.pyc -------------------------------------------------------------------------------- /reranker/codegraph_parser/python/__pycache__/codegraph_python_local.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/reranker/codegraph_parser/python/__pycache__/codegraph_python_local.cpython-38.pyc -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /cgm/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 2 | 3 | def calculate_metrics(y_true, y_pred, average='binary'): 4 | 5 | accuracy = accuracy_score(y_true, y_pred) 6 | precision = precision_score(y_true, y_pred, average=average) 7 | recall = recall_score(y_true, y_pred, average=average) 8 | f1 = f1_score(y_true, y_pred, average=average) 9 | 10 | metrics = { 11 | 'accuracy': accuracy, 12 | 'precision': precision, 13 | 'recall': recall, 14 | 'f1_score': f1 15 | } 16 | 17 | return metrics -------------------------------------------------------------------------------- /cgm/launch/zero3.sh: -------------------------------------------------------------------------------- 1 | accelerate launch \ 2 | --num_machines $N_NODE \ 3 | --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \ 4 | --use_deepspeed \ 5 | --deepspeed_multinode_launcher 'standard' \ 6 | --zero_stage 3 \ 7 | --offload_optimizer_device 'none' \ 8 | --offload_param_device 'none' \ 9 | --gradient_accumulation_steps 32 \ 10 | --gradient_clipping 1.0 \ 11 | --zero3_init_flag true \ 12 | --zero3_save_16bit_model true \ 13 | --main_training_function 'main' \ 14 | --mixed_precision 'bf16' \ 15 | --dynamo_backend 'no' \ 16 | --same_network \ 17 | --machine_rank $RANK \ 18 | --main_process_ip $MASTER_ADDR \ 19 | --main_process_port $MASTER_PORT \ 20 | --rdzv_backend 'static' \ 21 | train/train.py --c config/$TRAIN_CONFIG 22 | -------------------------------------------------------------------------------- /cgm/launch/zero2.sh: -------------------------------------------------------------------------------- 1 | accelerate launch \ 2 | --num_machines $N_NODE \ 3 | --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \ 4 | --use_deepspeed \ 5 | --deepspeed_multinode_launcher 'standard' \ 6 | --zero_stage 2 \ 7 | --offload_optimizer_device 'none' \ 8 | --offload_param_device 'none' \ 9 | --gradient_accumulation_steps 32 \ 10 | --gradient_clipping 1.0 \ 11 | --zero3_init_flag false \ 12 | --zero3_save_16bit_model false \ 13 | --main_training_function 'main' \ 14 | --mixed_precision 'bf16' \ 15 | --dynamo_backend 'no' \ 16 | --same_network \ 17 | --machine_rank $RANK \ 18 | --main_process_ip $MASTER_ADDR \ 19 | --main_process_port $MASTER_PORT \ 20 | --rdzv_backend 'static' \ 21 | train/train.py --c config/$TRAIN_CONFIG 22 | 23 | -------------------------------------------------------------------------------- /reranker/qwen_api.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | from vllm.sampling_params import BeamSearchParams 3 | from vllm.entrypoints.chat_utils import apply_hf_chat_template 4 | import os 5 | 6 | os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' 7 | 8 | class QwenAPI: 9 | def __init__(self, model_path, tensor_parallel_size=4): 10 | self.llm = LLM(model=model_path, tensor_parallel_size=tensor_parallel_size) 11 | 12 | sample_params = dict( 13 | temperature=0.1, 14 | repetition_penalty=1.1, 15 | top_k=50, 16 | top_p=0.98, 17 | max_tokens=1024 18 | ) 19 | 20 | self.sampling_params = SamplingParams(**sample_params) 21 | print(f"sampling_params: {sample_params}") 22 | 23 | def get_response(self, system_prompt: str, user_prompt: str): 24 | conversation = [ 25 | { 26 | "role": "system", 27 | "content": system_prompt 28 | }, 29 | { 30 | "role": "user", 31 | "content": user_prompt 32 | } 33 | ] 34 | outputs = self.llm.chat(conversation, sampling_params=self.sampling_params, use_tqdm=False) 35 | response = outputs[0].outputs[0].text 36 | return response 37 | 38 | 39 | if __name__ == "__main__": 40 | llm = QwenAPI("Qwen/Qwen2.5-1.5B-Instruct") 41 | user_prompt = "Where is the capital of China?" 42 | system_prompt = "You are a helpful assistant." 43 | llm.get_response(system_prompt, user_prompt) -------------------------------------------------------------------------------- /preprocess_embedding/generate_rewriter_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate embedding for Queries from Rewriter's Inferer 3 | """ 4 | 5 | from transformers import AutoTokenizer, AutoModel 6 | import torch 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import tqdm 11 | import json 12 | import pickle 13 | 14 | # custom 15 | import argparse, logging 16 | 17 | # input path 18 | rewriter_output_path = "rewriter_output.json" 19 | 20 | # save path 21 | rewriter_embedding_path = "rewriter_embedding.pkl" 22 | 23 | # load model 24 | model_name_or_path = "xxx/CodeFuse-CGE-Large" 25 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) 26 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right') 27 | 28 | if torch.cuda.is_available(): 29 | device = 'cuda' 30 | else: 31 | device = 'cpu' 32 | model.to(device) 33 | 34 | if __name__ == "__main__": 35 | 36 | with open(rewriter_output_path, 'r') as file: 37 | rewriter_output_dict = json.load(file) 38 | 39 | query_embedding_dict = {} 40 | 41 | for instance_id in tqdm.tqdm(rewriter_output_dict): 42 | query = rewriter_output_dict[instance_id]["query"] 43 | 44 | if len(query) == 0: 45 | continue 46 | 47 | query_embedding_dict[instance_id] = model.encode(tokenizer, query) 48 | 49 | with open(rewriter_embedding_path, 'wb') as f: 50 | pickle.dump(query_embedding_dict, f) -------------------------------------------------------------------------------- /cgm/config/template.json: -------------------------------------------------------------------------------- 1 | { 2 | "graph_dir": [], 3 | "train_files":[], 4 | "valid_files":[], 5 | "output_dir": "", 6 | "tb_dir": "", 7 | 8 | "embedding_dim": 256, 9 | "load_pretrained_encoder": false, 10 | "pretrained_encoder_path": null, 11 | 12 | "load_pretrained_adapter": false, 13 | "pretrained_adapter_path": null, 14 | "adapter_hidden_dim": 4096, 15 | "adapter_num_layers": 1, 16 | "adapter_num_heads": 8, 17 | 18 | "load_pretrained_tokenizer": true, 19 | "pretrained_tokenizer_path": "Qwen/Qwen2.5-Coder-7B-Instruct", 20 | 21 | "pretrained_model_path": "Qwen/Qwen2.5-Coder-7B-Instruct", 22 | "self_defined": false, 23 | "framework_type": "T1", 24 | "model_type": "Qwen", 25 | 26 | "pretrained_lora_path": null, 27 | "quantization": "4bit", 28 | 29 | "mode": "eal", 30 | "task": "unit_test", 31 | "use_chat": false, 32 | "use_adj": true, 33 | 34 | "peft": "LoRA", 35 | "lora_rank": 32, 36 | "lora_alpha": 32, 37 | "lora_dropout": 0.05, 38 | "lora_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 39 | 40 | "enc_peft": "LoRA", 41 | "enc_lora_rank": 32, 42 | "enc_lora_alpha": 32, 43 | "enc_lora_dropout": 0.05, 44 | "enc_lora_modules": "all-linear", 45 | 46 | "graph_pad_token": "<|graph_pad|>", 47 | "graph_pad_id": 32022, 48 | "graph_token_num": 512, 49 | 50 | "learning_rate": 1e-4, 51 | "min_lr": 1e-7, 52 | "weight_decay": 0.1, 53 | "lr_scheduler_type": "reduce_lr_on_plateau", 54 | 55 | "gradient_accumulation_steps": 1, 56 | "num_warmup_steps": 0, 57 | "adapter_warmup": false, 58 | "adapter_warmup_steps": 500, 59 | "num_train_epochs": 20, 60 | 61 | "data_split": "0.98,0.02", 62 | "max_train_steps": null, 63 | "max_train_samples": null, 64 | "max_valid_samples": 2048, 65 | "per_device_train_batch_size": 1, 66 | "per_device_eval_batch_size": 1, 67 | 68 | "seed": 42, 69 | "seq_length": 8192, 70 | "log_interval": 5, 71 | 72 | 73 | 74 | "step_checkpointing": false, 75 | "checkpointing_steps": 500, 76 | "step_evaluation": true, 77 | "evaluation_steps": 5000, 78 | "epoch_evaluation": false, 79 | "epoch_checkpointing": false, 80 | 81 | "early_stopping": true, 82 | "early_stopping_stall_num": 6, 83 | 84 | "attn_implementation": "sdpa" 85 | } 86 | -------------------------------------------------------------------------------- /cgm/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # print out arguments in a nice way 4 | def print_args(args, accelerator): 5 | # 计算所有键的最大字符串长度 6 | max_key_length = max(len(str(key)) for key in vars(args).keys()) 7 | 8 | message = "" 9 | message += "====" * 40 + "\n" 10 | message += '\n'.join([f'{k:<{max_key_length}} : {v}' for k, v in vars(args).items()]) + "\n" 11 | message += "====" * 40 + "\n" 12 | accelerator.print(message) 13 | 14 | def count_parameters(model): 15 | return sum(p.numel() for p in model.parameters()) 16 | 17 | def print_with_rank(accelerator, msg): 18 | print(accelerator.process_index, msg) 19 | 20 | def print_rank_0(*message): 21 | """If distributed is initialized print only on rank 0.""" 22 | if torch.distributed.is_initialized(): 23 | if torch.distributed.get_rank() == 0: 24 | print(*message, flush=True) 25 | else: 26 | print(*message, flush=True) 27 | 28 | def print_rank_0_highlight(*message): 29 | """If distributed is initialized print only on rank 0.""" 30 | if torch.distributed.is_initialized(): 31 | if torch.distributed.get_rank() == 0: 32 | print('=='*100) 33 | print(*message, flush=True) 34 | print('=='*100) 35 | else: 36 | print('=='*100) 37 | print(*message, flush=True) 38 | print('=='*100) 39 | 40 | def print_highlight(*message): 41 | print('=='*100) 42 | print(*message) 43 | print('=='*100) 44 | 45 | def get_computation_speed(batch_size_per_device, seq_len, step_time): 46 | return batch_size_per_device * seq_len / (step_time + 1e-12) 47 | 48 | def touch_print(accelerator, batch, num_tokens=10): 49 | """touch first and last tokens and labels for debugging usage""" 50 | accelerator.print(f"step 1 batch shape: {batch['input_ids'].shape},\n" 51 | f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}" 52 | f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}") 53 | accelerator.print(f"first {num_tokens} input_ids and loss_mask") 54 | for pt in range(1): 55 | accelerator.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") 56 | accelerator.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}") -------------------------------------------------------------------------------- /cgm/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss 3 | 4 | def loss_CGM(output_logits, labels, loss_mask): 5 | 6 | lm_logits = output_logits.contiguous() 7 | labels = labels.to(device=lm_logits.device).contiguous() 8 | loss_mask = loss_mask.to(device=lm_logits.device) 9 | # logits: (bs, l, v); labels, loss_mask: (bs, l) 10 | 11 | # lm loss 12 | bsz = labels.shape[0] 13 | loss_func = CrossEntropyLoss(reduction='none') 14 | losses = loss_func(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) # logits: (bs * l, v); labels: (bs * l,) 15 | # losses -> (bs, l) 16 | losses = losses.contiguous().view(bsz, -1) 17 | 18 | loss_mask = loss_mask.view(-1) 19 | losses = losses.view(-1) 20 | if loss_mask.sum() < 1: 21 | loss_lm = torch.sum(losses * loss_mask) 22 | else: 23 | loss_lm = torch.sum(losses * loss_mask) / loss_mask.sum() 24 | 25 | return loss_lm 26 | 27 | def acc_lp(logits,labels): 28 | predictions = torch.sigmoid(logits) 29 | acc = ((predictions > 0.5) == labels.bool()).float().mean() 30 | return acc.item() 31 | 32 | def loss_lp(outputs, edge_label_dict): 33 | loss_func = BCEWithLogitsLoss(reduction='mean') 34 | losses = [] 35 | edge_loss = {} 36 | edge_acc = {} 37 | total_acc = 0 38 | total_edges = 0 39 | for edge_type in edge_label_dict.keys(): 40 | lm_logits = outputs[edge_type].view(-1) 41 | labels = edge_label_dict[edge_type].to(device=lm_logits.device).view(-1) 42 | loss = loss_func(lm_logits,labels) 43 | losses.append(loss) 44 | acc = acc_lp(lm_logits, labels) 45 | edge_loss[edge_type] = loss.item() 46 | edge_acc[edge_type] = acc 47 | total_acc += len(labels) * acc 48 | total_edges += len(labels) 49 | # del lm_logits, labels, loss 50 | loss1 = torch.sum(torch.stack(losses)) 51 | total_acc = total_acc / total_edges 52 | # del losses, loss_func 53 | return loss1, edge_loss, edge_acc, total_acc 54 | 55 | def loss_ng(outputs, y_dict, mask_dict): 56 | loss_func = MSELoss(reduction='sum') 57 | loss2 = loss_func(outputs,y_dict['Method']) 58 | return loss2 59 | 60 | def loss_lpng(lp_outputs, ng_outputs, edge_label_dict, y_dict, mask_dict): 61 | loss1, edge_loss, edge_acc, total_acc = loss_lp(lp_outputs, edge_label_dict) 62 | loss2 = loss_ng(ng_outputs, y_dict, mask_dict) 63 | loss = loss1 + loss2 64 | return loss, loss1.item(), loss2.item(), edge_loss, edge_acc, total_acc 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /preprocess_embedding/generate_code_content.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate node content for all nodes in code graph 3 | """ 4 | from os.path import isfile 5 | import os, sys 6 | import pandas as pd 7 | import json 8 | import tqdm 9 | import re 10 | 11 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType 12 | 13 | def extract_code_and_doc(content): 14 | """ 15 | split code and doc 16 | """ 17 | # match docstring 18 | docstring_pattern = r'"""(.*?)"""|\'\'\'(.*?)\'\'\'' 19 | docstrings = re.findall(docstring_pattern, content, re.DOTALL) 20 | 21 | # extract pure code 22 | code_without_docstring = re.sub(docstring_pattern, '', content, flags=re.DOTALL) 23 | # merge docstring 24 | extracted_docstrings = "\n\n".join([d[0] or d[1] for d in docstrings]) 25 | return code_without_docstring, extracted_docstrings 26 | 27 | def get_graph_file_name(item): 28 | """ 29 | return graph_file_name 30 | """ 31 | 32 | raise NotImplementedError 33 | 34 | if __name__ == "__main__": 35 | 36 | graph_basic_df = pd.read_json("test_lite_basic_info.json") 37 | 38 | graph_data_path = "codegraph/" 39 | node_content_path = "xx/node_content/" 40 | graph_list = os.listdir(graph_data_path) 41 | 42 | # get the graph_file path 43 | graph_basic_df["graph_file"] = graph_basic_df.apply(lambda item: get_graph_file_name(item), axis=1) 44 | 45 | # generate code content for each repo 46 | for idx, item in tqdm.tqdm(graph_basic_df.iterrows()): 47 | 48 | instance_id = item.instance_id 49 | graph_file = item.graph_file 50 | # get the graph path 51 | tmp_graph_data_path = graph_data_path + graph_file 52 | 53 | # skip files which have been processed 54 | if os.path.isfile(node_content_path + '{}.json'.format(instance_id)): 55 | continue 56 | 57 | graph = parse(tmp_graph_data_path) 58 | 59 | try: 60 | nodes = graph.get_nodes() 61 | except: 62 | print(f"========= parse error: {tmp_graph_data_path} =========") 63 | continue 64 | node_code_dict = {} 65 | node_doc_dict = {} 66 | for node in nodes: 67 | node_id = node.node_id 68 | content = node.get_content() 69 | 70 | code, doc = extract_code_and_doc(content) 71 | 72 | node_code_dict[node_id] = code 73 | 74 | if doc.strip(): 75 | node_doc_dict[node_id] = doc 76 | 77 | # save the result 78 | with open(node_content_path + '{}.json'.format(instance_id), 'w', encoding='utf-8') as json_file: 79 | 80 | node_content_dict = { 81 | "code": node_code_dict, 82 | "doc": node_doc_dict 83 | } 84 | 85 | json.dump(node_content_dict, json_file, ensure_ascii=False, indent=4) -------------------------------------------------------------------------------- /preprocess_embedding/generate_code_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate embedding for code 3 | """ 4 | 5 | from transformers import AutoTokenizer, AutoModel 6 | import torch 7 | import os 8 | import numpy as np 9 | import tqdm 10 | import json 11 | import pickle 12 | 13 | # custom 14 | import argparse, logging 15 | 16 | # input and output path 17 | node_content_path = "xx/node_content/" 18 | node_embedding_path = "xx/tmp_node_embedding/" 19 | 20 | # load model 21 | model_name_or_path = "xxx/CodeFuse-CGE-Large" 22 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) 23 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right') 24 | 25 | if torch.cuda.is_available(): 26 | device = 'cuda' 27 | else: 28 | device = 'cpu' 29 | model.to(device) 30 | 31 | if __name__ == "__main__": 32 | 33 | node_embedding_dict = None 34 | node_embedding_list = os.listdir(node_embedding_path) 35 | node_embedding_list = [item.split('.')[0] for item in node_embedding_list] 36 | candidate_graphs = os.listdir(node_content_path) 37 | 38 | for filename in tqdm.tqdm(candidate_graphs): 39 | 40 | instance_id = filename.split('.')[0] 41 | 42 | # skip samples which have been processed 43 | if instance_id in node_embedding_list: 44 | continue 45 | 46 | with open(node_content_path + filename, 'r', encoding='utf-8') as file: 47 | node_content_dict = json.load(file) 48 | 49 | node_list = list(node_content_dict['code'].keys()) 50 | 51 | node_embedding_dict = {} 52 | node_code_embedding_dict = {} 53 | node_doc_embedding_dict = {} 54 | for node in node_list: 55 | code_content = node_content_dict['code'][node] 56 | if node in node_content_dict['doc']: 57 | doc_content = node_content_dict['doc'][node] 58 | code_content = code_content if code_content else " " 59 | doc_content = doc_content if doc_content else " " 60 | # batch process 61 | text = [code_content, doc_content] 62 | node_code_embedding_dict[node], node_doc_embedding_dict[node] = model.encode(tokenizer, text) 63 | 64 | else: 65 | # for node without doc 66 | code_content = code_content if code_content else " " 67 | node_code_embedding_dict[node] = model.encode(tokenizer, code_content) 68 | 69 | node_embedding_dict = { 70 | "code": node_code_embedding_dict, 71 | "doc": node_doc_embedding_dict 72 | } 73 | 74 | with open(node_embedding_path + '{}.pkl'.format(instance_id), 'wb') as f: 75 | pickle.dump(node_embedding_dict, f) 76 | 77 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /retriever/utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | def codegraph_to_nxgraph(graph): 4 | """ 5 | 将 CodeGraph 对象 转为 networkx 图对象 6 | :param graph: CodeGraph 对象 7 | :return graph_nx: nx.MultiDiGraph 对象 8 | """ 9 | 10 | # 创建图 有向且允许两个节点之间有多条边 11 | G = nx.MultiDiGraph() 12 | 13 | # 增加点 14 | for node in graph.nodes: 15 | G.add_node(graph.nodes[node]) 16 | 17 | # 增加边 18 | for edge in graph.edges: 19 | 20 | src_id = edge.source 21 | tgt_id = edge.target 22 | # 当前的数据中 会出现 不存在 node list 里的节点 23 | # 这样的 边 和 点 都先移除 24 | try: 25 | G.add_edge(graph.nodes[src_id], graph.nodes[tgt_id], type=edge.edge_type) 26 | except: 27 | pass 28 | 29 | print(f"nx Graph data parsed, nodes: {G.number_of_nodes():,}, edges: {G.number_of_edges():,}") 30 | 31 | return G 32 | 33 | def codegraph_to_nxgraph_lite(graph): 34 | """ 35 | 轻量版 将 CodeGraph 对象 转为 networkx 图对象 36 | 节点和边类型都脱离parser定义 37 | :param graph: CodeGraph 对象 38 | :return graph_nx: nx.MultiDiGraph 对象 39 | """ 40 | 41 | # 创建图 有向且允许两个节点之间有多条边 42 | G = nx.MultiDiGraph() 43 | 44 | # 增加点 45 | for node in graph.nodes: 46 | G.add_node(node.node_id) 47 | 48 | # 增加边 49 | for edge in graph.edges: 50 | 51 | src_id = edge.source 52 | tgt_id = edge.target 53 | try: 54 | # 当前的数据中 会出现 不存在 node list 里的节点 55 | # 这样的 边 和 点 都先移除 56 | graph.nodes[src_id] 57 | graph.nodes[tgt_id] 58 | # 确保 源节点目标节点 都存在后,再加入 59 | G.add_edge(src_id, tgt_id, type=edge.edge_type.name) 60 | except: 61 | pass 62 | 63 | print(f"nx Graph lite data parsed, nodes: {G.number_of_nodes():,}, edges: {G.number_of_edges():,}") 64 | 65 | return G 66 | 67 | def codegraph_to_nxgraph_analysis(graph): 68 | """ 69 | 分析版 将 CodeGraph 对象 转为 networkx 图对象 70 | 专门用于路径分析的版本,移除 Repo 和 Package 节点; 71 | 返回 有向图 和 无向图 两个 版本 72 | - 有向图:用于分析节点之间的转移概率;两个点之间可能存在多条边 73 | - 无向图:用于分析节点(File)之间的最短路径 74 | :param graph: CodeGraph 对象 75 | :return graph_nx: nx.MultiDiGraph 对象 76 | """ 77 | 78 | # 创建图 有向且允许两个节点之间有多条边 79 | G_d = nx.MultiDiGraph() 80 | G_u = nx.Graph() 81 | 82 | # 增加点 83 | for node_id in graph.nodes: 84 | node = graph.get_node_by_id(node_id) 85 | node_type = node.get_type().name 86 | if node_type in ['REPO', 'PACKAGE']: 87 | continue 88 | G_d.add_node(node.node_id) 89 | G_u.add_node(node.node_id) 90 | 91 | # 增加边 92 | for edge in graph.edges: 93 | 94 | src_id = edge.source 95 | tgt_id = edge.target 96 | 97 | if G_d.has_node(src_id) and G_d.has_node(tgt_id): 98 | G_d.add_edge(src_id, tgt_id, type=edge.edge_type.name) 99 | G_u.add_edge(src_id, tgt_id, type=edge.edge_type.name) 100 | 101 | print(f"nx Graph analysis data parsed, nodes: {G_d.number_of_nodes():,}, edges: {G_d.number_of_edges():,}") 102 | 103 | return G_d, G_u -------------------------------------------------------------------------------- /retriever/serialize_subgraph.py: -------------------------------------------------------------------------------- 1 | """ 2 | serialize subgraph to json file 3 | Here we provide two version 4 | ✅ Direct serialize the original subgraph 5 | ✅ Serialize the file-level subgraph 6 | """ 7 | import sys 8 | import json 9 | import os 10 | import tqdm 11 | import pandas as pd 12 | import networkx as nx 13 | 14 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType 15 | from utils import codegraph_to_nxgraph 16 | 17 | ############################# utils ############################# 18 | 19 | def get_contained_node(graph_nx, node): 20 | 21 | c_node_list = [] 22 | for suc_node in graph_nx.successors(node): 23 | if graph_nx[node][suc_node][0]['type'] == EdgeType.CONTAINS: 24 | c_node_list.append(suc_node) 25 | 26 | return c_node_list 27 | 28 | def get_inner_nodes(graph_nx, node): 29 | 30 | inner_nodes = get_contained_node(graph_nx, node) 31 | inner_nodes_all = [] 32 | 33 | while len(inner_nodes) != 0: 34 | 35 | tmp_inner_nodes = inner_nodes.copy() 36 | inner_nodes = [] 37 | for node in tmp_inner_nodes: 38 | inner_nodes_all.append(node) 39 | inner_nodes.extend(get_contained_node(graph_nx, node)) 40 | 41 | return list(set(inner_nodes_all)) 42 | 43 | def serialize_subgraph(graph_nx, file_name): 44 | 45 | node_list = [node.to_dict() for node in graph_nx.nodes()] 46 | 47 | # 获取子图中的连边关系 48 | edge_list = [] 49 | for edge in graph_nx.edges(): 50 | edge_type = graph_nx[edge[0]][edge[1]][0]['type'] 51 | tmp_edge_dict = { 52 | "edgeType": edge_type.name.lower(), 53 | "source": edge[0].node_id, 54 | "target": edge[1].node_id 55 | } 56 | edge_list.append(tmp_edge_dict) 57 | # 对所有的边,获取对应边类型 58 | graph_json = { 59 | "nodes": node_list, 60 | "edges": edge_list 61 | } 62 | 63 | with open(file_name + '.json', 'w') as json_file: 64 | json.dump(graph_json, json_file, indent=4) 65 | 66 | return True 67 | 68 | ############################# utils ############################# 69 | 70 | 71 | if __name__ == "__main__": 72 | 73 | test_basic_df = pd.read_json("test_basic_info.json") 74 | graph_data_path = "codegraph/" 75 | subgraph_dict_path = "subgraph_nodes.json" 76 | save_path = "subgraph/" 77 | 78 | with open(subgraph_dict_path, "r", encoding="utf-8") as file: 79 | one_hop_dict = json.load(file) 80 | file.close() 81 | 82 | for idx, item in tqdm.tqdm(test_basic_df.iterrows()): 83 | 84 | instance_id = item.instance_id 85 | graph_file = item.graph_file 86 | tmp_graph_data_path = graph_data_path + graph_file 87 | 88 | if not os.path.exists(tmp_graph_data_path): 89 | continue 90 | 91 | filename = save_path + instance_id 92 | 93 | if os.path.exists(filename + '.json'): 94 | continue 95 | 96 | graph = parse(tmp_graph_data_path) 97 | graph_nx = codegraph_to_nxgraph(graph) 98 | 99 | # Version 1: Directly Serialization 100 | # all_nodes = one_hop_dict[instance_id] 101 | 102 | # Version 1: Serialization in File-level 103 | all_nodes = one_hop_dict[instance_id] 104 | for node_id in all_nodes: 105 | node = graph.get_node_by_id(node_id) 106 | if node.get_type() == NodeType.FILE: 107 | inner_node = get_inner_nodes(graph_nx, node) 108 | for i_node in inner_node: 109 | all_nodes.append(i_node.node_id) 110 | 111 | all_nodes = [graph.get_node_by_id(node_id) for node_id in all_nodes] 112 | subgraph = graph_nx.subgraph(list(all_nodes)) 113 | 114 | serialize_subgraph(subgraph, filename) -------------------------------------------------------------------------------- /cgm/utils/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | import argparse, json 3 | from typing import List, Union 4 | import torch 5 | 6 | 7 | @dataclass 8 | class TrainArgs: 9 | graph_dir: Union[str, List[str]] 10 | train_files: Union[str, List[str]] 11 | valid_files: Union[str, List[str]] 12 | output_dir: str 13 | tb_dir: str 14 | 15 | embedding_dim: int = 2304 16 | 17 | load_pretrained_encoder: bool = False 18 | pretrained_encoder_path: Union[None, str] = None 19 | load_pretrained_adapter: bool = False 20 | pretrained_adapter_path: Union[None, str] = None 21 | adapter_hidden_dim: int = 4096 22 | adapter_num_layers: int = 1 23 | adapter_num_heads: int = 8 24 | 25 | self_defined: bool = False 26 | pretrained_model_path: Union[None, str] = None 27 | lm_hidden_dim: int = 4096 28 | quantization: Union[None, str] = None 29 | framework_type: Union[None, str] = "default" 30 | model_type: Union[None, str] = None 31 | 32 | load_pretrained_tokenizer: bool = True 33 | pretrained_tokenizer_path: Union[None, str] = None 34 | 35 | # for evaluation 36 | pretrained_lora_path: Union[None, str] = None 37 | 38 | # training mode: 39 | # "e" 1, "a" 2, "l" 3 40 | mode: str = "a" 41 | task: str = "align" 42 | use_chat: bool = True 43 | use_adj: bool = False 44 | 45 | # lora rank, the bigger, the more trainalbe parameters 46 | peft: Union[None, str] = None 47 | lora_rank: int = 32 48 | lora_alpha: int = 32 49 | lora_dropout: float = 0.05 50 | lora_modules: Union[str, List[str]] = "all-linear" 51 | 52 | enc_peft: Union[None, str] = None 53 | enc_lora_rank: int = 32 54 | enc_lora_alpha: int = 32 55 | enc_lora_dropout: float = 0.05 56 | enc_lora_modules: Union[str, List[str]] = "all-linear" 57 | 58 | graph_pad_token: str = "<|graph_pad|>" 59 | graph_pad_id: int = 32022 60 | graph_token_num: int = 512 61 | 62 | learning_rate: float = 5e-5 63 | min_lr: float = 5e-6 64 | weight_decay: float = 0.1 65 | lr_scheduler_type: str = "cosine" 66 | 67 | gradient_accumulation_steps: int = 1 68 | num_warmup_steps: int = 300 69 | adapter_warmup: bool = False 70 | adapter_warmup_steps: int = 500 71 | num_train_epochs: int = 2 72 | 73 | # train/valid split 74 | data_split: str = "0.98,0.02" 75 | max_train_samples: Union[None, int] = None 76 | max_valid_samples: Union[None, int] = None 77 | 78 | per_device_train_batch_size: int = 1 79 | per_device_eval_batch_size: int = 1 80 | 81 | seed: int = 42 82 | 83 | seq_length: int = 4096 84 | log_interval: int = 10 85 | step_checkpointing: bool = False 86 | checkpointing_steps: int = 100 87 | 88 | step_evaluation: bool = False 89 | evaluation_steps: int = 100 90 | 91 | # max train steps, if None, depends on num_train_epochs 92 | max_train_steps: Union[None, int] = None 93 | 94 | # if checkpointing every epoch, maybe True in sst 95 | epoch_checkpointing: bool = False 96 | epoch_evaluation: bool = False 97 | 98 | early_stopping: bool = False 99 | early_stopping_stall_num: int = 5 100 | 101 | attn_implementation: str = "flash_attention_2" 102 | 103 | def dict(self): 104 | return {k: str(v) for k, v in asdict(self).items()} 105 | 106 | 107 | def prepare_args(args_type="Train"): 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument("--c", type=str, default=None) 110 | parsed = parser.parse_args() 111 | with open(parsed.c, 'r') as f: 112 | c = json.load(f) 113 | if args_type == "Train": 114 | args = TrainArgs(**c) 115 | else: 116 | raise ValueError("args_type must be Train") 117 | if not torch.cuda.is_available(): 118 | args.attn_implementation = 'eager' 119 | 120 | return args 121 | -------------------------------------------------------------------------------- /rewriter/inference_rewriter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rewriter Inference 3 | """ 4 | import os 5 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:3950" 6 | 7 | import torch 8 | import sys 9 | import time 10 | from tqdm import tqdm 11 | import re 12 | 13 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GenerationConfig 14 | import transformers 15 | 16 | transformers.logging.set_verbosity_error() 17 | import json 18 | from copy import deepcopy 19 | 20 | # custom 21 | import argparse, logging 22 | from torch.utils.data import Dataset, DataLoader 23 | import panda as pd 24 | 25 | class PromptDataset(Dataset): 26 | def __init__(self, data): 27 | self.data = data 28 | 29 | def __getitem__(self, index): 30 | return self.data[index] 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | ################################################### 36 | 37 | def parse_args(): 38 | """ 39 | Parses the arguments 40 | """ 41 | 42 | parser = argparse.ArgumentParser(description="Run Inference Model.") 43 | 44 | parser.add_argument('--prompt_path', nargs='?', default='', 45 | help='Specify the prompts file') 46 | 47 | return parser.parse_args() 48 | 49 | def inference_LLM_patch(prompt_path): 50 | 51 | pretrained_path = 'xxx/Qwen2.5-72B-Instruct' 52 | 53 | model = AutoModelForCausalLM.from_pretrained(pretrained_path, device_map="auto", torch_dtype="auto") 54 | tokenizer = AutoTokenizer.from_pretrained(pretrained_path) 55 | 56 | model.half() 57 | 58 | data = [] 59 | test_basic_info = pd.read_json(prompt_path) 60 | data = test_basic_info["extractor_prompt"].tolist() + test_basic_info["inferer_prompt"].tolist() # get prompt 61 | instance_num = len(test_basic_info["extractor_prompt"].tolist()) 62 | 63 | dataset = PromptDataset(data) 64 | dataloader = DataLoader(dataset, batch_size=8, shuffle=False) 65 | 66 | response_list = [] 67 | for batch in tqdm(dataloader): 68 | try: 69 | text_batch = [] 70 | for prompt in batch: 71 | messages = [ 72 | {"role": "system", "content": "You are a helpful assistant."}, 73 | {"role": "user", "content": prompt} 74 | ] 75 | text = tokenizer.apply_chat_template( 76 | messages, 77 | add_generation_prompt=True, 78 | tokenize=False 79 | ) 80 | text_batch.append(text) 81 | 82 | model_inputs = tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt").to(model.device) 83 | 84 | generated_ids = model.generate( 85 | **model_inputs, 86 | max_new_tokens=512 87 | ) 88 | generated_ids = [ 89 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 90 | ] 91 | 92 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 93 | response_list.extend(response) 94 | torch.cuda.empty_cache() 95 | except Exception as e: 96 | print(f"Out of Memory. {e}") 97 | response_list.extend(["error"] * len(batch)) 98 | torch.cuda.empty_cache() 99 | 100 | #### save output #### 101 | test_basic_info["rewriter_inferer"] = response_list[:instance_num] 102 | test_basic_info["rewriter_extractor"] = response_list[instance_num:] 103 | 104 | test_basic_info.to_json("test_rewriter_output.json", index=False) 105 | 106 | return True 107 | 108 | if __name__ == "__main__": 109 | 110 | args = parse_args() 111 | 112 | print("Start Rewiter Running:") 113 | inference_LLM_patch(args.prompt_path) 114 | print("Rewiter Running Ended.") -------------------------------------------------------------------------------- /rewriter/prompt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt Template for Rewriter 3 | """ 4 | 5 | def generate_prompt_for_extractor(problem_statement, repo_name): 6 | """ 7 | 为 extractor 生成 prompt 8 | """ 9 | 10 | prompt = """ 11 | 12 | {} 13 | 14 | This is an issue related to repository '{}'. 15 | Instructions: 16 | 1. Analysis: 17 | ○ Analyze the provided issue description. Identify the relevant File, Class, or Function involved. 18 | ○ Determine the specific problem or error encountered and note any clues that may assist in locating the relevant or problematic area. 19 | 2. Extraction: 20 | ○ After the analysis, extract ALL the mentioned code entities (File, Class, or Function), especially Files. 21 | ○ Then extract three potential and meaningful keywords, responding in the following format: 22 | 23 | [start_of_analysis] 24 | 25 | [end_of_analysis] 26 | 27 | [start_of_related_code_entities] 28 | 29 | [end_of_related_code_entities] 30 | 31 | [start_of_related_keywords] 32 | 33 | [end_of_related_keywords] 34 | 35 | Notes: 36 | - Pay attention to the information in the error logs (if exists). 37 | - The buggy code exists solely in the project described in the issue (e.g., django, sklearn). Buggy location is usually not in the tests files or external packages. 38 | - Your extracted entities should be CONCISE, ACCURATE and INFORMATIVE. 39 | - Provide the relative path for code entities if specified (e.g., package/foo.py). Relative path is relative to the repository itself, do not include suffix like '/home/username/', '/etc/service/' or '/tree/master'. 40 | - Do not include any additional information such as line numbers or explanations in your extraction result. 41 | 42 | Preferred extraction Examples of Code Entities: 43 | - repo/cart.py 44 | - Class User() 45 | - def getData() 46 | Preferred extraction Examples of Keywords: 47 | - train_loop 48 | - hooks 49 | - docker 50 | 51 | Unpreferred extraction Examples of keywords: 52 | - something wrong 53 | - input validation 54 | - TypeError 55 | """.format(problem_statement, repo_name) 56 | 57 | return prompt 58 | 59 | def generate_prompt_for_inferer(problem_statement, repo_name): 60 | """ 61 | 为 inferer 生成 prompt 62 | """ 63 | 64 | prompt = """ 65 | 66 | {} 67 | 68 | This is an issue related to repository '{}'. 69 | Task: 70 | Based on the issue description provided, identify the characteristics of code entities (files, functions, class) that might need to be modified. 71 | For each characteristic, generate a search query that could help locate relevant code entities in a codebase. 72 | Instructions: 73 | First, analyze the issue description and identify keywords, features, and functionalities that are likely relevant to the modification of code entities. 74 | Then, create queries that capture these characteristics, focusing on: 75 | ● File names that may implement relevant functionalities. 76 | ● Functions or methods that are related to the features described in the issue. 77 | ● Any patterns or structures that might be relevant to the functionalities mentioned. 78 | For example: 79 | ● File related to the initialization of a neural network. 80 | ● Function related to the training process. 81 | ● Code used to configure the service. 82 | Please answer in the following format: 83 | 84 | [start_of_analysis] 85 | 86 | [end_of_analysis] 87 | 88 | [start_of_related_queries] 89 | query 1: 90 | query 2: 91 | ... 92 | [end_of_related_queries] 93 | 94 | Notes: 95 | - Your queries should be DETAILED, ACCURATE and INFORMATIVE. 96 | - Your queries should be a complete sentences and do not include additional explanation. 97 | - The number of queries is up to five, so be focus on the important characteristics. 98 | - Your queries should focus on the repository code itself, rather than other information like commit history. 99 | - Pay attention to the information in the error logs (if exists). 100 | 101 | Preferred Query Examples: 102 | - Look for references to "tqdm" or "progress_bar" within the training loop files to find where progress bars are currently updated. 103 | - Code snippets where 'gethostbyname' function from 'socket' module is called. 104 | - File name containing 'mysql.py' AND functions related to 'MySQLStatementSamples' initialization. 105 | - Functions or methods handling hostname resolution or encoding within 'datadog_checks' directory. 106 | - Find all occurrences of "early_stopping" within files that also mention "Trainer" to identify where early stopping logic is implemented and potentially needs adjustment for non-default 'val_check_interval'. 107 | """.format(problem_statement, repo_name) 108 | 109 | return prompt 110 | 111 | -------------------------------------------------------------------------------- /cgm/inference/layer.py: -------------------------------------------------------------------------------- 1 | """Attention layer.""" 2 | from typing import Any, Dict, List, Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from vllm.attention import AttentionMetadata, AttentionType 8 | from vllm.attention.selector import get_attn_backend 9 | from vllm.config import CacheConfig 10 | from vllm.model_executor.layers.quantization.base_config import ( 11 | QuantizationConfig) 12 | from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod 13 | 14 | 15 | class Attention(nn.Module): 16 | """Attention layer. 17 | 18 | This class takes query, key, and value tensors as input. The input tensors 19 | can either contain prompt tokens or generation tokens. 20 | The class does the following: 21 | 22 | 1. Store the input key and value tensors in the KV cache. 23 | 2. Perform (multi-head/multi-query/grouped-query) attention. 24 | 3. Return the output tensor. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | num_heads: int, 30 | head_size: int, 31 | scale: float, 32 | num_kv_heads: Optional[int] = None, 33 | alibi_slopes: Optional[List[float]] = None, 34 | cache_config: Optional[CacheConfig] = None, 35 | quant_config: Optional[QuantizationConfig] = None, 36 | blocksparse_params: Optional[Dict[str, Any]] = None, 37 | logits_soft_cap: Optional[float] = None, 38 | prefix: str = "", 39 | ) -> None: 40 | super().__init__() 41 | if cache_config is not None: 42 | kv_cache_dtype = cache_config.cache_dtype 43 | block_size = cache_config.block_size 44 | sliding_window = cache_config.sliding_window 45 | is_attention_free = cache_config.is_attention_free 46 | else: 47 | kv_cache_dtype = "auto" 48 | block_size = 16 49 | sliding_window = None 50 | is_attention_free = False 51 | if num_kv_heads is None: 52 | num_kv_heads = num_heads 53 | 54 | # The default k/v_scale is set to 1.0. This is ignored 55 | # when kv-cache is not fp8, and should be used with 56 | # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we 57 | # expect the pre-quantized k/v_scale to be loaded along 58 | # with the model weights. 59 | self.kv_cache_dtype = kv_cache_dtype 60 | self._k_scale = 1.0 61 | self._v_scale = 1.0 62 | quant_method = quant_config.get_quant_method( 63 | self, prefix=prefix) if quant_config else None 64 | if quant_method is not None: 65 | assert isinstance(quant_method, BaseKVCacheMethod) 66 | # TODO (mgoin): kv cache dtype should be specified in the FP8 67 | # checkpoint config and become the "auto" behavior 68 | if self.kv_cache_dtype == "fp8_e5m2": 69 | raise ValueError("fp8_e5m2 kv-cache is not supported with " 70 | "fp8 checkpoints.") 71 | # If quantization is enabled, we make "k_scale" and "v_scale" 72 | # parameters so that it can be loaded from the model checkpoint. 73 | # The k/v_scale will then be converted back to native float32 74 | # values after weight loading. 75 | self.quant_method = quant_method 76 | self.quant_method.create_weights(self) 77 | 78 | # During model initialization, the default dtype is set as the model 79 | # weight and activation dtype. 80 | dtype = torch.get_default_dtype() 81 | attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, 82 | block_size, is_attention_free, 83 | blocksparse_params is not None) 84 | impl_cls = attn_backend.get_impl_cls() 85 | self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, 86 | alibi_slopes, sliding_window, kv_cache_dtype, 87 | blocksparse_params, logits_soft_cap) 88 | 89 | def forward( 90 | self, 91 | query: torch.Tensor, 92 | key: torch.Tensor, 93 | value: torch.Tensor, 94 | kv_cache: torch.Tensor, 95 | attn_metadata: AttentionMetadata, 96 | attn_type: AttentionType = AttentionType.DECODER, 97 | ) -> torch.Tensor: 98 | 99 | return self.impl.forward(query, 100 | key, 101 | value, 102 | kv_cache, 103 | attn_metadata, 104 | self._k_scale, 105 | self._v_scale, 106 | attn_type=attn_type) 107 | 108 | def extra_repr(self) -> str: 109 | s = f"head_size={self.impl.head_size}" # type: ignore 110 | s += f", num_heads={self.impl.num_heads}" # type: ignore 111 | s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore 112 | s += f", scale={self.impl.scale}" # type: ignore 113 | s += f", backend={self.impl.__class__.__name__}" 114 | return s -------------------------------------------------------------------------------- /rewriter/rewriter_output_post_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post-processing: Extract Key information from Rewriter's Output 3 | """ 4 | import json 5 | import re 6 | import pandas as pd 7 | import os 8 | 9 | def extract_code_entities_from_rewriter(text): 10 | # find code_entities 11 | pattern_code_entities = r'\[start_of_related_code_entities\]\s*(.*?)\s*\[end_of_related_code_entities\]' 12 | match_code_entities = re.search(pattern_code_entities, text, re.DOTALL) 13 | if match_code_entities: 14 | code_entities = match_code_entities.group(1).strip().split('\n') 15 | else: # execute exceptional output 16 | pattern_1 = r'\s*(.*?)\s*' 17 | match_1 = re.search(pattern_1, text, re.DOTALL) 18 | pattern_2 = r'\[Start_of_Related_Code_Entities\]\s*(.*?)\s*\[End_of_Related_Code_Entities\]' 19 | match_2 = re.search(pattern_2, text, re.DOTALL) 20 | 21 | if match_1: 22 | code_entities = match_1.group(1).strip().split('\n') 23 | elif match_2: 24 | code_entities = match_2.group(1).strip().split('\n') 25 | else: 26 | code_entities = [] 27 | 28 | # add post processing 29 | for idx, entity in enumerate(code_entities): 30 | if entity.startswith("- "): 31 | code_entities[idx] = entity[2:] 32 | 33 | return code_entities 34 | 35 | def extract_related_keywords_from_rewriter(text): 36 | # find related_keywords 37 | pattern_related_keywords = r'\[start_of_related_keywords\]\s*(.*?)\s*\[end_of_related_keywords\]' 38 | match_related_keywords = re.search(pattern_related_keywords, text, re.DOTALL) 39 | if match_related_keywords: 40 | related_keywords = match_related_keywords.group(1).strip().split('\n') 41 | else: 42 | pattern_1 = r'\s*(.*?)\s*' 43 | match_1 = re.search(pattern_1, text, re.DOTALL) 44 | pattern_2 = r'\[Start_of_Related_Keywords\]\s*(.*?)\s*\[End_of_Related_Keywords\]' 45 | match_2 = re.search(pattern_2, text, re.DOTALL) 46 | 47 | if match_1: 48 | related_keywords = match_1.group(1).strip().split('\n') 49 | elif match_2: 50 | related_keywords = match_2.group(1).strip().split('\n') 51 | else: 52 | related_keywords = [] 53 | 54 | # add post processing 55 | for idx, keyword in enumerate(related_keywords): 56 | if keyword.startswith("- "): 57 | related_keywords[idx] = keyword[2:] 58 | 59 | return related_keywords 60 | 61 | def extract_query_from_rewriter(text): 62 | # match query 63 | pattern_query = r'\[start_of_related_queries\]\s*(.*?)\s*\[end_of_related_queries\]' 64 | match_query = re.search(pattern_query, text, re.DOTALL) 65 | if match_query: 66 | queries = match_query.group(1).strip().split('\n') 67 | else: 68 | pattern_1 = r'\s*(.*?)\s*' 69 | match_1 = re.search(pattern_1, text, re.DOTALL) 70 | pattern_2 = r'\[Start_of_Related_Queries\]\s*(.*?)\s*\[End_of_Related_Queries\]' 71 | match_2 = re.search(pattern_2, text, re.DOTALL) 72 | 73 | if match_1: 74 | queries = match_1.group(1).strip().split('\n') 75 | elif match_2: 76 | queries = match_2.group(1).strip().split('\n') 77 | else: 78 | queries = [] 79 | 80 | # add post processing 81 | for idx, query in enumerate(queries): 82 | if query.startswith("query"): 83 | queries[idx] = query[9:] 84 | elif query.startswith("-"): 85 | queries[idx] = query[2:] 86 | queries = [query for query in queries if len(query)>0] 87 | return queries 88 | 89 | 90 | if __name__ == "__main__": 91 | 92 | test_basic_info = pd.read_json("test_rewriter_output.json") 93 | 94 | # start post processing 95 | test_basic_info["rewriter_inferer_output"] = test_basic_info["rewriter_inferer"].apply(lambda item:extract_query_from_rewriter(item)) 96 | test_basic_info["rewriter_extractor_output_entity"] = test_basic_info["rewriter_extractor"].apply(lambda item:extract_code_entities_from_rewriter(item)) 97 | test_basic_info["rewriter_extractor_output_keyword"] = test_basic_info["rewriter_extractor"].apply(lambda item:extract_related_keywords_from_rewriter(item)) 98 | 99 | rewriter_output_dict = {} 100 | error_case = [] 101 | for idx, item in train_basic_info.iterrows(): 102 | instance_id = item.instance_id 103 | entity = item.rewriter_extractor_output_entity 104 | keyword = item.rewriter_extractor_output_keyword 105 | query = item.rewriter_inferer_output 106 | # if entity or keyword or query: 107 | if entity and keyword and query: 108 | rewriter_output_dict[instance_id] = { 109 | "code_entity": entity, 110 | "keyword": keyword, 111 | "query": query 112 | } 113 | else: 114 | error_case.append(instance_id) 115 | 116 | with open("rewriter_output.json", 'w', encoding='utf-8') as file: 117 | json.dump(rewriter_output_dict, file) 118 | 119 | # save trajs 120 | test_basic_info.to_json("test_rewriter_output.json", index=False) 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /reranker/prompt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prompt Template for Reranker 3 | """ 4 | 5 | reranker_stage_1_system_prompt = """ 6 | You are an experienced software developer who specializes in extracting the most relevant files for solving issues from many reference files. 7 | 8 | Task: 9 | Based on the information received about the issue from a repository, find the most likely few files from among those that may be able to resolve the issue. 10 | 11 | Instructions: 12 | 1. Analysis: 13 | - Analyze the provided issue description and files, and pay attention to the relevance of the provided files with the given issue, especially those might be modified during fixing the issue. 14 | - Determine the specific problem or error mentioned in the issue and note any clues that could help your judgment. 15 | 2. Extraction: 16 | - Based on your analysis, choose the Top **1** relevant files which might be used in fixing the issue. 17 | - You should choose files from the provided files, and should not modify their name in any way. 18 | 19 | Respond in the following format: 20 | [start_of_analysis] 21 | 22 | [end_of_analysis] 23 | 24 | [start_of_relevant_files] 25 | 1. 26 | 2. 27 | 3. ... 28 | [end_of_relevant_files] 29 | 30 | Notes: 31 | - You can refer to to the information in the error logs (if exists). 32 | - The relevant file usually exists in the project described in the issue (e.g., django, sklearn). File need modification is usually not in the tests files or external packages. 33 | - The file you choose should be contained in the provided files. 34 | - Provide the file path with files. Do not include redundant suffix like '/home/username/', '/etc/service/' or '/tree/master'. 35 | - Do not include any additional information such as line numbers or explanations in your extraction result. 36 | - Files for initialization and configuration might be modified during changing the code. 37 | 38 | Preferred extraction Examples of Related Files: 39 | 1. src/utils/file_handler.py 40 | 2. core/services/service_manager.py 41 | 3. ... 42 | """.strip() 43 | 44 | reranker_stage_1_user_prompt_template = """ 45 | 46 | {} 47 | 48 | 49 | 50 | {} 51 | 52 | 53 | 54 | {} 55 | 56 | 57 | 58 | {} 59 | 60 | """ 61 | 62 | reranker_stage_2_system_prompt = """ 63 | You are an experienced software developer who specializes in assessing the relevance of the file for solving the issue in software repositories. 64 | 65 | Task: 66 | For a file provided, evaluate the likelihood that modifying this file would resolve the given issue, and assign a score based on specific criteria. 67 | 68 | Instructions: 69 | 1. Analysis: 70 | - Analyze the provided issue description and the content of the single relevant file, pay attention to any keywords, error messages, or specific functionalities mentioned that relate to the file. 71 | - Determine how closely the contents and functionality of the file are tied to the problem or error described in the issue. 72 | - Consider the role of the file in the overall project structure (e.g., configuration files, core logic files versus test files, or utility scripts). 73 | 2. Scoring: 74 | - Based on your analysis, assign a score from 1 to 5 that represents the relevance of modifying the given file in order to solve the issue. 75 | 76 | Score Specifications: 77 | 1. **Score 1**: The file is almost certainly unrelated to the issue, with no apparent connection to the functionality or error described in the issue. 78 | 2. **Score 2**: The file may be tangentially related, but modifying it is unlikely to resolve the issue directly; possible in rare edge cases. 79 | 3. **Score 3**: The file has some relevance to the issue; it might interact with the affected functionality indirectly and tweaking it could be part of a broader fix. 80 | 4. **Score 4**: The file is likely related to the issue; it includes code that interacts directly with the functionality in question and could plausibly contain bugs that lead to the issue. 81 | 5. **Score 5**: The file is very likely the root cause or heavily involved in the issue and modifying it should directly address the error or problem mentioned. 82 | 83 | Respond in the following format: 84 | [start_of_analysis] 85 | 86 | [end_of_analysis] 87 | 88 | [start_of_score] 89 | Score 90 | [end_of_score] 91 | 92 | Notes: 93 | - The content of the file shows only the structure of this file, including the names of the classes and functions defined in this file. 94 | - You can refer to to the information in the error logs (if exists). 95 | """.strip() 96 | 97 | reranker_stage_2_user_prompt_template = """ 98 | 99 | {} 100 | 101 | 102 | 103 | {} 104 | 105 | 106 | 107 | {} 108 | 109 | 110 | 111 | {} 112 | 113 | """ 114 | 115 | def generate_prompt_for_reranker_stage_1(problem_statement, repo_name, py_file, other_file): 116 | """ 117 | problem_statement: issue内容 118 | repo_name: repo名 119 | py_file: 可能相关的py文件名list 120 | other_file: 其他可能相关的文件名list 121 | """ 122 | return reranker_stage_1_system_prompt, reranker_stage_1_user_prompt_template.format(repo_name, problem_statement, py_file, other_file) 123 | 124 | def generate_prompt_for_reranker_stage_2(problem_statement, repo_name, file_name, file_content): 125 | """ 126 | problem_statement: issue内容 127 | repo_name: repo名 128 | file_name: 文件名 129 | file_content: 文件内容,只包含类和函数的声明(class xxx和def xxx) 130 | """ 131 | return reranker_stage_2_system_prompt, reranker_stage_2_user_prompt_template.format(repo_name, problem_statement, file_name, file_content) 132 | -------------------------------------------------------------------------------- /rewriter/generate_rewriter_prompt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate Prompts for Rewriter 3 | """ 4 | 5 | import pandas as pd 6 | 7 | def generate_prompt_for_extractor_v1(problem_statement, repo_name): 8 | """ 9 | generate prompt for extractor 10 | """ 11 | 12 | prompt = """ 13 | 14 | {} 15 | 16 | This is an issue related to repository '{}'. 17 | Instructions: 18 | 1. Analysis: 19 | ○ Analyze the provided issue description. Identify the relevant File, Class, or Function involved. 20 | ○ Determine the specific problem or error encountered and note any clues that may assist in locating the relevant or problematic area. 21 | 2. Extraction: 22 | ○ After the analysis, extract ALL the mentioned code entities (File, Class, or Function), especially Files. 23 | ○ Then extract three potential and meaningful keywords, responding in the following format: 24 | 25 | [start_of_analysis] 26 | 27 | [end_of_analysis] 28 | 29 | [start_of_related_code_entities] 30 | 31 | [end_of_related_code_entities] 32 | 33 | [start_of_related_keywords] 34 | 35 | [end_of_related_keywords] 36 | 37 | Notes: 38 | - Pay attention to the information in the error logs (if exists). 39 | - The buggy code exists solely in the project described in the issue (e.g., django, sklearn). Buggy location is usually not in the tests files or external packages. 40 | - Your extracted entities should be CONCISE, ACCURATE and INFORMATIVE. 41 | - Provide the relative path for code entities if specified (e.g., package/foo.py). Relative path is relative to the repository itself, do not include suffix like '/home/username/', '/etc/service/' or '/tree/master'. 42 | - Do not include any additional information such as line numbers or explanations in your extraction result. 43 | 44 | Preferred extraction Examples of Code Entities: 45 | - repo/cart.py 46 | - Class User() 47 | - def getData() 48 | Preferred extraction Examples of Keywords: 49 | - train_loop 50 | - hooks 51 | - docker 52 | 53 | Unpreferred extraction Examples of keywords: 54 | - something wrong 55 | - input validation 56 | - TypeError 57 | """.format(problem_statement, repo_name) 58 | 59 | return prompt 60 | 61 | def generate_prompt_for_inferer_v1(problem_statement, repo_name): 62 | """ 63 | generate prompt for inferer 64 | """ 65 | 66 | prompt = """ 67 | 68 | {} 69 | 70 | This is an issue related to repository '{}'. 71 | Task: 72 | Based on the issue description provided, identify the characteristics of code entities (files, functions, class) that might need to be modified. 73 | For each characteristic, generate a search query that could help locate relevant code entities in a codebase. 74 | Instructions: 75 | First, analyze the issue description and identify keywords, features, and functionalities that are likely relevant to the modification of code entities. 76 | Then, create queries that capture these characteristics, focusing on: 77 | ● File names that may implement relevant functionalities. 78 | ● Functions or methods that are related to the features described in the issue. 79 | ● Any patterns or structures that might be relevant to the functionalities mentioned. 80 | For example: 81 | ● File related to the initialization of a neural network. 82 | ● Function related to the training process. 83 | ● Code used to configure the service. 84 | Please answer in the following format: 85 | 86 | [start_of_analysis] 87 | 88 | [end_of_analysis] 89 | 90 | [start_of_related_queries] 91 | query 1: 92 | query 2: 93 | ... 94 | [end_of_related_queries] 95 | 96 | Notes: 97 | - Your queries should be DETAILED, ACCURATE and INFORMATIVE. 98 | - Your queries should be a complete sentences and do not include additional explanation. 99 | - The number of queries is up to five, so be focus on the important characteristics. 100 | - Your queries should focus on the repository code itself, rather than other information like commit history. 101 | - Pay attention to the information in the error logs (if exists). 102 | 103 | Preferred Query Examples: 104 | - Look for references to "tqdm" or "progress_bar" within the training loop files to find where progress bars are currently updated. 105 | - Code snippets where 'gethostbyname' function from 'socket' module is called. 106 | - File name containing 'mysql.py' AND functions related to 'MySQLStatementSamples' initialization. 107 | - Functions or methods handling hostname resolution or encoding within 'datadog_checks' directory. 108 | - Find all occurrences of "early_stopping" within files that also mention "Trainer" to identify where early stopping logic is implemented and potentially needs adjustment for non-default 'val_check_interval'. 109 | """.format(problem_statement, repo_name) 110 | 111 | return prompt 112 | 113 | if __name__ == "__main__": 114 | 115 | # please refer to the format of SWE-bench lite dataset 116 | # test_basic_info should have keys: {"instance_id", "problem_statement", "repo"} 117 | test_basic_info = pd.read_json("test_basic_info.json") 118 | 119 | extractor_prompt_list = [] 120 | inferer_prompt_list = [] 121 | for idx, item in test_basic_info.iterrows(): 122 | # generate prompt for each instance 123 | tmp_extractor_prompt = generate_prompt_for_extractor_v1(item.problem_statement, item.repo) 124 | tmp_inferer_prompt = generate_prompt_for_inferer_v1(item.problem_statement, item.repo) 125 | extractor_prompt_list.append(tmp_extractor_prompt) 126 | inferer_prompt_list.append(tmp_inferer_prompt) 127 | 128 | test_basic_info["extractor_prompt"] = extractor_prompt_list 129 | test_basic_info["inferer_prompt"] = inferer_prompt_list 130 | 131 | test_basic_info.to_json("test_rewriter_prompt.json", index=False) 132 | 133 | -------------------------------------------------------------------------------- /retriever/locate_anchor_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | 基于 rapidfuzz + faiss 进行 anchor node 定位 3 | """ 4 | 5 | from rapidfuzz import process, fuzz 6 | import pandas as pd 7 | import json 8 | import tqdm 9 | import sys 10 | import pickle 11 | import numpy as np 12 | import faiss 13 | 14 | 15 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType 16 | from utils import codegraph_to_nxgraph 17 | 18 | def extract_info(item): 19 | """ 20 | 抽取需要匹配的字符部分 21 | """ 22 | return item[1] 23 | 24 | ################################# Extractor ################################# 25 | def get_extractor_anchor(graph, entity_query, keywords_query): 26 | """ 27 | 获取 关键词匹配结果 28 | """ 29 | 30 | all_nodes = graph.get_nodes() 31 | 32 | cand_name_list = [] 33 | cand_path_name_list = [] 34 | 35 | for node in all_nodes: 36 | node_type = node.get_type() 37 | if node_type in [NodeType.REPO, NodeType.PACKAGE]: 38 | continue 39 | 40 | try: 41 | node.name 42 | except: 43 | continue 44 | 45 | cand_name_list.append((node.node_id, node.name)) 46 | 47 | if node_type == NodeType.FILE: 48 | if node.path: 49 | name_with_path = node.path + "/" + node.name 50 | else: 51 | name_with_path = node.name 52 | cand_path_name_list.append((node.node_id, name_with_path)) 53 | 54 | cand_name_all = [] 55 | cand_path_name_all = [] 56 | 57 | for query in entity_query + keywords_query: 58 | 59 | if "/" in query: 60 | cand_path_name = process.extract((-1, query), cand_path_name_list, scorer=fuzz.WRatio, limit=3, processor=extract_info) 61 | cand_path_name_all.append(cand_path_name) 62 | 63 | query_wo_path = query.split('/')[-1] 64 | cand_name = process.extract((-1, query_wo_path), cand_name_list, scorer=fuzz.WRatio, limit=3, processor=extract_info) 65 | cand_name_all.append(cand_name) 66 | 67 | 68 | res = set() 69 | for query in cand_name_all: 70 | for item in query: 71 | res.add(item[0][0]) 72 | for query in cand_path_name_all: 73 | for item in query: 74 | res.add(item[0][0]) 75 | 76 | return res 77 | 78 | ################################# Extractor ################################# 79 | 80 | ################################# Inferer ################################# 81 | def get_inferer_anchor(query_emb, node_embedding, k=15): 82 | """ 83 | 根据 embedding 进行语义检索 84 | """ 85 | 86 | node2id_dict = {} 87 | id2node_dict = {} 88 | cand_vec = [] 89 | 90 | raw_node_embedding = node_embedding["code"] 91 | for i, node_id in enumerate(raw_node_embedding): 92 | node2id_dict[node_id] = i 93 | id2node_dict[i] = node_id 94 | cand_vec.append(raw_node_embedding[node_id]) 95 | 96 | cand_vec_np = np.array(cand_vec) 97 | 98 | ######### search ######### 99 | d = 1024 100 | nb = len(cand_vec_np) 101 | nq = 5 102 | 103 | index = faiss.IndexFlatL2(d) 104 | index.add(cand_vec_np) 105 | D, I = index.search(cand_vec_np[:5], k) 106 | D, I = index.search(query_emb, k) 107 | 108 | anchor_node = [] 109 | for query in I: 110 | tmp_node_list = [] 111 | for trans_id in query: 112 | tmp_node_list.append(int(id2node_dict[trans_id])) 113 | anchor_node.append(tmp_node_list) 114 | 115 | return anchor_node 116 | 117 | 118 | ################################# Inferer ################################# 119 | 120 | ################################# 辅助函数 ################################# 121 | def get_graph_file_name(item): 122 | """ 123 | return graph_file_name 124 | """ 125 | 126 | raise NotImplementedError 127 | ################################# 辅助函数 ################################# 128 | 129 | 130 | if __name__ == "__main__": 131 | 132 | # 数据变量定义 133 | test_basic_df = pd.read_json("test_basic_info.json") 134 | test_basic_df["graph_file"] = test_basic_df.apply(lambda item: get_graph_file_name(item), axis=1) 135 | 136 | graph_data_path = "codegraph/" # the path of codegraphs 137 | 138 | # 读入 rewriter 提取结果 和 node embedding 139 | rewriter_output_path = "/rewriter_output.json" 140 | query_embedding_path = "/rewriter_embedding.pkl" 141 | node_embedding_path = "/node_embedding/" 142 | with open(rewriter_output_path, "r", encoding="utf-8") as file: 143 | rewriter_output = json.load(file) 144 | file.close() 145 | 146 | with open(query_embedding_path, "rb") as file: 147 | query_embedding = pickle.load(file) 148 | file.close() 149 | 150 | # save path 151 | anchor_node_dict = {} 152 | 153 | for idx, item in tqdm.tqdm(test_basic_df.iterrows()): 154 | 155 | instance_id = item.instance_id 156 | graph_file = item.graph_file 157 | tmp_graph_data_path = graph_data_path + graph_file 158 | query_emb = query_embedding[instance_id] 159 | 160 | # 解析图数据 161 | graph = parse(tmp_graph_data_path) 162 | graph_nx = codegraph_to_nxgraph(graph) 163 | 164 | # 获取 rewriter 输出 165 | entity_query = rewriter_output[instance_id]["code_entity"] 166 | keyword_query = rewriter_output[instance_id]["keyword"] 167 | 168 | # 读入 node_embedding 169 | tmp_node_embedding = node_embedding_path + "{}.pkl".format(instance_id) 170 | with open(tmp_node_embedding, "rb") as file: 171 | tmp_node_embedding = pickle.load(file) 172 | file.close() 173 | 174 | # 定位 anchor nodes 175 | res_extractor = get_extractor_anchor(graph, entity_query, keyword_query) 176 | res_inferer = get_inferer_anchor(query_emb, tmp_node_embedding) 177 | 178 | anchor_node = { 179 | "extractor_anchor_nodes": list(res_extractor), 180 | "inferer_anchor_nodes": list(res_inferer), 181 | } 182 | 183 | anchor_node_dict[instance_id] = anchor_node 184 | 185 | # save 186 | with open("anchor_node.json", 'w', encoding='utf-8') as file: 187 | json.dump(anchor_node_dict, file) 188 | -------------------------------------------------------------------------------- /retriever/subgraph.py: -------------------------------------------------------------------------------- 1 | """ 2 | 启发式搜索逻辑 3 | - 对于 anchor node 进行一跳扩展 4 | - 对于扩展后的结果进行 连通 5 | """ 6 | import sys 7 | import json 8 | import os 9 | import tqdm 10 | import pandas as pd 11 | 12 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType 13 | from utils import codegraph_to_nxgraph 14 | 15 | ################################# 子图重构代码 ################################# 16 | def get_path_to_repo(node, pre_node_dict, graph_nx): 17 | """获取该节点到 repo 的路径 18 | :param node -> CodeGraph Node 采样出的子图节点 19 | :param pre_node_dict -> list(Node) 每个节点 20 | :return 21 | """ 22 | if node.get_type() == NodeType.REPO: 23 | return [node] 24 | 25 | pre_nodes = list() 26 | if node.node_id in pre_node_dict: 27 | pre_nodes = pre_node_dict[node.node_id] 28 | else: 29 | for pre_node in graph_nx.predecessors(node): 30 | # 判断 Edge 类型 - contains 31 | if graph_nx[pre_node][node][0]['type'] == EdgeType.CONTAINS: 32 | pre_nodes.append(pre_node) 33 | if pre_node.get_type() != NodeType.REPO: 34 | pre_nodes.extend(get_path_to_repo(pre_node, pre_node_dict, graph_nx)) 35 | pre_node_dict[node.node_id] = pre_nodes 36 | break 37 | 38 | return pre_nodes 39 | 40 | def reconstruct_graph(subgraph_nodes, graph_nx, pre_node_dict): 41 | """ 42 | 根据所给节点重构 连通 的 CodeGraph 43 | pre_node_dict 全局复用 44 | """ 45 | 46 | nodes = subgraph_nodes 47 | all_nodes = set(nodes) 48 | for node in nodes: 49 | pre_nodes = get_path_to_repo(node, pre_node_dict, graph_nx) 50 | all_nodes |= set(pre_nodes) 51 | 52 | # 根据节点裁剪子图 53 | subgraph = graph_nx.subgraph(list(all_nodes)) 54 | 55 | return subgraph 56 | 57 | ################################# 子图重构代码 ################################# 58 | 59 | ################################# BFS代码 ################################# 60 | def bfs_expand(graph_nx, subgraph_nodes, hops=1): 61 | """ 62 | 通过 bfs 扩展 63 | - 最笼统的版本:不区分方向的1-hop 64 | :param graph_nx nx格式的原图 65 | :param subgraph_nodes 需要扩展的节点 66 | :param hops 需要扩展的跳数 67 | """ 68 | 69 | seed_node = subgraph_nodes 70 | visited_node = set() 71 | # 记录所有被 nhop 覆盖的节点 72 | nhops_neighbors = set([node.node_id for node in seed_node]) 73 | 74 | for hop_idx in range(hops): 75 | tmp_seed_node = [] 76 | for node in seed_node: 77 | if node.node_id in visited_node: 78 | continue 79 | visited_node.add(node.node_id) 80 | suc_nodes = graph_nx.successors(node) 81 | pre_nodes = graph_nx.predecessors(node) 82 | for node in suc_nodes: 83 | tmp_seed_node.append(node) 84 | nhops_neighbors.add(node.node_id) 85 | 86 | for node in pre_nodes: 87 | tmp_seed_node.append(node) 88 | nhops_neighbors.add(node.node_id) 89 | 90 | seed_node = tmp_seed_node 91 | return nhops_neighbors 92 | 93 | def bfs_expand_file(graph_nx, subgraph_nodes, hops=1): 94 | """ 95 | 通过 bfs 扩展 96 | - 限制 File 遍历2跳 97 | :param graph_nx nx格式的原图 98 | :param subgraph_nodes 需要扩展的节点 99 | :param hops 需要扩展的跳数 100 | """ 101 | 102 | seed_node = subgraph_nodes 103 | visited_node = set() 104 | nhops_neighbors = set([node.node_id for node in seed_node]) 105 | 106 | for hop_idx in range(hops): 107 | tmp_seed_node = [] 108 | for node in seed_node: 109 | if node.node_id in visited_node: 110 | continue 111 | visited_node.add(node.node_id) 112 | suc_nodes = graph_nx.successors(node) 113 | pre_nodes = graph_nx.predecessors(node) 114 | for node in suc_nodes: 115 | if node.get_type() == NodeType.FILE: 116 | tmp_seed_node.append(node) 117 | nhops_neighbors.add(node.node_id) 118 | 119 | for node in pre_nodes: 120 | if node.get_type() == NodeType.FILE: 121 | tmp_seed_node.append(node) 122 | nhops_neighbors.add(node.node_id) 123 | 124 | seed_node = tmp_seed_node 125 | return nhops_neighbors 126 | ################################# BFS代码 ################################# 127 | 128 | ################################# 辅助函数 ################################# 129 | def get_graph_file_name(item): 130 | """ 131 | 生成 graph_file_name 132 | """ 133 | repo = item.repo 134 | repo = repo.replace("/", "#", 1) 135 | base_commit = item.base_commit 136 | return repo + "#" + base_commit + ".graph.json" 137 | ################################# 辅助函数 ################################# 138 | 139 | if __name__ == "__main__": 140 | 141 | # 数据变量定义 142 | test_basic_df = pd.read_json("/test_lite_basic_info.json") 143 | test_basic_df["graph_file"] = test_basic_df.apply(lambda item: get_graph_file_name(item), axis=1) 144 | 145 | graph_data_path = "/swe-bench-lite3/" 146 | anchor_node_path = "/anchor_nodes.json" 147 | 148 | with open(anchor_node_path, "r", encoding="utf-8") as file: 149 | anchor_node_dict = json.load(file) 150 | 151 | subgraph_id_dict = {} 152 | 153 | for idx, item in tqdm.tqdm(test_basic_df.iterrows()): 154 | 155 | instance_id = item.instance_id 156 | graph_file = item.graph_file 157 | tmp_graph_data_path = graph_data_path + graph_file 158 | 159 | # 解析图数据 160 | graph = parse(tmp_graph_data_path) 161 | graph_nx = codegraph_to_nxgraph(graph) 162 | 163 | # 获取 anchor_nodes 164 | anchor_nodes_raw = anchor_node_dict[instance_id] 165 | extractor_anchors = anchor_nodes_raw["extractor_anchor_nodes"] 166 | inferer_anchors = [node for node_list in anchor_nodes_raw["inferer_anchor_nodes"] for node in node_list] 167 | anchor_nodes = list(set(extractor_anchors + inferer_anchors)) 168 | 169 | # 先 bfs 再 reconstruct 170 | anchor_nodes = [graph.get_node_by_id(node_id) for node_id in anchor_nodes] 171 | expanded_nodes = bfs_expand_file(graph_nx, anchor_nodes, hops=2) 172 | 173 | expanded_nodes = [graph.get_node_by_id(node_id) for node_id in expanded_nodes] 174 | 175 | pre_node_dict = {} 176 | subgraph = reconstruct_graph(expanded_nodes, graph_nx, pre_node_dict) 177 | 178 | result_nodes = subgraph.nodes() 179 | 180 | # 获取子图的节点id 181 | result_nodes = [node.node_id for node in result_nodes if node.get_type() == NodeType.FILE] 182 | 183 | subgraph_id_dict[instance_id] = list(result_nodes) 184 | 185 | # save 186 | with open("subgraph_nodes.json", 'w', encoding='utf-8') as file: 187 | json.dump(anchor_node_dict, file) -------------------------------------------------------------------------------- /cgm/data/encode.py: -------------------------------------------------------------------------------- 1 | 2 | # Template: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' 3 | # CodeQwen1.5-7B-Chat, Qwen2-72B-Instruct 4 | QwenTokenizerConfig = { 5 | 'seq_length': 8192, 6 | 'SYSTEM': 'system', 7 | 'HUMAN': 'user', 8 | 'BOT': 'assistant', 9 | 'SENTENCE_START_MARKER': '', 10 | 'SENTENCE_END_MARKER': '', 11 | 'SYSTEM_START_MARKER': '<|im_start|>', 12 | 'SYSTEM_START_MARKER_2': '\n', 13 | 'SYSTEM_END_MARKER': '<|im_end|>\n', 14 | 'HUMAN_START_MARKER': '<|im_start|>', 15 | 'HUMAN_START_MARKER_2': '\n', 16 | 'HUMAN_END_MARKER': '<|im_end|>\n', 17 | 'BOT_START_MARKER': '<|im_start|>', 18 | 'BOT_START_MARKER_2': '\n', 19 | 'BOT_END_MARKER': '<|im_end|>\n', 20 | } 21 | 22 | # Template: <|begin▁of▁sentence|>{system_message}<|User|>{user_message_1}<|Assistant|>{assistant_message_1}<|end▁of▁sentence|><|User|>{user_message_2}<|Assistant|> 23 | # DeepSeek-V2.5 24 | DeepSeekTokenizerConfig = { 25 | 'seq_length': 8192, 26 | 'SYSTEM': 'system', 27 | 'HUMAN': 'user', 28 | 'BOT': 'assistant', 29 | 'SENTENCE_START_MARKER': '<|begin▁of▁sentence|>', 30 | 'SENTENCE_END_MARKER': '<|end▁of▁sentence|>', 31 | 'SYSTEM_START_MARKER': '', 32 | 'SYSTEM_START_MARKER_2': '', 33 | 'SYSTEM_END_MARKER': '', 34 | 'HUMAN_START_MARKER': '<|User|>', 35 | 'HUMAN_START_MARKER_2': '', 36 | 'HUMAN_END_MARKER': '', 37 | 'BOT_START_MARKER': '<|Assistant|>', 38 | 'BOT_START_MARKER_2': '', 39 | 'BOT_END_MARKER': '', 40 | } 41 | 42 | DeepSeekCoderTokenizerConfig = { 43 | 'seq_length': 8192, 44 | 'SYSTEM': 'system', 45 | 'HUMAN': 'user', 46 | 'BOT': 'assistant', 47 | 'SENTENCE_START_MARKER': '<|begin▁of▁sentence|>', 48 | 'SENTENCE_END_MARKER': '<|end▁of▁sentence|>', 49 | 'SYSTEM_START_MARKER': '', 50 | 'SYSTEM_START_MARKER_2': '', 51 | 'SYSTEM_END_MARKER': '\n\n', 52 | 'HUMAN_START_MARKER': 'User: ', 53 | 'HUMAN_START_MARKER_2': '', 54 | 'HUMAN_END_MARKER': '\n\n', 55 | 'BOT_START_MARKER': 'Assistant: ', 56 | 'BOT_START_MARKER_2': '', 57 | 'BOT_END_MARKER': '', 58 | } 59 | 60 | def format_eol(text): 61 | if not text.endswith("\n"): 62 | text += "\n" 63 | return text 64 | 65 | def get_template(data): 66 | template = [ 67 | {'role': 'system', 'content': ''}, 68 | {'role': 'user', 'content': data['prompt']}, 69 | {'role': 'assistant', 'content': data['answer']} 70 | ] 71 | 72 | def get_config(name): 73 | if name == 'Qwen': 74 | return QwenTokenizerConfig 75 | elif name == 'DeepSeek': 76 | return DeepSeekTokenizerConfig 77 | elif name == 'DeepSeek-Coder': 78 | return DeepSeekCoderTokenizerConfig 79 | else: 80 | raise NotImplementedError 81 | 82 | class BaseEncoder(object): 83 | def __init__(self, tokenizer, config_name): 84 | # self.args = args 85 | # seq_length - 1 for shifting 86 | # self.seq_length = args.seq_length - 1 87 | config = get_config(config_name) 88 | self.tokenizer = tokenizer 89 | # self.seq_length = tokenizer.model_max_length 90 | self.seq_length = config.get('seq_length') 91 | 92 | # TODO: default Qwen 93 | self.SYSTEM = config.get('SYSTEM') 94 | self.HUMAN = config.get('HUMAN') 95 | self.BOT = config.get('BOT') 96 | 97 | self.SENTENCE_START_MARKER = config.get('SENTENCE_START_MARKER') 98 | self.SENTENCE_END_MARKER = config.get('SENTENCE_END_MARKER'), 99 | 100 | self.SYSTEM_START_MARKER = config.get('SYSTEM_START_MARKER') 101 | self.SYSTEM_START_MARKER_2 = config.get('SYSTEM_START_MARKER_2') 102 | self.SYSTEM_END_MARKER = config.get('SYSTEM_END_MARKER') 103 | 104 | self.HUMAN_START_MARKER = config.get('HUMAN_START_MARKER') 105 | self.HUMAN_START_MARKER_2 = config.get('HUMAN_START_MARKER_2') 106 | self.HUMAN_END_MARKER = config.get('HUMAN_END_MARKER') 107 | 108 | self.BOT_START_MARKER = config.get('BOT_START_MARKER') 109 | self.BOT_START_MARKER_2 = config.get('BOT_START_MARKER_2') 110 | self.BOT_END_MARKER = config.get('BOT_END_MARKER') 111 | 112 | self.sentence_start_ids = self.tokenizer.encode(f"{self.SENTENCE_START_MARKER}", add_special_tokens=False) if self.SENTENCE_START_MARKER != '' else [] 113 | self.sentence_end_ids = self.tokenizer.encode(f"{self.SENTENCE_END_MARKER}", add_special_tokens=False) if self.SENTENCE_END_MARKER != '' else [] 114 | 115 | self.system_start_ids = self.tokenizer.encode(f"{self.SYSTEM_START_MARKER}{self.SYSTEM}{self.SYSTEM_START_MARKER_2}", add_special_tokens=False) 116 | self.system_end_ids = self.tokenizer.encode(f"{self.SYSTEM_END_MARKER}", add_special_tokens=False) if self.SYSTEM_END_MARKER != '' else [] 117 | 118 | self.human_start_ids = self.tokenizer.encode(f"{self.HUMAN_START_MARKER}{self.HUMAN}{self.HUMAN_START_MARKER_2}", add_special_tokens=False) 119 | self.human_end_ids = self.tokenizer.encode(f"{self.HUMAN_END_MARKER}", add_special_tokens=False) if self.HUMAN_END_MARKER != '' else [] 120 | 121 | self.bot_start_ids = self.tokenizer.encode(f"{self.BOT_START_MARKER}{self.BOT}{self.BOT_START_MARKER_2}", add_special_tokens=False) 122 | self.bot_end_ids = self.tokenizer.encode(f"{self.BOT_END_MARKER}", add_special_tokens=False) if self.BOT_END_MARKER != '' else [] 123 | 124 | self.end_ids = [self.tokenizer.eos_token_id] 125 | 126 | def padding(self, input_ids, loss_mask, qa_mask): 127 | pad_id = self.tokenizer.pad_token_id 128 | 129 | assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}" 130 | input_ids += [pad_id] * (self.seq_length - len(input_ids)) 131 | loss_mask += [0] * (self.seq_length - len(loss_mask)) 132 | qa_mask += [0] * (self.seq_length - len(loss_mask)) 133 | return { 134 | "input_ids": input_ids, 135 | "loss_mask": loss_mask, 136 | "qa_mask": qa_mask 137 | } 138 | 139 | class CGMEncoder(BaseEncoder): 140 | def __init__(self, tokenizer, config_name): 141 | super().__init__(tokenizer, config_name) 142 | 143 | def dataToInput(self, data, seg_role = None): 144 | input_ids, loss_mask, qa_mask = [], [], [] 145 | # TODO: expand 146 | message = [ 147 | {'role': self.HUMAN, 'content': 'prompt', 'marker': True, 'loss': 0}, 148 | {'role': self.BOT, 'content': 'answer', 'marker': True, 'loss': 1} 149 | ] 150 | if seg_role is not None: 151 | message = [item for item in message if item['role'] == seg_role] 152 | 153 | input_ids += self.sentence_start_ids 154 | loss_mask += [0] * len(self.sentence_start_ids) 155 | qa_mask += [0] * len(self.sentence_start_ids) 156 | 157 | for segment in message: 158 | role = segment['role'] 159 | content = segment['content'] 160 | marker = segment['marker'] 161 | loss = segment['loss'] 162 | 163 | if role == self.SYSTEM: 164 | system_ids = self.tokenizer.encode(str(data[content]), add_special_tokens=False) 165 | if marker: 166 | input_ids += self.system_start_ids + system_ids + self.system_end_ids 167 | loss_mask += [0] * len(self.system_start_ids) + [loss] * len(system_ids) + [0] * len(self.system_end_ids) 168 | qa_mask += [0] * len(self.system_start_ids) + [1] * len(system_ids) + [0] * len(self.system_end_ids) 169 | else: 170 | input_ids += system_ids 171 | loss_mask += [loss] * len(system_ids) 172 | qa_mask += [loss] * len(system_ids) 173 | 174 | elif role == self.HUMAN: 175 | 176 | human_ids = self.tokenizer.encode(str(data[content]), add_special_tokens=False) 177 | if marker: 178 | input_ids += self.human_start_ids + human_ids + self.human_end_ids 179 | loss_mask += [0] * len(self.human_start_ids) + [loss] * len(human_ids) + [0] * len(self.human_end_ids) 180 | qa_mask += [0] * len(self.human_start_ids) + [1] * len(human_ids) + [0] * len(self.human_end_ids) 181 | else: 182 | input_ids += human_ids 183 | loss_mask += [loss] * len(human_ids) 184 | qa_mask += [1] * len(human_ids) 185 | 186 | elif role == self.BOT: 187 | bot_ids = self.tokenizer.encode(str(data[content]), add_special_tokens=False) 188 | if marker: 189 | input_ids += self.bot_start_ids + bot_ids + self.bot_end_ids 190 | loss_mask += [0] * len(self.bot_start_ids) + [loss] * len(bot_ids) + [0] * len(self.bot_end_ids) 191 | qa_mask += [0] * len(self.bot_start_ids) + [0] * len(bot_ids) + [0] * len(self.bot_end_ids) 192 | else: 193 | input_ids += bot_ids 194 | loss_mask += [loss] * len(bot_ids) 195 | qa_mask += [0] * len(bot_ids) 196 | 197 | else: 198 | raise ValueError(f"wrong {role} for {config_name}") 199 | 200 | input_ids += self.sentence_end_ids 201 | loss_mask += [1] * len(self.sentence_end_ids) 202 | qa_mask += [0] * len(self.sentence_end_ids) 203 | 204 | assert len(input_ids) == len(loss_mask) 205 | 206 | if len(input_ids) <= self.seq_length: 207 | # features = self.padding(input_ids, loss_mask, qa_mask) 208 | features = {} 209 | features['input_ids'] = input_ids 210 | features['loss_mask'] = loss_mask 211 | features['qa_mask'] = qa_mask 212 | else: 213 | features = {} 214 | features['input_ids'] = input_ids[:self.seq_length - 1] 215 | features['loss_mask'] = loss_mask[:self.seq_length - 1] 216 | features['qa_mask'] = qa_mask[:self.seq_length - 1] 217 | 218 | features['input_ids'] += self.sentence_end_ids 219 | features['loss_mask'] += [1] * len(self.sentence_end_ids) 220 | features['qa_mask'] += [0] * len(self.sentence_end_ids) 221 | 222 | assert len(features['input_ids']) == len(features['loss_mask']) 223 | 224 | return features 225 | 226 | -------------------------------------------------------------------------------- /cgm/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.getcwd()) 5 | 6 | import time, os, json, math, logging 7 | import numpy as np 8 | import torch 9 | import random 10 | from torch.utils.data import DataLoader, random_split, Subset 11 | import random 12 | # from deepspeed.ops.adam import FusedAdam as AdamW 13 | from torch.optim import AdamW 14 | from transformers import AutoModel, AutoTokenizer 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from transformers import ( 18 | set_seed, 19 | get_scheduler, 20 | ) 21 | from utils.arguments import prepare_args 22 | 23 | from modeling.cgm import CGM 24 | 25 | from data.encode import CGMEncoder 26 | 27 | from utils.common_utils import print_args, print_with_rank, print_rank_0 28 | from utils.train_utils import accelerate_train_CGM 29 | 30 | from datasets import load_dataset 31 | import datetime 32 | from peft import ( 33 | LoraConfig, 34 | TaskType, 35 | get_peft_model, 36 | prepare_model_for_kbit_training, 37 | PeftModel, 38 | ) 39 | 40 | from torch.optim.lr_scheduler import ReduceLROnPlateau 41 | 42 | 43 | # from load_hetero_dataset import load_dataset, perpare_dataloader 44 | 45 | def str_to_tuple(s): 46 | st = s.strip('()') 47 | return tuple(item.strip().strip("'") for item in st.split(',')) 48 | 49 | def getRawGraph(filename, suffix="json"): 50 | if os.path.exists(filename): 51 | if suffix == 'json': 52 | with open(filename) as f: 53 | example_graph = json.load(f) 54 | f.close() 55 | elif suffix == 'pt': 56 | with open(filename, 'rb') as f: 57 | example_graph = torch.load(f) 58 | # example_graph = torch.load(filename) 59 | f.close() 60 | return example_graph 61 | return None 62 | 63 | task_ids = { 64 | (0, 'graph_query'), 65 | (1, 'api'), 66 | (2, 'issue_fix'), 67 | (3, 'unit_test'), 68 | (4, 'readme_summary'), 69 | } 70 | 71 | task_to_id = {task: idx for idx, task in task_ids} 72 | 73 | def collate_cgm(graph_dir, encoder, qa_type='mft', seq_l=8192, use_chat=True): 74 | def collate(batches): 75 | result_batches = [] 76 | for batch in batches: 77 | result_batch = {} 78 | graph = getRawGraph(batch['repo'], suffix='json') 79 | 80 | if graph is not None: 81 | graph['reponame'] = batch['repo'].split('/')[-1].split('.')[0] 82 | graph['language'] = batch['language'] 83 | result_batch['graph'] = graph 84 | if use_chat: 85 | features = encoder.dataToInput(batch) 86 | input_ids = features['input_ids'] 87 | loss_mask = features['loss_mask'] 88 | qa_mask = features['qa_mask'] 89 | else: 90 | query_ids = encoder.tokenizer.encode(batch['prompt'], add_special_tokens=False) 91 | answer_ids = encoder.tokenizer.encode(batch['answer'], add_special_tokens=False) + [ 92 | encoder.tokenizer.eos_token_id] 93 | qa_mask = [1] * len(query_ids) + [0] * len(answer_ids) 94 | loss_mask = [0] * len(query_ids) + [1] * len(answer_ids) 95 | input_ids = query_ids + answer_ids 96 | 97 | min_seq = min(seq_l, len(input_ids)) 98 | result_batch['x'] = torch.tensor(input_ids, dtype=torch.int64)[:min_seq - 1].contiguous() 99 | result_batch['qa_mask'] = torch.tensor(qa_mask, dtype=torch.bool)[:min_seq - 1].contiguous() 100 | result_batch['y'] = torch.tensor(input_ids, dtype=torch.int64)[1:min_seq].contiguous() 101 | result_batch['loss_mask'] = torch.tensor(loss_mask, dtype=torch.bool)[1:min_seq].contiguous() 102 | 103 | if qa_type == 'mft': 104 | result_batch['task'] = task_to_id[batch['task']] 105 | else: 106 | raise ValueError(f"graph none for {batch['repo']}") 107 | 108 | result_batches.append(result_batch) 109 | 110 | final_result_batch = {} 111 | for key in result_batches[0].keys(): 112 | if key == 'task': 113 | final_result_batch[key] = torch.tensor([rb[key] for rb in result_batches]) 114 | elif key == 'graph': 115 | final_result_batch[key] = [rb[key] for rb in result_batches] 116 | else: 117 | final_result_batch[key] = torch.stack([rb[key] for rb in result_batches]) 118 | return final_result_batch 119 | 120 | return collate 121 | 122 | 123 | def train(args): 124 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) 125 | 126 | print_args(args, accelerator) 127 | 128 | # prepare logger 129 | logger = get_logger(__name__) 130 | logging.basicConfig( 131 | format="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", 132 | datefmt="%Y-%m-%d %H:%M:%S", 133 | level=logging.INFO, 134 | ) 135 | logger.info(accelerator.state, main_process_only=True) 136 | 137 | train_files = args.train_files 138 | valid_files = args.valid_files 139 | 140 | graph_dir = args.graph_dir 141 | 142 | dataset = load_dataset('json', data_files={'train': train_files, 'valid': valid_files}) 143 | 144 | train_dataset = dataset['train'] 145 | valid_dataset = dataset['valid'] 146 | 147 | epoch_train = len(train_dataset) 148 | 149 | if args.peft: 150 | save_suffix = args.framework_type + '_' + str( 151 | args.pretrained_model_path.split('/')[-1]) + '_' + args.task + '_M' + args.mode + '_LR' + str( 152 | args.learning_rate) + '_GA' + str(args.gradient_accumulation_steps) + '_' + str(args.peft) + '_r' + str( 153 | args.lora_rank) + '_alpha' + str(args.lora_alpha) + '_d' + str(args.lora_dropout) + '_m' + str( 154 | args.lora_modules) + str(datetime.datetime.now().strftime('%Y%m%d%H')) + '/' 155 | else: 156 | save_suffix = args.framework_type + '_' + str( 157 | args.pretrained_model_path.split('/')[-1]) + '_' + args.task + '_M' + args.mode + '_LR' + str( 158 | args.learning_rate) + '_GA' + str(args.gradient_accumulation_steps) + '_' + str( 159 | datetime.datetime.now().strftime('%Y%m%d%H')) + '/' 160 | 161 | args.output_dir = args.output_dir + save_suffix 162 | args.tb_dir = args.tb_dir + save_suffix 163 | if 'l' not in args.mode and args.peft: 164 | args.mode = args.mode + 'l' 165 | if args.peft == "QLoRA": 166 | if not args.quantization: 167 | args.quantization = "4bit" 168 | 169 | if accelerator.is_main_process: 170 | if not os.path.exists(args.output_dir): 171 | os.makedirs(args.output_dir) 172 | if not os.path.exists(args.tb_dir): 173 | os.makedirs(args.tb_dir) 174 | 175 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_tokenizer_path, trust_remote_code=False) 176 | 177 | encoder = CGMEncoder(tokenizer=tokenizer, config_name=args.model_type) 178 | 179 | collate_fn = collate_cgm( 180 | graph_dir, 181 | encoder, 182 | qa_type=args.task, 183 | seq_l=8192, 184 | use_chat=args.use_chat, 185 | ) 186 | 187 | train_unit_batch_size = accelerator.num_processes * args.gradient_accumulation_steps * args.per_device_train_batch_size 188 | total_train_samples = len(train_dataset) 189 | if args.max_train_samples: 190 | max_train_samples = args.max_train_samples 191 | if total_train_samples > max_train_samples: 192 | total_train_samples = max_train_samples 193 | 194 | max_divisible_samples = (total_train_samples // train_unit_batch_size) * train_unit_batch_size 195 | subset_indices = list(range(max_divisible_samples)) 196 | train_subset = Subset(train_dataset, subset_indices) 197 | train_dataloader = DataLoader(train_subset, batch_size=args.per_device_train_batch_size, collate_fn=collate_fn, 198 | shuffle=True) 199 | 200 | if args.max_valid_samples: 201 | max_valid_samples = args.max_valid_samples 202 | valid_unit_batch_size = accelerator.num_processes * args.per_device_eval_batch_size 203 | total_valid_samples = len(valid_dataset) 204 | if total_valid_samples > max_valid_samples: 205 | indices = list(range(max_valid_samples)) 206 | random.shuffle(indices) 207 | subset_indices = indices[:max_valid_samples] 208 | valid_subset = Subset(valid_dataset, subset_indices) 209 | else: 210 | max_divisible_samples = (total_valid_samples // valid_unit_batch_size) * valid_unit_batch_size 211 | subset_indices = list(range(max_divisible_samples)) 212 | valid_subset = Subset(valid_dataset, subset_indices) 213 | valid_dataloader = DataLoader(valid_subset, batch_size=args.per_device_eval_batch_size, collate_fn=collate_fn, 214 | shuffle=True) 215 | else: 216 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=collate_fn, 217 | shuffle=True) 218 | 219 | logger.info(f"Train Samples: {len(train_dataloader)}", main_process_only=True) 220 | logger.info(f"Valid Samples: {len(valid_dataloader)}", main_process_only=True) 221 | 222 | model = CGM(args) 223 | # Please disable checkpointing and re-enable use-cache for inference 224 | model.lm.gradient_checkpointing_enable() 225 | model.lm.config.use_cache = False 226 | 227 | if args.peft == "QLoRA": 228 | model.lm = prepare_model_for_kbit_training(model.lm) # use_gradient_checkpointing default is True 229 | else: 230 | model.lm.gradient_checkpointing_enable() 231 | 232 | if args.peft: 233 | peft_config = LoraConfig( 234 | task_type=TaskType.CAUSAL_LM, 235 | inference_mode=False, 236 | r=args.lora_rank, 237 | lora_alpha=args.lora_alpha, 238 | lora_dropout=args.lora_dropout, 239 | target_modules=args.lora_modules, 240 | bias="lora_only", 241 | ) 242 | model.lm = get_peft_model(model.lm, peft_config) 243 | 244 | encoder_peft_config = LoraConfig( 245 | task_type=TaskType.CAUSAL_LM, 246 | inference_mode=False, 247 | r=args.enc_lora_rank, 248 | lora_alpha=args.enc_lora_alpha, 249 | lora_dropout=args.enc_lora_dropout, 250 | target_modules=args.enc_lora_modules, 251 | bias="lora_only", 252 | ) 253 | model.encoder = get_peft_model(model.encoder, encoder_peft_config) 254 | 255 | if args.adapter_warmup: 256 | if 'l' in args.mode: 257 | for param in model.lm.parameters(): 258 | param.requires_grad = False 259 | 260 | encoder_params = list(model.encoder.parameters()) if 'e' in args.mode else [] 261 | pma_params = list(model.pma.parameters()) if 'p' in args.mode else [] 262 | adapter_params = list(model.adapter.parameters()) if 'a' in args.mode else [] 263 | lm_params = list(model.lm.parameters()) if 'l' in args.mode else [] 264 | 265 | trained_params = encoder_params + pma_params + adapter_params + lm_params 266 | # trained_params = adapter_params + lm_params 267 | if not trained_params: 268 | raise ValueError("No parameters to train. Please check the mode argument.") 269 | 270 | optimizer = AdamW( 271 | trained_params, 272 | weight_decay=args.weight_decay, 273 | lr=args.learning_rate, 274 | betas=(0.9, 0.95), 275 | ) 276 | overrode_max_train_steps = False 277 | num_update_steps_per_epoch = math.ceil(epoch_train / args.gradient_accumulation_steps) 278 | if args.max_train_steps is None: 279 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 280 | overrode_max_train_steps = True 281 | 282 | if args.lr_scheduler_type == "reduce_lr_on_plateau": 283 | lr_scheduler = ReduceLROnPlateau( 284 | optimizer, 285 | mode='min', 286 | factor=0.75, 287 | patience=3, 288 | threshold=0.0001, 289 | threshold_mode='rel', 290 | cooldown=0, 291 | min_lr=args.min_lr, 292 | eps=1e-08, 293 | ) 294 | else: 295 | lr_scheduler = get_scheduler( 296 | name=args.lr_scheduler_type, 297 | optimizer=optimizer, 298 | num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, 299 | num_training_steps=args.max_train_steps, 300 | ) 301 | 302 | logger.info( 303 | f"{'==' * 100}\nbefore accelerator preparation: [dataloader: {epoch_train}][epochs: {args.num_train_epochs}][total steps: {args.max_train_steps}]\n{'==' * 100}") 304 | if torch.cuda.is_available(): 305 | model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare( 306 | model, train_dataloader, valid_dataloader, optimizer, lr_scheduler 307 | ) 308 | 309 | epoch_train = epoch_train / accelerator.num_processes 310 | num_update_steps_per_epoch = math.ceil(epoch_train / args.gradient_accumulation_steps) 311 | if overrode_max_train_steps: 312 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 313 | 314 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 315 | logger.info( 316 | f"{'==' * 100}\nafter accelerator preparation: [dataloader: {epoch_train}][epochs: {args.num_train_epochs}][total steps: {args.max_train_steps}]\n{'==' * 100}") 317 | 318 | logger.info(f"{'==' * 100}Training...") 319 | 320 | accelerate_train_CGM(accelerator, 321 | model, 322 | train_dataloader, 323 | valid_dataloader, 324 | optimizer, 325 | lr_scheduler, 326 | tokenizer, 327 | epoch_train, 328 | args) 329 | 330 | if __name__ == "__main__": 331 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 332 | args = prepare_args() 333 | set_seed(args.seed) 334 | train(args) -------------------------------------------------------------------------------- /cgm/modeling/cgm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from data.preprocess import getJavaSentence, getPythonSentence, getSentence 5 | from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer, BitsAndBytesConfig 6 | from utils.common_utils import count_parameters, print_rank_0 7 | import torch.nn.functional as F 8 | 9 | from models.qwen2._4_46_1.modeling_qwen2 import Qwen2ForCausalLM 10 | from models.qwen2._4_46_1.modeling_attn_mask_utils import AttentionMaskConverter 11 | 12 | def graph2embedding(self, data, model, tokenizer, reponame, language, save_adj, peft, return_type=None): 13 | node_embeddings = {} 14 | node_id_to_index = {} 15 | index_counter = 0 16 | 17 | device = model.device 18 | 19 | for node in data['nodes']: 20 | nodeType = node['nodeType'] 21 | 22 | if 'nodeId' in node.keys(): 23 | node_id = node['nodeId'] 24 | elif 'id' in node.keys(): 25 | node_id = node['id'] 26 | else: 27 | raise ValueError("No key named id/nodeId") 28 | 29 | sentence = getSentence(node, nodeType, reponame, 1024000) 30 | 31 | if sentence == "": 32 | node_embedding = torch.zeros((1, self.args.embedding_dim), dtype=torch.float32).to(device) 33 | node_embeddings[node_id] = [node_embedding] 34 | # sentence_dict[index_counter] = "" 35 | node_id_to_index[node_id] = [index_counter] 36 | index_counter += 1 37 | else: 38 | # 手动切词 39 | tokens = tokenizer.tokenize(sentence) 40 | num_tokens = len(tokens) 41 | num_segments = (num_tokens + 511) // 512 # Calculate number of segments 42 | embeddings = [] 43 | # segments = [] 44 | node_id_to_index[node_id] = list(range(index_counter, index_counter + num_segments)) 45 | for i in range(num_segments): 46 | start = i * 512 47 | end = min((i + 1) * 512, num_tokens) 48 | segment_tokens = tokens[start:end] 49 | segment_ids = torch.tensor(tokenizer.convert_tokens_to_ids(segment_tokens), device=device).unsqueeze(0) 50 | 51 | if peft: 52 | # return_type: ALL_256, ALL_768 53 | segment_embedding = model.model(segment_ids, return_type=return_type) 54 | else: 55 | segment_embedding = model(segment_ids) 56 | embeddings.append(segment_embedding) 57 | index_counter += 1 58 | 59 | node_embeddings[node_id] = embeddings 60 | 61 | num_nodes = index_counter 62 | 63 | # TODO: add sparse adj 64 | if save_adj: 65 | adj_matrix = torch.zeros((num_nodes, num_nodes)).to(device) 66 | 67 | for edge in data['edges']: 68 | source_id = edge['source'] 69 | target_id = edge['target'] 70 | source_indices = node_id_to_index.get(source_id) 71 | target_indices = node_id_to_index.get(target_id) 72 | if source_indices is None or target_indices is None: 73 | # if source_indices is None: 74 | # print(f"{source_id} not exists") 75 | # if target_indices is None: 76 | # print(f"{target_id} not exists") 77 | continue 78 | 79 | for source_index in source_indices: 80 | for target_index in target_indices: 81 | adj_matrix[source_index, target_index] = 1 82 | 83 | # Connect embeddings of the same node 84 | for node_id, indices in node_id_to_index.items(): 85 | for i in range(len(indices)): 86 | for j in range(i + 1, len(indices)): 87 | adj_matrix[indices[i], indices[j]] = 1 88 | adj_matrix[indices[j], indices[i]] = 1 89 | else: 90 | adj_matrix = None 91 | 92 | all_embeddings = [] 93 | for value in node_embeddings.values(): 94 | if isinstance(value, torch.Tensor): 95 | all_embeddings.append(value) 96 | elif isinstance(value, list): 97 | for tensor in value: 98 | all_embeddings.append(tensor) 99 | 100 | embeddings = torch.stack(all_embeddings, dim=0).squeeze(1) 101 | 102 | # embeddings = torch.stack(list(node_embeddings.values())) 103 | # embeddings = torch.stack(sum(node_embeddings.values(), [])) 104 | # embeddings = torch.cat(list(node_embeddings.values()), dim=0) 105 | 106 | return embeddings, adj_matrix # sentence_dict 107 | 108 | class adapter(nn.Module): 109 | def __init__(self, args): 110 | super(adapter, self).__init__() 111 | self.fc1 = nn.Linear(args.embedding_dim, args.adapter_hidden_dim) 112 | self.gelu = nn.GELU() 113 | self.fc2 = nn.Linear(args.adapter_hidden_dim, args.lm_hidden_dim) 114 | 115 | def forward(self, x): 116 | return self.fc2(self.gelu(self.fc1(x))) 117 | 118 | class CGM(nn.Module): 119 | def __init__(self, args): 120 | super(CGM, self).__init__() 121 | # text encoder 122 | self.encoder_tokenizer = AutoTokenizer.from_pretrained(args.pretrained_encoder_path, trust_remote_code=True) 123 | self.encoder = AutoModel.from_pretrained( 124 | args.pretrained_encoder_path, 125 | torch_dtype="auto", 126 | trust_remote_code=True 127 | ) 128 | 129 | if args.self_defined: 130 | if args.quantization == "8bit": 131 | self.lm = Qwen2ForCausalLM.from_pretrained( 132 | args.pretrained_model_path, 133 | attn_implementation=args.attn_implementation, 134 | torch_dtype="auto", 135 | trust_remote_code=False, 136 | quantization_config=( 137 | BitsAndBytesConfig( 138 | load_in_8bit=(args.quantization == "8bit"), 139 | bnb_8bit_compute_dtype=torch.float8, 140 | bnb_8bit_use_double_quant=True, 141 | bnb_8bit_quant_type="fp8", 142 | bnb_8bit_quant_storage=torch.float8, 143 | ) 144 | if args.quantization == "8bit" 145 | else None 146 | ), 147 | ) 148 | elif args.quantization == "4bit": 149 | self.lm = Qwen2ForCausalLM.from_pretrained( 150 | args.pretrained_model_path, 151 | attn_implementation=args.attn_implementation, 152 | torch_dtype="auto", 153 | trust_remote_code=False, 154 | quantization_config=( 155 | BitsAndBytesConfig( 156 | load_in_4bit=(args.quantization == "4bit"), 157 | bnb_4bit_compute_dtype=torch.bfloat16, 158 | bnb_4bit_use_double_quant=True, 159 | bnb_4bit_quant_type="nf4", 160 | bnb_4bit_quant_storage=torch.bfloat16, 161 | ) 162 | if args.quantization == "4bit" 163 | else None 164 | ), 165 | ) 166 | elif not args.quantization: 167 | self.lm = Qwen2ForCausalLM.from_pretrained( 168 | args.pretrained_model_path, 169 | attn_implementation=args.attn_implementation, 170 | torch_dtype="auto", 171 | trust_remote_code=False, 172 | ) 173 | else: 174 | raise NotImplementedError(f"unrecognized args.qunatization: {args.quantization}") 175 | else: 176 | self.lm = AutoModelForCausalLM.from_pretrained( 177 | args.pretrained_model_path, 178 | attn_implementation=args.attn_implementation, 179 | torch_dtype="auto", 180 | trust_remote_code=True, 181 | quantization_config=( 182 | BitsAndBytesConfig( 183 | load_in_4bit=(args.quantization == "4bit"), 184 | bnb_4bit_compute_dtype=torch.bfloat16, 185 | bnb_4bit_use_double_quant=True, 186 | bnb_4bit_quant_type="nf4", 187 | bnb_4bit_quant_storage=torch.bfloat16, 188 | ) 189 | if args.quantization == "4bit" 190 | else None 191 | ), 192 | ) 193 | 194 | args.lm_hidden_dim = self.lm.config.hidden_size 195 | self.args = args 196 | self.adapter = adapter(args) 197 | if args.load_pretrained_adapter: 198 | self.adapter.load_state_dict(torch.load(args.pretrained_adapter_path)) 199 | print_rank_0(f"Adapter loaded from {args.pretrained_adapter_path}") 200 | else: 201 | print_rank_0("Adapter initialized") 202 | print_rank_0(f"Parameters of Encoder: {count_parameters(self.encoder) / 1e6:.1f}M") 203 | print_rank_0(f"Parameters of Adapter: {count_parameters(self.adapter) / 1e6:.1f}M") 204 | print_rank_0(f"Parameters of LLM: {count_parameters(self.lm) / 1e9:.2f}B") 205 | 206 | def graph2embedding(self, data, reponame, return_type=None): 207 | node_embeddings = {} 208 | node_id_to_index = {} 209 | index_counter = 0 210 | 211 | model = self.encoder 212 | tokenizer = self.encoder_tokenizer 213 | save_adj = self.args.use_adj, 214 | peft = self.args.peft 215 | 216 | device = model.device 217 | 218 | for node in data['nodes']: 219 | nodeType = node['nodeType'] 220 | 221 | if 'nodeId' in node.keys(): 222 | node_id = node['nodeId'] 223 | elif 'id' in node.keys(): 224 | node_id = node['id'] 225 | else: 226 | raise ValueError("No key named id/nodeId") 227 | 228 | sentence = getSentence(node, nodeType, reponame, 1024000) 229 | 230 | if sentence == "": 231 | node_embedding = torch.zeros((1, self.args.embedding_dim), dtype=torch.float32).to(device) 232 | node_embeddings[node_id] = [node_embedding] 233 | node_id_to_index[node_id] = [index_counter] 234 | index_counter += 1 235 | else: 236 | tokens = tokenizer.tokenize(sentence) 237 | num_tokens = len(tokens) 238 | num_segments = (num_tokens + 511) // 512 # Calculate number of segments 239 | embeddings = [] 240 | node_id_to_index[node_id] = list(range(index_counter, index_counter + num_segments)) 241 | for i in range(num_segments): 242 | start = i * 512 243 | end = min((i + 1) * 512, num_tokens) 244 | segment_tokens = tokens[start:end] 245 | segment_ids = torch.tensor(tokenizer.convert_tokens_to_ids(segment_tokens), 246 | device=device).unsqueeze(0) 247 | 248 | if peft: 249 | # return_type: ALL_256, ALL_768 250 | segment_embedding = model.model(segment_ids, return_type=return_type) 251 | else: 252 | segment_embedding = model(segment_ids) 253 | embeddings.append(segment_embedding) 254 | index_counter += 1 255 | 256 | node_embeddings[node_id] = embeddings 257 | 258 | num_nodes = index_counter 259 | 260 | # TODO: add sparse adj 261 | if save_adj: 262 | adj_matrix = torch.zeros((num_nodes, num_nodes)).to(device) 263 | 264 | for edge in data['edges']: 265 | source_id = edge['source'] 266 | target_id = edge['target'] 267 | source_indices = node_id_to_index.get(source_id) 268 | target_indices = node_id_to_index.get(target_id) 269 | if source_indices is None or target_indices is None: 270 | continue 271 | 272 | for source_index in source_indices: 273 | for target_index in target_indices: 274 | adj_matrix[source_index, target_index] = 1 275 | 276 | # Connect embeddings of the same node 277 | for node_id, indices in node_id_to_index.items(): 278 | for i in range(len(indices)): 279 | for j in range(i + 1, len(indices)): 280 | adj_matrix[indices[i], indices[j]] = 1 281 | adj_matrix[indices[j], indices[i]] = 1 282 | else: 283 | adj_matrix = None 284 | 285 | all_embeddings = [] 286 | for value in node_embeddings.values(): 287 | if isinstance(value, torch.Tensor): 288 | all_embeddings.append(value) 289 | elif isinstance(value, list): 290 | for tensor in value: 291 | all_embeddings.append(tensor) 292 | 293 | embeddings = torch.stack(all_embeddings, dim=0).squeeze(1) 294 | 295 | return embeddings, adj_matrix # sentence_dict 296 | 297 | def forward(self, graph, qa_ids, qa_mask): 298 | graph_embeddings, adj_matrix = graph2embedding( 299 | data=graph, 300 | reponame=graph['reponame'], 301 | return_type="ALL_256", 302 | ) 303 | 304 | embeddings = self.adapter(graph_embeddings) 305 | 306 | if self.args.peft: 307 | inputs_embeds = self.lm.model.model.embed_tokens(qa_ids) 308 | else: 309 | inputs_embeds = self.lm.model.embed_tokens(qa_ids) 310 | 311 | input_embeddings = torch.cat((embeddings, inputs_embeds), dim=-2) 312 | input_embeddings = input_embeddings.unsqueeze(0) 313 | 314 | if adj_matrix is not None and self.args.use_adj: 315 | 316 | if len(adj_matrix.shape) == 2: 317 | adj_matrix = adj_matrix.unsqueeze(0) 318 | batch_size, seq_len_x, _ = adj_matrix.shape 319 | 320 | seq_len_q = inputs_embeds.size(-2) 321 | 322 | qa_matrix = torch.ones(batch_size, seq_len_q, seq_len_q, device=qa_mask.device) 323 | qa_matrix = torch.tril(qa_matrix) 324 | 325 | matrix_xq = qa_mask.unsqueeze(1) * torch.ones(batch_size, seq_len_x, seq_len_q, device=qa_mask.device) 326 | 327 | matrix_qx = torch.ones(batch_size, seq_len_q, seq_len_x, device=qa_mask.device) 328 | 329 | # Construct the full attention mask 330 | attention_mask = torch.cat([ 331 | torch.cat([adj_matrix, matrix_xq], dim=2), # x_embeddings part 332 | torch.cat([matrix_qx, qa_matrix], dim=2) # q_embeddings part 333 | ], dim=1).squeeze(1) 334 | 335 | outputs = self.lm(inputs_embeds=input_embeddings, 336 | attention_mask=attention_mask, 337 | return_dict=True) 338 | 339 | else: 340 | outputs = self.lm(inputs_embeds=input_embeddings, 341 | return_dict=True) 342 | 343 | return outputs 344 | 345 | -------------------------------------------------------------------------------- /reranker/reranker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | import pandas as pd 5 | import argparse 6 | 7 | from qwen_api import QwenAPI 8 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType 9 | 10 | stage_1_system_prompt = """ 11 | You are an experienced software developer who specializes in extracting the most relevant files for solving issues from many reference files. 12 | 13 | Task: 14 | Based on the information received about the issue from a repository, find the most likely few files from among those that may be able to resolve the issue. 15 | 16 | Instructions: 17 | 1. Analysis: 18 | - Analyze the provided issue description and files, and pay attention to the relevance of the provided files with the given issue, especially those might be modified during fixing the issue. 19 | - Determine the specific problem or error mentioned in the issue and note any clues that could help your judgment. 20 | 2. Extraction: 21 | - Based on your analysis, choose the Top **{}** relevant files which might be used in fixing the issue. 22 | - You should choose files from the provided files, and should not modify their name in any way. 23 | 24 | Respond in the following format: 25 | [start_of_analysis] 26 | 27 | [end_of_analysis] 28 | 29 | [start_of_relevant_files] 30 | 1. 31 | 2. 32 | 3. ... 33 | [end_of_relevant_files] 34 | 35 | Notes: 36 | - You can refer to to the information in the error logs (if exists). 37 | - The relevant file usually exists in the project described in the issue (e.g., django, sklearn). File need modification is usually not in the tests files or external packages. 38 | - The file you choose should be contained in the provided files. 39 | - Provide the file path with files. Do not include redundant suffix like '/home/username/', '/etc/service/' or '/tree/master'. 40 | - Do not include any additional information such as line numbers or explanations in your extraction result. 41 | - Files for initialization and configuration might be modified during changing the code. 42 | 43 | Preferred extraction Examples of Related Files: 44 | 1. src/utils/file_handler.py 45 | 2. core/services/service_manager.py 46 | 3. ... 47 | """.strip() 48 | 49 | stage_1_user_prompt_template = """ 50 | 51 | {} 52 | 53 | 54 | 55 | {} 56 | 57 | 58 | 59 | {} 60 | 61 | 62 | 63 | {} 64 | 65 | """ 66 | 67 | stage_2_system_prompt_v3 = """ 68 | You are an experienced software developer who specializes in assessing the relevance of the file for solving the issue in software repositories. 69 | 70 | Task: 71 | For a file provided, evaluate the likelihood that modifying this file would resolve the given issue, and assign a score based on specific criteria. 72 | 73 | Instructions: 74 | 1. Analysis: 75 | - Analyze the provided issue description and the content of the single relevant file, pay attention to any keywords, error messages, or specific functionalities mentioned that relate to the file. 76 | - Determine how closely the contents and functionality of the file are tied to the problem or error described in the issue. 77 | - Consider the role of the file in the overall project structure (e.g., configuration files, core logic files versus test files, or utility scripts). 78 | 2. Scoring: 79 | - Based on your analysis, assign a score from 1 to 5 that represents the relevance of modifying the given file in order to solve the issue. 80 | 81 | Score Specifications: 82 | 1. **Score 1**: The file is almost certainly unrelated to the issue, with no apparent connection to the functionality or error described in the issue. 83 | 2. **Score 2**: The file may be tangentially related, but modifying it is unlikely to resolve the issue directly; possible in rare edge cases. 84 | 3. **Score 3**: The file has some relevance to the issue; it might interact with the affected functionality indirectly and tweaking it could be part of a broader fix. 85 | 4. **Score 4**: The file is likely related to the issue; it includes code that interacts directly with the functionality in question and could plausibly contain bugs that lead to the issue. 86 | 5. **Score 5**: The file is very likely the root cause or heavily involved in the issue and modifying it should directly address the error or problem mentioned. 87 | 88 | Respond in the following format: 89 | [start_of_analysis] 90 | 91 | [end_of_analysis] 92 | 93 | [start_of_score] 94 | Score 95 | [end_of_score] 96 | 97 | Notes: 98 | - The content of the file shows only the structure of this file, including the names of the classes and functions defined in this file. 99 | - You can refer to to the information in the error logs (if exists). 100 | """.strip() 101 | 102 | stage_2_user_prompt_template = """ 103 | 104 | {} 105 | 106 | 107 | 108 | {} 109 | 110 | 111 | 112 | {} 113 | 114 | 115 | 116 | {} 117 | 118 | """ 119 | 120 | 121 | def get_python_inner_class_and_function(graph, node_id, layer_cnt = 0): 122 | """ 123 | 寻找某个node的内部函数和类的node,返回list 124 | dfs,返回的列表每个item是(深度, node) 125 | """ 126 | ret_list = [] 127 | 128 | # 限制深度 129 | if layer_cnt > 5: 130 | return ret_list 131 | 132 | node = graph.get_node_by_id(node_id) 133 | inner_node_ids = graph.get_out_nodes(node_id) 134 | for inner_node_id in inner_node_ids: 135 | inner_node = graph.get_node_by_id(inner_node_id) 136 | 137 | if inner_node.get_type() == NodeType.FUNCTION and "def " + inner_node.name in node.text: 138 | ret_list.append((layer_cnt, inner_node)) 139 | ret_list.extend(get_python_inner_class_and_function(graph, inner_node.node_id, layer_cnt + 1)) 140 | elif inner_node.get_type() == NodeType.CLASS and "class " + inner_node.name in node.text: 141 | ret_list.append((layer_cnt, inner_node)) 142 | ret_list.extend(get_python_inner_class_and_function(graph, inner_node.node_id, layer_cnt + 1)) 143 | 144 | return ret_list 145 | 146 | def parse_reranker_stage_1(response): 147 | # relevant_file 148 | pattern = r"\[start_of_relevant_files\]\s*(.*?)\s*\[end_of_relevant_files\]" 149 | match = re.search(pattern, response, re.DOTALL) 150 | if match: 151 | relevant_files = match.group(1).strip().split("\n") 152 | else: 153 | pattern = r"\s*(.*?)\s*" 154 | match = re.search(pattern, response, re.DOTALL) 155 | if match: 156 | relevant_files = match.group(1).strip().split("\n") 157 | else: 158 | pattern = r"\[Start_of_Relevant_Files\]\s*(.*?)\s*\[End_of_Relevant_Files\]" 159 | match = re.search(pattern, response, re.DOTALL) 160 | if match: 161 | relevant_files = match.group(1).strip().split("\n") 162 | else: 163 | relevant_files = [] 164 | 165 | print(relevant_files) 166 | for idx, relevant_file in enumerate(relevant_files): 167 | new_relevant_file = relevant_file 168 | if new_relevant_file.startswith("- "): 169 | new_relevant_file = new_relevant_file[2:] 170 | 171 | pattern = r"\d+ *\.(.+)" 172 | match = re.search(pattern, new_relevant_file) 173 | if match: 174 | new_relevant_file = match.group(1).strip() 175 | relevant_files[idx] = new_relevant_file 176 | 177 | return relevant_files 178 | 179 | def parse_reranker_stage_2(response): 180 | # score 181 | pattern = r"\[start_of_score\]\s*(.*?)\s*\[end_of_score\]" 182 | match = re.search(pattern, response, re.DOTALL) 183 | if match: 184 | score = match.group(1).strip().split("\n") 185 | else: 186 | pattern = r"\s*(.*?)\s*" 187 | match = re.search(pattern, response, re.DOTALL) 188 | if match: 189 | score = match.group(1).strip().split("\n") 190 | else: 191 | pattern = r"\[Start_of_Score\]\s*(.*?)\s*\[End_of_Score\]" 192 | match = re.search(pattern, response, re.DOTALL) 193 | if match: 194 | score = match.group(1).strip().split("\n") 195 | else: 196 | score = ["0"] 197 | 198 | score = score[0] 199 | if score.startswith("- "): 200 | score = score[2:] 201 | 202 | pattern = r"Score (\d+)" 203 | match = re.search(pattern, score) 204 | if match: 205 | score = match.group(1) 206 | score = int(score) 207 | else: 208 | score = 0 209 | 210 | return score 211 | 212 | def extract_files_from_subgraph(subgraph_path, output_path): 213 | subgraph_list = os.listdir(subgraph_path) 214 | # print(len(subgraph_list)) 215 | 216 | for subgraph in subgraph_list: 217 | if not subgraph.endswith(".json"): 218 | continue 219 | # print(subgraph) 220 | try: 221 | with open(os.path.join(subgraph_path, subgraph), "r", encoding="utf-8") as f: 222 | subgraph_json = json.load(f) 223 | except: 224 | print(f"broken json file: {subgraph}") 225 | continue 226 | subgraph_nodes = subgraph_json["nodes"] 227 | file_nodes = [node for node in subgraph_nodes if node["nodeType"] == "File"] 228 | pred_files = [] 229 | for node in file_nodes: 230 | file_path = node["filePath"] 231 | file_name = node["fileName"] 232 | if file_path is None: 233 | file = file_name 234 | else: 235 | file = os.path.join(file_path, file_name) 236 | pred_files.append(file) 237 | 238 | subgraph_name = subgraph.split(".")[0] 239 | with open((os.path.join(output_path, subgraph_name + ".json")), "w", encoding="utf-8") as f: 240 | json.dump(pred_files, f, indent=4) 241 | 242 | def parse_args(): 243 | parser = argparse.ArgumentParser(description="Run Reranker.") 244 | 245 | parser.add_argument('--stage_1_k', type=int, default=10, help='Specify the k for stage 1') 246 | parser.add_argument('--stage_2_k', type=int, default=5, help='Specify the k for stage 2') 247 | 248 | return parser.parse_args() 249 | 250 | 251 | if __name__ == "__main__": 252 | llm = QwenAPI("Qwen/Qwen2.5-72B-Instruct") 253 | 254 | output_dir = "reranker_outputs/" 255 | 256 | subgraph_file_dir = "subgraph/" 257 | 258 | # retriever得到的file list 259 | retriever_filtered_files_dir = "subgraph_extracted_files/" 260 | os.makedirs(retriever_filtered_files_dir, exist_ok=True) 261 | extract_files_from_subgraph(subgraph_file_dir, retriever_filtered_files_dir) 262 | 263 | df = pd.read_json("test_basic_info.json") 264 | 265 | args = parse_args() 266 | 267 | # stage_1 268 | stage_1 = True 269 | stage_1_output_dir = os.path.join(output_dir, f"stage_1_top_{args.stage_1_k}") 270 | os.makedirs(os.path.join(stage_1_output_dir, "relevant_files"), exist_ok=True) 271 | os.makedirs(os.path.join(stage_1_output_dir, "response"), exist_ok=True) 272 | stage_1_system_prompt = stage_1_system_prompt.format(args.stage_1_k) 273 | if stage_1: 274 | reranker_stage_1_outputs = os.listdir(os.path.join(stage_1_output_dir, "relevant_files")) 275 | reranker_stage_1_outputs = [item.split(".")[0] for item in reranker_stage_1_outputs] 276 | for i, data in enumerate(df): 277 | repo, instance_id, base_commit, patch, test_patch, problem_statement, hints_text, created_at, version, fail_to_pass, pass_to_pass = data["repo"], data["instance_id"], data["base_commit"], data["patch"], data["test_patch"], data["problem_statement"], data["hints_text"], data["created_at"], data["version"], data["FAIL_TO_PASS"], data["PASS_TO_PASS"] 278 | 279 | if instance_id in reranker_stage_1_outputs: 280 | print(f"Stage 1 index {i} skip") 281 | continue 282 | 283 | if os.path.exists(os.path.join(retriever_filtered_files_dir, instance_id + ".json")): 284 | with open(os.path.join(retriever_filtered_files_dir, instance_id + ".json"), "r") as f: 285 | filtered_files = json.load(f) 286 | else: 287 | raise ValueError 288 | 289 | python_files = [item for item in filtered_files if item.endswith(".py")] 290 | other_files = [item for item in filtered_files if not item.endswith(".py")] 291 | user_prompt = stage_1_user_prompt_template.format(repo, problem_statement, "\n".join(python_files), "\n".join(other_files)) 292 | print(user_prompt) 293 | 294 | response = llm.get_response(stage_1_system_prompt, user_prompt) 295 | print(response) 296 | 297 | relevant_files = parse_reranker_stage_1(response) 298 | with open(os.path.join(stage_1_output_dir, "relevant_files", instance_id + ".json"), "w") as f: 299 | json.dump(relevant_files, f, indent=4) 300 | with open(os.path.join(stage_1_output_dir, "response", instance_id + ".txt"), "w", encoding="utf-8") as f: 301 | f.write(response) 302 | 303 | print(f"Stage 1 index {i} done") 304 | 305 | # stage_2 306 | stage_2 = True 307 | stage_2_output_dir = os.path.join(output_dir, f"stage_2_{args.stage_2_k}") 308 | os.makedirs(os.path.join(stage_2_output_dir, "relevant_files"), exist_ok=True) 309 | os.makedirs(os.path.join(stage_2_output_dir, "response"), exist_ok=True) 310 | if stage_2: 311 | reranker_stage_2_outputs = os.listdir(os.path.join(stage_2_output_dir, "relevant_files")) 312 | reranker_stage_2_outputs = [item.split(".")[0] for item in reranker_stage_2_outputs] 313 | for i, data in enumerate(df): 314 | repo, instance_id, base_commit, patch, test_patch, problem_statement, hints_text, created_at, version, fail_to_pass, pass_to_pass = data["repo"], data["instance_id"], data["base_commit"], data["patch"], data["test_patch"], data["problem_statement"], data["hints_text"], data["created_at"], data["version"], data["FAIL_TO_PASS"], data["PASS_TO_PASS"] 315 | 316 | if instance_id in reranker_stage_2_outputs: 317 | print(f"Stage 2 index {i} skip") 318 | continue 319 | 320 | with open(os.path.join(stage_1_output_dir, "relevant_files", instance_id + ".json"), "r") as f: 321 | stage_1_relevant_files = json.load(f) 322 | 323 | # 读取子图 324 | if os.path.exists(os.path.join(subgraph_file_dir, instance_id + ".json")): 325 | subgraph_file_path = os.path.join(subgraph_file_dir, instance_id + ".json") 326 | graph = parse(subgraph_file_path) 327 | else: 328 | raise ValueError 329 | 330 | relevant_file_score = {} 331 | relevant_file_response = {} 332 | for relevant_file in stage_1_relevant_files: 333 | relevant_file_content = "" 334 | find = False 335 | 336 | # 保留class和function 337 | for file_node in graph.get_nodes_by_type(NodeType.FILE): 338 | file_path = file_node.path 339 | file_name = file_node.name 340 | if file_path is None: 341 | file = file_name 342 | else: 343 | file = os.path.join(file_path, file_name) 344 | if file == relevant_file: 345 | class_and_function_list = get_python_inner_class_and_function(graph, file_node.node_id) 346 | 347 | relevant_file_content = "" 348 | for layer, node in class_and_function_list: 349 | # 添加缩进 350 | if node.get_type() == NodeType.CLASS: 351 | relevant_file_content += " " * layer + "class " + node.name # 接口文档中的字段是className,但实际parse时会处理到name中 352 | relevant_file_content += "\n" 353 | elif node.get_type() == NodeType.FUNCTION: 354 | if node.name != "": 355 | relevant_file_content += " " * layer + "def " + node.name 356 | relevant_file_content += "\n" 357 | 358 | find = True 359 | break 360 | 361 | # 未找到直接默认0,不打分 362 | if not find: 363 | relevant_file_score[relevant_file] = 0 364 | relevant_file_response[relevant_file] = "" 365 | else: 366 | user_prompt = stage_2_user_prompt_template.format(repo, problem_statement, relevant_file, relevant_file_content) 367 | print(user_prompt) 368 | 369 | response = llm.get_response(stage_2_system_prompt_v3, user_prompt) 370 | 371 | score = parse_reranker_stage_2(response) 372 | relevant_file_score[relevant_file] = score 373 | relevant_file_response[relevant_file] = response 374 | 375 | k = args.stage_2_k 376 | # 找出分数最高的topk个文件 377 | sorted_relevant_files = sorted(relevant_file_score.items(), key=lambda item: item[1], reverse=True) 378 | if k <= len(sorted_relevant_files): 379 | selected_relevant_files = sorted_relevant_files[:k] 380 | selected_relevant_files = [item[0] for item in selected_relevant_files] 381 | else: 382 | selected_relevant_files = [item[0] for item in sorted_relevant_files] 383 | 384 | with open(os.path.join(stage_2_output_dir, "relevant_files", instance_id + ".json"), "w") as f: 385 | json.dump(dict(relevant_file_score=relevant_file_score, selected_relevant_files=selected_relevant_files), f, indent=4) 386 | with open(os.path.join(stage_2_output_dir, "response", instance_id + ".json"), "w") as f: 387 | json.dump(relevant_file_response, f, indent=4) 388 | 389 | print(f"Stage 2 index {i} done") 390 | 391 | -------------------------------------------------------------------------------- /cgm/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import json 3 | # from codegraph import * 4 | 5 | from torch.utils.data.dataset import Dataset 6 | from transformers import AutoModel, AutoTokenizer 7 | 8 | # from FlagEmbedding import BGEM3FlagModel 9 | # from sentence_transformers import SentenceTransformer 10 | 11 | from datasets import Dataset as HFDataset 12 | from datasets import load_dataset 13 | 14 | import torch 15 | import numpy as np 16 | import logging 17 | import time 18 | import gc 19 | 20 | import random 21 | import string 22 | import os 23 | import sys 24 | 25 | import json 26 | from collections import defaultdict 27 | import random 28 | 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | def getJavaSentence(node, nodeType, reponame, max_len): 32 | def process_Repo(node): 33 | return reponame 34 | 35 | def process_Module(node): 36 | return node['name'] 37 | 38 | def process_Package(node): 39 | return node['name'] 40 | 41 | def process_File(node): 42 | path = node.get('path','') 43 | if len(path) > 0: 44 | path = path + '/' 45 | return f"{path}{node['name']}" 46 | 47 | def process_TextFile(node): 48 | return f"{node['name']}\n{node.get('text','')}" 49 | 50 | def process_Class(node): 51 | return f"{node.get('modifiers','')} {node['name']}\n{node.get('comment','')}".strip(' ') 52 | 53 | def process_Field(node): 54 | return f"{node.get('modifiers','')} {node['fieldType']} {node['name']}\n{node.get('comment','')}".strip(' ') 55 | 56 | def process_Method(node): 57 | className = node.get('className','') 58 | methodName = node.get('methodName', '') 59 | if len(methodName) == 0 or len(className) == 0: 60 | split = node['signature'].split('#') 61 | className = split[0] 62 | methodName = split[1].split('(')[0] 63 | name = className + '.' + methodName 64 | comment = f"{node.get('comment','')}\n" if not node.get('comment','') == '' else '' 65 | text = f"{node.get('modifiers','')} {node.get('text','')}" if not node.get('modifiers','') == '' else node.get('text','') 66 | return f"{name}\n{comment}{text}" 67 | 68 | def process_default(node): 69 | raise ValueError(f"unrecognized nodeType for node.keys {node['nodeType']} {str(node.keys())}") 70 | return "" 71 | 72 | processors = { 73 | 'Repo': process_Repo, 74 | 'Module': process_Module, 75 | 'Package': process_Package, 76 | 'File': process_File, 77 | 'TextFile': process_TextFile, 78 | 'Textfile': process_TextFile, 79 | 'Class': process_Class, 80 | 'Field': process_Field, 81 | 'Method': process_Method 82 | } 83 | 84 | sentence = processors.get(nodeType, process_default)(node) 85 | 86 | # TODO: limit token not str size 87 | if len(sentence) > max_len: 88 | sentence = sentence[:max_len] 89 | 90 | return sentence 91 | 92 | def getPythonSentence(node, nodeType, reponame, max_len): 93 | def process_Repo(node): 94 | return reponame 95 | 96 | def process_Package(node): 97 | return node['name'] 98 | 99 | def process_File(node): 100 | path = node.get('filePath','') 101 | if len(path) > 0: 102 | path = path + '/' 103 | return f"{path}{node['fileName']}\n{node.get('text','')}" 104 | 105 | def process_TextFile(node): 106 | return f"{node['name']}\n{node.get('text','')}" 107 | 108 | def process_Class(node): 109 | return f"{node.get('classType','')} {node['className']}\n{node.get('comment','')}\n{node.get('text','')}".strip(' ') 110 | 111 | def process_Attribute(node): 112 | return f"{node.get('attributeType','')} {node['name']}\n{node.get('comment','')}\n{node.get('text','')}".strip(' ') 113 | 114 | def process_Function(node): 115 | comment = f"{node.get('comment','')}\n" if not node.get('comment','') == '' else '' 116 | return f"{node.get('header','')} {node['name']}\n{comment}{node.get('text','')}".strip(' ') 117 | 118 | def process_Lambda(node): 119 | return f"{node.get('text','')}".strip(' ') 120 | 121 | def process_default(node): 122 | raise ValueError(f"unrecognized nodeType for node.keys {node['nodeType']} {str(node.keys())}") 123 | return "" 124 | 125 | processors = { 126 | 'Repo': process_Repo, 127 | 'Package': process_Package, 128 | 'File': process_File, 129 | 'TextFile': process_TextFile, 130 | 'Textfile': process_TextFile, 131 | 'Class': process_Class, 132 | 'Attribute': process_Attribute, 133 | 'Function': process_Function, 134 | 'Lambda': process_Lambda 135 | } 136 | 137 | sentence = processors.get(nodeType, process_default)(node) 138 | 139 | # TODO: limit token not str size 140 | if len(sentence) > max_len: 141 | sentence = sentence[:max_len] 142 | 143 | return sentence 144 | 145 | def graph2embedding(data, model, tokenizor, reponame, language, save_adj): 146 | node_embeddings = {} 147 | sentence_dict = {} 148 | node_id_to_index = {} 149 | index_counter = 0 150 | 151 | for node in data['nodes']: 152 | nodeType = node['nodeType'] 153 | 154 | if 'nodeId' in node.keys(): 155 | node_id = node['nodeId'] 156 | elif 'id' in node.keys(): 157 | node_id = node['id'] 158 | else: 159 | raise ValueError("No key named id/nodeId") 160 | 161 | if language == 'java': 162 | sentence = getJavaSentence(node, nodeType, reponame, 1024000) 163 | elif language == 'python': 164 | sentence = getPythonSentence(node, nodeType, reponame, 1024000) 165 | else: 166 | raise ValueError(f"Language {language} not supported") 167 | 168 | if sentence == "": 169 | node_embedding = torch.zeros((1, 256), dtype=torch.float32).to(device) 170 | node_embeddings[node_id] = [node_embedding] 171 | sentence_dict[index_counter] = "" 172 | node_id_to_index[node_id] = [index_counter] 173 | index_counter += 1 174 | else: 175 | # 手动切词 176 | tokens = tokenizor.tokenize(sentence) 177 | num_tokens = len(tokens) 178 | num_segments = (num_tokens + 511) // 512 # Calculate number of segments 179 | embeddings = [] 180 | segments = [] 181 | node_id_to_index[node_id] = list(range(index_counter, index_counter + num_segments)) 182 | for i in range(num_segments): 183 | start = i * 512 184 | end = min((i + 1) * 512, num_tokens) 185 | segment_tokens = tokens[start:end] 186 | segment_sentence = tokenizor.convert_tokens_to_string(segment_tokens) 187 | segment_ids = tokenizor.encode(segment_sentence, return_tensors="pt").to(device) 188 | with torch.no_grad(): 189 | segment_embedding = model(segment_ids) 190 | embeddings.append(segment_embedding) 191 | segments.append(segment_sentence) 192 | sentence_dict[index_counter] = segment_sentence 193 | index_counter += 1 194 | 195 | node_embeddings[node_id] = embeddings 196 | 197 | num_nodes = index_counter 198 | 199 | if save_adj: 200 | adj_matrix = torch.zeros((num_nodes, num_nodes)) 201 | 202 | for edge in data['edges']: 203 | source_id = edge['source'] 204 | target_id = edge['target'] 205 | source_indices = node_id_to_index.get(source_id) 206 | target_indices = node_id_to_index.get(target_id) 207 | if source_indices is None or target_indices is None: 208 | # if source_indices is None: 209 | # print(f"{source_id} not exists") 210 | # if target_indices is None: 211 | # print(f"{target_id} not exists") 212 | continue 213 | 214 | for source_index in source_indices: 215 | for target_index in target_indices: 216 | adj_matrix[source_index, target_index] = 1 217 | 218 | # Connect embeddings of the same node 219 | for node_id, indices in node_id_to_index.items(): 220 | for i in range(len(indices)): 221 | for j in range(i + 1, len(indices)): 222 | adj_matrix[indices[i], indices[j]] = 1 223 | adj_matrix[indices[j], indices[i]] = 1 224 | else: 225 | adj_matrix = None 226 | 227 | all_embeddings = [] 228 | for value in node_embeddings.values(): 229 | if isinstance(value, torch.Tensor): 230 | all_embeddings.append(value) 231 | elif isinstance(value, list): 232 | for tensor in value: 233 | all_embeddings.append(tensor) 234 | 235 | embeddings = torch.stack(all_embeddings, dim=0) 236 | 237 | # embeddings = torch.stack(list(node_embeddings.values())) 238 | # embeddings = torch.stack(sum(node_embeddings.values(), [])) 239 | # embeddings = torch.cat(list(node_embeddings.values()), dim=0) 240 | 241 | return embeddings, adj_matrix, sentence_dict 242 | 243 | def preprocess_graph(graphdir, savedir, recdir, jsondir, language = 'java', model = None, tokenizor = None, filenum = 1, suffix = 'pt', node_limit = 20000, save_adj = True, save_rec = True): 244 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 245 | logger = logging.getLogger(__name__) 246 | logger.info(f"Parsing json file: {jsondir}{filenum}.json") 247 | 248 | if not os.path.exists(savedir): 249 | os.makedirs(savedir) 250 | 251 | if not os.path.exists(recdir): 252 | os.makedirs(recdir) 253 | 254 | if jsondir == graphdir: 255 | glist = os.listdir(graphdir) 256 | else: 257 | with open(f'{jsondir}{filenum}.json', 'r') as f: 258 | glist = json.load(f) 259 | f.close() 260 | 261 | for gname in glist: 262 | if gname.startswith('._'): 263 | gname = gname[2:] 264 | raw_file = os.path.join(graphdir, gname) 265 | print(gname) 266 | 267 | if len(gname.split('#')) == 2: 268 | appName = gname.split('#')[1].split('-graph.json')[0] 269 | repoName = appName 270 | groupName = gname.split('#')[0] 271 | commitId = '0' 272 | elif len(gname.split('#')) == 3: 273 | appName = gname.split('#')[1] 274 | repoName = appName 275 | groupName = gname.split('#')[0] 276 | commitId = gname.split('#')[2].split('.graph.json')[0] 277 | elif len(gname.split('___')) == 3: 278 | parts = gname.split('___') 279 | appName = parts[0] 280 | repoName = parts[1].split('__')[1] 281 | groupName = parts[1].split('__')[0] 282 | commitId = parts[2].split('.')[0] 283 | else: 284 | print(f"{gname} can't be renamed") 285 | continue 286 | tmp1 = f"{appName}___{repoName}___{groupName}___{commitId}.{suffix}" 287 | tmp2 = f"{appName}___{repoName}___{groupName}___{commitId}.json" 288 | print(tmp1) 289 | 290 | save_file = os.path.join(savedir, tmp1) 291 | rec_file = os.path.join(recdir, tmp2) 292 | 293 | if not os.path.exists(raw_file): 294 | continue 295 | if os.path.exists(save_file) and os.path.exists(rec_file): 296 | continue 297 | logger.info(f'Start {gname} transforming...') 298 | try: 299 | with open(raw_file, 'r') as f1: 300 | content = f1.read() 301 | data = json.loads(content) 302 | 303 | if len(data['nodes']) > node_limit: 304 | continue 305 | embeddings, adj_matrix, sentence_dict = graph2embedding(data, model, tokenizor, gname, language, save_adj) 306 | f1.close() 307 | 308 | if suffix == 'json': 309 | if save_adj: 310 | data_dict = { 311 | "embeddings": embeddings.tolist(), 312 | "adj_matrix": adj_matrix.tolist() 313 | } 314 | else: 315 | data_dict = { 316 | "embeddings": embeddings.tolist(), 317 | } 318 | 319 | with open(save_file, 'w') as f: 320 | json.dump(data_dict, f) 321 | f.close() 322 | 323 | elif suffix == 'pt': 324 | if save_adj: 325 | data_dict = { 326 | "embeddings": embeddings.detach(), 327 | "adj_matrix": adj_matrix.detach() 328 | } 329 | else: 330 | data_dict = { 331 | "embeddings": embeddings.detach(), 332 | } 333 | torch.save(data_dict, save_file) 334 | 335 | if save_rec: 336 | rec_dict = { 337 | "text": list(sentence_dict.values()) 338 | } 339 | 340 | with open(rec_file, 'w') as f: 341 | json.dump(rec_dict, f) 342 | f.close() 343 | except json.JSONDecodeError as e: 344 | print('Json Decode Error: '+ gname) 345 | 346 | def preprocess(graphdir, savedir, recdir, jsondir, language = 'java', mode = 'pretrain', filenum = 1, suffix = 'pt', node_limit = 20000, save_adj = True, save_rec = True): 347 | 348 | model1_path = "salesforce/codet5p-110m-embedding" 349 | tokenizer1 = AutoTokenizer.from_pretrained(model1_path, trust_remote_code=True, device = device) 350 | model1 = AutoModel.from_pretrained(model1_path, trust_remote_code=True, torch_dtype="auto").to(device).eval() 351 | 352 | if mode == 'pretrain': 353 | preprocess_graph( 354 | graphdir=graphdir, 355 | savedir=savedir, 356 | recdir=recdir, 357 | jsondir=jsondir, 358 | language=language, 359 | model=model1, 360 | tokenizor=tokenizer1, 361 | filenum=filenum, 362 | suffix=suffix, 363 | node_limit=node_limit, 364 | save_adj=save_adj, 365 | save_rec=save_rec) 366 | else: 367 | raise NotImplementedError 368 | 369 | def json_split(loaddirs, savedir, split_num=64): 370 | if not os.path.exists(savedir): 371 | os.makedirs(savedir) 372 | 373 | file_list = [] 374 | for loaddir in loaddirs: 375 | file_list += os.listdir(loaddir) 376 | total_num = len(file_list) 377 | sep_num = total_num // split_num 378 | print(f'total num: {total_num}, sep num: {sep_num}') 379 | 380 | for i in range(split_num): 381 | start = i * sep_num 382 | end = start + sep_num if i != split_num - 1 else total_num 383 | with open(f'{savedir}/{i+1}.json', 'w') as f: 384 | json.dump(file_list[start:end], f) 385 | f.close() 386 | 387 | def json_split_from_json(input_json, savedir, split_num=64): 388 | with open(input_json, 'r') as file: 389 | data = json.load(file) 390 | 391 | total_items = len(data) 392 | num_files = (total_items + split_num - 1) // split_num # 向上取整 393 | 394 | if not os.path.exists(savedir): 395 | os.makedirs(savedir) 396 | 397 | for i in range(num_files): 398 | start = i * split_num 399 | end = min(start + split_num, total_items) 400 | split_data = data[start:end] 401 | 402 | save_file = os.path.join(savedir, f"{i+1}.json") 403 | 404 | with open(save_file, 'w') as file: 405 | json.dump(split_data, file, indent=4) 406 | 407 | def detect_pt_file_errors(directory, output_json): 408 | error_files = [] 409 | 410 | for root, _, files in os.walk(directory): 411 | for file in files: 412 | if file.endswith('.pt'): 413 | file_path = os.path.join(root, file) 414 | try: 415 | with open(file_path, 'rb') as f: 416 | tmp = torch.load(f) 417 | del tmp 418 | gc.collect() 419 | except Exception as e: 420 | error_files.append(file_path) 421 | print(f"Error loading {file_path}: {e}") 422 | 423 | with open(output_json, 'w') as f: 424 | json.dump(error_files, f, indent=4) 425 | 426 | print(f"Detected {len(error_files)} error files. Details saved in {output_json}") 427 | 428 | def transfer_pt_file_errors(input_json, output_json): 429 | with open(input_json, 'r') as file: 430 | data = json.load(file) 431 | 432 | def transform_path(path): 433 | repo_str = path.split('/')[-1] 434 | appName = repo_str.split('___')[0] 435 | groupName = repo_str.split('___')[2] 436 | new_repo_str = f"{groupName}#{appName}-graph.json" 437 | return new_repo_str 438 | 439 | transformed_data = [transform_path(path) for path in data] 440 | 441 | with open(output_json, 'w') as file: 442 | json.dump(transformed_data, file, indent=4) 443 | 444 | def get_list(graph_dirs): 445 | all_files = [os.path.join(graph_dir, file) 446 | for graph_dir in graph_dirs 447 | for file in os.listdir(graph_dir)] 448 | return all_files 449 | 450 | def get_list_constrained(graph_dirs, size_limit = 500 * 1024 * 1024): 451 | filtered_files = [] 452 | for graph_dir in graph_dirs: 453 | glist = os.listdir(graph_dir) 454 | for file_name in glist: 455 | file_path = os.path.join(graph_dir, file_name) 456 | if os.path.isfile(file_path): 457 | file_size = os.path.getsize(file_path) 458 | if file_size < size_limit: 459 | filtered_files.append(file_path) 460 | 461 | return filtered_files 462 | 463 | def get_graph_path(glist, filename, suffix): 464 | sp = filename.split('___') 465 | if len(sp) == 4: 466 | appName = sp[0] 467 | repoName = sp[1] 468 | groupName = sp[2] 469 | commitId = sp[3].split('.')[0] 470 | 471 | matched_graphs = [] 472 | for graph in glist: 473 | graph_parts = graph.split('/')[-1].split('___') 474 | if len(graph_parts) == 4: 475 | graph_appName = graph_parts[0] 476 | graph_repoName = graph_parts[1] 477 | graph_groupName = graph_parts[2] 478 | graph_commitId = graph_parts[3].split('.')[0] 479 | 480 | if graph_appName == appName: 481 | matched_graphs.append((graph, graph_repoName, graph_groupName, graph_commitId)) 482 | 483 | if not matched_graphs: 484 | return None 485 | 486 | if not commitId == '0': 487 | for graph, graph_repoName, graph_groupName, graph_commitId in matched_graphs: 488 | if commitId == graph_commitId: 489 | return graph 490 | 491 | best_match = None 492 | best_match_score = -2 493 | for graph, graph_repoName, graph_groupName, _ in matched_graphs: 494 | score = (repoName == graph_repoName) + (groupName == graph_groupName) 495 | if score > best_match_score: 496 | best_match_score = score 497 | best_match = graph 498 | 499 | return best_match 500 | else: 501 | raise ValueError(f"{filename} to graph not supported") 502 | 503 | def split_jsonl_dataset(input_file, train_file, test_file, train_ratio=0.98): 504 | def read_jsonl(file_path): 505 | with open(file_path, 'r') as file: 506 | for line in file: 507 | yield json.loads(line) 508 | 509 | data = list(read_jsonl(input_file)) 510 | repo_dict = defaultdict(list) 511 | 512 | for item in data: 513 | repo_dict[item['repo']].append(item) 514 | 515 | repos = list(repo_dict.keys()) 516 | random.shuffle(repos) 517 | 518 | split_index = int(len(repos) * train_ratio) 519 | train_repos = repos[:split_index] 520 | test_repos = repos[split_index:] 521 | 522 | train_data = [] 523 | test_data = [] 524 | 525 | for repo in train_repos: 526 | train_data.extend(repo_dict[repo]) 527 | for repo in test_repos: 528 | test_data.extend(repo_dict[repo]) 529 | 530 | with open(train_file, 'w') as file: 531 | for item in train_data: 532 | file.write(json.dumps(item) + '\n') 533 | 534 | with open(test_file, 'w') as file: 535 | for item in test_data: 536 | file.write(json.dumps(item) + '\n') 537 | 538 | 539 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CGM: Code Graph LLM 2 | 3 | ![CodefuseLogo](./assets/github-codefuse-logo-update.jpg) 4 | 5 | ## Contents 6 | - [News](#news) 7 | - [Introduction](#introduction) 8 | - [Installation](#installation) 9 | - [Examples](#examples) 10 | - [Rewriter](#rewriter) 11 | - [Retriever](#retriever) 12 | - [Reranker](#reranker) 13 | - [Reader](#reader) 14 | - [Contributing](#contributing) 15 | - [Citation](#citation) 16 | - [Join Us](#join-us) 17 | 18 | ## News 19 | 20 | 🔥🔥🔥 [2025/09/19] Our paper [Code Graph Model (CGM): A Graph-Integrated Large Language Model for Repository-Level Software Engineering Tasks](https://arxiv.org/abs/2505.16901) has been accepted to NeurIPS 2025! 21 | 22 | ![SWE-Bench-Lite](./assets/swe-bench-modified.png) 23 | 24 | 🔥🔥🔥 [2025/01/15] We are pleased to announce the updated version of the CGM-72B-V1.2. The model further achieves a remarkable 44.00% resolve rate on the SWE-Bench-Lite leaderboard. 25 | 26 | 🔥🔥🔥 [2024/12/28] We are pleased to announce the updated version of the CGM-72B-V1.1. The model further achieves a remarkable 41.67% resolve rate on the SWE-Bench-Lite leaderboard. 27 | 28 | 🔥🔥🔥 [2024/10/28] We are pleased to announce that CGM-72B achieves a remarkable 35.67% resolve rate on the SWE-Bench-Lite leaderboard. 29 | 30 | 🔥🔥🔥 [2024/10/28] We released **CGM**, mainly for repository-level coding tasks. 31 | 32 | - 📜 **Paper**: [Code Graph Model (CGM): A Graph-Integrated Large Language Model for Repository-Level Software Engineering Tasks](https://arxiv.org/abs/2505.16901) 33 | - 🤖 **Model**: [codefuse-ai/CodeFuse-CGM-72B](https://huggingface.co/codefuse-ai/CodeFuse-CGM-72B) 34 | - 📊 **Data**: [codefuse-ai/CodeGraph](https://huggingface.co/datasets/codefuse-ai/CodeGraph) 35 | 36 | ## Introduction 37 | We propose a graph-based framework CGM for real-world SE tasks. Before CGM starts its work, we construct a repository-level code graph to better represent the repository context and its structure by Code Graph Generator. Inspired by the Retrieval-Augmented Generation (RAG) approach, CGM framework is designed as a chain structure consisting of four atomic nodes, termed as R4 (Rewriter, Retriever, Reranker, and Reader) chain for this scenario. Given an issue, the initial input to the CGM framework includes the issue description and the corresponding code graph. Rewriter will first rewrite the original issue by extracting keywords and generating relevant queries for code graph. Then a heuristic code subgraph is retrieved through Retriever based on the matching anchor nodes from rewriter output. Given that the resulting subgraph provides a relatively broad context necessary for reference, we need a Reranker to identify the files most likely to be modified as a further hint. Subsequently, both the retrieved subgraph and the identified files are input into a trainable, graph-based Reader to generate the corresponding code patch. 38 | 39 | ### Framework 40 | 41 | ![Framework](./assets/cgm_method_v3.png) 42 | 43 | ### Highlights 44 | :white_check_mark: **Code Graph**: Train models on multiple tasks while maintaining a balance between them. The models can even generalize to new, previously unseen tasks. 45 | 46 | :white_check_mark: **Multi-framework**: It provides support for both Accelerate (with Deepspeed and FSDP) 47 | 48 | :white_check_mark: **Efficient fine-tuning**: It supports LoRA, QLoRA as well as Full-parameters training, enabling fine-tuning of large models with minimal resources. The training speed meets the demands of almost all fine-tuning scenarios. 49 | 50 | ## Installation 51 | ### Prerequisites 52 | - Python 3.8+ 53 | - pip 54 | 55 | ### Required Packages 56 | 57 | ```bash 58 | transformers==4.46.1 59 | tokenizers==0.20.0 60 | accelerate==1.0.1 61 | peft==0.13.2 62 | jinja2==2.11.3 63 | fuzzywuzzy==0.18.0 64 | python-Levenshtein==0.25.1 65 | networkx==3.0 66 | ``` 67 | 68 | ## Examples 69 | 70 | The following chart illustrates the whole processing pipeline of R3. 71 | ![R3 Pipeline](assets/pipeline.png) 72 | 73 | 74 | ### Pre-process for Retriever: Generate Node Embedding 75 | 76 | Before Retriever, we need to embed 77 | - all the nodes in Code Graph into embeddings 78 | - Queries generated by Rewriter, into embeddings 79 | 80 | by [CGE-large](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Large) 81 | 82 | ```bash 83 | python generate_code_content.py # this step will preprocess each code graph and extract content from each node (save a .json file for each repo) 84 | python generate_code_embedding.py # this step will generate embedding for each node by CGE-large (save a .pkl file for each repo) 85 | python generate_rewriter_embedding.py # this step will generate embedding for each Query enerated by Rewriter (save a .pkl file) 86 | ``` 87 | 88 | Requirements for CGE-large 89 | ```bash 90 | torch==2.1.0 91 | transformers==4.39.2 92 | tokenizers==0.15.2 93 | accelerate==0.28.0 94 | ``` 95 | 96 | ### Rewriter 97 | 98 | Given issues meta data, execute the following scripts to generate Rewriter results (both Inferer and Extractor). 99 | 100 | ```bash 101 | python generate_rewriter_prompt.py # this step will generate a json file "test_rewriter_prompt.json" containing prompts used in Rewriter 102 | 103 | python inference_rewriter.py --prompt_path test_rewriter_prompt.json # this step will load Qwen Model to execuate inference and generate Rewriter's output "test_rewriter_output.json" 104 | 105 | python rewriter_output_post_processing.py # this step will load Rewriter's output "rewriter_output.json" and generate post processed output "test_rewriter_output.json 106 | ``` 107 | 108 | Use function ```generate_prompt_for_extractor``` and ```generate_prompt_for_inferer``` in ```rewriter/prompt.py``` 109 | ```python 110 | def generate_prompt_for_extractor(problem_statement, repo_name): 111 | prompt = """ 112 | 113 | {} 114 | 115 | This is an issue related to repository '{}'. 116 | Instructions: 117 | 1. Analysis: 118 | ○ Analyze the provided issue description. Identify the relevant File, Class, or Function involved. 119 | ○ Determine the specific problem or error encountered and note any clues that may assist in locating the relevant or problematic area. 120 | 2. Extraction: 121 | ○ After the analysis, extract ALL the mentioned code entities (File, Class, or Function), especially Files. 122 | ○ Then extract three potential and meaningful keywords, responding in the following format: 123 | 124 | [start_of_analysis] 125 | 126 | [end_of_analysis] 127 | 128 | [start_of_related_code_entities] 129 | 130 | [end_of_related_code_entities] 131 | 132 | [start_of_related_keywords] 133 | 134 | [end_of_related_keywords] 135 | 136 | Notes: 137 | - Pay attention to the information in the error logs (if exists). 138 | - The buggy code exists solely in the project described in the issue (e.g., django, sklearn). Buggy location is usually not in the tests files or external packages. 139 | - Your extracted entities should be CONCISE, ACCURATE and INFORMATIVE. 140 | - Provide the relative path for code entities if specified (e.g., package/foo.py). Relative path is relative to the repository itself, do not include suffix like '/home/username/', '/etc/service/' or '/tree/master'. 141 | - Do not include any additional information such as line numbers or explanations in your extraction result. 142 | 143 | Preferred extraction Examples of Code Entities: 144 | - repo/cart.py 145 | - Class User() 146 | - def getData() 147 | Preferred extraction Examples of Keywords: 148 | - train_loop 149 | - hooks 150 | - docker 151 | 152 | Unpreferred extraction Examples of keywords: 153 | - something wrong 154 | - input validation 155 | - TypeError 156 | """.format(problem_statement, repo_name) 157 | 158 | return prompt 159 | 160 | def generate_prompt_for_inferer(problem_statement, repo_name): 161 | prompt = """ 162 | 163 | {} 164 | 165 | This is an issue related to repository '{}'. 166 | Task: 167 | Based on the issue description provided, identify the characteristics of code entities (files, functions, class) that might need to be modified. 168 | For each characteristic, generate a search query that could help locate relevant code entities in a codebase. 169 | Instructions: 170 | First, analyze the issue description and identify keywords, features, and functionalities that are likely relevant to the modification of code entities. 171 | Then, create queries that capture these characteristics, focusing on: 172 | ● File names that may implement relevant functionalities. 173 | ● Functions or methods that are related to the features described in the issue. 174 | ● Any patterns or structures that might be relevant to the functionalities mentioned. 175 | For example: 176 | ● File related to the initialization of a neural network. 177 | ● Function related to the training process. 178 | ● Code used to configure the service. 179 | Please answer in the following format: 180 | 181 | [start_of_analysis] 182 | 183 | [end_of_analysis] 184 | 185 | [start_of_related_queries] 186 | query 1: 187 | query 2: 188 | ... 189 | [end_of_related_queries] 190 | 191 | Notes: 192 | - Your queries should be DETAILED, ACCURATE and INFORMATIVE. 193 | - Your queries should be a complete sentences and do not include additional explanation. 194 | - The number of queries is up to five, so be focus on the important characteristics. 195 | - Your queries should focus on the repository code itself, rather than other information like commit history. 196 | - Pay attention to the information in the error logs (if exists). 197 | 198 | Preferred Query Examples: 199 | - Look for references to "tqdm" or "progress_bar" within the training loop files to find where progress bars are currently updated. 200 | - Code snippets where 'gethostbyname' function from 'socket' module is called. 201 | - File name containing 'mysql.py' AND functions related to 'MySQLStatementSamples' initialization. 202 | - Functions or methods handling hostname resolution or encoding within 'datadog_checks' directory. 203 | - Find all occurrences of "early_stopping" within files that also mention "Trainer" to identify where early stopping logic is implemented and potentially needs adjustment for non-default 'val_check_interval'. 204 | """.format(problem_statement, repo_name) 205 | 206 | return prompt 207 | ``` 208 | You can use the rewriter prompt by 209 | ```python 210 | from rewriter.prompt import generate_prompt_for_extractor, generate_prompt_for_inferer 211 | 212 | # Generate extraction prompt 213 | extraction_prompt = generate_prompt_for_extractor(problem_statement, repo_name) 214 | 215 | # Generate inference prompt 216 | inference_prompt = generate_prompt_for_inferer(problem_statement, repo_name) 217 | ``` 218 | 219 | ### Retriever 220 | Now, we have 221 | - Original CodeGraph: `codegraph/` 222 | - Node Embedding of CodeGraph: `node_embedding/` 223 | - Query Embedding of Rewriter's Inferer: `rewriter_embedding.pkl` 224 | - Output of Rewriter's Extractor: `rewriter_output.json` 225 | 226 | Then we can execute Retriever: 227 | 228 | ```bash 229 | python locate_anchor_node.py # this step will used the above input and generate ancode nodes "anchor_node.json" for all samples 230 | python subgraph.py # this step will then expand the anchor nodes to a connected subgraph (saving as a set of node_id) 231 | python serialize_subgraph.py # based on the above subgraph node ids, this step will serialize the subgraph into json format (which is the final output of Retriever) 232 | ``` 233 | 234 | Requirements for Retriever 235 | ```bash 236 | RapidFuzz==1.5.0 237 | faiss-cpu 238 | ``` 239 | 240 | ### Reranker 241 | 242 | Reranker is used to determine the most relevant files from the subgraph generated by Retriever. The input is the subgraph json file which is the output of Retriever 243 | 244 | ```bash 245 | python reranker.py --stage_1_k 10 --stage_2_k 5 # this step will load the subgraph json file and generate the output of Reranker. 246 | ``` 247 | 248 | Requirements for Reranker 249 | ```bash 250 | vllm>=0.8.5 251 | ``` 252 | 253 | Use function ```generate_prompt_for_reranker_stage_1``` and ```generate_prompt_for_reranker_stage_2``` in ```reranker/prompt.py``` 254 | ```python 255 | """ 256 | Prompt Template for Reranker 257 | """ 258 | 259 | reranker_stage_1_system_prompt = """ 260 | You are an experienced software developer who specializes in extracting the most relevant files for solving issues from many reference files. 261 | 262 | Task: 263 | Based on the information received about the issue from a repository, find the most likely few files from among those that may be able to resolve the issue. 264 | 265 | Instructions: 266 | 1. Analysis: 267 | - Analyze the provided issue description and files, and pay attention to the relevance of the provided files with the given issue, especially those might be modified during fixing the issue. 268 | - Determine the specific problem or error mentioned in the issue and note any clues that could help your judgment. 269 | 2. Extraction: 270 | - Based on your analysis, choose the Top **1** relevant files which might be used in fixing the issue. 271 | - You should choose files from the provided files, and should not modify their name in any way. 272 | 273 | Respond in the following format: 274 | [start_of_analysis] 275 | 276 | [end_of_analysis] 277 | 278 | [start_of_relevant_files] 279 | 1. 280 | 2. 281 | 3. ... 282 | [end_of_relevant_files] 283 | 284 | Notes: 285 | - You can refer to to the information in the error logs (if exists). 286 | - The relevant file usually exists in the project described in the issue (e.g., django, sklearn). File need modification is usually not in the tests files or external packages. 287 | - The file you choose should be contained in the provided files. 288 | - Provide the file path with files. Do not include redundant suffix like '/home/username/', '/etc/service/' or '/tree/master'. 289 | - Do not include any additional information such as line numbers or explanations in your extraction result. 290 | - Files for initialization and configuration might be modified during changing the code. 291 | 292 | Preferred extraction Examples of Related Files: 293 | 1. src/utils/file_handler.py 294 | 2. core/services/service_manager.py 295 | 3. ... 296 | """.strip() 297 | 298 | reranker_stage_1_user_prompt_template = """ 299 | 300 | {} 301 | 302 | 303 | 304 | {} 305 | 306 | 307 | 308 | {} 309 | 310 | 311 | 312 | {} 313 | 314 | """ 315 | 316 | reranker_stage_2_system_prompt = """ 317 | You are an experienced software developer who specializes in assessing the relevance of the file for solving the issue in software repositories. 318 | 319 | Task: 320 | For a file provided, evaluate the likelihood that modifying this file would resolve the given issue, and assign a score based on specific criteria. 321 | 322 | Instructions: 323 | 1. Analysis: 324 | - Analyze the provided issue description and the content of the single relevant file, pay attention to any keywords, error messages, or specific functionalities mentioned that relate to the file. 325 | - Determine how closely the contents and functionality of the file are tied to the problem or error described in the issue. 326 | - Consider the role of the file in the overall project structure (e.g., configuration files, core logic files versus test files, or utility scripts). 327 | 2. Scoring: 328 | - Based on your analysis, assign a score from 1 to 5 that represents the relevance of modifying the given file in order to solve the issue. 329 | 330 | Score Specifications: 331 | 1. **Score 1**: The file is almost certainly unrelated to the issue, with no apparent connection to the functionality or error described in the issue. 332 | 2. **Score 2**: The file may be tangentially related, but modifying it is unlikely to resolve the issue directly; possible in rare edge cases. 333 | 3. **Score 3**: The file has some relevance to the issue; it might interact with the affected functionality indirectly and tweaking it could be part of a broader fix. 334 | 4. **Score 4**: The file is likely related to the issue; it includes code that interacts directly with the functionality in question and could plausibly contain bugs that lead to the issue. 335 | 5. **Score 5**: The file is very likely the root cause or heavily involved in the issue and modifying it should directly address the error or problem mentioned. 336 | 337 | Respond in the following format: 338 | [start_of_analysis] 339 | 340 | [end_of_analysis] 341 | 342 | [start_of_score] 343 | Score 344 | [end_of_score] 345 | 346 | Notes: 347 | - The content of the file shows only the structure of this file, including the names of the classes and functions defined in this file. 348 | - You can refer to to the information in the error logs (if exists). 349 | """.strip() 350 | 351 | reranker_stage_2_user_prompt_template = """ 352 | 353 | {} 354 | 355 | 356 | 357 | {} 358 | 359 | 360 | 361 | {} 362 | 363 | 364 | 365 | {} 366 | 367 | """ 368 | 369 | def generate_prompt_for_reranker_stage_1(problem_statement, repo_name, py_file, other_file): 370 | """ 371 | problem_statement: issue 372 | repo_name: repo 373 | py_file: py file list 374 | other_file: related file list 375 | """ 376 | return reranker_stage_1_system_prompt, reranker_stage_1_user_prompt_template.format(repo_name, problem_statement, py_file, other_file) 377 | 378 | def generate_prompt_for_reranker_stage_2(problem_statement, repo_name, file_name, file_content): 379 | """ 380 | problem_statement: issue 381 | repo_name: repo 382 | file_name: file 383 | file_content: file content(class xxx和def xxx) 384 | """ 385 | return reranker_stage_2_system_prompt, reranker_stage_2_user_prompt_template.format(repo_name, problem_statement, file_name, file_content) 386 | ``` 387 | You can use the reranker prompt by 388 | ```python 389 | from reranker.prompt import generate_prompt_for_reranker_stage_1, generate_prompt_for_reranker_stage_2 390 | 391 | # Stage 1: Identify relevant files 392 | system_prompt, user_prompt = generate_prompt_for_reranker_stage_1( 393 | problem_statement, 394 | repo_name, 395 | py_file_list, 396 | other_file_list 397 | ) 398 | 399 | # Stage 2: Score file relevance 400 | system_prompt, user_prompt = generate_prompt_for_reranker_stage_2( 401 | problem_statement, 402 | repo_name, 403 | target_file, 404 | file_content 405 | ) 406 | ``` 407 | 408 | ### Reader 409 | Execute the Reader module with DeepSpeed configurations: 410 | ```bash 411 | # Zero-2 Configuration 412 | EXPORT N_NODE={YOUR_MACHINE_NUM} && \ 413 | EXPORT N_GPU_PER_NODE={YOUR_GPU_NUM} && \ 414 | EXPORT TRAIN_CONFIG={TRAIN_CONFIG}.json && \ 415 | bash launch/zero2.sh 416 | 417 | # Zero-3 Configuration 418 | EXPORT N_NODE={YOUR_MACHINE_NUM} && \ 419 | EXPORT N_GPU_PER_NODE={YOUR_GPU_NUM} && \ 420 | EXPORT TRAIN_CONFIG={TRAIN_CONFIG}.json && \ 421 | bash launch/zero3.sh 422 | ``` 423 | 424 | ## Contributing 425 | Contributions are welcome! If you have any suggestions, ideas, bug reports, or new model/feature supported, please open an issue or submit a pull request. 426 | 427 | We welcome contributions from the community! Please follow these guidelines: 428 | 429 | 1. Fork the repository 430 | 431 | 2. Create your feature branch 432 | 433 | 3. Commit your changes 434 | 435 | 4. Push to the branch 436 | 437 | 5. Open a Pull Request 438 | 439 | For major changes, please open an issue first to discuss the proposed changes. 440 | 441 | 442 | ## Citation 443 | If you find our work useful or helpful for your R&D works, please feel free to cite our paper as below. 444 | ```bibtex 445 | @misc{tao2025codegraphmodelcgm, 446 | title={Code Graph Model (CGM): A Graph-Integrated Large Language Model for Repository-Level Software Engineering Tasks}, 447 | author={Hongyuan Tao and Ying Zhang and Zhenhao Tang and Hongen Peng and Xukun Zhu and Bingchang Liu and Yingguang Yang and Ziyin Zhang and Zhaogui Xu and Haipeng Zhang and Linchao Zhu and Rui Wang and Hang Yu and Jianguo Li and Peng Di}, 448 | year={2025}, 449 | eprint={2505.16901}, 450 | archivePrefix={arXiv}, 451 | primaryClass={cs.SE}, 452 | url={https://arxiv.org/abs/2505.16901}, 453 | } 454 | ``` 455 | ## Join-US 456 | 457 | We are the AI Native team within the Platform Technology Business Group at Ant Group, dedicated to the intelligentization of Ant Group's platform engineering. Established for over three years, our team has played a pivotal role in supporting the intelligent operation and maintenance of Ant Group's cloud computing infrastructure. Our mission is to build algorithm services and platforms with a wide user base through world-class technological innovation and impact, supporting the implementation of internal and external products and businesses. 458 | Embracing an innovation-driven ethos, our team not only supports business implementation but also propels technological influence. Over the past three years, we have published more than 20 papers at top conferences like ICLR, NeurIPS, KDD, and ACL. Our innovative business outcomes have earned us two Ant Technology's highest T-Star awards and one SuperMA award from Ant Group. Our open-source project CodeFuse has received 4K stars as of February 2024, and our models have been downloaded over 1.5 million times on Huggingface and Modelscope. 459 | 460 | We are on the lookout for top talents to join our vibrant team! If you're eager to develop your career in an environment filled with energy, innovation, and a culture of excellence, we welcome you to explore our career opportunities for both campus and experienced hires. Join us and be a part of creating the next milestone in the industry. 461 | 462 | **Contact**: hyu.hugo@antgroup.com 463 | -------------------------------------------------------------------------------- /cgm/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.functional import one_hot 7 | from tqdm.auto import tqdm 8 | import time 9 | import datetime 10 | from collections import OrderedDict 11 | 12 | sys.path.append("..") 13 | from utils.common_utils import touch_print, print_rank_0 14 | from utils.loss import loss_CGM 15 | from torch.utils.tensorboard import SummaryWriter 16 | from accelerate.logging import get_logger 17 | from torch.cuda.amp import autocast 18 | 19 | logger = get_logger(__name__) 20 | 21 | task_ids = { 22 | (0, 'graph_query'), 23 | (1, 'api'), 24 | (2, 'issue_fix'), 25 | (3, 'unit_test'), 26 | (4, 'readme_summary'), 27 | } 28 | 29 | task_to_id = {task: idx for idx, task in task_ids} 30 | id_to_task = {idx: task for idx, task in task_ids} 31 | 32 | def check_weight_dtype(model): 33 | for name, param in model.named_parameters(): 34 | print_rank_0(f"Layer {name}: {param.dtype}") 35 | 36 | def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps): 37 | for key, value in log_dict.items(): 38 | summary_writer.add_scalar(f'{key}', value, completed_steps) 39 | 40 | def accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir: str, completed_steps: int, args): 41 | accelerator.wait_for_everyone() 42 | 43 | accelerator.print(f"[CHECKPOINT] Saving checkpoint") 44 | unwrapped_model = accelerator.unwrap_model(model) 45 | 46 | save_encoder = False 47 | save_adapter = False 48 | save_lm = False 49 | if 'e' in args.mode: 50 | save_encoder = True 51 | if 'a' in args.mode: 52 | save_adapter = True 53 | if 'l' in args.mode: 54 | save_lm = True 55 | 56 | if accelerator.is_main_process: 57 | if not os.path.exists(output_dir): 58 | os.makedirs(output_dir) 59 | tokenizer.save_pretrained(output_dir) 60 | 61 | if save_adapter: 62 | torch.save(accelerator.get_state_dict(model.adapter), f"{output_dir}/adapter.pth") 63 | 64 | 65 | if save_encoder: 66 | unwrapped_model.encoder.save_pretrained( 67 | f"{output_dir}/encoder", 68 | is_main_process=accelerator.is_main_process, 69 | save_function=accelerator.save, 70 | state_dict=accelerator.get_state_dict(model.encoder) 71 | ) 72 | 73 | if save_lm: 74 | unwrapped_model.lm.save_pretrained( 75 | output_dir, 76 | is_main_process=accelerator.is_main_process, 77 | save_function=accelerator.save, 78 | state_dict=accelerator.get_state_dict(model.lm) 79 | ) 80 | 81 | accelerator.print( 82 | f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved" 83 | ) 84 | 85 | accelerator.wait_for_everyone() 86 | 87 | def accelerate_evaluate_CGM(accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, min_eval_loss, stall_num, 88 | best_step, summary_writer): 89 | 90 | losses = [] 91 | end_eval = False 92 | eval_step = 0 93 | for batch in valid_dataloader: 94 | with torch.no_grad(): 95 | if args.task == 'mft': 96 | task = batch['task'] 97 | 98 | x = batch['x'] 99 | y = batch['y'] 100 | loss_mask = batch['loss_mask'] 101 | qa_mask = batch['qa_mask'] 102 | embeddings = batch['embeddings'] 103 | 104 | len_y = y.shape[1] 105 | 106 | outputs = model( 107 | graph_embeddings = embeddings, 108 | qa_embeddings = x, 109 | qa_mask = qa_mask 110 | ) 111 | output_logits = outputs['logits'][:,-len_y:,:] 112 | 113 | if args.task == 'mft': 114 | loss_dict = {task_id: torch.tensor(0.0, device=output_logits.device) for task_id, _ in task_ids} 115 | 116 | task_id = task.item() 117 | loss_dict[task_id] += loss_CGM( 118 | output_logits = output_logits, 119 | labels = y, 120 | loss_mask = loss_mask, 121 | ) 122 | 123 | loss = sum(loss_dict.values()) 124 | else: 125 | loss = loss_CGM( 126 | output_logits = output_logits, 127 | labels = y, 128 | loss_mask = loss_mask, 129 | ) 130 | 131 | eval_step += 1 132 | 133 | losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) 134 | 135 | accelerator.wait_for_everyone() 136 | valid_batch_num = len(losses) 137 | gathered_size = losses[0].shape 138 | losses = torch.cat(losses) 139 | 140 | try: 141 | eval_loss = torch.mean(losses) 142 | if eval_loss <= min_eval_loss: 143 | min_eval_loss = eval_loss 144 | stall_num = 0 145 | best_step = completed_steps 146 | else: 147 | stall_num += 1 148 | perplexity = math.exp(eval_loss) 149 | except OverflowError: 150 | perplexity = float("inf") 151 | 152 | logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]" 153 | f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]" 154 | f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}]") 155 | eval_log_dict = { 156 | "valid/valid_loss": eval_loss.float(), 157 | "valid/perplexity": perplexity 158 | } 159 | 160 | if accelerator.is_main_process: 161 | write_tensorboard(summary_writer, eval_log_dict, completed_steps) 162 | 163 | return eval_loss, min_eval_loss, stall_num, best_step 164 | 165 | def accelerate_evaluate_CGM_mft(accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, min_eval_loss, stall_num, 166 | best_step, summary_writer): 167 | 168 | losses = [] 169 | task_eval_counts = {} 170 | task_losses = {} 171 | eval_step = 0 172 | 173 | for batch in valid_dataloader: 174 | with torch.no_grad(): 175 | if args.task == 'mft': 176 | task = batch['task'] 177 | task_id = task.item() 178 | if task_id not in task_eval_counts: 179 | task_eval_counts[task_id] = 0 180 | task_losses[task_id] = [] 181 | if task_eval_counts[task_id] >= 50: 182 | continue 183 | 184 | x = batch['x'] 185 | y = batch['y'] 186 | loss_mask = batch['loss_mask'] 187 | qa_mask = batch['qa_mask'] 188 | embeddings = batch['embeddings'] 189 | 190 | len_y = y.shape[1] 191 | 192 | outputs = model( 193 | graph_embeddings = embeddings, 194 | qa_embeddings = x, 195 | qa_mask = qa_mask 196 | ) 197 | output_logits = outputs['logits'][:,-len_y:,:] 198 | 199 | if args.task == 'mft': 200 | loss = loss_CGM( 201 | output_logits = output_logits, 202 | labels = y, 203 | loss_mask = loss_mask, 204 | ) 205 | task_losses[task_id].append(loss.item()) 206 | else: 207 | loss = loss_CGM( 208 | output_logits = output_logits, 209 | labels = y, 210 | loss_mask = loss_mask, 211 | ) 212 | 213 | eval_step += 1 214 | losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) 215 | task_eval_counts[task_id] += 1 216 | 217 | accelerator.wait_for_everyone() 218 | valid_batch_num = len(losses) 219 | gathered_size = losses[0].shape 220 | losses = torch.cat(losses) 221 | 222 | try: 223 | eval_loss = torch.mean(losses) 224 | if eval_loss <= min_eval_loss: 225 | min_eval_loss = eval_loss 226 | stall_num = 0 227 | best_step = completed_steps 228 | else: 229 | stall_num += 1 230 | perplexity = math.exp(eval_loss) 231 | except OverflowError: 232 | perplexity = float("inf") 233 | 234 | logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]" 235 | f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]" 236 | f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}]") 237 | 238 | for task_id, task_loss_list in task_losses.items(): 239 | task_eval_loss = sum(task_loss_list) / len(task_loss_list) if task_loss_list else 0.0 240 | logger.info(f"[EVAL][task_id={task_id}][task_loss={task_eval_loss:.6f}]") 241 | eval_log_dict = { 242 | "valid/valid_loss": eval_loss.float(), 243 | "valid/perplexity": perplexity, 244 | f"valid/{id_to_task[task_id]}": task_eval_loss 245 | } 246 | 247 | if accelerator.is_main_process: 248 | write_tensorboard(summary_writer, eval_log_dict, completed_steps) 249 | 250 | return eval_loss, min_eval_loss, stall_num, best_step 251 | 252 | def accelerate_monitor_CGM_mft(accelerator, reduce_loss_dict, args, completed_steps, 253 | lr_scheduler, optimizer, summary_writer): 254 | 255 | """ 256 | gather reduce_loss from all N devices. 257 | train logging and tensorboarding. 258 | """ 259 | # gathered_loss_dict = {task_id: accelerator.gather(reduce_loss) for task_id, reduce_loss in reduce_loss_dict.items()} 260 | gathered_loss_dict = {task_id: reduce_loss for task_id, reduce_loss in reduce_loss_dict.items()} 261 | print_rank_0(f"*******************gathered_loss_dict*******************") 262 | print_rank_0(gathered_loss_dict) 263 | 264 | train_log_dict = { 265 | f"train/{id_to_task[task_id]}": torch.mean(gathered_loss) / max(reduce_loss_count_dict[task_id], 1) 266 | for task_id, gathered_loss in gathered_loss_dict.items() 267 | } 268 | train_log_dict["train/lr"] = optimizer.param_groups[0]['lr'] 269 | 270 | logger.info( 271 | f"[TRAIN][completed_steps={completed_steps}]" 272 | f"[lr={optimizer.param_groups[0]['lr']:.4e}]", 273 | ) 274 | for task_id, train_loss in train_log_dict.items(): 275 | if task_id != "train/lr": 276 | logger.info(f"{task_id}={train_loss:.6f}") 277 | 278 | if accelerator.is_main_process: 279 | write_tensorboard(summary_writer, train_log_dict, completed_steps) 280 | 281 | 282 | def accelerate_monitor_CGM(accelerator, reduce_loss, args, completed_steps, 283 | lr_scheduler, optimizer, summary_writer): 284 | 285 | """ 286 | gather reduce_loss from all N devices. 287 | train logging and tensorboarding. 288 | """ 289 | reduce_losses = accelerator.gather(reduce_loss) 290 | # reduce_losses = reduce_loss 291 | 292 | train_loss = torch.mean(reduce_losses) / (args.log_interval * args.gradient_accumulation_steps) 293 | 294 | 295 | logger.info( 296 | f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]" 297 | f"[gather shape={reduce_losses.shape}][lr={optimizer.param_groups[0]['lr']:.4e}]", 298 | ) 299 | 300 | train_log_dict = { 301 | "train/train_loss": train_loss, 302 | "train/lr": optimizer.param_groups[0]['lr'] 303 | } 304 | 305 | if accelerator.is_main_process: 306 | write_tensorboard(summary_writer, train_log_dict, completed_steps) 307 | 308 | 309 | def accelerate_train_CGM(accelerator, model, train_dataloader, valid_dataloader, optimizer, lr_scheduler, tokenizer, 310 | total_train_dataset_size, args): 311 | 312 | summary_writer = SummaryWriter(log_dir=args.tb_dir, filename_suffix=args.tb_dir.split('/')[-1]) if accelerator.is_main_process else None 313 | # Train! 314 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 315 | logger.info("**************************************** Running training ****************************************") 316 | logger.info(f" Num examples = {total_train_dataset_size}") 317 | logger.info(f" Num Epochs = {args.num_train_epochs}") 318 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 319 | logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 320 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 321 | logger.info(f" Total optimization(update/completed) steps = {args.max_train_steps}") 322 | logger.info(f" Complete/Optimization steps per Epoch = {args.max_train_steps // args.num_train_epochs}") 323 | logger.info("***************************************************************************************************") 324 | 325 | # Only show the progress bar once on each machine. 326 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 327 | # check_weight_dtype(model.lm) 328 | # exit() 329 | # set starting_epoch, completed_steps and resume_step of train_dataloader 330 | completed_steps = 0 331 | starting_epoch = 0 332 | 333 | # monitor minimum eval_loss, stalling num, and best_step 334 | min_eval_loss = float('inf') 335 | eval_loss = 100.0 336 | checkpoint_eval_loss = float('inf') 337 | checkpoint_stall_num = 0 338 | stall_num = 0 339 | best_step = None 340 | 341 | reduce_loss = 0 342 | reduce_loss_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids) 343 | reduce_loss_count_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids) 344 | for epoch in range(starting_epoch, args.num_train_epochs): 345 | model.train() 346 | 347 | for step, batch in enumerate(train_dataloader): 348 | 349 | with accelerator.accumulate(model): 350 | 351 | graph = batch['graph'] 352 | x = batch['x'] 353 | y = batch['y'].unsqueeze(0) 354 | loss_mask = batch['loss_mask'].unsqueeze(0) 355 | qa_mask = batch['qa_mask'].unsqueeze(0) 356 | 357 | if args.task == 'mft': 358 | task = batch['task'] 359 | 360 | outputs = model( 361 | graph = graph, 362 | qa_ids = x, 363 | qa_mask = qa_mask, 364 | ) 365 | 366 | len_y = y.shape[1] 367 | output_logits = outputs['logits'][:,-len_y:,:] 368 | 369 | if args.task == 'mft': 370 | loss_dict = {task_id: torch.tensor(0.0, device=output_logits.device) for task_id, _ in task_ids} 371 | 372 | task_id = task.item() 373 | loss_dict[task_id] += loss_CGM( 374 | output_logits = output_logits, 375 | labels = y, 376 | loss_mask = loss_mask, 377 | ) 378 | 379 | loss = sum(loss_dict.values()) 380 | else: 381 | loss = loss_CGM( 382 | output_logits = output_logits, 383 | labels = y, 384 | loss_mask = loss_mask, 385 | ) 386 | 387 | accelerator.backward(loss) 388 | 389 | optimizer.step() 390 | if not args.lr_scheduler_type == "reduce_lr_on_plateau": 391 | lr_scheduler.step() 392 | optimizer.zero_grad() 393 | 394 | if optimizer.param_groups[0]['lr'] <= args.min_lr: 395 | optimizer.param_groups[0]['lr'] = args.min_lr 396 | 397 | if args.task == 'mft': 398 | for task_id in reduce_loss_dict.keys(): 399 | if not torch.isnan(loss_dict[task_id]): 400 | reduce_loss_dict[task_id] += loss_dict[task_id].detach().float() 401 | reduce_loss_count_dict[task_id] += 1 402 | else: 403 | 404 | if not torch.isnan(loss): 405 | reduce_loss += loss.detach().float() 406 | else: 407 | logger.info("loss nan") 408 | 409 | if accelerator.sync_gradients: 410 | completed_steps += 1 411 | if args.task == 'mft': 412 | reduce_loss = sum(reduce_loss_dict.values()) 413 | logger.info(f"accelerator step (accumulate) {completed_steps}, loss: {reduce_loss}") 414 | 415 | if completed_steps % args.log_interval == 0: 416 | if args.task == 'mft': 417 | progress_bar.update(args.log_interval) 418 | accelerate_monitor_CGM_mft( 419 | accelerator, reduce_loss_dict, reduce_loss_count_dict, args, completed_steps, 420 | lr_scheduler, optimizer, summary_writer 421 | ) 422 | reduce_loss_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids) 423 | reduce_loss_count_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids) 424 | else: 425 | if isinstance(reduce_loss, torch.Tensor): 426 | progress_bar.update(args.log_interval) 427 | accelerate_monitor_CGM( 428 | accelerator, reduce_loss, args, completed_steps, 429 | lr_scheduler, optimizer, summary_writer 430 | ) 431 | reduce_loss = 0 432 | 433 | # steps checkpointing 434 | if args.step_checkpointing and completed_steps % args.checkpointing_steps == 0: 435 | output_dir = f"step_{completed_steps}" 436 | if args.output_dir is not None: 437 | output_dir = os.path.join(args.output_dir, output_dir) 438 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args) 439 | 440 | if args.step_evaluation and completed_steps % args.evaluation_steps == 0: 441 | logger.info(f"start evaluation...") 442 | model.eval() 443 | model.lm.gradient_checkpointing_disable() 444 | model.lm.config.use_cache = True 445 | if args.task == 'mft': 446 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM_mft( 447 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, 448 | min_eval_loss, stall_num, best_step, summary_writer 449 | ) 450 | else: 451 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM( 452 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, 453 | min_eval_loss, stall_num, best_step, summary_writer 454 | ) 455 | model.train() 456 | model.lm.gradient_checkpointing_enable() 457 | model.lm.config.use_cache = False 458 | 459 | if args.lr_scheduler_type == "reduce_lr_on_plateau": 460 | lr_scheduler.step(eval_loss) 461 | 462 | if eval_loss < checkpoint_eval_loss: 463 | checkpoint_eval_loss = eval_loss 464 | output_dir = f"step_{completed_steps}_stall_{checkpoint_stall_num}" 465 | if args.output_dir is not None: 466 | output_dir = os.path.join(args.output_dir, output_dir) 467 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args) 468 | checkpoint_stall_num = 0 469 | else: 470 | if checkpoint_stall_num < 2: 471 | output_dir = f"step_{completed_steps}_stall_{checkpoint_stall_num}" 472 | if args.output_dir is not None: 473 | output_dir = os.path.join(args.output_dir, output_dir) 474 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args) 475 | checkpoint_stall_num += 1 476 | 477 | if args.lr_scheduler_type == "reduce_lr_on_plateau": 478 | pass 479 | elif args.lr_scheduler_type == 'cosine': 480 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33 481 | lr_scheduler.base_lrs = optimizer.param_groups[0]['lr'] 482 | lr_scheduler.step() 483 | else: 484 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33 485 | lr_scheduler.step() 486 | 487 | # adapter warmup 488 | if args.adapter_warmup and completed_steps >= args.adapter_warmup_steps: 489 | if 'l' in args.mode: 490 | for param in model.lm.parameters(): 491 | param.requires_grad = True 492 | args.adapter_warmup = False 493 | 494 | accelerator.wait_for_everyone() 495 | 496 | load_t = time.time() 497 | 498 | if args.epoch_evaluation: 499 | model.eval() 500 | model.lm.gradient_checkpointing_disable() 501 | model.lm.config.use_cache = True 502 | if args.task == 'mft': 503 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM_mft( 504 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, 505 | min_eval_loss, stall_num, best_step, summary_writer 506 | ) 507 | else: 508 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM( 509 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, 510 | min_eval_loss, stall_num, best_step, summary_writer 511 | ) 512 | model.train() 513 | model.lm.gradient_checkpointing_enable() 514 | model.lm.config.use_cache = False 515 | 516 | if args.lr_scheduler_type == "reduce_lr_on_plateau": 517 | lr_scheduler.step(eval_loss) 518 | 519 | if eval_loss < checkpoint_eval_loss: 520 | checkpoint_eval_loss = eval_loss 521 | output_dir = f"epoch_{epoch}" 522 | ckpt_tag = output_dir 523 | if args.output_dir is not None: 524 | output_dir = os.path.join(args.output_dir, output_dir) 525 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args) 526 | else: 527 | if args.lr_scheduler_type == "reduce_lr_on_plateau": 528 | pass 529 | # lr_scheduler.step(eval_loss) 530 | elif args.lr_scheduler_type == 'cosine': 531 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33 532 | lr_scheduler.base_lrs = optimizer.param_groups[0]['lr'] 533 | lr_scheduler.step() 534 | else: 535 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33 536 | lr_scheduler.step() 537 | 538 | # epoch checkpointing 539 | if args.epoch_checkpointing: 540 | output_dir = f"epoch_{epoch}" 541 | if args.output_dir is not None: 542 | output_dir = os.path.join(args.output_dir, output_dir) 543 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args) 544 | 545 | if summary_writer: 546 | summary_writer.close() 547 | 548 | output_dir = f"final_step_{completed_steps}" 549 | if args.output_dir is not None: 550 | output_dir = os.path.join(args.output_dir, output_dir) 551 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args) 552 | -------------------------------------------------------------------------------- /cgm/models/qwen2/_4_46_1/modeling_attn_mask_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import List, Optional, Tuple, Union 16 | 17 | import torch 18 | 19 | from transformers.utils.import_utils import is_torchdynamo_compiling 20 | 21 | # TODO: Copyied Version: transformers == 4.46.1 22 | 23 | @dataclass 24 | class AttentionMaskConverter: 25 | """ 26 | A utility attention mask class that allows one to: 27 | - Create a causal 4d mask 28 | - Create a causal 4d mask with slided window 29 | - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, 30 | key_value_length) that can be multiplied with attention scores 31 | 32 | Examples: 33 | 34 | ```python 35 | >>> import torch 36 | >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter 37 | 38 | >>> converter = AttentionMaskConverter(True) 39 | >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) 40 | tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], 41 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], 42 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], 43 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], 44 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) 45 | ``` 46 | 47 | Parameters: 48 | is_causal (`bool`): 49 | Whether the attention mask should be a uni-directional (causal) or bi-directional mask. 50 | 51 | sliding_window (`int`, *optional*): 52 | Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. 53 | """ 54 | 55 | is_causal: bool 56 | sliding_window: int 57 | 58 | def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): 59 | self.is_causal = is_causal 60 | self.sliding_window = sliding_window 61 | 62 | if self.sliding_window is not None and self.sliding_window <= 0: 63 | raise ValueError( 64 | f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" 65 | ) 66 | 67 | def to_causal_4d( 68 | self, 69 | batch_size: int, 70 | query_length: int, 71 | key_value_length: int, 72 | dtype: torch.dtype, 73 | device: Union[torch.device, "str"] = "cpu", 74 | ) -> Optional[torch.Tensor]: 75 | """ 76 | Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative 77 | bias to upper right hand triangular matrix (causal mask). 78 | """ 79 | if not self.is_causal: 80 | raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") 81 | 82 | # If shape is not cached, create a new causal mask and cache it 83 | input_shape = (batch_size, query_length) 84 | past_key_values_length = key_value_length - query_length 85 | 86 | # create causal mask 87 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 88 | causal_4d_mask = None 89 | if input_shape[-1] > 1 or self.sliding_window is not None: 90 | causal_4d_mask = self._make_causal_mask( 91 | input_shape, 92 | dtype, 93 | device=device, 94 | past_key_values_length=past_key_values_length, 95 | sliding_window=self.sliding_window, 96 | ) 97 | 98 | return causal_4d_mask 99 | 100 | def to_4d( 101 | self, 102 | attention_mask_2d: torch.Tensor, 103 | query_length: int, 104 | dtype: torch.dtype, 105 | key_value_length: Optional[int] = None, 106 | ) -> torch.Tensor: 107 | """ 108 | Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, 109 | key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is 110 | causal, a causal mask will be added. 111 | """ 112 | input_shape = (attention_mask_2d.shape[0], query_length) 113 | 114 | # create causal mask 115 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 116 | causal_4d_mask = None 117 | if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: 118 | if key_value_length is None: 119 | raise ValueError( 120 | "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." 121 | ) 122 | 123 | past_key_values_length = key_value_length - query_length 124 | causal_4d_mask = self._make_causal_mask( 125 | input_shape, 126 | dtype, 127 | device=attention_mask_2d.device, 128 | past_key_values_length=past_key_values_length, 129 | sliding_window=self.sliding_window, 130 | ) 131 | elif self.sliding_window is not None: 132 | raise NotImplementedError("Sliding window is currently only implemented for causal masking") 133 | 134 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 135 | expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( 136 | attention_mask_2d.device 137 | ) 138 | 139 | if causal_4d_mask is not None: 140 | expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) 141 | 142 | # expanded_attn_mask + causal_4d_mask can cause some overflow 143 | expanded_4d_mask = expanded_attn_mask 144 | 145 | return expanded_4d_mask 146 | 147 | def _3d_to_4d( 148 | self, 149 | attention_mask_3d: torch.Tensor, 150 | query_length: int, 151 | dtype: torch.dtype, 152 | key_value_length: Optional[int] = None, 153 | ) -> torch.Tensor: 154 | """ 155 | Converts 3D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, 156 | key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is 157 | causal, a causal mask will be added. 158 | """ 159 | input_shape = (attention_mask_3d.shape[0], query_length, key_value_length) 160 | 161 | # create causal mask 162 | # [bsz, tgt_seq_len, src_seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 163 | causal_4d_mask = None 164 | if (input_shape[-2] > 1 or self.sliding_window is not None) and self.is_causal: 165 | if key_value_length is None: 166 | raise ValueError( 167 | "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." 168 | ) 169 | 170 | past_key_values_length = key_value_length - query_length 171 | causal_4d_mask = self._make_causal_mask( 172 | input_shape, 173 | dtype, 174 | device=attention_mask_3d.device, 175 | past_key_values_length=past_key_values_length, 176 | sliding_window=self.sliding_window, 177 | ) 178 | elif self.sliding_window is not None: 179 | raise NotImplementedError("Sliding window is currently only implemented for causal masking") 180 | 181 | # [bsz, tgt_seq_len, src_seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 182 | expanded_attn_mask = self._expand_mask_3d(attention_mask_3d, dtype, tgt_len=input_shape[-2]).to( 183 | attention_mask_3d.device 184 | ) 185 | 186 | if causal_4d_mask is not None: 187 | expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) 188 | 189 | # expanded_attn_mask + causal_4d_mask can cause some overflow 190 | expanded_4d_mask = expanded_attn_mask 191 | 192 | return expanded_4d_mask 193 | 194 | @staticmethod 195 | def _make_causal_mask( 196 | input_ids_shape: torch.Size, 197 | dtype: torch.dtype, 198 | device: torch.device, 199 | past_key_values_length: int = 0, 200 | sliding_window: Optional[int] = None, 201 | ): 202 | """ 203 | Make causal mask used for bi-directional self-attention. 204 | """ 205 | bsz, tgt_len = input_ids_shape 206 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 207 | mask_cond = torch.arange(mask.size(-1), device=device) 208 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 209 | 210 | mask = mask.to(dtype) 211 | 212 | if past_key_values_length > 0: 213 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 214 | 215 | # add lower triangular sliding window mask if necessary 216 | if sliding_window is not None: 217 | diagonal = past_key_values_length - sliding_window - 1 218 | 219 | context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) 220 | mask.masked_fill_(context_mask, torch.finfo(dtype).min) 221 | 222 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 223 | 224 | @staticmethod 225 | def _make_causal_mask_3d( 226 | input_ids_shape: torch.Size, 227 | dtype: torch.dtype, 228 | device: torch.device, 229 | past_key_values_length: int = 0, 230 | sliding_window: Optional[int] = None, 231 | ): 232 | """ 233 | Make causal mask used for bi-directional self-attention. 234 | """ 235 | bsz, tgt_len, src_len = input_ids_shape 236 | mask = torch.full((tgt_len, src_len), torch.finfo(dtype).min, device=device) 237 | mask_cond = torch.arange(mask.size(-1), device=device) 238 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 239 | 240 | mask = mask.to(dtype) 241 | 242 | if past_key_values_length > 0: 243 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 244 | 245 | # add lower triangular sliding window mask if necessary 246 | if sliding_window is not None: 247 | diagonal = past_key_values_length - sliding_window - 1 248 | 249 | context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) 250 | mask.masked_fill_(context_mask, torch.finfo(dtype).min) 251 | 252 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, src_len + past_key_values_length) 253 | 254 | @staticmethod 255 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 256 | """ 257 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 258 | """ 259 | bsz, src_len = mask.size() 260 | tgt_len = tgt_len if tgt_len is not None else src_len 261 | 262 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 263 | 264 | inverted_mask = 1.0 - expanded_mask 265 | 266 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 267 | 268 | @staticmethod 269 | def _expand_mask_3d(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 270 | """ 271 | Expands attention_mask from `[bsz, tgt_seq_len, src_seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 272 | """ 273 | bsz, tgt_seq_len, src_seq_len = mask.size() 274 | tgt_len = tgt_len if tgt_len is not None else tgt_seq_len 275 | 276 | expanded_mask = mask[:, None, :, :].expand(bsz, 1, tgt_len, src_seq_len).to(dtype) 277 | 278 | inverted_mask = 1.0 - expanded_mask 279 | 280 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 281 | 282 | @staticmethod 283 | def _unmask_unattended( 284 | expanded_mask: torch.FloatTensor, 285 | min_dtype: float, 286 | ): 287 | # fmt: off 288 | """ 289 | Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when 290 | using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 291 | Details: https://github.com/pytorch/pytorch/issues/110213 292 | 293 | `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. 294 | `attention_mask` is [bsz, src_seq_len]. 295 | 296 | The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. 297 | 298 | For example, if `expanded_mask` is (e.g. here left-padding case) 299 | ``` 300 | [[[[0, 0, 0], 301 | [0, 0, 0], 302 | [0, 0, 1]]], 303 | [[[1, 0, 0], 304 | [1, 1, 0], 305 | [1, 1, 1]]], 306 | [[[0, 0, 0], 307 | [0, 1, 0], 308 | [0, 1, 1]]]] 309 | ``` 310 | then the modified `expanded_mask` will be 311 | ``` 312 | [[[[1, 1, 1], <-- modified 313 | [1, 1, 1], <-- modified 314 | [0, 0, 1]]], 315 | [[[1, 0, 0], 316 | [1, 1, 0], 317 | [1, 1, 1]]], 318 | [[[1, 1, 1], <-- modified 319 | [0, 1, 0], 320 | [0, 1, 1]]]] 321 | ``` 322 | """ 323 | # fmt: on 324 | if expanded_mask.dtype == torch.bool: 325 | raise ValueError( 326 | "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." 327 | ) 328 | 329 | return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) 330 | 331 | @staticmethod 332 | def _ignore_causal_mask_sdpa( 333 | attention_mask: Optional[torch.Tensor], 334 | inputs_embeds: torch.Tensor, 335 | past_key_values_length: int, 336 | sliding_window: Optional[int] = None, 337 | is_training: bool = False, 338 | ) -> bool: 339 | """ 340 | Detects whether the optional user-specified attention_mask & the automatically created causal mask can be 341 | ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. 342 | 343 | In case no token is masked in the `attention_mask` argument, if `query_length == 1` or 344 | `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, 345 | allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is 346 | passed). 347 | """ 348 | 349 | _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] 350 | key_value_length = query_length + past_key_values_length 351 | 352 | is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling() 353 | 354 | ignore_causal_mask = False 355 | 356 | if attention_mask is None: 357 | # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input 358 | # shape, thus SDPA's `is_causal` argument is rightfully updated 359 | # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using 360 | # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is 361 | # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` 362 | # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). 363 | # Thus, we only set `ignore_causal_mask = True` if the model is set to training. 364 | # 365 | # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` 366 | # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). 367 | if ( 368 | (is_training or not is_tracing) 369 | and (query_length == 1 or key_value_length == query_length) 370 | and (sliding_window is None or key_value_length < sliding_window) 371 | ): 372 | ignore_causal_mask = True 373 | elif sliding_window is None or key_value_length < sliding_window: 374 | if len(attention_mask.shape) == 4: 375 | return False 376 | elif not is_tracing and torch.all(attention_mask == 1): 377 | if query_length == 1 or key_value_length == query_length: 378 | # For query_length == 1, causal attention and bi-directional attention are the same. 379 | ignore_causal_mask = True 380 | 381 | # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore 382 | # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in 383 | # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. 384 | # Reference: https://github.com/pytorch/pytorch/issues/108108 385 | # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. 386 | 387 | return ignore_causal_mask 388 | 389 | 390 | def _prepare_4d_causal_attention_mask( 391 | attention_mask: Optional[torch.Tensor], 392 | input_shape: Union[torch.Size, Tuple, List], 393 | inputs_embeds: torch.Tensor, 394 | past_key_values_length: int, 395 | sliding_window: Optional[int] = None, 396 | ): 397 | """ 398 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 399 | `(batch_size, key_value_length)` 400 | 401 | Args: 402 | attention_mask (`torch.Tensor` or `None`): 403 | A 2D attention mask of shape `(batch_size, key_value_length)` 404 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`): 405 | The input shape should be a tuple that defines `(batch_size, query_length)`. 406 | inputs_embeds (`torch.Tensor`): 407 | The embedded inputs as a torch Tensor. 408 | past_key_values_length (`int`): 409 | The length of the key value cache. 410 | sliding_window (`int`, *optional*): 411 | If the model uses windowed attention, a sliding window should be passed. 412 | """ 413 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) 414 | 415 | key_value_length = input_shape[-1] + past_key_values_length 416 | 417 | # 4d mask is passed through the layers 418 | if attention_mask is not None and len(attention_mask.shape) == 2: 419 | attention_mask = attn_mask_converter.to_4d( 420 | attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype 421 | ) 422 | elif attention_mask is not None and len(attention_mask.shape) == 4: 423 | expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) 424 | if tuple(attention_mask.shape) != expected_shape: 425 | raise ValueError( 426 | f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." 427 | ) 428 | else: 429 | # if the 4D mask has correct shape - invert it and fill with negative infinity 430 | inverted_mask = 1.0 - attention_mask 431 | attention_mask = inverted_mask.masked_fill( 432 | inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min 433 | ) 434 | else: 435 | attention_mask = attn_mask_converter.to_causal_4d( 436 | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device 437 | ) 438 | 439 | return attention_mask 440 | 441 | 442 | # Adapted from _prepare_4d_causal_attention_mask 443 | def _prepare_4d_causal_attention_mask_for_sdpa( 444 | attention_mask: Optional[torch.Tensor], 445 | input_shape: Union[torch.Size, Tuple, List], 446 | inputs_embeds: torch.Tensor, 447 | past_key_values_length: int, 448 | sliding_window: Optional[int] = None, 449 | ): 450 | """ 451 | Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. 452 | 453 | In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and 454 | `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, 455 | allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). 456 | """ 457 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) 458 | 459 | key_value_length = input_shape[-1] + past_key_values_length 460 | 461 | # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` 462 | # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. 463 | # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). 464 | is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling() 465 | 466 | ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( 467 | attention_mask=attention_mask, 468 | inputs_embeds=inputs_embeds, 469 | past_key_values_length=past_key_values_length, 470 | sliding_window=sliding_window, 471 | ) 472 | 473 | if ignore_causal_mask: 474 | expanded_4d_mask = None 475 | elif attention_mask is None: 476 | expanded_4d_mask = attn_mask_converter.to_causal_4d( 477 | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device 478 | ) 479 | else: 480 | if attention_mask.dim() == 4: 481 | expanded_4d_mask = attention_mask 482 | else: 483 | expanded_4d_mask = attn_mask_converter.to_4d( 484 | attention_mask, 485 | input_shape[-1], 486 | dtype=inputs_embeds.dtype, 487 | key_value_length=key_value_length, 488 | ) 489 | 490 | # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when 491 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. 492 | # Details: https://github.com/pytorch/pytorch/issues/110213 493 | if not is_tracing and expanded_4d_mask.device.type == "cuda": 494 | expanded_4d_mask = AttentionMaskConverter._unmask_unattended( 495 | expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min 496 | ) 497 | 498 | return expanded_4d_mask 499 | 500 | 501 | def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 502 | """ 503 | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 504 | `(batch_size, key_value_length)` 505 | 506 | Args: 507 | mask (`torch.Tensor`): 508 | A 2D attention mask of shape `(batch_size, key_value_length)` 509 | dtype (`torch.dtype`): 510 | The torch dtype the created mask shall have. 511 | tgt_len (`int`): 512 | The target length or query length the created mask shall have. 513 | """ 514 | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) 515 | 516 | 517 | def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 518 | """ 519 | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape 520 | `(batch_size, key_value_length)` 521 | 522 | Args: 523 | mask (`torch.Tensor`): 524 | A 2D attention mask of shape `(batch_size, key_value_length)` 525 | dtype (`torch.dtype`): 526 | The torch dtype the created mask shall have. 527 | tgt_len (`int`): 528 | The target length or query length the created mask shall have. 529 | """ 530 | _, key_value_length = mask.shape 531 | tgt_len = tgt_len if tgt_len is not None else key_value_length 532 | 533 | is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling() 534 | 535 | # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. 536 | if not is_tracing and torch.all(mask == 1): 537 | return None 538 | else: 539 | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) 540 | 541 | 542 | def _create_4d_causal_attention_mask( 543 | input_shape: Union[torch.Size, Tuple, List], 544 | dtype: torch.dtype, 545 | device: torch.device, 546 | past_key_values_length: int = 0, 547 | sliding_window: Optional[int] = None, 548 | ) -> Optional[torch.Tensor]: 549 | """ 550 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` 551 | 552 | Args: 553 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`): 554 | The input shape should be a tuple that defines `(batch_size, query_length)`. 555 | dtype (`torch.dtype`): 556 | The torch dtype the created mask shall have. 557 | device (`int`): 558 | The torch device the created mask shall have. 559 | sliding_window (`int`, *optional*): 560 | If the model uses windowed attention, a sliding window should be passed. 561 | """ 562 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) 563 | 564 | key_value_length = past_key_values_length + input_shape[-1] 565 | attention_mask = attn_mask_converter.to_causal_4d( 566 | input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device 567 | ) 568 | 569 | return attention_mask --------------------------------------------------------------------------------