├── cgm
├── utils
│ ├── __init__.py
│ ├── metrics.py
│ ├── common_utils.py
│ ├── loss.py
│ ├── arguments.py
│ └── train_utils.py
├── launch
│ ├── zero3.sh
│ └── zero2.sh
├── config
│ └── template.json
├── inference
│ └── layer.py
├── data
│ ├── encode.py
│ └── preprocess.py
├── train
│ └── train.py
├── modeling
│ └── cgm.py
└── models
│ └── qwen2
│ └── _4_46_1
│ └── modeling_attn_mask_utils.py
├── assets
├── pipeline.png
├── cgm_method_v3.png
├── SWE-Bench-Lite.png
├── framework_1126.jpeg
├── cgm_framework_0123.png
├── swe-bench-lite-ow.png
├── swe-bench-modified.png
└── github-codefuse-logo-update.jpg
├── reranker
├── codegraph_parser
│ ├── java
│ │ └── __pycache__
│ │ │ └── codegraph_java_local.cpython-38.pyc
│ └── python
│ │ └── __pycache__
│ │ └── codegraph_python_local.cpython-38.pyc
├── qwen_api.py
├── prompt.py
└── reranker.py
├── LEGAL.md
├── preprocess_embedding
├── generate_rewriter_embedding.py
├── generate_code_content.py
└── generate_code_embedding.py
├── retriever
├── utils.py
├── serialize_subgraph.py
├── locate_anchor_node.py
└── subgraph.py
├── rewriter
├── inference_rewriter.py
├── prompt.py
├── rewriter_output_post_process.py
└── generate_rewriter_prompt.py
└── README.md
/cgm/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/cgm_method_v3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/cgm_method_v3.png
--------------------------------------------------------------------------------
/assets/SWE-Bench-Lite.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/SWE-Bench-Lite.png
--------------------------------------------------------------------------------
/assets/framework_1126.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/framework_1126.jpeg
--------------------------------------------------------------------------------
/assets/cgm_framework_0123.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/cgm_framework_0123.png
--------------------------------------------------------------------------------
/assets/swe-bench-lite-ow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/swe-bench-lite-ow.png
--------------------------------------------------------------------------------
/assets/swe-bench-modified.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/swe-bench-modified.png
--------------------------------------------------------------------------------
/assets/github-codefuse-logo-update.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/assets/github-codefuse-logo-update.jpg
--------------------------------------------------------------------------------
/reranker/codegraph_parser/java/__pycache__/codegraph_java_local.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/reranker/codegraph_parser/java/__pycache__/codegraph_java_local.cpython-38.pyc
--------------------------------------------------------------------------------
/reranker/codegraph_parser/python/__pycache__/codegraph_python_local.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/codefuse-ai/CodeFuse-CGM/HEAD/reranker/codegraph_parser/python/__pycache__/codegraph_python_local.cpython-38.pyc
--------------------------------------------------------------------------------
/LEGAL.md:
--------------------------------------------------------------------------------
1 | Legal Disclaimer
2 |
3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
4 |
5 | 法律免责声明
6 |
7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
--------------------------------------------------------------------------------
/cgm/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
2 |
3 | def calculate_metrics(y_true, y_pred, average='binary'):
4 |
5 | accuracy = accuracy_score(y_true, y_pred)
6 | precision = precision_score(y_true, y_pred, average=average)
7 | recall = recall_score(y_true, y_pred, average=average)
8 | f1 = f1_score(y_true, y_pred, average=average)
9 |
10 | metrics = {
11 | 'accuracy': accuracy,
12 | 'precision': precision,
13 | 'recall': recall,
14 | 'f1_score': f1
15 | }
16 |
17 | return metrics
--------------------------------------------------------------------------------
/cgm/launch/zero3.sh:
--------------------------------------------------------------------------------
1 | accelerate launch \
2 | --num_machines $N_NODE \
3 | --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \
4 | --use_deepspeed \
5 | --deepspeed_multinode_launcher 'standard' \
6 | --zero_stage 3 \
7 | --offload_optimizer_device 'none' \
8 | --offload_param_device 'none' \
9 | --gradient_accumulation_steps 32 \
10 | --gradient_clipping 1.0 \
11 | --zero3_init_flag true \
12 | --zero3_save_16bit_model true \
13 | --main_training_function 'main' \
14 | --mixed_precision 'bf16' \
15 | --dynamo_backend 'no' \
16 | --same_network \
17 | --machine_rank $RANK \
18 | --main_process_ip $MASTER_ADDR \
19 | --main_process_port $MASTER_PORT \
20 | --rdzv_backend 'static' \
21 | train/train.py --c config/$TRAIN_CONFIG
22 |
--------------------------------------------------------------------------------
/cgm/launch/zero2.sh:
--------------------------------------------------------------------------------
1 | accelerate launch \
2 | --num_machines $N_NODE \
3 | --num_processes $(($N_NODE*$N_GPU_PER_NODE)) \
4 | --use_deepspeed \
5 | --deepspeed_multinode_launcher 'standard' \
6 | --zero_stage 2 \
7 | --offload_optimizer_device 'none' \
8 | --offload_param_device 'none' \
9 | --gradient_accumulation_steps 32 \
10 | --gradient_clipping 1.0 \
11 | --zero3_init_flag false \
12 | --zero3_save_16bit_model false \
13 | --main_training_function 'main' \
14 | --mixed_precision 'bf16' \
15 | --dynamo_backend 'no' \
16 | --same_network \
17 | --machine_rank $RANK \
18 | --main_process_ip $MASTER_ADDR \
19 | --main_process_port $MASTER_PORT \
20 | --rdzv_backend 'static' \
21 | train/train.py --c config/$TRAIN_CONFIG
22 |
23 |
--------------------------------------------------------------------------------
/reranker/qwen_api.py:
--------------------------------------------------------------------------------
1 | from vllm import LLM, SamplingParams
2 | from vllm.sampling_params import BeamSearchParams
3 | from vllm.entrypoints.chat_utils import apply_hf_chat_template
4 | import os
5 |
6 | os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
7 |
8 | class QwenAPI:
9 | def __init__(self, model_path, tensor_parallel_size=4):
10 | self.llm = LLM(model=model_path, tensor_parallel_size=tensor_parallel_size)
11 |
12 | sample_params = dict(
13 | temperature=0.1,
14 | repetition_penalty=1.1,
15 | top_k=50,
16 | top_p=0.98,
17 | max_tokens=1024
18 | )
19 |
20 | self.sampling_params = SamplingParams(**sample_params)
21 | print(f"sampling_params: {sample_params}")
22 |
23 | def get_response(self, system_prompt: str, user_prompt: str):
24 | conversation = [
25 | {
26 | "role": "system",
27 | "content": system_prompt
28 | },
29 | {
30 | "role": "user",
31 | "content": user_prompt
32 | }
33 | ]
34 | outputs = self.llm.chat(conversation, sampling_params=self.sampling_params, use_tqdm=False)
35 | response = outputs[0].outputs[0].text
36 | return response
37 |
38 |
39 | if __name__ == "__main__":
40 | llm = QwenAPI("Qwen/Qwen2.5-1.5B-Instruct")
41 | user_prompt = "Where is the capital of China?"
42 | system_prompt = "You are a helpful assistant."
43 | llm.get_response(system_prompt, user_prompt)
--------------------------------------------------------------------------------
/preprocess_embedding/generate_rewriter_embedding.py:
--------------------------------------------------------------------------------
1 | """
2 | generate embedding for Queries from Rewriter's Inferer
3 | """
4 |
5 | from transformers import AutoTokenizer, AutoModel
6 | import torch
7 | import os
8 | import numpy as np
9 | import pandas as pd
10 | import tqdm
11 | import json
12 | import pickle
13 |
14 | # custom
15 | import argparse, logging
16 |
17 | # input path
18 | rewriter_output_path = "rewriter_output.json"
19 |
20 | # save path
21 | rewriter_embedding_path = "rewriter_embedding.pkl"
22 |
23 | # load model
24 | model_name_or_path = "xxx/CodeFuse-CGE-Large"
25 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
26 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right')
27 |
28 | if torch.cuda.is_available():
29 | device = 'cuda'
30 | else:
31 | device = 'cpu'
32 | model.to(device)
33 |
34 | if __name__ == "__main__":
35 |
36 | with open(rewriter_output_path, 'r') as file:
37 | rewriter_output_dict = json.load(file)
38 |
39 | query_embedding_dict = {}
40 |
41 | for instance_id in tqdm.tqdm(rewriter_output_dict):
42 | query = rewriter_output_dict[instance_id]["query"]
43 |
44 | if len(query) == 0:
45 | continue
46 |
47 | query_embedding_dict[instance_id] = model.encode(tokenizer, query)
48 |
49 | with open(rewriter_embedding_path, 'wb') as f:
50 | pickle.dump(query_embedding_dict, f)
--------------------------------------------------------------------------------
/cgm/config/template.json:
--------------------------------------------------------------------------------
1 | {
2 | "graph_dir": [],
3 | "train_files":[],
4 | "valid_files":[],
5 | "output_dir": "",
6 | "tb_dir": "",
7 |
8 | "embedding_dim": 256,
9 | "load_pretrained_encoder": false,
10 | "pretrained_encoder_path": null,
11 |
12 | "load_pretrained_adapter": false,
13 | "pretrained_adapter_path": null,
14 | "adapter_hidden_dim": 4096,
15 | "adapter_num_layers": 1,
16 | "adapter_num_heads": 8,
17 |
18 | "load_pretrained_tokenizer": true,
19 | "pretrained_tokenizer_path": "Qwen/Qwen2.5-Coder-7B-Instruct",
20 |
21 | "pretrained_model_path": "Qwen/Qwen2.5-Coder-7B-Instruct",
22 | "self_defined": false,
23 | "framework_type": "T1",
24 | "model_type": "Qwen",
25 |
26 | "pretrained_lora_path": null,
27 | "quantization": "4bit",
28 |
29 | "mode": "eal",
30 | "task": "unit_test",
31 | "use_chat": false,
32 | "use_adj": true,
33 |
34 | "peft": "LoRA",
35 | "lora_rank": 32,
36 | "lora_alpha": 32,
37 | "lora_dropout": 0.05,
38 | "lora_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
39 |
40 | "enc_peft": "LoRA",
41 | "enc_lora_rank": 32,
42 | "enc_lora_alpha": 32,
43 | "enc_lora_dropout": 0.05,
44 | "enc_lora_modules": "all-linear",
45 |
46 | "graph_pad_token": "<|graph_pad|>",
47 | "graph_pad_id": 32022,
48 | "graph_token_num": 512,
49 |
50 | "learning_rate": 1e-4,
51 | "min_lr": 1e-7,
52 | "weight_decay": 0.1,
53 | "lr_scheduler_type": "reduce_lr_on_plateau",
54 |
55 | "gradient_accumulation_steps": 1,
56 | "num_warmup_steps": 0,
57 | "adapter_warmup": false,
58 | "adapter_warmup_steps": 500,
59 | "num_train_epochs": 20,
60 |
61 | "data_split": "0.98,0.02",
62 | "max_train_steps": null,
63 | "max_train_samples": null,
64 | "max_valid_samples": 2048,
65 | "per_device_train_batch_size": 1,
66 | "per_device_eval_batch_size": 1,
67 |
68 | "seed": 42,
69 | "seq_length": 8192,
70 | "log_interval": 5,
71 |
72 |
73 |
74 | "step_checkpointing": false,
75 | "checkpointing_steps": 500,
76 | "step_evaluation": true,
77 | "evaluation_steps": 5000,
78 | "epoch_evaluation": false,
79 | "epoch_checkpointing": false,
80 |
81 | "early_stopping": true,
82 | "early_stopping_stall_num": 6,
83 |
84 | "attn_implementation": "sdpa"
85 | }
86 |
--------------------------------------------------------------------------------
/cgm/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # print out arguments in a nice way
4 | def print_args(args, accelerator):
5 | # 计算所有键的最大字符串长度
6 | max_key_length = max(len(str(key)) for key in vars(args).keys())
7 |
8 | message = ""
9 | message += "====" * 40 + "\n"
10 | message += '\n'.join([f'{k:<{max_key_length}} : {v}' for k, v in vars(args).items()]) + "\n"
11 | message += "====" * 40 + "\n"
12 | accelerator.print(message)
13 |
14 | def count_parameters(model):
15 | return sum(p.numel() for p in model.parameters())
16 |
17 | def print_with_rank(accelerator, msg):
18 | print(accelerator.process_index, msg)
19 |
20 | def print_rank_0(*message):
21 | """If distributed is initialized print only on rank 0."""
22 | if torch.distributed.is_initialized():
23 | if torch.distributed.get_rank() == 0:
24 | print(*message, flush=True)
25 | else:
26 | print(*message, flush=True)
27 |
28 | def print_rank_0_highlight(*message):
29 | """If distributed is initialized print only on rank 0."""
30 | if torch.distributed.is_initialized():
31 | if torch.distributed.get_rank() == 0:
32 | print('=='*100)
33 | print(*message, flush=True)
34 | print('=='*100)
35 | else:
36 | print('=='*100)
37 | print(*message, flush=True)
38 | print('=='*100)
39 |
40 | def print_highlight(*message):
41 | print('=='*100)
42 | print(*message)
43 | print('=='*100)
44 |
45 | def get_computation_speed(batch_size_per_device, seq_len, step_time):
46 | return batch_size_per_device * seq_len / (step_time + 1e-12)
47 |
48 | def touch_print(accelerator, batch, num_tokens=10):
49 | """touch first and last tokens and labels for debugging usage"""
50 | accelerator.print(f"step 1 batch shape: {batch['input_ids'].shape},\n"
51 | f"last {num_tokens} labels: {batch['labels'][:, -num_tokens:]}"
52 | f"last {num_tokens} loss mask: {batch['loss_mask'][:, -num_tokens:]}")
53 | accelerator.print(f"first {num_tokens} input_ids and loss_mask")
54 | for pt in range(1):
55 | accelerator.print(f"{batch['input_ids'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
56 | accelerator.print(f"{batch['loss_mask'][:, num_tokens * pt: num_tokens * pt + num_tokens]}")
--------------------------------------------------------------------------------
/cgm/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
3 |
4 | def loss_CGM(output_logits, labels, loss_mask):
5 |
6 | lm_logits = output_logits.contiguous()
7 | labels = labels.to(device=lm_logits.device).contiguous()
8 | loss_mask = loss_mask.to(device=lm_logits.device)
9 | # logits: (bs, l, v); labels, loss_mask: (bs, l)
10 |
11 | # lm loss
12 | bsz = labels.shape[0]
13 | loss_func = CrossEntropyLoss(reduction='none')
14 | losses = loss_func(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) # logits: (bs * l, v); labels: (bs * l,)
15 | # losses -> (bs, l)
16 | losses = losses.contiguous().view(bsz, -1)
17 |
18 | loss_mask = loss_mask.view(-1)
19 | losses = losses.view(-1)
20 | if loss_mask.sum() < 1:
21 | loss_lm = torch.sum(losses * loss_mask)
22 | else:
23 | loss_lm = torch.sum(losses * loss_mask) / loss_mask.sum()
24 |
25 | return loss_lm
26 |
27 | def acc_lp(logits,labels):
28 | predictions = torch.sigmoid(logits)
29 | acc = ((predictions > 0.5) == labels.bool()).float().mean()
30 | return acc.item()
31 |
32 | def loss_lp(outputs, edge_label_dict):
33 | loss_func = BCEWithLogitsLoss(reduction='mean')
34 | losses = []
35 | edge_loss = {}
36 | edge_acc = {}
37 | total_acc = 0
38 | total_edges = 0
39 | for edge_type in edge_label_dict.keys():
40 | lm_logits = outputs[edge_type].view(-1)
41 | labels = edge_label_dict[edge_type].to(device=lm_logits.device).view(-1)
42 | loss = loss_func(lm_logits,labels)
43 | losses.append(loss)
44 | acc = acc_lp(lm_logits, labels)
45 | edge_loss[edge_type] = loss.item()
46 | edge_acc[edge_type] = acc
47 | total_acc += len(labels) * acc
48 | total_edges += len(labels)
49 | # del lm_logits, labels, loss
50 | loss1 = torch.sum(torch.stack(losses))
51 | total_acc = total_acc / total_edges
52 | # del losses, loss_func
53 | return loss1, edge_loss, edge_acc, total_acc
54 |
55 | def loss_ng(outputs, y_dict, mask_dict):
56 | loss_func = MSELoss(reduction='sum')
57 | loss2 = loss_func(outputs,y_dict['Method'])
58 | return loss2
59 |
60 | def loss_lpng(lp_outputs, ng_outputs, edge_label_dict, y_dict, mask_dict):
61 | loss1, edge_loss, edge_acc, total_acc = loss_lp(lp_outputs, edge_label_dict)
62 | loss2 = loss_ng(ng_outputs, y_dict, mask_dict)
63 | loss = loss1 + loss2
64 | return loss, loss1.item(), loss2.item(), edge_loss, edge_acc, total_acc
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/preprocess_embedding/generate_code_content.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate node content for all nodes in code graph
3 | """
4 | from os.path import isfile
5 | import os, sys
6 | import pandas as pd
7 | import json
8 | import tqdm
9 | import re
10 |
11 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType
12 |
13 | def extract_code_and_doc(content):
14 | """
15 | split code and doc
16 | """
17 | # match docstring
18 | docstring_pattern = r'"""(.*?)"""|\'\'\'(.*?)\'\'\''
19 | docstrings = re.findall(docstring_pattern, content, re.DOTALL)
20 |
21 | # extract pure code
22 | code_without_docstring = re.sub(docstring_pattern, '', content, flags=re.DOTALL)
23 | # merge docstring
24 | extracted_docstrings = "\n\n".join([d[0] or d[1] for d in docstrings])
25 | return code_without_docstring, extracted_docstrings
26 |
27 | def get_graph_file_name(item):
28 | """
29 | return graph_file_name
30 | """
31 |
32 | raise NotImplementedError
33 |
34 | if __name__ == "__main__":
35 |
36 | graph_basic_df = pd.read_json("test_lite_basic_info.json")
37 |
38 | graph_data_path = "codegraph/"
39 | node_content_path = "xx/node_content/"
40 | graph_list = os.listdir(graph_data_path)
41 |
42 | # get the graph_file path
43 | graph_basic_df["graph_file"] = graph_basic_df.apply(lambda item: get_graph_file_name(item), axis=1)
44 |
45 | # generate code content for each repo
46 | for idx, item in tqdm.tqdm(graph_basic_df.iterrows()):
47 |
48 | instance_id = item.instance_id
49 | graph_file = item.graph_file
50 | # get the graph path
51 | tmp_graph_data_path = graph_data_path + graph_file
52 |
53 | # skip files which have been processed
54 | if os.path.isfile(node_content_path + '{}.json'.format(instance_id)):
55 | continue
56 |
57 | graph = parse(tmp_graph_data_path)
58 |
59 | try:
60 | nodes = graph.get_nodes()
61 | except:
62 | print(f"========= parse error: {tmp_graph_data_path} =========")
63 | continue
64 | node_code_dict = {}
65 | node_doc_dict = {}
66 | for node in nodes:
67 | node_id = node.node_id
68 | content = node.get_content()
69 |
70 | code, doc = extract_code_and_doc(content)
71 |
72 | node_code_dict[node_id] = code
73 |
74 | if doc.strip():
75 | node_doc_dict[node_id] = doc
76 |
77 | # save the result
78 | with open(node_content_path + '{}.json'.format(instance_id), 'w', encoding='utf-8') as json_file:
79 |
80 | node_content_dict = {
81 | "code": node_code_dict,
82 | "doc": node_doc_dict
83 | }
84 |
85 | json.dump(node_content_dict, json_file, ensure_ascii=False, indent=4)
--------------------------------------------------------------------------------
/preprocess_embedding/generate_code_embedding.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate embedding for code
3 | """
4 |
5 | from transformers import AutoTokenizer, AutoModel
6 | import torch
7 | import os
8 | import numpy as np
9 | import tqdm
10 | import json
11 | import pickle
12 |
13 | # custom
14 | import argparse, logging
15 |
16 | # input and output path
17 | node_content_path = "xx/node_content/"
18 | node_embedding_path = "xx/tmp_node_embedding/"
19 |
20 | # load model
21 | model_name_or_path = "xxx/CodeFuse-CGE-Large"
22 | model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
23 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, truncation_side='right', padding_side='right')
24 |
25 | if torch.cuda.is_available():
26 | device = 'cuda'
27 | else:
28 | device = 'cpu'
29 | model.to(device)
30 |
31 | if __name__ == "__main__":
32 |
33 | node_embedding_dict = None
34 | node_embedding_list = os.listdir(node_embedding_path)
35 | node_embedding_list = [item.split('.')[0] for item in node_embedding_list]
36 | candidate_graphs = os.listdir(node_content_path)
37 |
38 | for filename in tqdm.tqdm(candidate_graphs):
39 |
40 | instance_id = filename.split('.')[0]
41 |
42 | # skip samples which have been processed
43 | if instance_id in node_embedding_list:
44 | continue
45 |
46 | with open(node_content_path + filename, 'r', encoding='utf-8') as file:
47 | node_content_dict = json.load(file)
48 |
49 | node_list = list(node_content_dict['code'].keys())
50 |
51 | node_embedding_dict = {}
52 | node_code_embedding_dict = {}
53 | node_doc_embedding_dict = {}
54 | for node in node_list:
55 | code_content = node_content_dict['code'][node]
56 | if node in node_content_dict['doc']:
57 | doc_content = node_content_dict['doc'][node]
58 | code_content = code_content if code_content else " "
59 | doc_content = doc_content if doc_content else " "
60 | # batch process
61 | text = [code_content, doc_content]
62 | node_code_embedding_dict[node], node_doc_embedding_dict[node] = model.encode(tokenizer, text)
63 |
64 | else:
65 | # for node without doc
66 | code_content = code_content if code_content else " "
67 | node_code_embedding_dict[node] = model.encode(tokenizer, code_content)
68 |
69 | node_embedding_dict = {
70 | "code": node_code_embedding_dict,
71 | "doc": node_doc_embedding_dict
72 | }
73 |
74 | with open(node_embedding_path + '{}.pkl'.format(instance_id), 'wb') as f:
75 | pickle.dump(node_embedding_dict, f)
76 |
77 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/retriever/utils.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 |
3 | def codegraph_to_nxgraph(graph):
4 | """
5 | 将 CodeGraph 对象 转为 networkx 图对象
6 | :param graph: CodeGraph 对象
7 | :return graph_nx: nx.MultiDiGraph 对象
8 | """
9 |
10 | # 创建图 有向且允许两个节点之间有多条边
11 | G = nx.MultiDiGraph()
12 |
13 | # 增加点
14 | for node in graph.nodes:
15 | G.add_node(graph.nodes[node])
16 |
17 | # 增加边
18 | for edge in graph.edges:
19 |
20 | src_id = edge.source
21 | tgt_id = edge.target
22 | # 当前的数据中 会出现 不存在 node list 里的节点
23 | # 这样的 边 和 点 都先移除
24 | try:
25 | G.add_edge(graph.nodes[src_id], graph.nodes[tgt_id], type=edge.edge_type)
26 | except:
27 | pass
28 |
29 | print(f"nx Graph data parsed, nodes: {G.number_of_nodes():,}, edges: {G.number_of_edges():,}")
30 |
31 | return G
32 |
33 | def codegraph_to_nxgraph_lite(graph):
34 | """
35 | 轻量版 将 CodeGraph 对象 转为 networkx 图对象
36 | 节点和边类型都脱离parser定义
37 | :param graph: CodeGraph 对象
38 | :return graph_nx: nx.MultiDiGraph 对象
39 | """
40 |
41 | # 创建图 有向且允许两个节点之间有多条边
42 | G = nx.MultiDiGraph()
43 |
44 | # 增加点
45 | for node in graph.nodes:
46 | G.add_node(node.node_id)
47 |
48 | # 增加边
49 | for edge in graph.edges:
50 |
51 | src_id = edge.source
52 | tgt_id = edge.target
53 | try:
54 | # 当前的数据中 会出现 不存在 node list 里的节点
55 | # 这样的 边 和 点 都先移除
56 | graph.nodes[src_id]
57 | graph.nodes[tgt_id]
58 | # 确保 源节点目标节点 都存在后,再加入
59 | G.add_edge(src_id, tgt_id, type=edge.edge_type.name)
60 | except:
61 | pass
62 |
63 | print(f"nx Graph lite data parsed, nodes: {G.number_of_nodes():,}, edges: {G.number_of_edges():,}")
64 |
65 | return G
66 |
67 | def codegraph_to_nxgraph_analysis(graph):
68 | """
69 | 分析版 将 CodeGraph 对象 转为 networkx 图对象
70 | 专门用于路径分析的版本,移除 Repo 和 Package 节点;
71 | 返回 有向图 和 无向图 两个 版本
72 | - 有向图:用于分析节点之间的转移概率;两个点之间可能存在多条边
73 | - 无向图:用于分析节点(File)之间的最短路径
74 | :param graph: CodeGraph 对象
75 | :return graph_nx: nx.MultiDiGraph 对象
76 | """
77 |
78 | # 创建图 有向且允许两个节点之间有多条边
79 | G_d = nx.MultiDiGraph()
80 | G_u = nx.Graph()
81 |
82 | # 增加点
83 | for node_id in graph.nodes:
84 | node = graph.get_node_by_id(node_id)
85 | node_type = node.get_type().name
86 | if node_type in ['REPO', 'PACKAGE']:
87 | continue
88 | G_d.add_node(node.node_id)
89 | G_u.add_node(node.node_id)
90 |
91 | # 增加边
92 | for edge in graph.edges:
93 |
94 | src_id = edge.source
95 | tgt_id = edge.target
96 |
97 | if G_d.has_node(src_id) and G_d.has_node(tgt_id):
98 | G_d.add_edge(src_id, tgt_id, type=edge.edge_type.name)
99 | G_u.add_edge(src_id, tgt_id, type=edge.edge_type.name)
100 |
101 | print(f"nx Graph analysis data parsed, nodes: {G_d.number_of_nodes():,}, edges: {G_d.number_of_edges():,}")
102 |
103 | return G_d, G_u
--------------------------------------------------------------------------------
/retriever/serialize_subgraph.py:
--------------------------------------------------------------------------------
1 | """
2 | serialize subgraph to json file
3 | Here we provide two version
4 | ✅ Direct serialize the original subgraph
5 | ✅ Serialize the file-level subgraph
6 | """
7 | import sys
8 | import json
9 | import os
10 | import tqdm
11 | import pandas as pd
12 | import networkx as nx
13 |
14 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType
15 | from utils import codegraph_to_nxgraph
16 |
17 | ############################# utils #############################
18 |
19 | def get_contained_node(graph_nx, node):
20 |
21 | c_node_list = []
22 | for suc_node in graph_nx.successors(node):
23 | if graph_nx[node][suc_node][0]['type'] == EdgeType.CONTAINS:
24 | c_node_list.append(suc_node)
25 |
26 | return c_node_list
27 |
28 | def get_inner_nodes(graph_nx, node):
29 |
30 | inner_nodes = get_contained_node(graph_nx, node)
31 | inner_nodes_all = []
32 |
33 | while len(inner_nodes) != 0:
34 |
35 | tmp_inner_nodes = inner_nodes.copy()
36 | inner_nodes = []
37 | for node in tmp_inner_nodes:
38 | inner_nodes_all.append(node)
39 | inner_nodes.extend(get_contained_node(graph_nx, node))
40 |
41 | return list(set(inner_nodes_all))
42 |
43 | def serialize_subgraph(graph_nx, file_name):
44 |
45 | node_list = [node.to_dict() for node in graph_nx.nodes()]
46 |
47 | # 获取子图中的连边关系
48 | edge_list = []
49 | for edge in graph_nx.edges():
50 | edge_type = graph_nx[edge[0]][edge[1]][0]['type']
51 | tmp_edge_dict = {
52 | "edgeType": edge_type.name.lower(),
53 | "source": edge[0].node_id,
54 | "target": edge[1].node_id
55 | }
56 | edge_list.append(tmp_edge_dict)
57 | # 对所有的边,获取对应边类型
58 | graph_json = {
59 | "nodes": node_list,
60 | "edges": edge_list
61 | }
62 |
63 | with open(file_name + '.json', 'w') as json_file:
64 | json.dump(graph_json, json_file, indent=4)
65 |
66 | return True
67 |
68 | ############################# utils #############################
69 |
70 |
71 | if __name__ == "__main__":
72 |
73 | test_basic_df = pd.read_json("test_basic_info.json")
74 | graph_data_path = "codegraph/"
75 | subgraph_dict_path = "subgraph_nodes.json"
76 | save_path = "subgraph/"
77 |
78 | with open(subgraph_dict_path, "r", encoding="utf-8") as file:
79 | one_hop_dict = json.load(file)
80 | file.close()
81 |
82 | for idx, item in tqdm.tqdm(test_basic_df.iterrows()):
83 |
84 | instance_id = item.instance_id
85 | graph_file = item.graph_file
86 | tmp_graph_data_path = graph_data_path + graph_file
87 |
88 | if not os.path.exists(tmp_graph_data_path):
89 | continue
90 |
91 | filename = save_path + instance_id
92 |
93 | if os.path.exists(filename + '.json'):
94 | continue
95 |
96 | graph = parse(tmp_graph_data_path)
97 | graph_nx = codegraph_to_nxgraph(graph)
98 |
99 | # Version 1: Directly Serialization
100 | # all_nodes = one_hop_dict[instance_id]
101 |
102 | # Version 1: Serialization in File-level
103 | all_nodes = one_hop_dict[instance_id]
104 | for node_id in all_nodes:
105 | node = graph.get_node_by_id(node_id)
106 | if node.get_type() == NodeType.FILE:
107 | inner_node = get_inner_nodes(graph_nx, node)
108 | for i_node in inner_node:
109 | all_nodes.append(i_node.node_id)
110 |
111 | all_nodes = [graph.get_node_by_id(node_id) for node_id in all_nodes]
112 | subgraph = graph_nx.subgraph(list(all_nodes))
113 |
114 | serialize_subgraph(subgraph, filename)
--------------------------------------------------------------------------------
/cgm/utils/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 | import argparse, json
3 | from typing import List, Union
4 | import torch
5 |
6 |
7 | @dataclass
8 | class TrainArgs:
9 | graph_dir: Union[str, List[str]]
10 | train_files: Union[str, List[str]]
11 | valid_files: Union[str, List[str]]
12 | output_dir: str
13 | tb_dir: str
14 |
15 | embedding_dim: int = 2304
16 |
17 | load_pretrained_encoder: bool = False
18 | pretrained_encoder_path: Union[None, str] = None
19 | load_pretrained_adapter: bool = False
20 | pretrained_adapter_path: Union[None, str] = None
21 | adapter_hidden_dim: int = 4096
22 | adapter_num_layers: int = 1
23 | adapter_num_heads: int = 8
24 |
25 | self_defined: bool = False
26 | pretrained_model_path: Union[None, str] = None
27 | lm_hidden_dim: int = 4096
28 | quantization: Union[None, str] = None
29 | framework_type: Union[None, str] = "default"
30 | model_type: Union[None, str] = None
31 |
32 | load_pretrained_tokenizer: bool = True
33 | pretrained_tokenizer_path: Union[None, str] = None
34 |
35 | # for evaluation
36 | pretrained_lora_path: Union[None, str] = None
37 |
38 | # training mode:
39 | # "e" 1, "a" 2, "l" 3
40 | mode: str = "a"
41 | task: str = "align"
42 | use_chat: bool = True
43 | use_adj: bool = False
44 |
45 | # lora rank, the bigger, the more trainalbe parameters
46 | peft: Union[None, str] = None
47 | lora_rank: int = 32
48 | lora_alpha: int = 32
49 | lora_dropout: float = 0.05
50 | lora_modules: Union[str, List[str]] = "all-linear"
51 |
52 | enc_peft: Union[None, str] = None
53 | enc_lora_rank: int = 32
54 | enc_lora_alpha: int = 32
55 | enc_lora_dropout: float = 0.05
56 | enc_lora_modules: Union[str, List[str]] = "all-linear"
57 |
58 | graph_pad_token: str = "<|graph_pad|>"
59 | graph_pad_id: int = 32022
60 | graph_token_num: int = 512
61 |
62 | learning_rate: float = 5e-5
63 | min_lr: float = 5e-6
64 | weight_decay: float = 0.1
65 | lr_scheduler_type: str = "cosine"
66 |
67 | gradient_accumulation_steps: int = 1
68 | num_warmup_steps: int = 300
69 | adapter_warmup: bool = False
70 | adapter_warmup_steps: int = 500
71 | num_train_epochs: int = 2
72 |
73 | # train/valid split
74 | data_split: str = "0.98,0.02"
75 | max_train_samples: Union[None, int] = None
76 | max_valid_samples: Union[None, int] = None
77 |
78 | per_device_train_batch_size: int = 1
79 | per_device_eval_batch_size: int = 1
80 |
81 | seed: int = 42
82 |
83 | seq_length: int = 4096
84 | log_interval: int = 10
85 | step_checkpointing: bool = False
86 | checkpointing_steps: int = 100
87 |
88 | step_evaluation: bool = False
89 | evaluation_steps: int = 100
90 |
91 | # max train steps, if None, depends on num_train_epochs
92 | max_train_steps: Union[None, int] = None
93 |
94 | # if checkpointing every epoch, maybe True in sst
95 | epoch_checkpointing: bool = False
96 | epoch_evaluation: bool = False
97 |
98 | early_stopping: bool = False
99 | early_stopping_stall_num: int = 5
100 |
101 | attn_implementation: str = "flash_attention_2"
102 |
103 | def dict(self):
104 | return {k: str(v) for k, v in asdict(self).items()}
105 |
106 |
107 | def prepare_args(args_type="Train"):
108 | parser = argparse.ArgumentParser()
109 | parser.add_argument("--c", type=str, default=None)
110 | parsed = parser.parse_args()
111 | with open(parsed.c, 'r') as f:
112 | c = json.load(f)
113 | if args_type == "Train":
114 | args = TrainArgs(**c)
115 | else:
116 | raise ValueError("args_type must be Train")
117 | if not torch.cuda.is_available():
118 | args.attn_implementation = 'eager'
119 |
120 | return args
121 |
--------------------------------------------------------------------------------
/rewriter/inference_rewriter.py:
--------------------------------------------------------------------------------
1 | """
2 | Rewriter Inference
3 | """
4 | import os
5 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:3950"
6 |
7 | import torch
8 | import sys
9 | import time
10 | from tqdm import tqdm
11 | import re
12 |
13 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GenerationConfig
14 | import transformers
15 |
16 | transformers.logging.set_verbosity_error()
17 | import json
18 | from copy import deepcopy
19 |
20 | # custom
21 | import argparse, logging
22 | from torch.utils.data import Dataset, DataLoader
23 | import panda as pd
24 |
25 | class PromptDataset(Dataset):
26 | def __init__(self, data):
27 | self.data = data
28 |
29 | def __getitem__(self, index):
30 | return self.data[index]
31 |
32 | def __len__(self):
33 | return len(self.data)
34 |
35 | ###################################################
36 |
37 | def parse_args():
38 | """
39 | Parses the arguments
40 | """
41 |
42 | parser = argparse.ArgumentParser(description="Run Inference Model.")
43 |
44 | parser.add_argument('--prompt_path', nargs='?', default='',
45 | help='Specify the prompts file')
46 |
47 | return parser.parse_args()
48 |
49 | def inference_LLM_patch(prompt_path):
50 |
51 | pretrained_path = 'xxx/Qwen2.5-72B-Instruct'
52 |
53 | model = AutoModelForCausalLM.from_pretrained(pretrained_path, device_map="auto", torch_dtype="auto")
54 | tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
55 |
56 | model.half()
57 |
58 | data = []
59 | test_basic_info = pd.read_json(prompt_path)
60 | data = test_basic_info["extractor_prompt"].tolist() + test_basic_info["inferer_prompt"].tolist() # get prompt
61 | instance_num = len(test_basic_info["extractor_prompt"].tolist())
62 |
63 | dataset = PromptDataset(data)
64 | dataloader = DataLoader(dataset, batch_size=8, shuffle=False)
65 |
66 | response_list = []
67 | for batch in tqdm(dataloader):
68 | try:
69 | text_batch = []
70 | for prompt in batch:
71 | messages = [
72 | {"role": "system", "content": "You are a helpful assistant."},
73 | {"role": "user", "content": prompt}
74 | ]
75 | text = tokenizer.apply_chat_template(
76 | messages,
77 | add_generation_prompt=True,
78 | tokenize=False
79 | )
80 | text_batch.append(text)
81 |
82 | model_inputs = tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt").to(model.device)
83 |
84 | generated_ids = model.generate(
85 | **model_inputs,
86 | max_new_tokens=512
87 | )
88 | generated_ids = [
89 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
90 | ]
91 |
92 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
93 | response_list.extend(response)
94 | torch.cuda.empty_cache()
95 | except Exception as e:
96 | print(f"Out of Memory. {e}")
97 | response_list.extend(["error"] * len(batch))
98 | torch.cuda.empty_cache()
99 |
100 | #### save output ####
101 | test_basic_info["rewriter_inferer"] = response_list[:instance_num]
102 | test_basic_info["rewriter_extractor"] = response_list[instance_num:]
103 |
104 | test_basic_info.to_json("test_rewriter_output.json", index=False)
105 |
106 | return True
107 |
108 | if __name__ == "__main__":
109 |
110 | args = parse_args()
111 |
112 | print("Start Rewiter Running:")
113 | inference_LLM_patch(args.prompt_path)
114 | print("Rewiter Running Ended.")
--------------------------------------------------------------------------------
/rewriter/prompt.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt Template for Rewriter
3 | """
4 |
5 | def generate_prompt_for_extractor(problem_statement, repo_name):
6 | """
7 | 为 extractor 生成 prompt
8 | """
9 |
10 | prompt = """
11 |
12 | {}
13 |
14 | This is an issue related to repository '{}'.
15 | Instructions:
16 | 1. Analysis:
17 | ○ Analyze the provided issue description. Identify the relevant File, Class, or Function involved.
18 | ○ Determine the specific problem or error encountered and note any clues that may assist in locating the relevant or problematic area.
19 | 2. Extraction:
20 | ○ After the analysis, extract ALL the mentioned code entities (File, Class, or Function), especially Files.
21 | ○ Then extract three potential and meaningful keywords, responding in the following format:
22 |
23 | [start_of_analysis]
24 |
25 | [end_of_analysis]
26 |
27 | [start_of_related_code_entities]
28 |
29 | [end_of_related_code_entities]
30 |
31 | [start_of_related_keywords]
32 |
33 | [end_of_related_keywords]
34 |
35 | Notes:
36 | - Pay attention to the information in the error logs (if exists).
37 | - The buggy code exists solely in the project described in the issue (e.g., django, sklearn). Buggy location is usually not in the tests files or external packages.
38 | - Your extracted entities should be CONCISE, ACCURATE and INFORMATIVE.
39 | - Provide the relative path for code entities if specified (e.g., package/foo.py). Relative path is relative to the repository itself, do not include suffix like '/home/username/', '/etc/service/' or '/tree/master'.
40 | - Do not include any additional information such as line numbers or explanations in your extraction result.
41 |
42 | Preferred extraction Examples of Code Entities:
43 | - repo/cart.py
44 | - Class User()
45 | - def getData()
46 | Preferred extraction Examples of Keywords:
47 | - train_loop
48 | - hooks
49 | - docker
50 |
51 | Unpreferred extraction Examples of keywords:
52 | - something wrong
53 | - input validation
54 | - TypeError
55 | """.format(problem_statement, repo_name)
56 |
57 | return prompt
58 |
59 | def generate_prompt_for_inferer(problem_statement, repo_name):
60 | """
61 | 为 inferer 生成 prompt
62 | """
63 |
64 | prompt = """
65 |
66 | {}
67 |
68 | This is an issue related to repository '{}'.
69 | Task:
70 | Based on the issue description provided, identify the characteristics of code entities (files, functions, class) that might need to be modified.
71 | For each characteristic, generate a search query that could help locate relevant code entities in a codebase.
72 | Instructions:
73 | First, analyze the issue description and identify keywords, features, and functionalities that are likely relevant to the modification of code entities.
74 | Then, create queries that capture these characteristics, focusing on:
75 | ● File names that may implement relevant functionalities.
76 | ● Functions or methods that are related to the features described in the issue.
77 | ● Any patterns or structures that might be relevant to the functionalities mentioned.
78 | For example:
79 | ● File related to the initialization of a neural network.
80 | ● Function related to the training process.
81 | ● Code used to configure the service.
82 | Please answer in the following format:
83 |
84 | [start_of_analysis]
85 |
86 | [end_of_analysis]
87 |
88 | [start_of_related_queries]
89 | query 1:
90 | query 2:
91 | ...
92 | [end_of_related_queries]
93 |
94 | Notes:
95 | - Your queries should be DETAILED, ACCURATE and INFORMATIVE.
96 | - Your queries should be a complete sentences and do not include additional explanation.
97 | - The number of queries is up to five, so be focus on the important characteristics.
98 | - Your queries should focus on the repository code itself, rather than other information like commit history.
99 | - Pay attention to the information in the error logs (if exists).
100 |
101 | Preferred Query Examples:
102 | - Look for references to "tqdm" or "progress_bar" within the training loop files to find where progress bars are currently updated.
103 | - Code snippets where 'gethostbyname' function from 'socket' module is called.
104 | - File name containing 'mysql.py' AND functions related to 'MySQLStatementSamples' initialization.
105 | - Functions or methods handling hostname resolution or encoding within 'datadog_checks' directory.
106 | - Find all occurrences of "early_stopping" within files that also mention "Trainer" to identify where early stopping logic is implemented and potentially needs adjustment for non-default 'val_check_interval'.
107 | """.format(problem_statement, repo_name)
108 |
109 | return prompt
110 |
111 |
--------------------------------------------------------------------------------
/cgm/inference/layer.py:
--------------------------------------------------------------------------------
1 | """Attention layer."""
2 | from typing import Any, Dict, List, Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from vllm.attention import AttentionMetadata, AttentionType
8 | from vllm.attention.selector import get_attn_backend
9 | from vllm.config import CacheConfig
10 | from vllm.model_executor.layers.quantization.base_config import (
11 | QuantizationConfig)
12 | from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
13 |
14 |
15 | class Attention(nn.Module):
16 | """Attention layer.
17 |
18 | This class takes query, key, and value tensors as input. The input tensors
19 | can either contain prompt tokens or generation tokens.
20 | The class does the following:
21 |
22 | 1. Store the input key and value tensors in the KV cache.
23 | 2. Perform (multi-head/multi-query/grouped-query) attention.
24 | 3. Return the output tensor.
25 | """
26 |
27 | def __init__(
28 | self,
29 | num_heads: int,
30 | head_size: int,
31 | scale: float,
32 | num_kv_heads: Optional[int] = None,
33 | alibi_slopes: Optional[List[float]] = None,
34 | cache_config: Optional[CacheConfig] = None,
35 | quant_config: Optional[QuantizationConfig] = None,
36 | blocksparse_params: Optional[Dict[str, Any]] = None,
37 | logits_soft_cap: Optional[float] = None,
38 | prefix: str = "",
39 | ) -> None:
40 | super().__init__()
41 | if cache_config is not None:
42 | kv_cache_dtype = cache_config.cache_dtype
43 | block_size = cache_config.block_size
44 | sliding_window = cache_config.sliding_window
45 | is_attention_free = cache_config.is_attention_free
46 | else:
47 | kv_cache_dtype = "auto"
48 | block_size = 16
49 | sliding_window = None
50 | is_attention_free = False
51 | if num_kv_heads is None:
52 | num_kv_heads = num_heads
53 |
54 | # The default k/v_scale is set to 1.0. This is ignored
55 | # when kv-cache is not fp8, and should be used with
56 | # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
57 | # expect the pre-quantized k/v_scale to be loaded along
58 | # with the model weights.
59 | self.kv_cache_dtype = kv_cache_dtype
60 | self._k_scale = 1.0
61 | self._v_scale = 1.0
62 | quant_method = quant_config.get_quant_method(
63 | self, prefix=prefix) if quant_config else None
64 | if quant_method is not None:
65 | assert isinstance(quant_method, BaseKVCacheMethod)
66 | # TODO (mgoin): kv cache dtype should be specified in the FP8
67 | # checkpoint config and become the "auto" behavior
68 | if self.kv_cache_dtype == "fp8_e5m2":
69 | raise ValueError("fp8_e5m2 kv-cache is not supported with "
70 | "fp8 checkpoints.")
71 | # If quantization is enabled, we make "k_scale" and "v_scale"
72 | # parameters so that it can be loaded from the model checkpoint.
73 | # The k/v_scale will then be converted back to native float32
74 | # values after weight loading.
75 | self.quant_method = quant_method
76 | self.quant_method.create_weights(self)
77 |
78 | # During model initialization, the default dtype is set as the model
79 | # weight and activation dtype.
80 | dtype = torch.get_default_dtype()
81 | attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
82 | block_size, is_attention_free,
83 | blocksparse_params is not None)
84 | impl_cls = attn_backend.get_impl_cls()
85 | self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
86 | alibi_slopes, sliding_window, kv_cache_dtype,
87 | blocksparse_params, logits_soft_cap)
88 |
89 | def forward(
90 | self,
91 | query: torch.Tensor,
92 | key: torch.Tensor,
93 | value: torch.Tensor,
94 | kv_cache: torch.Tensor,
95 | attn_metadata: AttentionMetadata,
96 | attn_type: AttentionType = AttentionType.DECODER,
97 | ) -> torch.Tensor:
98 |
99 | return self.impl.forward(query,
100 | key,
101 | value,
102 | kv_cache,
103 | attn_metadata,
104 | self._k_scale,
105 | self._v_scale,
106 | attn_type=attn_type)
107 |
108 | def extra_repr(self) -> str:
109 | s = f"head_size={self.impl.head_size}" # type: ignore
110 | s += f", num_heads={self.impl.num_heads}" # type: ignore
111 | s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
112 | s += f", scale={self.impl.scale}" # type: ignore
113 | s += f", backend={self.impl.__class__.__name__}"
114 | return s
--------------------------------------------------------------------------------
/rewriter/rewriter_output_post_process.py:
--------------------------------------------------------------------------------
1 | """
2 | Post-processing: Extract Key information from Rewriter's Output
3 | """
4 | import json
5 | import re
6 | import pandas as pd
7 | import os
8 |
9 | def extract_code_entities_from_rewriter(text):
10 | # find code_entities
11 | pattern_code_entities = r'\[start_of_related_code_entities\]\s*(.*?)\s*\[end_of_related_code_entities\]'
12 | match_code_entities = re.search(pattern_code_entities, text, re.DOTALL)
13 | if match_code_entities:
14 | code_entities = match_code_entities.group(1).strip().split('\n')
15 | else: # execute exceptional output
16 | pattern_1 = r'\s*(.*?)\s*'
17 | match_1 = re.search(pattern_1, text, re.DOTALL)
18 | pattern_2 = r'\[Start_of_Related_Code_Entities\]\s*(.*?)\s*\[End_of_Related_Code_Entities\]'
19 | match_2 = re.search(pattern_2, text, re.DOTALL)
20 |
21 | if match_1:
22 | code_entities = match_1.group(1).strip().split('\n')
23 | elif match_2:
24 | code_entities = match_2.group(1).strip().split('\n')
25 | else:
26 | code_entities = []
27 |
28 | # add post processing
29 | for idx, entity in enumerate(code_entities):
30 | if entity.startswith("- "):
31 | code_entities[idx] = entity[2:]
32 |
33 | return code_entities
34 |
35 | def extract_related_keywords_from_rewriter(text):
36 | # find related_keywords
37 | pattern_related_keywords = r'\[start_of_related_keywords\]\s*(.*?)\s*\[end_of_related_keywords\]'
38 | match_related_keywords = re.search(pattern_related_keywords, text, re.DOTALL)
39 | if match_related_keywords:
40 | related_keywords = match_related_keywords.group(1).strip().split('\n')
41 | else:
42 | pattern_1 = r'\s*(.*?)\s*'
43 | match_1 = re.search(pattern_1, text, re.DOTALL)
44 | pattern_2 = r'\[Start_of_Related_Keywords\]\s*(.*?)\s*\[End_of_Related_Keywords\]'
45 | match_2 = re.search(pattern_2, text, re.DOTALL)
46 |
47 | if match_1:
48 | related_keywords = match_1.group(1).strip().split('\n')
49 | elif match_2:
50 | related_keywords = match_2.group(1).strip().split('\n')
51 | else:
52 | related_keywords = []
53 |
54 | # add post processing
55 | for idx, keyword in enumerate(related_keywords):
56 | if keyword.startswith("- "):
57 | related_keywords[idx] = keyword[2:]
58 |
59 | return related_keywords
60 |
61 | def extract_query_from_rewriter(text):
62 | # match query
63 | pattern_query = r'\[start_of_related_queries\]\s*(.*?)\s*\[end_of_related_queries\]'
64 | match_query = re.search(pattern_query, text, re.DOTALL)
65 | if match_query:
66 | queries = match_query.group(1).strip().split('\n')
67 | else:
68 | pattern_1 = r'\s*(.*?)\s*'
69 | match_1 = re.search(pattern_1, text, re.DOTALL)
70 | pattern_2 = r'\[Start_of_Related_Queries\]\s*(.*?)\s*\[End_of_Related_Queries\]'
71 | match_2 = re.search(pattern_2, text, re.DOTALL)
72 |
73 | if match_1:
74 | queries = match_1.group(1).strip().split('\n')
75 | elif match_2:
76 | queries = match_2.group(1).strip().split('\n')
77 | else:
78 | queries = []
79 |
80 | # add post processing
81 | for idx, query in enumerate(queries):
82 | if query.startswith("query"):
83 | queries[idx] = query[9:]
84 | elif query.startswith("-"):
85 | queries[idx] = query[2:]
86 | queries = [query for query in queries if len(query)>0]
87 | return queries
88 |
89 |
90 | if __name__ == "__main__":
91 |
92 | test_basic_info = pd.read_json("test_rewriter_output.json")
93 |
94 | # start post processing
95 | test_basic_info["rewriter_inferer_output"] = test_basic_info["rewriter_inferer"].apply(lambda item:extract_query_from_rewriter(item))
96 | test_basic_info["rewriter_extractor_output_entity"] = test_basic_info["rewriter_extractor"].apply(lambda item:extract_code_entities_from_rewriter(item))
97 | test_basic_info["rewriter_extractor_output_keyword"] = test_basic_info["rewriter_extractor"].apply(lambda item:extract_related_keywords_from_rewriter(item))
98 |
99 | rewriter_output_dict = {}
100 | error_case = []
101 | for idx, item in train_basic_info.iterrows():
102 | instance_id = item.instance_id
103 | entity = item.rewriter_extractor_output_entity
104 | keyword = item.rewriter_extractor_output_keyword
105 | query = item.rewriter_inferer_output
106 | # if entity or keyword or query:
107 | if entity and keyword and query:
108 | rewriter_output_dict[instance_id] = {
109 | "code_entity": entity,
110 | "keyword": keyword,
111 | "query": query
112 | }
113 | else:
114 | error_case.append(instance_id)
115 |
116 | with open("rewriter_output.json", 'w', encoding='utf-8') as file:
117 | json.dump(rewriter_output_dict, file)
118 |
119 | # save trajs
120 | test_basic_info.to_json("test_rewriter_output.json", index=False)
121 |
122 |
123 |
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/reranker/prompt.py:
--------------------------------------------------------------------------------
1 | """
2 | Prompt Template for Reranker
3 | """
4 |
5 | reranker_stage_1_system_prompt = """
6 | You are an experienced software developer who specializes in extracting the most relevant files for solving issues from many reference files.
7 |
8 | Task:
9 | Based on the information received about the issue from a repository, find the most likely few files from among those that may be able to resolve the issue.
10 |
11 | Instructions:
12 | 1. Analysis:
13 | - Analyze the provided issue description and files, and pay attention to the relevance of the provided files with the given issue, especially those might be modified during fixing the issue.
14 | - Determine the specific problem or error mentioned in the issue and note any clues that could help your judgment.
15 | 2. Extraction:
16 | - Based on your analysis, choose the Top **1** relevant files which might be used in fixing the issue.
17 | - You should choose files from the provided files, and should not modify their name in any way.
18 |
19 | Respond in the following format:
20 | [start_of_analysis]
21 |
22 | [end_of_analysis]
23 |
24 | [start_of_relevant_files]
25 | 1.
26 | 2.
27 | 3. ...
28 | [end_of_relevant_files]
29 |
30 | Notes:
31 | - You can refer to to the information in the error logs (if exists).
32 | - The relevant file usually exists in the project described in the issue (e.g., django, sklearn). File need modification is usually not in the tests files or external packages.
33 | - The file you choose should be contained in the provided files.
34 | - Provide the file path with files. Do not include redundant suffix like '/home/username/', '/etc/service/' or '/tree/master'.
35 | - Do not include any additional information such as line numbers or explanations in your extraction result.
36 | - Files for initialization and configuration might be modified during changing the code.
37 |
38 | Preferred extraction Examples of Related Files:
39 | 1. src/utils/file_handler.py
40 | 2. core/services/service_manager.py
41 | 3. ...
42 | """.strip()
43 |
44 | reranker_stage_1_user_prompt_template = """
45 |
46 | {}
47 |
48 |
49 |
50 | {}
51 |
52 |
53 |
54 | {}
55 |
56 |
57 |
58 | {}
59 |
60 | """
61 |
62 | reranker_stage_2_system_prompt = """
63 | You are an experienced software developer who specializes in assessing the relevance of the file for solving the issue in software repositories.
64 |
65 | Task:
66 | For a file provided, evaluate the likelihood that modifying this file would resolve the given issue, and assign a score based on specific criteria.
67 |
68 | Instructions:
69 | 1. Analysis:
70 | - Analyze the provided issue description and the content of the single relevant file, pay attention to any keywords, error messages, or specific functionalities mentioned that relate to the file.
71 | - Determine how closely the contents and functionality of the file are tied to the problem or error described in the issue.
72 | - Consider the role of the file in the overall project structure (e.g., configuration files, core logic files versus test files, or utility scripts).
73 | 2. Scoring:
74 | - Based on your analysis, assign a score from 1 to 5 that represents the relevance of modifying the given file in order to solve the issue.
75 |
76 | Score Specifications:
77 | 1. **Score 1**: The file is almost certainly unrelated to the issue, with no apparent connection to the functionality or error described in the issue.
78 | 2. **Score 2**: The file may be tangentially related, but modifying it is unlikely to resolve the issue directly; possible in rare edge cases.
79 | 3. **Score 3**: The file has some relevance to the issue; it might interact with the affected functionality indirectly and tweaking it could be part of a broader fix.
80 | 4. **Score 4**: The file is likely related to the issue; it includes code that interacts directly with the functionality in question and could plausibly contain bugs that lead to the issue.
81 | 5. **Score 5**: The file is very likely the root cause or heavily involved in the issue and modifying it should directly address the error or problem mentioned.
82 |
83 | Respond in the following format:
84 | [start_of_analysis]
85 |
86 | [end_of_analysis]
87 |
88 | [start_of_score]
89 | Score
90 | [end_of_score]
91 |
92 | Notes:
93 | - The content of the file shows only the structure of this file, including the names of the classes and functions defined in this file.
94 | - You can refer to to the information in the error logs (if exists).
95 | """.strip()
96 |
97 | reranker_stage_2_user_prompt_template = """
98 |
99 | {}
100 |
101 |
102 |
103 | {}
104 |
105 |
106 |
107 | {}
108 |
109 |
110 |
111 | {}
112 |
113 | """
114 |
115 | def generate_prompt_for_reranker_stage_1(problem_statement, repo_name, py_file, other_file):
116 | """
117 | problem_statement: issue内容
118 | repo_name: repo名
119 | py_file: 可能相关的py文件名list
120 | other_file: 其他可能相关的文件名list
121 | """
122 | return reranker_stage_1_system_prompt, reranker_stage_1_user_prompt_template.format(repo_name, problem_statement, py_file, other_file)
123 |
124 | def generate_prompt_for_reranker_stage_2(problem_statement, repo_name, file_name, file_content):
125 | """
126 | problem_statement: issue内容
127 | repo_name: repo名
128 | file_name: 文件名
129 | file_content: 文件内容,只包含类和函数的声明(class xxx和def xxx)
130 | """
131 | return reranker_stage_2_system_prompt, reranker_stage_2_user_prompt_template.format(repo_name, problem_statement, file_name, file_content)
132 |
--------------------------------------------------------------------------------
/rewriter/generate_rewriter_prompt.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate Prompts for Rewriter
3 | """
4 |
5 | import pandas as pd
6 |
7 | def generate_prompt_for_extractor_v1(problem_statement, repo_name):
8 | """
9 | generate prompt for extractor
10 | """
11 |
12 | prompt = """
13 |
14 | {}
15 |
16 | This is an issue related to repository '{}'.
17 | Instructions:
18 | 1. Analysis:
19 | ○ Analyze the provided issue description. Identify the relevant File, Class, or Function involved.
20 | ○ Determine the specific problem or error encountered and note any clues that may assist in locating the relevant or problematic area.
21 | 2. Extraction:
22 | ○ After the analysis, extract ALL the mentioned code entities (File, Class, or Function), especially Files.
23 | ○ Then extract three potential and meaningful keywords, responding in the following format:
24 |
25 | [start_of_analysis]
26 |
27 | [end_of_analysis]
28 |
29 | [start_of_related_code_entities]
30 |
31 | [end_of_related_code_entities]
32 |
33 | [start_of_related_keywords]
34 |
35 | [end_of_related_keywords]
36 |
37 | Notes:
38 | - Pay attention to the information in the error logs (if exists).
39 | - The buggy code exists solely in the project described in the issue (e.g., django, sklearn). Buggy location is usually not in the tests files or external packages.
40 | - Your extracted entities should be CONCISE, ACCURATE and INFORMATIVE.
41 | - Provide the relative path for code entities if specified (e.g., package/foo.py). Relative path is relative to the repository itself, do not include suffix like '/home/username/', '/etc/service/' or '/tree/master'.
42 | - Do not include any additional information such as line numbers or explanations in your extraction result.
43 |
44 | Preferred extraction Examples of Code Entities:
45 | - repo/cart.py
46 | - Class User()
47 | - def getData()
48 | Preferred extraction Examples of Keywords:
49 | - train_loop
50 | - hooks
51 | - docker
52 |
53 | Unpreferred extraction Examples of keywords:
54 | - something wrong
55 | - input validation
56 | - TypeError
57 | """.format(problem_statement, repo_name)
58 |
59 | return prompt
60 |
61 | def generate_prompt_for_inferer_v1(problem_statement, repo_name):
62 | """
63 | generate prompt for inferer
64 | """
65 |
66 | prompt = """
67 |
68 | {}
69 |
70 | This is an issue related to repository '{}'.
71 | Task:
72 | Based on the issue description provided, identify the characteristics of code entities (files, functions, class) that might need to be modified.
73 | For each characteristic, generate a search query that could help locate relevant code entities in a codebase.
74 | Instructions:
75 | First, analyze the issue description and identify keywords, features, and functionalities that are likely relevant to the modification of code entities.
76 | Then, create queries that capture these characteristics, focusing on:
77 | ● File names that may implement relevant functionalities.
78 | ● Functions or methods that are related to the features described in the issue.
79 | ● Any patterns or structures that might be relevant to the functionalities mentioned.
80 | For example:
81 | ● File related to the initialization of a neural network.
82 | ● Function related to the training process.
83 | ● Code used to configure the service.
84 | Please answer in the following format:
85 |
86 | [start_of_analysis]
87 |
88 | [end_of_analysis]
89 |
90 | [start_of_related_queries]
91 | query 1:
92 | query 2:
93 | ...
94 | [end_of_related_queries]
95 |
96 | Notes:
97 | - Your queries should be DETAILED, ACCURATE and INFORMATIVE.
98 | - Your queries should be a complete sentences and do not include additional explanation.
99 | - The number of queries is up to five, so be focus on the important characteristics.
100 | - Your queries should focus on the repository code itself, rather than other information like commit history.
101 | - Pay attention to the information in the error logs (if exists).
102 |
103 | Preferred Query Examples:
104 | - Look for references to "tqdm" or "progress_bar" within the training loop files to find where progress bars are currently updated.
105 | - Code snippets where 'gethostbyname' function from 'socket' module is called.
106 | - File name containing 'mysql.py' AND functions related to 'MySQLStatementSamples' initialization.
107 | - Functions or methods handling hostname resolution or encoding within 'datadog_checks' directory.
108 | - Find all occurrences of "early_stopping" within files that also mention "Trainer" to identify where early stopping logic is implemented and potentially needs adjustment for non-default 'val_check_interval'.
109 | """.format(problem_statement, repo_name)
110 |
111 | return prompt
112 |
113 | if __name__ == "__main__":
114 |
115 | # please refer to the format of SWE-bench lite dataset
116 | # test_basic_info should have keys: {"instance_id", "problem_statement", "repo"}
117 | test_basic_info = pd.read_json("test_basic_info.json")
118 |
119 | extractor_prompt_list = []
120 | inferer_prompt_list = []
121 | for idx, item in test_basic_info.iterrows():
122 | # generate prompt for each instance
123 | tmp_extractor_prompt = generate_prompt_for_extractor_v1(item.problem_statement, item.repo)
124 | tmp_inferer_prompt = generate_prompt_for_inferer_v1(item.problem_statement, item.repo)
125 | extractor_prompt_list.append(tmp_extractor_prompt)
126 | inferer_prompt_list.append(tmp_inferer_prompt)
127 |
128 | test_basic_info["extractor_prompt"] = extractor_prompt_list
129 | test_basic_info["inferer_prompt"] = inferer_prompt_list
130 |
131 | test_basic_info.to_json("test_rewriter_prompt.json", index=False)
132 |
133 |
--------------------------------------------------------------------------------
/retriever/locate_anchor_node.py:
--------------------------------------------------------------------------------
1 | """
2 | 基于 rapidfuzz + faiss 进行 anchor node 定位
3 | """
4 |
5 | from rapidfuzz import process, fuzz
6 | import pandas as pd
7 | import json
8 | import tqdm
9 | import sys
10 | import pickle
11 | import numpy as np
12 | import faiss
13 |
14 |
15 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType
16 | from utils import codegraph_to_nxgraph
17 |
18 | def extract_info(item):
19 | """
20 | 抽取需要匹配的字符部分
21 | """
22 | return item[1]
23 |
24 | ################################# Extractor #################################
25 | def get_extractor_anchor(graph, entity_query, keywords_query):
26 | """
27 | 获取 关键词匹配结果
28 | """
29 |
30 | all_nodes = graph.get_nodes()
31 |
32 | cand_name_list = []
33 | cand_path_name_list = []
34 |
35 | for node in all_nodes:
36 | node_type = node.get_type()
37 | if node_type in [NodeType.REPO, NodeType.PACKAGE]:
38 | continue
39 |
40 | try:
41 | node.name
42 | except:
43 | continue
44 |
45 | cand_name_list.append((node.node_id, node.name))
46 |
47 | if node_type == NodeType.FILE:
48 | if node.path:
49 | name_with_path = node.path + "/" + node.name
50 | else:
51 | name_with_path = node.name
52 | cand_path_name_list.append((node.node_id, name_with_path))
53 |
54 | cand_name_all = []
55 | cand_path_name_all = []
56 |
57 | for query in entity_query + keywords_query:
58 |
59 | if "/" in query:
60 | cand_path_name = process.extract((-1, query), cand_path_name_list, scorer=fuzz.WRatio, limit=3, processor=extract_info)
61 | cand_path_name_all.append(cand_path_name)
62 |
63 | query_wo_path = query.split('/')[-1]
64 | cand_name = process.extract((-1, query_wo_path), cand_name_list, scorer=fuzz.WRatio, limit=3, processor=extract_info)
65 | cand_name_all.append(cand_name)
66 |
67 |
68 | res = set()
69 | for query in cand_name_all:
70 | for item in query:
71 | res.add(item[0][0])
72 | for query in cand_path_name_all:
73 | for item in query:
74 | res.add(item[0][0])
75 |
76 | return res
77 |
78 | ################################# Extractor #################################
79 |
80 | ################################# Inferer #################################
81 | def get_inferer_anchor(query_emb, node_embedding, k=15):
82 | """
83 | 根据 embedding 进行语义检索
84 | """
85 |
86 | node2id_dict = {}
87 | id2node_dict = {}
88 | cand_vec = []
89 |
90 | raw_node_embedding = node_embedding["code"]
91 | for i, node_id in enumerate(raw_node_embedding):
92 | node2id_dict[node_id] = i
93 | id2node_dict[i] = node_id
94 | cand_vec.append(raw_node_embedding[node_id])
95 |
96 | cand_vec_np = np.array(cand_vec)
97 |
98 | ######### search #########
99 | d = 1024
100 | nb = len(cand_vec_np)
101 | nq = 5
102 |
103 | index = faiss.IndexFlatL2(d)
104 | index.add(cand_vec_np)
105 | D, I = index.search(cand_vec_np[:5], k)
106 | D, I = index.search(query_emb, k)
107 |
108 | anchor_node = []
109 | for query in I:
110 | tmp_node_list = []
111 | for trans_id in query:
112 | tmp_node_list.append(int(id2node_dict[trans_id]))
113 | anchor_node.append(tmp_node_list)
114 |
115 | return anchor_node
116 |
117 |
118 | ################################# Inferer #################################
119 |
120 | ################################# 辅助函数 #################################
121 | def get_graph_file_name(item):
122 | """
123 | return graph_file_name
124 | """
125 |
126 | raise NotImplementedError
127 | ################################# 辅助函数 #################################
128 |
129 |
130 | if __name__ == "__main__":
131 |
132 | # 数据变量定义
133 | test_basic_df = pd.read_json("test_basic_info.json")
134 | test_basic_df["graph_file"] = test_basic_df.apply(lambda item: get_graph_file_name(item), axis=1)
135 |
136 | graph_data_path = "codegraph/" # the path of codegraphs
137 |
138 | # 读入 rewriter 提取结果 和 node embedding
139 | rewriter_output_path = "/rewriter_output.json"
140 | query_embedding_path = "/rewriter_embedding.pkl"
141 | node_embedding_path = "/node_embedding/"
142 | with open(rewriter_output_path, "r", encoding="utf-8") as file:
143 | rewriter_output = json.load(file)
144 | file.close()
145 |
146 | with open(query_embedding_path, "rb") as file:
147 | query_embedding = pickle.load(file)
148 | file.close()
149 |
150 | # save path
151 | anchor_node_dict = {}
152 |
153 | for idx, item in tqdm.tqdm(test_basic_df.iterrows()):
154 |
155 | instance_id = item.instance_id
156 | graph_file = item.graph_file
157 | tmp_graph_data_path = graph_data_path + graph_file
158 | query_emb = query_embedding[instance_id]
159 |
160 | # 解析图数据
161 | graph = parse(tmp_graph_data_path)
162 | graph_nx = codegraph_to_nxgraph(graph)
163 |
164 | # 获取 rewriter 输出
165 | entity_query = rewriter_output[instance_id]["code_entity"]
166 | keyword_query = rewriter_output[instance_id]["keyword"]
167 |
168 | # 读入 node_embedding
169 | tmp_node_embedding = node_embedding_path + "{}.pkl".format(instance_id)
170 | with open(tmp_node_embedding, "rb") as file:
171 | tmp_node_embedding = pickle.load(file)
172 | file.close()
173 |
174 | # 定位 anchor nodes
175 | res_extractor = get_extractor_anchor(graph, entity_query, keyword_query)
176 | res_inferer = get_inferer_anchor(query_emb, tmp_node_embedding)
177 |
178 | anchor_node = {
179 | "extractor_anchor_nodes": list(res_extractor),
180 | "inferer_anchor_nodes": list(res_inferer),
181 | }
182 |
183 | anchor_node_dict[instance_id] = anchor_node
184 |
185 | # save
186 | with open("anchor_node.json", 'w', encoding='utf-8') as file:
187 | json.dump(anchor_node_dict, file)
188 |
--------------------------------------------------------------------------------
/retriever/subgraph.py:
--------------------------------------------------------------------------------
1 | """
2 | 启发式搜索逻辑
3 | - 对于 anchor node 进行一跳扩展
4 | - 对于扩展后的结果进行 连通
5 | """
6 | import sys
7 | import json
8 | import os
9 | import tqdm
10 | import pandas as pd
11 |
12 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType
13 | from utils import codegraph_to_nxgraph
14 |
15 | ################################# 子图重构代码 #################################
16 | def get_path_to_repo(node, pre_node_dict, graph_nx):
17 | """获取该节点到 repo 的路径
18 | :param node -> CodeGraph Node 采样出的子图节点
19 | :param pre_node_dict -> list(Node) 每个节点
20 | :return
21 | """
22 | if node.get_type() == NodeType.REPO:
23 | return [node]
24 |
25 | pre_nodes = list()
26 | if node.node_id in pre_node_dict:
27 | pre_nodes = pre_node_dict[node.node_id]
28 | else:
29 | for pre_node in graph_nx.predecessors(node):
30 | # 判断 Edge 类型 - contains
31 | if graph_nx[pre_node][node][0]['type'] == EdgeType.CONTAINS:
32 | pre_nodes.append(pre_node)
33 | if pre_node.get_type() != NodeType.REPO:
34 | pre_nodes.extend(get_path_to_repo(pre_node, pre_node_dict, graph_nx))
35 | pre_node_dict[node.node_id] = pre_nodes
36 | break
37 |
38 | return pre_nodes
39 |
40 | def reconstruct_graph(subgraph_nodes, graph_nx, pre_node_dict):
41 | """
42 | 根据所给节点重构 连通 的 CodeGraph
43 | pre_node_dict 全局复用
44 | """
45 |
46 | nodes = subgraph_nodes
47 | all_nodes = set(nodes)
48 | for node in nodes:
49 | pre_nodes = get_path_to_repo(node, pre_node_dict, graph_nx)
50 | all_nodes |= set(pre_nodes)
51 |
52 | # 根据节点裁剪子图
53 | subgraph = graph_nx.subgraph(list(all_nodes))
54 |
55 | return subgraph
56 |
57 | ################################# 子图重构代码 #################################
58 |
59 | ################################# BFS代码 #################################
60 | def bfs_expand(graph_nx, subgraph_nodes, hops=1):
61 | """
62 | 通过 bfs 扩展
63 | - 最笼统的版本:不区分方向的1-hop
64 | :param graph_nx nx格式的原图
65 | :param subgraph_nodes 需要扩展的节点
66 | :param hops 需要扩展的跳数
67 | """
68 |
69 | seed_node = subgraph_nodes
70 | visited_node = set()
71 | # 记录所有被 nhop 覆盖的节点
72 | nhops_neighbors = set([node.node_id for node in seed_node])
73 |
74 | for hop_idx in range(hops):
75 | tmp_seed_node = []
76 | for node in seed_node:
77 | if node.node_id in visited_node:
78 | continue
79 | visited_node.add(node.node_id)
80 | suc_nodes = graph_nx.successors(node)
81 | pre_nodes = graph_nx.predecessors(node)
82 | for node in suc_nodes:
83 | tmp_seed_node.append(node)
84 | nhops_neighbors.add(node.node_id)
85 |
86 | for node in pre_nodes:
87 | tmp_seed_node.append(node)
88 | nhops_neighbors.add(node.node_id)
89 |
90 | seed_node = tmp_seed_node
91 | return nhops_neighbors
92 |
93 | def bfs_expand_file(graph_nx, subgraph_nodes, hops=1):
94 | """
95 | 通过 bfs 扩展
96 | - 限制 File 遍历2跳
97 | :param graph_nx nx格式的原图
98 | :param subgraph_nodes 需要扩展的节点
99 | :param hops 需要扩展的跳数
100 | """
101 |
102 | seed_node = subgraph_nodes
103 | visited_node = set()
104 | nhops_neighbors = set([node.node_id for node in seed_node])
105 |
106 | for hop_idx in range(hops):
107 | tmp_seed_node = []
108 | for node in seed_node:
109 | if node.node_id in visited_node:
110 | continue
111 | visited_node.add(node.node_id)
112 | suc_nodes = graph_nx.successors(node)
113 | pre_nodes = graph_nx.predecessors(node)
114 | for node in suc_nodes:
115 | if node.get_type() == NodeType.FILE:
116 | tmp_seed_node.append(node)
117 | nhops_neighbors.add(node.node_id)
118 |
119 | for node in pre_nodes:
120 | if node.get_type() == NodeType.FILE:
121 | tmp_seed_node.append(node)
122 | nhops_neighbors.add(node.node_id)
123 |
124 | seed_node = tmp_seed_node
125 | return nhops_neighbors
126 | ################################# BFS代码 #################################
127 |
128 | ################################# 辅助函数 #################################
129 | def get_graph_file_name(item):
130 | """
131 | 生成 graph_file_name
132 | """
133 | repo = item.repo
134 | repo = repo.replace("/", "#", 1)
135 | base_commit = item.base_commit
136 | return repo + "#" + base_commit + ".graph.json"
137 | ################################# 辅助函数 #################################
138 |
139 | if __name__ == "__main__":
140 |
141 | # 数据变量定义
142 | test_basic_df = pd.read_json("/test_lite_basic_info.json")
143 | test_basic_df["graph_file"] = test_basic_df.apply(lambda item: get_graph_file_name(item), axis=1)
144 |
145 | graph_data_path = "/swe-bench-lite3/"
146 | anchor_node_path = "/anchor_nodes.json"
147 |
148 | with open(anchor_node_path, "r", encoding="utf-8") as file:
149 | anchor_node_dict = json.load(file)
150 |
151 | subgraph_id_dict = {}
152 |
153 | for idx, item in tqdm.tqdm(test_basic_df.iterrows()):
154 |
155 | instance_id = item.instance_id
156 | graph_file = item.graph_file
157 | tmp_graph_data_path = graph_data_path + graph_file
158 |
159 | # 解析图数据
160 | graph = parse(tmp_graph_data_path)
161 | graph_nx = codegraph_to_nxgraph(graph)
162 |
163 | # 获取 anchor_nodes
164 | anchor_nodes_raw = anchor_node_dict[instance_id]
165 | extractor_anchors = anchor_nodes_raw["extractor_anchor_nodes"]
166 | inferer_anchors = [node for node_list in anchor_nodes_raw["inferer_anchor_nodes"] for node in node_list]
167 | anchor_nodes = list(set(extractor_anchors + inferer_anchors))
168 |
169 | # 先 bfs 再 reconstruct
170 | anchor_nodes = [graph.get_node_by_id(node_id) for node_id in anchor_nodes]
171 | expanded_nodes = bfs_expand_file(graph_nx, anchor_nodes, hops=2)
172 |
173 | expanded_nodes = [graph.get_node_by_id(node_id) for node_id in expanded_nodes]
174 |
175 | pre_node_dict = {}
176 | subgraph = reconstruct_graph(expanded_nodes, graph_nx, pre_node_dict)
177 |
178 | result_nodes = subgraph.nodes()
179 |
180 | # 获取子图的节点id
181 | result_nodes = [node.node_id for node in result_nodes if node.get_type() == NodeType.FILE]
182 |
183 | subgraph_id_dict[instance_id] = list(result_nodes)
184 |
185 | # save
186 | with open("subgraph_nodes.json", 'w', encoding='utf-8') as file:
187 | json.dump(anchor_node_dict, file)
--------------------------------------------------------------------------------
/cgm/data/encode.py:
--------------------------------------------------------------------------------
1 |
2 | # Template: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n'
3 | # CodeQwen1.5-7B-Chat, Qwen2-72B-Instruct
4 | QwenTokenizerConfig = {
5 | 'seq_length': 8192,
6 | 'SYSTEM': 'system',
7 | 'HUMAN': 'user',
8 | 'BOT': 'assistant',
9 | 'SENTENCE_START_MARKER': '',
10 | 'SENTENCE_END_MARKER': '',
11 | 'SYSTEM_START_MARKER': '<|im_start|>',
12 | 'SYSTEM_START_MARKER_2': '\n',
13 | 'SYSTEM_END_MARKER': '<|im_end|>\n',
14 | 'HUMAN_START_MARKER': '<|im_start|>',
15 | 'HUMAN_START_MARKER_2': '\n',
16 | 'HUMAN_END_MARKER': '<|im_end|>\n',
17 | 'BOT_START_MARKER': '<|im_start|>',
18 | 'BOT_START_MARKER_2': '\n',
19 | 'BOT_END_MARKER': '<|im_end|>\n',
20 | }
21 |
22 | # Template: <|begin▁of▁sentence|>{system_message}<|User|>{user_message_1}<|Assistant|>{assistant_message_1}<|end▁of▁sentence|><|User|>{user_message_2}<|Assistant|>
23 | # DeepSeek-V2.5
24 | DeepSeekTokenizerConfig = {
25 | 'seq_length': 8192,
26 | 'SYSTEM': 'system',
27 | 'HUMAN': 'user',
28 | 'BOT': 'assistant',
29 | 'SENTENCE_START_MARKER': '<|begin▁of▁sentence|>',
30 | 'SENTENCE_END_MARKER': '<|end▁of▁sentence|>',
31 | 'SYSTEM_START_MARKER': '',
32 | 'SYSTEM_START_MARKER_2': '',
33 | 'SYSTEM_END_MARKER': '',
34 | 'HUMAN_START_MARKER': '<|User|>',
35 | 'HUMAN_START_MARKER_2': '',
36 | 'HUMAN_END_MARKER': '',
37 | 'BOT_START_MARKER': '<|Assistant|>',
38 | 'BOT_START_MARKER_2': '',
39 | 'BOT_END_MARKER': '',
40 | }
41 |
42 | DeepSeekCoderTokenizerConfig = {
43 | 'seq_length': 8192,
44 | 'SYSTEM': 'system',
45 | 'HUMAN': 'user',
46 | 'BOT': 'assistant',
47 | 'SENTENCE_START_MARKER': '<|begin▁of▁sentence|>',
48 | 'SENTENCE_END_MARKER': '<|end▁of▁sentence|>',
49 | 'SYSTEM_START_MARKER': '',
50 | 'SYSTEM_START_MARKER_2': '',
51 | 'SYSTEM_END_MARKER': '\n\n',
52 | 'HUMAN_START_MARKER': 'User: ',
53 | 'HUMAN_START_MARKER_2': '',
54 | 'HUMAN_END_MARKER': '\n\n',
55 | 'BOT_START_MARKER': 'Assistant: ',
56 | 'BOT_START_MARKER_2': '',
57 | 'BOT_END_MARKER': '',
58 | }
59 |
60 | def format_eol(text):
61 | if not text.endswith("\n"):
62 | text += "\n"
63 | return text
64 |
65 | def get_template(data):
66 | template = [
67 | {'role': 'system', 'content': ''},
68 | {'role': 'user', 'content': data['prompt']},
69 | {'role': 'assistant', 'content': data['answer']}
70 | ]
71 |
72 | def get_config(name):
73 | if name == 'Qwen':
74 | return QwenTokenizerConfig
75 | elif name == 'DeepSeek':
76 | return DeepSeekTokenizerConfig
77 | elif name == 'DeepSeek-Coder':
78 | return DeepSeekCoderTokenizerConfig
79 | else:
80 | raise NotImplementedError
81 |
82 | class BaseEncoder(object):
83 | def __init__(self, tokenizer, config_name):
84 | # self.args = args
85 | # seq_length - 1 for shifting
86 | # self.seq_length = args.seq_length - 1
87 | config = get_config(config_name)
88 | self.tokenizer = tokenizer
89 | # self.seq_length = tokenizer.model_max_length
90 | self.seq_length = config.get('seq_length')
91 |
92 | # TODO: default Qwen
93 | self.SYSTEM = config.get('SYSTEM')
94 | self.HUMAN = config.get('HUMAN')
95 | self.BOT = config.get('BOT')
96 |
97 | self.SENTENCE_START_MARKER = config.get('SENTENCE_START_MARKER')
98 | self.SENTENCE_END_MARKER = config.get('SENTENCE_END_MARKER'),
99 |
100 | self.SYSTEM_START_MARKER = config.get('SYSTEM_START_MARKER')
101 | self.SYSTEM_START_MARKER_2 = config.get('SYSTEM_START_MARKER_2')
102 | self.SYSTEM_END_MARKER = config.get('SYSTEM_END_MARKER')
103 |
104 | self.HUMAN_START_MARKER = config.get('HUMAN_START_MARKER')
105 | self.HUMAN_START_MARKER_2 = config.get('HUMAN_START_MARKER_2')
106 | self.HUMAN_END_MARKER = config.get('HUMAN_END_MARKER')
107 |
108 | self.BOT_START_MARKER = config.get('BOT_START_MARKER')
109 | self.BOT_START_MARKER_2 = config.get('BOT_START_MARKER_2')
110 | self.BOT_END_MARKER = config.get('BOT_END_MARKER')
111 |
112 | self.sentence_start_ids = self.tokenizer.encode(f"{self.SENTENCE_START_MARKER}", add_special_tokens=False) if self.SENTENCE_START_MARKER != '' else []
113 | self.sentence_end_ids = self.tokenizer.encode(f"{self.SENTENCE_END_MARKER}", add_special_tokens=False) if self.SENTENCE_END_MARKER != '' else []
114 |
115 | self.system_start_ids = self.tokenizer.encode(f"{self.SYSTEM_START_MARKER}{self.SYSTEM}{self.SYSTEM_START_MARKER_2}", add_special_tokens=False)
116 | self.system_end_ids = self.tokenizer.encode(f"{self.SYSTEM_END_MARKER}", add_special_tokens=False) if self.SYSTEM_END_MARKER != '' else []
117 |
118 | self.human_start_ids = self.tokenizer.encode(f"{self.HUMAN_START_MARKER}{self.HUMAN}{self.HUMAN_START_MARKER_2}", add_special_tokens=False)
119 | self.human_end_ids = self.tokenizer.encode(f"{self.HUMAN_END_MARKER}", add_special_tokens=False) if self.HUMAN_END_MARKER != '' else []
120 |
121 | self.bot_start_ids = self.tokenizer.encode(f"{self.BOT_START_MARKER}{self.BOT}{self.BOT_START_MARKER_2}", add_special_tokens=False)
122 | self.bot_end_ids = self.tokenizer.encode(f"{self.BOT_END_MARKER}", add_special_tokens=False) if self.BOT_END_MARKER != '' else []
123 |
124 | self.end_ids = [self.tokenizer.eos_token_id]
125 |
126 | def padding(self, input_ids, loss_mask, qa_mask):
127 | pad_id = self.tokenizer.pad_token_id
128 |
129 | assert len(input_ids) <= self.seq_length, f"padding sequence: {len(input_ids)} > {self.seq_length}"
130 | input_ids += [pad_id] * (self.seq_length - len(input_ids))
131 | loss_mask += [0] * (self.seq_length - len(loss_mask))
132 | qa_mask += [0] * (self.seq_length - len(loss_mask))
133 | return {
134 | "input_ids": input_ids,
135 | "loss_mask": loss_mask,
136 | "qa_mask": qa_mask
137 | }
138 |
139 | class CGMEncoder(BaseEncoder):
140 | def __init__(self, tokenizer, config_name):
141 | super().__init__(tokenizer, config_name)
142 |
143 | def dataToInput(self, data, seg_role = None):
144 | input_ids, loss_mask, qa_mask = [], [], []
145 | # TODO: expand
146 | message = [
147 | {'role': self.HUMAN, 'content': 'prompt', 'marker': True, 'loss': 0},
148 | {'role': self.BOT, 'content': 'answer', 'marker': True, 'loss': 1}
149 | ]
150 | if seg_role is not None:
151 | message = [item for item in message if item['role'] == seg_role]
152 |
153 | input_ids += self.sentence_start_ids
154 | loss_mask += [0] * len(self.sentence_start_ids)
155 | qa_mask += [0] * len(self.sentence_start_ids)
156 |
157 | for segment in message:
158 | role = segment['role']
159 | content = segment['content']
160 | marker = segment['marker']
161 | loss = segment['loss']
162 |
163 | if role == self.SYSTEM:
164 | system_ids = self.tokenizer.encode(str(data[content]), add_special_tokens=False)
165 | if marker:
166 | input_ids += self.system_start_ids + system_ids + self.system_end_ids
167 | loss_mask += [0] * len(self.system_start_ids) + [loss] * len(system_ids) + [0] * len(self.system_end_ids)
168 | qa_mask += [0] * len(self.system_start_ids) + [1] * len(system_ids) + [0] * len(self.system_end_ids)
169 | else:
170 | input_ids += system_ids
171 | loss_mask += [loss] * len(system_ids)
172 | qa_mask += [loss] * len(system_ids)
173 |
174 | elif role == self.HUMAN:
175 |
176 | human_ids = self.tokenizer.encode(str(data[content]), add_special_tokens=False)
177 | if marker:
178 | input_ids += self.human_start_ids + human_ids + self.human_end_ids
179 | loss_mask += [0] * len(self.human_start_ids) + [loss] * len(human_ids) + [0] * len(self.human_end_ids)
180 | qa_mask += [0] * len(self.human_start_ids) + [1] * len(human_ids) + [0] * len(self.human_end_ids)
181 | else:
182 | input_ids += human_ids
183 | loss_mask += [loss] * len(human_ids)
184 | qa_mask += [1] * len(human_ids)
185 |
186 | elif role == self.BOT:
187 | bot_ids = self.tokenizer.encode(str(data[content]), add_special_tokens=False)
188 | if marker:
189 | input_ids += self.bot_start_ids + bot_ids + self.bot_end_ids
190 | loss_mask += [0] * len(self.bot_start_ids) + [loss] * len(bot_ids) + [0] * len(self.bot_end_ids)
191 | qa_mask += [0] * len(self.bot_start_ids) + [0] * len(bot_ids) + [0] * len(self.bot_end_ids)
192 | else:
193 | input_ids += bot_ids
194 | loss_mask += [loss] * len(bot_ids)
195 | qa_mask += [0] * len(bot_ids)
196 |
197 | else:
198 | raise ValueError(f"wrong {role} for {config_name}")
199 |
200 | input_ids += self.sentence_end_ids
201 | loss_mask += [1] * len(self.sentence_end_ids)
202 | qa_mask += [0] * len(self.sentence_end_ids)
203 |
204 | assert len(input_ids) == len(loss_mask)
205 |
206 | if len(input_ids) <= self.seq_length:
207 | # features = self.padding(input_ids, loss_mask, qa_mask)
208 | features = {}
209 | features['input_ids'] = input_ids
210 | features['loss_mask'] = loss_mask
211 | features['qa_mask'] = qa_mask
212 | else:
213 | features = {}
214 | features['input_ids'] = input_ids[:self.seq_length - 1]
215 | features['loss_mask'] = loss_mask[:self.seq_length - 1]
216 | features['qa_mask'] = qa_mask[:self.seq_length - 1]
217 |
218 | features['input_ids'] += self.sentence_end_ids
219 | features['loss_mask'] += [1] * len(self.sentence_end_ids)
220 | features['qa_mask'] += [0] * len(self.sentence_end_ids)
221 |
222 | assert len(features['input_ids']) == len(features['loss_mask'])
223 |
224 | return features
225 |
226 |
--------------------------------------------------------------------------------
/cgm/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.getcwd())
5 |
6 | import time, os, json, math, logging
7 | import numpy as np
8 | import torch
9 | import random
10 | from torch.utils.data import DataLoader, random_split, Subset
11 | import random
12 | # from deepspeed.ops.adam import FusedAdam as AdamW
13 | from torch.optim import AdamW
14 | from transformers import AutoModel, AutoTokenizer
15 | from accelerate import Accelerator
16 | from accelerate.logging import get_logger
17 | from transformers import (
18 | set_seed,
19 | get_scheduler,
20 | )
21 | from utils.arguments import prepare_args
22 |
23 | from modeling.cgm import CGM
24 |
25 | from data.encode import CGMEncoder
26 |
27 | from utils.common_utils import print_args, print_with_rank, print_rank_0
28 | from utils.train_utils import accelerate_train_CGM
29 |
30 | from datasets import load_dataset
31 | import datetime
32 | from peft import (
33 | LoraConfig,
34 | TaskType,
35 | get_peft_model,
36 | prepare_model_for_kbit_training,
37 | PeftModel,
38 | )
39 |
40 | from torch.optim.lr_scheduler import ReduceLROnPlateau
41 |
42 |
43 | # from load_hetero_dataset import load_dataset, perpare_dataloader
44 |
45 | def str_to_tuple(s):
46 | st = s.strip('()')
47 | return tuple(item.strip().strip("'") for item in st.split(','))
48 |
49 | def getRawGraph(filename, suffix="json"):
50 | if os.path.exists(filename):
51 | if suffix == 'json':
52 | with open(filename) as f:
53 | example_graph = json.load(f)
54 | f.close()
55 | elif suffix == 'pt':
56 | with open(filename, 'rb') as f:
57 | example_graph = torch.load(f)
58 | # example_graph = torch.load(filename)
59 | f.close()
60 | return example_graph
61 | return None
62 |
63 | task_ids = {
64 | (0, 'graph_query'),
65 | (1, 'api'),
66 | (2, 'issue_fix'),
67 | (3, 'unit_test'),
68 | (4, 'readme_summary'),
69 | }
70 |
71 | task_to_id = {task: idx for idx, task in task_ids}
72 |
73 | def collate_cgm(graph_dir, encoder, qa_type='mft', seq_l=8192, use_chat=True):
74 | def collate(batches):
75 | result_batches = []
76 | for batch in batches:
77 | result_batch = {}
78 | graph = getRawGraph(batch['repo'], suffix='json')
79 |
80 | if graph is not None:
81 | graph['reponame'] = batch['repo'].split('/')[-1].split('.')[0]
82 | graph['language'] = batch['language']
83 | result_batch['graph'] = graph
84 | if use_chat:
85 | features = encoder.dataToInput(batch)
86 | input_ids = features['input_ids']
87 | loss_mask = features['loss_mask']
88 | qa_mask = features['qa_mask']
89 | else:
90 | query_ids = encoder.tokenizer.encode(batch['prompt'], add_special_tokens=False)
91 | answer_ids = encoder.tokenizer.encode(batch['answer'], add_special_tokens=False) + [
92 | encoder.tokenizer.eos_token_id]
93 | qa_mask = [1] * len(query_ids) + [0] * len(answer_ids)
94 | loss_mask = [0] * len(query_ids) + [1] * len(answer_ids)
95 | input_ids = query_ids + answer_ids
96 |
97 | min_seq = min(seq_l, len(input_ids))
98 | result_batch['x'] = torch.tensor(input_ids, dtype=torch.int64)[:min_seq - 1].contiguous()
99 | result_batch['qa_mask'] = torch.tensor(qa_mask, dtype=torch.bool)[:min_seq - 1].contiguous()
100 | result_batch['y'] = torch.tensor(input_ids, dtype=torch.int64)[1:min_seq].contiguous()
101 | result_batch['loss_mask'] = torch.tensor(loss_mask, dtype=torch.bool)[1:min_seq].contiguous()
102 |
103 | if qa_type == 'mft':
104 | result_batch['task'] = task_to_id[batch['task']]
105 | else:
106 | raise ValueError(f"graph none for {batch['repo']}")
107 |
108 | result_batches.append(result_batch)
109 |
110 | final_result_batch = {}
111 | for key in result_batches[0].keys():
112 | if key == 'task':
113 | final_result_batch[key] = torch.tensor([rb[key] for rb in result_batches])
114 | elif key == 'graph':
115 | final_result_batch[key] = [rb[key] for rb in result_batches]
116 | else:
117 | final_result_batch[key] = torch.stack([rb[key] for rb in result_batches])
118 | return final_result_batch
119 |
120 | return collate
121 |
122 |
123 | def train(args):
124 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
125 |
126 | print_args(args, accelerator)
127 |
128 | # prepare logger
129 | logger = get_logger(__name__)
130 | logging.basicConfig(
131 | format="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
132 | datefmt="%Y-%m-%d %H:%M:%S",
133 | level=logging.INFO,
134 | )
135 | logger.info(accelerator.state, main_process_only=True)
136 |
137 | train_files = args.train_files
138 | valid_files = args.valid_files
139 |
140 | graph_dir = args.graph_dir
141 |
142 | dataset = load_dataset('json', data_files={'train': train_files, 'valid': valid_files})
143 |
144 | train_dataset = dataset['train']
145 | valid_dataset = dataset['valid']
146 |
147 | epoch_train = len(train_dataset)
148 |
149 | if args.peft:
150 | save_suffix = args.framework_type + '_' + str(
151 | args.pretrained_model_path.split('/')[-1]) + '_' + args.task + '_M' + args.mode + '_LR' + str(
152 | args.learning_rate) + '_GA' + str(args.gradient_accumulation_steps) + '_' + str(args.peft) + '_r' + str(
153 | args.lora_rank) + '_alpha' + str(args.lora_alpha) + '_d' + str(args.lora_dropout) + '_m' + str(
154 | args.lora_modules) + str(datetime.datetime.now().strftime('%Y%m%d%H')) + '/'
155 | else:
156 | save_suffix = args.framework_type + '_' + str(
157 | args.pretrained_model_path.split('/')[-1]) + '_' + args.task + '_M' + args.mode + '_LR' + str(
158 | args.learning_rate) + '_GA' + str(args.gradient_accumulation_steps) + '_' + str(
159 | datetime.datetime.now().strftime('%Y%m%d%H')) + '/'
160 |
161 | args.output_dir = args.output_dir + save_suffix
162 | args.tb_dir = args.tb_dir + save_suffix
163 | if 'l' not in args.mode and args.peft:
164 | args.mode = args.mode + 'l'
165 | if args.peft == "QLoRA":
166 | if not args.quantization:
167 | args.quantization = "4bit"
168 |
169 | if accelerator.is_main_process:
170 | if not os.path.exists(args.output_dir):
171 | os.makedirs(args.output_dir)
172 | if not os.path.exists(args.tb_dir):
173 | os.makedirs(args.tb_dir)
174 |
175 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_tokenizer_path, trust_remote_code=False)
176 |
177 | encoder = CGMEncoder(tokenizer=tokenizer, config_name=args.model_type)
178 |
179 | collate_fn = collate_cgm(
180 | graph_dir,
181 | encoder,
182 | qa_type=args.task,
183 | seq_l=8192,
184 | use_chat=args.use_chat,
185 | )
186 |
187 | train_unit_batch_size = accelerator.num_processes * args.gradient_accumulation_steps * args.per_device_train_batch_size
188 | total_train_samples = len(train_dataset)
189 | if args.max_train_samples:
190 | max_train_samples = args.max_train_samples
191 | if total_train_samples > max_train_samples:
192 | total_train_samples = max_train_samples
193 |
194 | max_divisible_samples = (total_train_samples // train_unit_batch_size) * train_unit_batch_size
195 | subset_indices = list(range(max_divisible_samples))
196 | train_subset = Subset(train_dataset, subset_indices)
197 | train_dataloader = DataLoader(train_subset, batch_size=args.per_device_train_batch_size, collate_fn=collate_fn,
198 | shuffle=True)
199 |
200 | if args.max_valid_samples:
201 | max_valid_samples = args.max_valid_samples
202 | valid_unit_batch_size = accelerator.num_processes * args.per_device_eval_batch_size
203 | total_valid_samples = len(valid_dataset)
204 | if total_valid_samples > max_valid_samples:
205 | indices = list(range(max_valid_samples))
206 | random.shuffle(indices)
207 | subset_indices = indices[:max_valid_samples]
208 | valid_subset = Subset(valid_dataset, subset_indices)
209 | else:
210 | max_divisible_samples = (total_valid_samples // valid_unit_batch_size) * valid_unit_batch_size
211 | subset_indices = list(range(max_divisible_samples))
212 | valid_subset = Subset(valid_dataset, subset_indices)
213 | valid_dataloader = DataLoader(valid_subset, batch_size=args.per_device_eval_batch_size, collate_fn=collate_fn,
214 | shuffle=True)
215 | else:
216 | valid_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=collate_fn,
217 | shuffle=True)
218 |
219 | logger.info(f"Train Samples: {len(train_dataloader)}", main_process_only=True)
220 | logger.info(f"Valid Samples: {len(valid_dataloader)}", main_process_only=True)
221 |
222 | model = CGM(args)
223 | # Please disable checkpointing and re-enable use-cache for inference
224 | model.lm.gradient_checkpointing_enable()
225 | model.lm.config.use_cache = False
226 |
227 | if args.peft == "QLoRA":
228 | model.lm = prepare_model_for_kbit_training(model.lm) # use_gradient_checkpointing default is True
229 | else:
230 | model.lm.gradient_checkpointing_enable()
231 |
232 | if args.peft:
233 | peft_config = LoraConfig(
234 | task_type=TaskType.CAUSAL_LM,
235 | inference_mode=False,
236 | r=args.lora_rank,
237 | lora_alpha=args.lora_alpha,
238 | lora_dropout=args.lora_dropout,
239 | target_modules=args.lora_modules,
240 | bias="lora_only",
241 | )
242 | model.lm = get_peft_model(model.lm, peft_config)
243 |
244 | encoder_peft_config = LoraConfig(
245 | task_type=TaskType.CAUSAL_LM,
246 | inference_mode=False,
247 | r=args.enc_lora_rank,
248 | lora_alpha=args.enc_lora_alpha,
249 | lora_dropout=args.enc_lora_dropout,
250 | target_modules=args.enc_lora_modules,
251 | bias="lora_only",
252 | )
253 | model.encoder = get_peft_model(model.encoder, encoder_peft_config)
254 |
255 | if args.adapter_warmup:
256 | if 'l' in args.mode:
257 | for param in model.lm.parameters():
258 | param.requires_grad = False
259 |
260 | encoder_params = list(model.encoder.parameters()) if 'e' in args.mode else []
261 | pma_params = list(model.pma.parameters()) if 'p' in args.mode else []
262 | adapter_params = list(model.adapter.parameters()) if 'a' in args.mode else []
263 | lm_params = list(model.lm.parameters()) if 'l' in args.mode else []
264 |
265 | trained_params = encoder_params + pma_params + adapter_params + lm_params
266 | # trained_params = adapter_params + lm_params
267 | if not trained_params:
268 | raise ValueError("No parameters to train. Please check the mode argument.")
269 |
270 | optimizer = AdamW(
271 | trained_params,
272 | weight_decay=args.weight_decay,
273 | lr=args.learning_rate,
274 | betas=(0.9, 0.95),
275 | )
276 | overrode_max_train_steps = False
277 | num_update_steps_per_epoch = math.ceil(epoch_train / args.gradient_accumulation_steps)
278 | if args.max_train_steps is None:
279 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
280 | overrode_max_train_steps = True
281 |
282 | if args.lr_scheduler_type == "reduce_lr_on_plateau":
283 | lr_scheduler = ReduceLROnPlateau(
284 | optimizer,
285 | mode='min',
286 | factor=0.75,
287 | patience=3,
288 | threshold=0.0001,
289 | threshold_mode='rel',
290 | cooldown=0,
291 | min_lr=args.min_lr,
292 | eps=1e-08,
293 | )
294 | else:
295 | lr_scheduler = get_scheduler(
296 | name=args.lr_scheduler_type,
297 | optimizer=optimizer,
298 | num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
299 | num_training_steps=args.max_train_steps,
300 | )
301 |
302 | logger.info(
303 | f"{'==' * 100}\nbefore accelerator preparation: [dataloader: {epoch_train}][epochs: {args.num_train_epochs}][total steps: {args.max_train_steps}]\n{'==' * 100}")
304 | if torch.cuda.is_available():
305 | model, train_dataloader, valid_dataloader, optimizer, lr_scheduler = accelerator.prepare(
306 | model, train_dataloader, valid_dataloader, optimizer, lr_scheduler
307 | )
308 |
309 | epoch_train = epoch_train / accelerator.num_processes
310 | num_update_steps_per_epoch = math.ceil(epoch_train / args.gradient_accumulation_steps)
311 | if overrode_max_train_steps:
312 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
313 |
314 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
315 | logger.info(
316 | f"{'==' * 100}\nafter accelerator preparation: [dataloader: {epoch_train}][epochs: {args.num_train_epochs}][total steps: {args.max_train_steps}]\n{'==' * 100}")
317 |
318 | logger.info(f"{'==' * 100}Training...")
319 |
320 | accelerate_train_CGM(accelerator,
321 | model,
322 | train_dataloader,
323 | valid_dataloader,
324 | optimizer,
325 | lr_scheduler,
326 | tokenizer,
327 | epoch_train,
328 | args)
329 |
330 | if __name__ == "__main__":
331 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
332 | args = prepare_args()
333 | set_seed(args.seed)
334 | train(args)
--------------------------------------------------------------------------------
/cgm/modeling/cgm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from data.preprocess import getJavaSentence, getPythonSentence, getSentence
5 | from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer, BitsAndBytesConfig
6 | from utils.common_utils import count_parameters, print_rank_0
7 | import torch.nn.functional as F
8 |
9 | from models.qwen2._4_46_1.modeling_qwen2 import Qwen2ForCausalLM
10 | from models.qwen2._4_46_1.modeling_attn_mask_utils import AttentionMaskConverter
11 |
12 | def graph2embedding(self, data, model, tokenizer, reponame, language, save_adj, peft, return_type=None):
13 | node_embeddings = {}
14 | node_id_to_index = {}
15 | index_counter = 0
16 |
17 | device = model.device
18 |
19 | for node in data['nodes']:
20 | nodeType = node['nodeType']
21 |
22 | if 'nodeId' in node.keys():
23 | node_id = node['nodeId']
24 | elif 'id' in node.keys():
25 | node_id = node['id']
26 | else:
27 | raise ValueError("No key named id/nodeId")
28 |
29 | sentence = getSentence(node, nodeType, reponame, 1024000)
30 |
31 | if sentence == "":
32 | node_embedding = torch.zeros((1, self.args.embedding_dim), dtype=torch.float32).to(device)
33 | node_embeddings[node_id] = [node_embedding]
34 | # sentence_dict[index_counter] = ""
35 | node_id_to_index[node_id] = [index_counter]
36 | index_counter += 1
37 | else:
38 | # 手动切词
39 | tokens = tokenizer.tokenize(sentence)
40 | num_tokens = len(tokens)
41 | num_segments = (num_tokens + 511) // 512 # Calculate number of segments
42 | embeddings = []
43 | # segments = []
44 | node_id_to_index[node_id] = list(range(index_counter, index_counter + num_segments))
45 | for i in range(num_segments):
46 | start = i * 512
47 | end = min((i + 1) * 512, num_tokens)
48 | segment_tokens = tokens[start:end]
49 | segment_ids = torch.tensor(tokenizer.convert_tokens_to_ids(segment_tokens), device=device).unsqueeze(0)
50 |
51 | if peft:
52 | # return_type: ALL_256, ALL_768
53 | segment_embedding = model.model(segment_ids, return_type=return_type)
54 | else:
55 | segment_embedding = model(segment_ids)
56 | embeddings.append(segment_embedding)
57 | index_counter += 1
58 |
59 | node_embeddings[node_id] = embeddings
60 |
61 | num_nodes = index_counter
62 |
63 | # TODO: add sparse adj
64 | if save_adj:
65 | adj_matrix = torch.zeros((num_nodes, num_nodes)).to(device)
66 |
67 | for edge in data['edges']:
68 | source_id = edge['source']
69 | target_id = edge['target']
70 | source_indices = node_id_to_index.get(source_id)
71 | target_indices = node_id_to_index.get(target_id)
72 | if source_indices is None or target_indices is None:
73 | # if source_indices is None:
74 | # print(f"{source_id} not exists")
75 | # if target_indices is None:
76 | # print(f"{target_id} not exists")
77 | continue
78 |
79 | for source_index in source_indices:
80 | for target_index in target_indices:
81 | adj_matrix[source_index, target_index] = 1
82 |
83 | # Connect embeddings of the same node
84 | for node_id, indices in node_id_to_index.items():
85 | for i in range(len(indices)):
86 | for j in range(i + 1, len(indices)):
87 | adj_matrix[indices[i], indices[j]] = 1
88 | adj_matrix[indices[j], indices[i]] = 1
89 | else:
90 | adj_matrix = None
91 |
92 | all_embeddings = []
93 | for value in node_embeddings.values():
94 | if isinstance(value, torch.Tensor):
95 | all_embeddings.append(value)
96 | elif isinstance(value, list):
97 | for tensor in value:
98 | all_embeddings.append(tensor)
99 |
100 | embeddings = torch.stack(all_embeddings, dim=0).squeeze(1)
101 |
102 | # embeddings = torch.stack(list(node_embeddings.values()))
103 | # embeddings = torch.stack(sum(node_embeddings.values(), []))
104 | # embeddings = torch.cat(list(node_embeddings.values()), dim=0)
105 |
106 | return embeddings, adj_matrix # sentence_dict
107 |
108 | class adapter(nn.Module):
109 | def __init__(self, args):
110 | super(adapter, self).__init__()
111 | self.fc1 = nn.Linear(args.embedding_dim, args.adapter_hidden_dim)
112 | self.gelu = nn.GELU()
113 | self.fc2 = nn.Linear(args.adapter_hidden_dim, args.lm_hidden_dim)
114 |
115 | def forward(self, x):
116 | return self.fc2(self.gelu(self.fc1(x)))
117 |
118 | class CGM(nn.Module):
119 | def __init__(self, args):
120 | super(CGM, self).__init__()
121 | # text encoder
122 | self.encoder_tokenizer = AutoTokenizer.from_pretrained(args.pretrained_encoder_path, trust_remote_code=True)
123 | self.encoder = AutoModel.from_pretrained(
124 | args.pretrained_encoder_path,
125 | torch_dtype="auto",
126 | trust_remote_code=True
127 | )
128 |
129 | if args.self_defined:
130 | if args.quantization == "8bit":
131 | self.lm = Qwen2ForCausalLM.from_pretrained(
132 | args.pretrained_model_path,
133 | attn_implementation=args.attn_implementation,
134 | torch_dtype="auto",
135 | trust_remote_code=False,
136 | quantization_config=(
137 | BitsAndBytesConfig(
138 | load_in_8bit=(args.quantization == "8bit"),
139 | bnb_8bit_compute_dtype=torch.float8,
140 | bnb_8bit_use_double_quant=True,
141 | bnb_8bit_quant_type="fp8",
142 | bnb_8bit_quant_storage=torch.float8,
143 | )
144 | if args.quantization == "8bit"
145 | else None
146 | ),
147 | )
148 | elif args.quantization == "4bit":
149 | self.lm = Qwen2ForCausalLM.from_pretrained(
150 | args.pretrained_model_path,
151 | attn_implementation=args.attn_implementation,
152 | torch_dtype="auto",
153 | trust_remote_code=False,
154 | quantization_config=(
155 | BitsAndBytesConfig(
156 | load_in_4bit=(args.quantization == "4bit"),
157 | bnb_4bit_compute_dtype=torch.bfloat16,
158 | bnb_4bit_use_double_quant=True,
159 | bnb_4bit_quant_type="nf4",
160 | bnb_4bit_quant_storage=torch.bfloat16,
161 | )
162 | if args.quantization == "4bit"
163 | else None
164 | ),
165 | )
166 | elif not args.quantization:
167 | self.lm = Qwen2ForCausalLM.from_pretrained(
168 | args.pretrained_model_path,
169 | attn_implementation=args.attn_implementation,
170 | torch_dtype="auto",
171 | trust_remote_code=False,
172 | )
173 | else:
174 | raise NotImplementedError(f"unrecognized args.qunatization: {args.quantization}")
175 | else:
176 | self.lm = AutoModelForCausalLM.from_pretrained(
177 | args.pretrained_model_path,
178 | attn_implementation=args.attn_implementation,
179 | torch_dtype="auto",
180 | trust_remote_code=True,
181 | quantization_config=(
182 | BitsAndBytesConfig(
183 | load_in_4bit=(args.quantization == "4bit"),
184 | bnb_4bit_compute_dtype=torch.bfloat16,
185 | bnb_4bit_use_double_quant=True,
186 | bnb_4bit_quant_type="nf4",
187 | bnb_4bit_quant_storage=torch.bfloat16,
188 | )
189 | if args.quantization == "4bit"
190 | else None
191 | ),
192 | )
193 |
194 | args.lm_hidden_dim = self.lm.config.hidden_size
195 | self.args = args
196 | self.adapter = adapter(args)
197 | if args.load_pretrained_adapter:
198 | self.adapter.load_state_dict(torch.load(args.pretrained_adapter_path))
199 | print_rank_0(f"Adapter loaded from {args.pretrained_adapter_path}")
200 | else:
201 | print_rank_0("Adapter initialized")
202 | print_rank_0(f"Parameters of Encoder: {count_parameters(self.encoder) / 1e6:.1f}M")
203 | print_rank_0(f"Parameters of Adapter: {count_parameters(self.adapter) / 1e6:.1f}M")
204 | print_rank_0(f"Parameters of LLM: {count_parameters(self.lm) / 1e9:.2f}B")
205 |
206 | def graph2embedding(self, data, reponame, return_type=None):
207 | node_embeddings = {}
208 | node_id_to_index = {}
209 | index_counter = 0
210 |
211 | model = self.encoder
212 | tokenizer = self.encoder_tokenizer
213 | save_adj = self.args.use_adj,
214 | peft = self.args.peft
215 |
216 | device = model.device
217 |
218 | for node in data['nodes']:
219 | nodeType = node['nodeType']
220 |
221 | if 'nodeId' in node.keys():
222 | node_id = node['nodeId']
223 | elif 'id' in node.keys():
224 | node_id = node['id']
225 | else:
226 | raise ValueError("No key named id/nodeId")
227 |
228 | sentence = getSentence(node, nodeType, reponame, 1024000)
229 |
230 | if sentence == "":
231 | node_embedding = torch.zeros((1, self.args.embedding_dim), dtype=torch.float32).to(device)
232 | node_embeddings[node_id] = [node_embedding]
233 | node_id_to_index[node_id] = [index_counter]
234 | index_counter += 1
235 | else:
236 | tokens = tokenizer.tokenize(sentence)
237 | num_tokens = len(tokens)
238 | num_segments = (num_tokens + 511) // 512 # Calculate number of segments
239 | embeddings = []
240 | node_id_to_index[node_id] = list(range(index_counter, index_counter + num_segments))
241 | for i in range(num_segments):
242 | start = i * 512
243 | end = min((i + 1) * 512, num_tokens)
244 | segment_tokens = tokens[start:end]
245 | segment_ids = torch.tensor(tokenizer.convert_tokens_to_ids(segment_tokens),
246 | device=device).unsqueeze(0)
247 |
248 | if peft:
249 | # return_type: ALL_256, ALL_768
250 | segment_embedding = model.model(segment_ids, return_type=return_type)
251 | else:
252 | segment_embedding = model(segment_ids)
253 | embeddings.append(segment_embedding)
254 | index_counter += 1
255 |
256 | node_embeddings[node_id] = embeddings
257 |
258 | num_nodes = index_counter
259 |
260 | # TODO: add sparse adj
261 | if save_adj:
262 | adj_matrix = torch.zeros((num_nodes, num_nodes)).to(device)
263 |
264 | for edge in data['edges']:
265 | source_id = edge['source']
266 | target_id = edge['target']
267 | source_indices = node_id_to_index.get(source_id)
268 | target_indices = node_id_to_index.get(target_id)
269 | if source_indices is None or target_indices is None:
270 | continue
271 |
272 | for source_index in source_indices:
273 | for target_index in target_indices:
274 | adj_matrix[source_index, target_index] = 1
275 |
276 | # Connect embeddings of the same node
277 | for node_id, indices in node_id_to_index.items():
278 | for i in range(len(indices)):
279 | for j in range(i + 1, len(indices)):
280 | adj_matrix[indices[i], indices[j]] = 1
281 | adj_matrix[indices[j], indices[i]] = 1
282 | else:
283 | adj_matrix = None
284 |
285 | all_embeddings = []
286 | for value in node_embeddings.values():
287 | if isinstance(value, torch.Tensor):
288 | all_embeddings.append(value)
289 | elif isinstance(value, list):
290 | for tensor in value:
291 | all_embeddings.append(tensor)
292 |
293 | embeddings = torch.stack(all_embeddings, dim=0).squeeze(1)
294 |
295 | return embeddings, adj_matrix # sentence_dict
296 |
297 | def forward(self, graph, qa_ids, qa_mask):
298 | graph_embeddings, adj_matrix = graph2embedding(
299 | data=graph,
300 | reponame=graph['reponame'],
301 | return_type="ALL_256",
302 | )
303 |
304 | embeddings = self.adapter(graph_embeddings)
305 |
306 | if self.args.peft:
307 | inputs_embeds = self.lm.model.model.embed_tokens(qa_ids)
308 | else:
309 | inputs_embeds = self.lm.model.embed_tokens(qa_ids)
310 |
311 | input_embeddings = torch.cat((embeddings, inputs_embeds), dim=-2)
312 | input_embeddings = input_embeddings.unsqueeze(0)
313 |
314 | if adj_matrix is not None and self.args.use_adj:
315 |
316 | if len(adj_matrix.shape) == 2:
317 | adj_matrix = adj_matrix.unsqueeze(0)
318 | batch_size, seq_len_x, _ = adj_matrix.shape
319 |
320 | seq_len_q = inputs_embeds.size(-2)
321 |
322 | qa_matrix = torch.ones(batch_size, seq_len_q, seq_len_q, device=qa_mask.device)
323 | qa_matrix = torch.tril(qa_matrix)
324 |
325 | matrix_xq = qa_mask.unsqueeze(1) * torch.ones(batch_size, seq_len_x, seq_len_q, device=qa_mask.device)
326 |
327 | matrix_qx = torch.ones(batch_size, seq_len_q, seq_len_x, device=qa_mask.device)
328 |
329 | # Construct the full attention mask
330 | attention_mask = torch.cat([
331 | torch.cat([adj_matrix, matrix_xq], dim=2), # x_embeddings part
332 | torch.cat([matrix_qx, qa_matrix], dim=2) # q_embeddings part
333 | ], dim=1).squeeze(1)
334 |
335 | outputs = self.lm(inputs_embeds=input_embeddings,
336 | attention_mask=attention_mask,
337 | return_dict=True)
338 |
339 | else:
340 | outputs = self.lm(inputs_embeds=input_embeddings,
341 | return_dict=True)
342 |
343 | return outputs
344 |
345 |
--------------------------------------------------------------------------------
/reranker/reranker.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import os
4 | import pandas as pd
5 | import argparse
6 |
7 | from qwen_api import QwenAPI
8 | from codegraph_parser.python.codegraph_python_local import parse, NodeType, EdgeType
9 |
10 | stage_1_system_prompt = """
11 | You are an experienced software developer who specializes in extracting the most relevant files for solving issues from many reference files.
12 |
13 | Task:
14 | Based on the information received about the issue from a repository, find the most likely few files from among those that may be able to resolve the issue.
15 |
16 | Instructions:
17 | 1. Analysis:
18 | - Analyze the provided issue description and files, and pay attention to the relevance of the provided files with the given issue, especially those might be modified during fixing the issue.
19 | - Determine the specific problem or error mentioned in the issue and note any clues that could help your judgment.
20 | 2. Extraction:
21 | - Based on your analysis, choose the Top **{}** relevant files which might be used in fixing the issue.
22 | - You should choose files from the provided files, and should not modify their name in any way.
23 |
24 | Respond in the following format:
25 | [start_of_analysis]
26 |
27 | [end_of_analysis]
28 |
29 | [start_of_relevant_files]
30 | 1.
31 | 2.
32 | 3. ...
33 | [end_of_relevant_files]
34 |
35 | Notes:
36 | - You can refer to to the information in the error logs (if exists).
37 | - The relevant file usually exists in the project described in the issue (e.g., django, sklearn). File need modification is usually not in the tests files or external packages.
38 | - The file you choose should be contained in the provided files.
39 | - Provide the file path with files. Do not include redundant suffix like '/home/username/', '/etc/service/' or '/tree/master'.
40 | - Do not include any additional information such as line numbers or explanations in your extraction result.
41 | - Files for initialization and configuration might be modified during changing the code.
42 |
43 | Preferred extraction Examples of Related Files:
44 | 1. src/utils/file_handler.py
45 | 2. core/services/service_manager.py
46 | 3. ...
47 | """.strip()
48 |
49 | stage_1_user_prompt_template = """
50 |
51 | {}
52 |
53 |
54 |
55 | {}
56 |
57 |
58 |
59 | {}
60 |
61 |
62 |
63 | {}
64 |
65 | """
66 |
67 | stage_2_system_prompt_v3 = """
68 | You are an experienced software developer who specializes in assessing the relevance of the file for solving the issue in software repositories.
69 |
70 | Task:
71 | For a file provided, evaluate the likelihood that modifying this file would resolve the given issue, and assign a score based on specific criteria.
72 |
73 | Instructions:
74 | 1. Analysis:
75 | - Analyze the provided issue description and the content of the single relevant file, pay attention to any keywords, error messages, or specific functionalities mentioned that relate to the file.
76 | - Determine how closely the contents and functionality of the file are tied to the problem or error described in the issue.
77 | - Consider the role of the file in the overall project structure (e.g., configuration files, core logic files versus test files, or utility scripts).
78 | 2. Scoring:
79 | - Based on your analysis, assign a score from 1 to 5 that represents the relevance of modifying the given file in order to solve the issue.
80 |
81 | Score Specifications:
82 | 1. **Score 1**: The file is almost certainly unrelated to the issue, with no apparent connection to the functionality or error described in the issue.
83 | 2. **Score 2**: The file may be tangentially related, but modifying it is unlikely to resolve the issue directly; possible in rare edge cases.
84 | 3. **Score 3**: The file has some relevance to the issue; it might interact with the affected functionality indirectly and tweaking it could be part of a broader fix.
85 | 4. **Score 4**: The file is likely related to the issue; it includes code that interacts directly with the functionality in question and could plausibly contain bugs that lead to the issue.
86 | 5. **Score 5**: The file is very likely the root cause or heavily involved in the issue and modifying it should directly address the error or problem mentioned.
87 |
88 | Respond in the following format:
89 | [start_of_analysis]
90 |
91 | [end_of_analysis]
92 |
93 | [start_of_score]
94 | Score
95 | [end_of_score]
96 |
97 | Notes:
98 | - The content of the file shows only the structure of this file, including the names of the classes and functions defined in this file.
99 | - You can refer to to the information in the error logs (if exists).
100 | """.strip()
101 |
102 | stage_2_user_prompt_template = """
103 |
104 | {}
105 |
106 |
107 |
108 | {}
109 |
110 |
111 |
112 | {}
113 |
114 |
115 |
116 | {}
117 |
118 | """
119 |
120 |
121 | def get_python_inner_class_and_function(graph, node_id, layer_cnt = 0):
122 | """
123 | 寻找某个node的内部函数和类的node,返回list
124 | dfs,返回的列表每个item是(深度, node)
125 | """
126 | ret_list = []
127 |
128 | # 限制深度
129 | if layer_cnt > 5:
130 | return ret_list
131 |
132 | node = graph.get_node_by_id(node_id)
133 | inner_node_ids = graph.get_out_nodes(node_id)
134 | for inner_node_id in inner_node_ids:
135 | inner_node = graph.get_node_by_id(inner_node_id)
136 |
137 | if inner_node.get_type() == NodeType.FUNCTION and "def " + inner_node.name in node.text:
138 | ret_list.append((layer_cnt, inner_node))
139 | ret_list.extend(get_python_inner_class_and_function(graph, inner_node.node_id, layer_cnt + 1))
140 | elif inner_node.get_type() == NodeType.CLASS and "class " + inner_node.name in node.text:
141 | ret_list.append((layer_cnt, inner_node))
142 | ret_list.extend(get_python_inner_class_and_function(graph, inner_node.node_id, layer_cnt + 1))
143 |
144 | return ret_list
145 |
146 | def parse_reranker_stage_1(response):
147 | # relevant_file
148 | pattern = r"\[start_of_relevant_files\]\s*(.*?)\s*\[end_of_relevant_files\]"
149 | match = re.search(pattern, response, re.DOTALL)
150 | if match:
151 | relevant_files = match.group(1).strip().split("\n")
152 | else:
153 | pattern = r"\s*(.*?)\s*"
154 | match = re.search(pattern, response, re.DOTALL)
155 | if match:
156 | relevant_files = match.group(1).strip().split("\n")
157 | else:
158 | pattern = r"\[Start_of_Relevant_Files\]\s*(.*?)\s*\[End_of_Relevant_Files\]"
159 | match = re.search(pattern, response, re.DOTALL)
160 | if match:
161 | relevant_files = match.group(1).strip().split("\n")
162 | else:
163 | relevant_files = []
164 |
165 | print(relevant_files)
166 | for idx, relevant_file in enumerate(relevant_files):
167 | new_relevant_file = relevant_file
168 | if new_relevant_file.startswith("- "):
169 | new_relevant_file = new_relevant_file[2:]
170 |
171 | pattern = r"\d+ *\.(.+)"
172 | match = re.search(pattern, new_relevant_file)
173 | if match:
174 | new_relevant_file = match.group(1).strip()
175 | relevant_files[idx] = new_relevant_file
176 |
177 | return relevant_files
178 |
179 | def parse_reranker_stage_2(response):
180 | # score
181 | pattern = r"\[start_of_score\]\s*(.*?)\s*\[end_of_score\]"
182 | match = re.search(pattern, response, re.DOTALL)
183 | if match:
184 | score = match.group(1).strip().split("\n")
185 | else:
186 | pattern = r"\s*(.*?)\s*"
187 | match = re.search(pattern, response, re.DOTALL)
188 | if match:
189 | score = match.group(1).strip().split("\n")
190 | else:
191 | pattern = r"\[Start_of_Score\]\s*(.*?)\s*\[End_of_Score\]"
192 | match = re.search(pattern, response, re.DOTALL)
193 | if match:
194 | score = match.group(1).strip().split("\n")
195 | else:
196 | score = ["0"]
197 |
198 | score = score[0]
199 | if score.startswith("- "):
200 | score = score[2:]
201 |
202 | pattern = r"Score (\d+)"
203 | match = re.search(pattern, score)
204 | if match:
205 | score = match.group(1)
206 | score = int(score)
207 | else:
208 | score = 0
209 |
210 | return score
211 |
212 | def extract_files_from_subgraph(subgraph_path, output_path):
213 | subgraph_list = os.listdir(subgraph_path)
214 | # print(len(subgraph_list))
215 |
216 | for subgraph in subgraph_list:
217 | if not subgraph.endswith(".json"):
218 | continue
219 | # print(subgraph)
220 | try:
221 | with open(os.path.join(subgraph_path, subgraph), "r", encoding="utf-8") as f:
222 | subgraph_json = json.load(f)
223 | except:
224 | print(f"broken json file: {subgraph}")
225 | continue
226 | subgraph_nodes = subgraph_json["nodes"]
227 | file_nodes = [node for node in subgraph_nodes if node["nodeType"] == "File"]
228 | pred_files = []
229 | for node in file_nodes:
230 | file_path = node["filePath"]
231 | file_name = node["fileName"]
232 | if file_path is None:
233 | file = file_name
234 | else:
235 | file = os.path.join(file_path, file_name)
236 | pred_files.append(file)
237 |
238 | subgraph_name = subgraph.split(".")[0]
239 | with open((os.path.join(output_path, subgraph_name + ".json")), "w", encoding="utf-8") as f:
240 | json.dump(pred_files, f, indent=4)
241 |
242 | def parse_args():
243 | parser = argparse.ArgumentParser(description="Run Reranker.")
244 |
245 | parser.add_argument('--stage_1_k', type=int, default=10, help='Specify the k for stage 1')
246 | parser.add_argument('--stage_2_k', type=int, default=5, help='Specify the k for stage 2')
247 |
248 | return parser.parse_args()
249 |
250 |
251 | if __name__ == "__main__":
252 | llm = QwenAPI("Qwen/Qwen2.5-72B-Instruct")
253 |
254 | output_dir = "reranker_outputs/"
255 |
256 | subgraph_file_dir = "subgraph/"
257 |
258 | # retriever得到的file list
259 | retriever_filtered_files_dir = "subgraph_extracted_files/"
260 | os.makedirs(retriever_filtered_files_dir, exist_ok=True)
261 | extract_files_from_subgraph(subgraph_file_dir, retriever_filtered_files_dir)
262 |
263 | df = pd.read_json("test_basic_info.json")
264 |
265 | args = parse_args()
266 |
267 | # stage_1
268 | stage_1 = True
269 | stage_1_output_dir = os.path.join(output_dir, f"stage_1_top_{args.stage_1_k}")
270 | os.makedirs(os.path.join(stage_1_output_dir, "relevant_files"), exist_ok=True)
271 | os.makedirs(os.path.join(stage_1_output_dir, "response"), exist_ok=True)
272 | stage_1_system_prompt = stage_1_system_prompt.format(args.stage_1_k)
273 | if stage_1:
274 | reranker_stage_1_outputs = os.listdir(os.path.join(stage_1_output_dir, "relevant_files"))
275 | reranker_stage_1_outputs = [item.split(".")[0] for item in reranker_stage_1_outputs]
276 | for i, data in enumerate(df):
277 | repo, instance_id, base_commit, patch, test_patch, problem_statement, hints_text, created_at, version, fail_to_pass, pass_to_pass = data["repo"], data["instance_id"], data["base_commit"], data["patch"], data["test_patch"], data["problem_statement"], data["hints_text"], data["created_at"], data["version"], data["FAIL_TO_PASS"], data["PASS_TO_PASS"]
278 |
279 | if instance_id in reranker_stage_1_outputs:
280 | print(f"Stage 1 index {i} skip")
281 | continue
282 |
283 | if os.path.exists(os.path.join(retriever_filtered_files_dir, instance_id + ".json")):
284 | with open(os.path.join(retriever_filtered_files_dir, instance_id + ".json"), "r") as f:
285 | filtered_files = json.load(f)
286 | else:
287 | raise ValueError
288 |
289 | python_files = [item for item in filtered_files if item.endswith(".py")]
290 | other_files = [item for item in filtered_files if not item.endswith(".py")]
291 | user_prompt = stage_1_user_prompt_template.format(repo, problem_statement, "\n".join(python_files), "\n".join(other_files))
292 | print(user_prompt)
293 |
294 | response = llm.get_response(stage_1_system_prompt, user_prompt)
295 | print(response)
296 |
297 | relevant_files = parse_reranker_stage_1(response)
298 | with open(os.path.join(stage_1_output_dir, "relevant_files", instance_id + ".json"), "w") as f:
299 | json.dump(relevant_files, f, indent=4)
300 | with open(os.path.join(stage_1_output_dir, "response", instance_id + ".txt"), "w", encoding="utf-8") as f:
301 | f.write(response)
302 |
303 | print(f"Stage 1 index {i} done")
304 |
305 | # stage_2
306 | stage_2 = True
307 | stage_2_output_dir = os.path.join(output_dir, f"stage_2_{args.stage_2_k}")
308 | os.makedirs(os.path.join(stage_2_output_dir, "relevant_files"), exist_ok=True)
309 | os.makedirs(os.path.join(stage_2_output_dir, "response"), exist_ok=True)
310 | if stage_2:
311 | reranker_stage_2_outputs = os.listdir(os.path.join(stage_2_output_dir, "relevant_files"))
312 | reranker_stage_2_outputs = [item.split(".")[0] for item in reranker_stage_2_outputs]
313 | for i, data in enumerate(df):
314 | repo, instance_id, base_commit, patch, test_patch, problem_statement, hints_text, created_at, version, fail_to_pass, pass_to_pass = data["repo"], data["instance_id"], data["base_commit"], data["patch"], data["test_patch"], data["problem_statement"], data["hints_text"], data["created_at"], data["version"], data["FAIL_TO_PASS"], data["PASS_TO_PASS"]
315 |
316 | if instance_id in reranker_stage_2_outputs:
317 | print(f"Stage 2 index {i} skip")
318 | continue
319 |
320 | with open(os.path.join(stage_1_output_dir, "relevant_files", instance_id + ".json"), "r") as f:
321 | stage_1_relevant_files = json.load(f)
322 |
323 | # 读取子图
324 | if os.path.exists(os.path.join(subgraph_file_dir, instance_id + ".json")):
325 | subgraph_file_path = os.path.join(subgraph_file_dir, instance_id + ".json")
326 | graph = parse(subgraph_file_path)
327 | else:
328 | raise ValueError
329 |
330 | relevant_file_score = {}
331 | relevant_file_response = {}
332 | for relevant_file in stage_1_relevant_files:
333 | relevant_file_content = ""
334 | find = False
335 |
336 | # 保留class和function
337 | for file_node in graph.get_nodes_by_type(NodeType.FILE):
338 | file_path = file_node.path
339 | file_name = file_node.name
340 | if file_path is None:
341 | file = file_name
342 | else:
343 | file = os.path.join(file_path, file_name)
344 | if file == relevant_file:
345 | class_and_function_list = get_python_inner_class_and_function(graph, file_node.node_id)
346 |
347 | relevant_file_content = ""
348 | for layer, node in class_and_function_list:
349 | # 添加缩进
350 | if node.get_type() == NodeType.CLASS:
351 | relevant_file_content += " " * layer + "class " + node.name # 接口文档中的字段是className,但实际parse时会处理到name中
352 | relevant_file_content += "\n"
353 | elif node.get_type() == NodeType.FUNCTION:
354 | if node.name != "":
355 | relevant_file_content += " " * layer + "def " + node.name
356 | relevant_file_content += "\n"
357 |
358 | find = True
359 | break
360 |
361 | # 未找到直接默认0,不打分
362 | if not find:
363 | relevant_file_score[relevant_file] = 0
364 | relevant_file_response[relevant_file] = ""
365 | else:
366 | user_prompt = stage_2_user_prompt_template.format(repo, problem_statement, relevant_file, relevant_file_content)
367 | print(user_prompt)
368 |
369 | response = llm.get_response(stage_2_system_prompt_v3, user_prompt)
370 |
371 | score = parse_reranker_stage_2(response)
372 | relevant_file_score[relevant_file] = score
373 | relevant_file_response[relevant_file] = response
374 |
375 | k = args.stage_2_k
376 | # 找出分数最高的topk个文件
377 | sorted_relevant_files = sorted(relevant_file_score.items(), key=lambda item: item[1], reverse=True)
378 | if k <= len(sorted_relevant_files):
379 | selected_relevant_files = sorted_relevant_files[:k]
380 | selected_relevant_files = [item[0] for item in selected_relevant_files]
381 | else:
382 | selected_relevant_files = [item[0] for item in sorted_relevant_files]
383 |
384 | with open(os.path.join(stage_2_output_dir, "relevant_files", instance_id + ".json"), "w") as f:
385 | json.dump(dict(relevant_file_score=relevant_file_score, selected_relevant_files=selected_relevant_files), f, indent=4)
386 | with open(os.path.join(stage_2_output_dir, "response", instance_id + ".json"), "w") as f:
387 | json.dump(relevant_file_response, f, indent=4)
388 |
389 | print(f"Stage 2 index {i} done")
390 |
391 |
--------------------------------------------------------------------------------
/cgm/data/preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | import json
3 | # from codegraph import *
4 |
5 | from torch.utils.data.dataset import Dataset
6 | from transformers import AutoModel, AutoTokenizer
7 |
8 | # from FlagEmbedding import BGEM3FlagModel
9 | # from sentence_transformers import SentenceTransformer
10 |
11 | from datasets import Dataset as HFDataset
12 | from datasets import load_dataset
13 |
14 | import torch
15 | import numpy as np
16 | import logging
17 | import time
18 | import gc
19 |
20 | import random
21 | import string
22 | import os
23 | import sys
24 |
25 | import json
26 | from collections import defaultdict
27 | import random
28 |
29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 |
31 | def getJavaSentence(node, nodeType, reponame, max_len):
32 | def process_Repo(node):
33 | return reponame
34 |
35 | def process_Module(node):
36 | return node['name']
37 |
38 | def process_Package(node):
39 | return node['name']
40 |
41 | def process_File(node):
42 | path = node.get('path','')
43 | if len(path) > 0:
44 | path = path + '/'
45 | return f"{path}{node['name']}"
46 |
47 | def process_TextFile(node):
48 | return f"{node['name']}\n{node.get('text','')}"
49 |
50 | def process_Class(node):
51 | return f"{node.get('modifiers','')} {node['name']}\n{node.get('comment','')}".strip(' ')
52 |
53 | def process_Field(node):
54 | return f"{node.get('modifiers','')} {node['fieldType']} {node['name']}\n{node.get('comment','')}".strip(' ')
55 |
56 | def process_Method(node):
57 | className = node.get('className','')
58 | methodName = node.get('methodName', '')
59 | if len(methodName) == 0 or len(className) == 0:
60 | split = node['signature'].split('#')
61 | className = split[0]
62 | methodName = split[1].split('(')[0]
63 | name = className + '.' + methodName
64 | comment = f"{node.get('comment','')}\n" if not node.get('comment','') == '' else ''
65 | text = f"{node.get('modifiers','')} {node.get('text','')}" if not node.get('modifiers','') == '' else node.get('text','')
66 | return f"{name}\n{comment}{text}"
67 |
68 | def process_default(node):
69 | raise ValueError(f"unrecognized nodeType for node.keys {node['nodeType']} {str(node.keys())}")
70 | return ""
71 |
72 | processors = {
73 | 'Repo': process_Repo,
74 | 'Module': process_Module,
75 | 'Package': process_Package,
76 | 'File': process_File,
77 | 'TextFile': process_TextFile,
78 | 'Textfile': process_TextFile,
79 | 'Class': process_Class,
80 | 'Field': process_Field,
81 | 'Method': process_Method
82 | }
83 |
84 | sentence = processors.get(nodeType, process_default)(node)
85 |
86 | # TODO: limit token not str size
87 | if len(sentence) > max_len:
88 | sentence = sentence[:max_len]
89 |
90 | return sentence
91 |
92 | def getPythonSentence(node, nodeType, reponame, max_len):
93 | def process_Repo(node):
94 | return reponame
95 |
96 | def process_Package(node):
97 | return node['name']
98 |
99 | def process_File(node):
100 | path = node.get('filePath','')
101 | if len(path) > 0:
102 | path = path + '/'
103 | return f"{path}{node['fileName']}\n{node.get('text','')}"
104 |
105 | def process_TextFile(node):
106 | return f"{node['name']}\n{node.get('text','')}"
107 |
108 | def process_Class(node):
109 | return f"{node.get('classType','')} {node['className']}\n{node.get('comment','')}\n{node.get('text','')}".strip(' ')
110 |
111 | def process_Attribute(node):
112 | return f"{node.get('attributeType','')} {node['name']}\n{node.get('comment','')}\n{node.get('text','')}".strip(' ')
113 |
114 | def process_Function(node):
115 | comment = f"{node.get('comment','')}\n" if not node.get('comment','') == '' else ''
116 | return f"{node.get('header','')} {node['name']}\n{comment}{node.get('text','')}".strip(' ')
117 |
118 | def process_Lambda(node):
119 | return f"{node.get('text','')}".strip(' ')
120 |
121 | def process_default(node):
122 | raise ValueError(f"unrecognized nodeType for node.keys {node['nodeType']} {str(node.keys())}")
123 | return ""
124 |
125 | processors = {
126 | 'Repo': process_Repo,
127 | 'Package': process_Package,
128 | 'File': process_File,
129 | 'TextFile': process_TextFile,
130 | 'Textfile': process_TextFile,
131 | 'Class': process_Class,
132 | 'Attribute': process_Attribute,
133 | 'Function': process_Function,
134 | 'Lambda': process_Lambda
135 | }
136 |
137 | sentence = processors.get(nodeType, process_default)(node)
138 |
139 | # TODO: limit token not str size
140 | if len(sentence) > max_len:
141 | sentence = sentence[:max_len]
142 |
143 | return sentence
144 |
145 | def graph2embedding(data, model, tokenizor, reponame, language, save_adj):
146 | node_embeddings = {}
147 | sentence_dict = {}
148 | node_id_to_index = {}
149 | index_counter = 0
150 |
151 | for node in data['nodes']:
152 | nodeType = node['nodeType']
153 |
154 | if 'nodeId' in node.keys():
155 | node_id = node['nodeId']
156 | elif 'id' in node.keys():
157 | node_id = node['id']
158 | else:
159 | raise ValueError("No key named id/nodeId")
160 |
161 | if language == 'java':
162 | sentence = getJavaSentence(node, nodeType, reponame, 1024000)
163 | elif language == 'python':
164 | sentence = getPythonSentence(node, nodeType, reponame, 1024000)
165 | else:
166 | raise ValueError(f"Language {language} not supported")
167 |
168 | if sentence == "":
169 | node_embedding = torch.zeros((1, 256), dtype=torch.float32).to(device)
170 | node_embeddings[node_id] = [node_embedding]
171 | sentence_dict[index_counter] = ""
172 | node_id_to_index[node_id] = [index_counter]
173 | index_counter += 1
174 | else:
175 | # 手动切词
176 | tokens = tokenizor.tokenize(sentence)
177 | num_tokens = len(tokens)
178 | num_segments = (num_tokens + 511) // 512 # Calculate number of segments
179 | embeddings = []
180 | segments = []
181 | node_id_to_index[node_id] = list(range(index_counter, index_counter + num_segments))
182 | for i in range(num_segments):
183 | start = i * 512
184 | end = min((i + 1) * 512, num_tokens)
185 | segment_tokens = tokens[start:end]
186 | segment_sentence = tokenizor.convert_tokens_to_string(segment_tokens)
187 | segment_ids = tokenizor.encode(segment_sentence, return_tensors="pt").to(device)
188 | with torch.no_grad():
189 | segment_embedding = model(segment_ids)
190 | embeddings.append(segment_embedding)
191 | segments.append(segment_sentence)
192 | sentence_dict[index_counter] = segment_sentence
193 | index_counter += 1
194 |
195 | node_embeddings[node_id] = embeddings
196 |
197 | num_nodes = index_counter
198 |
199 | if save_adj:
200 | adj_matrix = torch.zeros((num_nodes, num_nodes))
201 |
202 | for edge in data['edges']:
203 | source_id = edge['source']
204 | target_id = edge['target']
205 | source_indices = node_id_to_index.get(source_id)
206 | target_indices = node_id_to_index.get(target_id)
207 | if source_indices is None or target_indices is None:
208 | # if source_indices is None:
209 | # print(f"{source_id} not exists")
210 | # if target_indices is None:
211 | # print(f"{target_id} not exists")
212 | continue
213 |
214 | for source_index in source_indices:
215 | for target_index in target_indices:
216 | adj_matrix[source_index, target_index] = 1
217 |
218 | # Connect embeddings of the same node
219 | for node_id, indices in node_id_to_index.items():
220 | for i in range(len(indices)):
221 | for j in range(i + 1, len(indices)):
222 | adj_matrix[indices[i], indices[j]] = 1
223 | adj_matrix[indices[j], indices[i]] = 1
224 | else:
225 | adj_matrix = None
226 |
227 | all_embeddings = []
228 | for value in node_embeddings.values():
229 | if isinstance(value, torch.Tensor):
230 | all_embeddings.append(value)
231 | elif isinstance(value, list):
232 | for tensor in value:
233 | all_embeddings.append(tensor)
234 |
235 | embeddings = torch.stack(all_embeddings, dim=0)
236 |
237 | # embeddings = torch.stack(list(node_embeddings.values()))
238 | # embeddings = torch.stack(sum(node_embeddings.values(), []))
239 | # embeddings = torch.cat(list(node_embeddings.values()), dim=0)
240 |
241 | return embeddings, adj_matrix, sentence_dict
242 |
243 | def preprocess_graph(graphdir, savedir, recdir, jsondir, language = 'java', model = None, tokenizor = None, filenum = 1, suffix = 'pt', node_limit = 20000, save_adj = True, save_rec = True):
244 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
245 | logger = logging.getLogger(__name__)
246 | logger.info(f"Parsing json file: {jsondir}{filenum}.json")
247 |
248 | if not os.path.exists(savedir):
249 | os.makedirs(savedir)
250 |
251 | if not os.path.exists(recdir):
252 | os.makedirs(recdir)
253 |
254 | if jsondir == graphdir:
255 | glist = os.listdir(graphdir)
256 | else:
257 | with open(f'{jsondir}{filenum}.json', 'r') as f:
258 | glist = json.load(f)
259 | f.close()
260 |
261 | for gname in glist:
262 | if gname.startswith('._'):
263 | gname = gname[2:]
264 | raw_file = os.path.join(graphdir, gname)
265 | print(gname)
266 |
267 | if len(gname.split('#')) == 2:
268 | appName = gname.split('#')[1].split('-graph.json')[0]
269 | repoName = appName
270 | groupName = gname.split('#')[0]
271 | commitId = '0'
272 | elif len(gname.split('#')) == 3:
273 | appName = gname.split('#')[1]
274 | repoName = appName
275 | groupName = gname.split('#')[0]
276 | commitId = gname.split('#')[2].split('.graph.json')[0]
277 | elif len(gname.split('___')) == 3:
278 | parts = gname.split('___')
279 | appName = parts[0]
280 | repoName = parts[1].split('__')[1]
281 | groupName = parts[1].split('__')[0]
282 | commitId = parts[2].split('.')[0]
283 | else:
284 | print(f"{gname} can't be renamed")
285 | continue
286 | tmp1 = f"{appName}___{repoName}___{groupName}___{commitId}.{suffix}"
287 | tmp2 = f"{appName}___{repoName}___{groupName}___{commitId}.json"
288 | print(tmp1)
289 |
290 | save_file = os.path.join(savedir, tmp1)
291 | rec_file = os.path.join(recdir, tmp2)
292 |
293 | if not os.path.exists(raw_file):
294 | continue
295 | if os.path.exists(save_file) and os.path.exists(rec_file):
296 | continue
297 | logger.info(f'Start {gname} transforming...')
298 | try:
299 | with open(raw_file, 'r') as f1:
300 | content = f1.read()
301 | data = json.loads(content)
302 |
303 | if len(data['nodes']) > node_limit:
304 | continue
305 | embeddings, adj_matrix, sentence_dict = graph2embedding(data, model, tokenizor, gname, language, save_adj)
306 | f1.close()
307 |
308 | if suffix == 'json':
309 | if save_adj:
310 | data_dict = {
311 | "embeddings": embeddings.tolist(),
312 | "adj_matrix": adj_matrix.tolist()
313 | }
314 | else:
315 | data_dict = {
316 | "embeddings": embeddings.tolist(),
317 | }
318 |
319 | with open(save_file, 'w') as f:
320 | json.dump(data_dict, f)
321 | f.close()
322 |
323 | elif suffix == 'pt':
324 | if save_adj:
325 | data_dict = {
326 | "embeddings": embeddings.detach(),
327 | "adj_matrix": adj_matrix.detach()
328 | }
329 | else:
330 | data_dict = {
331 | "embeddings": embeddings.detach(),
332 | }
333 | torch.save(data_dict, save_file)
334 |
335 | if save_rec:
336 | rec_dict = {
337 | "text": list(sentence_dict.values())
338 | }
339 |
340 | with open(rec_file, 'w') as f:
341 | json.dump(rec_dict, f)
342 | f.close()
343 | except json.JSONDecodeError as e:
344 | print('Json Decode Error: '+ gname)
345 |
346 | def preprocess(graphdir, savedir, recdir, jsondir, language = 'java', mode = 'pretrain', filenum = 1, suffix = 'pt', node_limit = 20000, save_adj = True, save_rec = True):
347 |
348 | model1_path = "salesforce/codet5p-110m-embedding"
349 | tokenizer1 = AutoTokenizer.from_pretrained(model1_path, trust_remote_code=True, device = device)
350 | model1 = AutoModel.from_pretrained(model1_path, trust_remote_code=True, torch_dtype="auto").to(device).eval()
351 |
352 | if mode == 'pretrain':
353 | preprocess_graph(
354 | graphdir=graphdir,
355 | savedir=savedir,
356 | recdir=recdir,
357 | jsondir=jsondir,
358 | language=language,
359 | model=model1,
360 | tokenizor=tokenizer1,
361 | filenum=filenum,
362 | suffix=suffix,
363 | node_limit=node_limit,
364 | save_adj=save_adj,
365 | save_rec=save_rec)
366 | else:
367 | raise NotImplementedError
368 |
369 | def json_split(loaddirs, savedir, split_num=64):
370 | if not os.path.exists(savedir):
371 | os.makedirs(savedir)
372 |
373 | file_list = []
374 | for loaddir in loaddirs:
375 | file_list += os.listdir(loaddir)
376 | total_num = len(file_list)
377 | sep_num = total_num // split_num
378 | print(f'total num: {total_num}, sep num: {sep_num}')
379 |
380 | for i in range(split_num):
381 | start = i * sep_num
382 | end = start + sep_num if i != split_num - 1 else total_num
383 | with open(f'{savedir}/{i+1}.json', 'w') as f:
384 | json.dump(file_list[start:end], f)
385 | f.close()
386 |
387 | def json_split_from_json(input_json, savedir, split_num=64):
388 | with open(input_json, 'r') as file:
389 | data = json.load(file)
390 |
391 | total_items = len(data)
392 | num_files = (total_items + split_num - 1) // split_num # 向上取整
393 |
394 | if not os.path.exists(savedir):
395 | os.makedirs(savedir)
396 |
397 | for i in range(num_files):
398 | start = i * split_num
399 | end = min(start + split_num, total_items)
400 | split_data = data[start:end]
401 |
402 | save_file = os.path.join(savedir, f"{i+1}.json")
403 |
404 | with open(save_file, 'w') as file:
405 | json.dump(split_data, file, indent=4)
406 |
407 | def detect_pt_file_errors(directory, output_json):
408 | error_files = []
409 |
410 | for root, _, files in os.walk(directory):
411 | for file in files:
412 | if file.endswith('.pt'):
413 | file_path = os.path.join(root, file)
414 | try:
415 | with open(file_path, 'rb') as f:
416 | tmp = torch.load(f)
417 | del tmp
418 | gc.collect()
419 | except Exception as e:
420 | error_files.append(file_path)
421 | print(f"Error loading {file_path}: {e}")
422 |
423 | with open(output_json, 'w') as f:
424 | json.dump(error_files, f, indent=4)
425 |
426 | print(f"Detected {len(error_files)} error files. Details saved in {output_json}")
427 |
428 | def transfer_pt_file_errors(input_json, output_json):
429 | with open(input_json, 'r') as file:
430 | data = json.load(file)
431 |
432 | def transform_path(path):
433 | repo_str = path.split('/')[-1]
434 | appName = repo_str.split('___')[0]
435 | groupName = repo_str.split('___')[2]
436 | new_repo_str = f"{groupName}#{appName}-graph.json"
437 | return new_repo_str
438 |
439 | transformed_data = [transform_path(path) for path in data]
440 |
441 | with open(output_json, 'w') as file:
442 | json.dump(transformed_data, file, indent=4)
443 |
444 | def get_list(graph_dirs):
445 | all_files = [os.path.join(graph_dir, file)
446 | for graph_dir in graph_dirs
447 | for file in os.listdir(graph_dir)]
448 | return all_files
449 |
450 | def get_list_constrained(graph_dirs, size_limit = 500 * 1024 * 1024):
451 | filtered_files = []
452 | for graph_dir in graph_dirs:
453 | glist = os.listdir(graph_dir)
454 | for file_name in glist:
455 | file_path = os.path.join(graph_dir, file_name)
456 | if os.path.isfile(file_path):
457 | file_size = os.path.getsize(file_path)
458 | if file_size < size_limit:
459 | filtered_files.append(file_path)
460 |
461 | return filtered_files
462 |
463 | def get_graph_path(glist, filename, suffix):
464 | sp = filename.split('___')
465 | if len(sp) == 4:
466 | appName = sp[0]
467 | repoName = sp[1]
468 | groupName = sp[2]
469 | commitId = sp[3].split('.')[0]
470 |
471 | matched_graphs = []
472 | for graph in glist:
473 | graph_parts = graph.split('/')[-1].split('___')
474 | if len(graph_parts) == 4:
475 | graph_appName = graph_parts[0]
476 | graph_repoName = graph_parts[1]
477 | graph_groupName = graph_parts[2]
478 | graph_commitId = graph_parts[3].split('.')[0]
479 |
480 | if graph_appName == appName:
481 | matched_graphs.append((graph, graph_repoName, graph_groupName, graph_commitId))
482 |
483 | if not matched_graphs:
484 | return None
485 |
486 | if not commitId == '0':
487 | for graph, graph_repoName, graph_groupName, graph_commitId in matched_graphs:
488 | if commitId == graph_commitId:
489 | return graph
490 |
491 | best_match = None
492 | best_match_score = -2
493 | for graph, graph_repoName, graph_groupName, _ in matched_graphs:
494 | score = (repoName == graph_repoName) + (groupName == graph_groupName)
495 | if score > best_match_score:
496 | best_match_score = score
497 | best_match = graph
498 |
499 | return best_match
500 | else:
501 | raise ValueError(f"{filename} to graph not supported")
502 |
503 | def split_jsonl_dataset(input_file, train_file, test_file, train_ratio=0.98):
504 | def read_jsonl(file_path):
505 | with open(file_path, 'r') as file:
506 | for line in file:
507 | yield json.loads(line)
508 |
509 | data = list(read_jsonl(input_file))
510 | repo_dict = defaultdict(list)
511 |
512 | for item in data:
513 | repo_dict[item['repo']].append(item)
514 |
515 | repos = list(repo_dict.keys())
516 | random.shuffle(repos)
517 |
518 | split_index = int(len(repos) * train_ratio)
519 | train_repos = repos[:split_index]
520 | test_repos = repos[split_index:]
521 |
522 | train_data = []
523 | test_data = []
524 |
525 | for repo in train_repos:
526 | train_data.extend(repo_dict[repo])
527 | for repo in test_repos:
528 | test_data.extend(repo_dict[repo])
529 |
530 | with open(train_file, 'w') as file:
531 | for item in train_data:
532 | file.write(json.dumps(item) + '\n')
533 |
534 | with open(test_file, 'w') as file:
535 | for item in test_data:
536 | file.write(json.dumps(item) + '\n')
537 |
538 |
539 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CGM: Code Graph LLM
2 |
3 | 
4 |
5 | ## Contents
6 | - [News](#news)
7 | - [Introduction](#introduction)
8 | - [Installation](#installation)
9 | - [Examples](#examples)
10 | - [Rewriter](#rewriter)
11 | - [Retriever](#retriever)
12 | - [Reranker](#reranker)
13 | - [Reader](#reader)
14 | - [Contributing](#contributing)
15 | - [Citation](#citation)
16 | - [Join Us](#join-us)
17 |
18 | ## News
19 |
20 | 🔥🔥🔥 [2025/09/19] Our paper [Code Graph Model (CGM): A Graph-Integrated Large Language Model for Repository-Level Software Engineering Tasks](https://arxiv.org/abs/2505.16901) has been accepted to NeurIPS 2025!
21 |
22 | 
23 |
24 | 🔥🔥🔥 [2025/01/15] We are pleased to announce the updated version of the CGM-72B-V1.2. The model further achieves a remarkable 44.00% resolve rate on the SWE-Bench-Lite leaderboard.
25 |
26 | 🔥🔥🔥 [2024/12/28] We are pleased to announce the updated version of the CGM-72B-V1.1. The model further achieves a remarkable 41.67% resolve rate on the SWE-Bench-Lite leaderboard.
27 |
28 | 🔥🔥🔥 [2024/10/28] We are pleased to announce that CGM-72B achieves a remarkable 35.67% resolve rate on the SWE-Bench-Lite leaderboard.
29 |
30 | 🔥🔥🔥 [2024/10/28] We released **CGM**, mainly for repository-level coding tasks.
31 |
32 | - 📜 **Paper**: [Code Graph Model (CGM): A Graph-Integrated Large Language Model for Repository-Level Software Engineering Tasks](https://arxiv.org/abs/2505.16901)
33 | - 🤖 **Model**: [codefuse-ai/CodeFuse-CGM-72B](https://huggingface.co/codefuse-ai/CodeFuse-CGM-72B)
34 | - 📊 **Data**: [codefuse-ai/CodeGraph](https://huggingface.co/datasets/codefuse-ai/CodeGraph)
35 |
36 | ## Introduction
37 | We propose a graph-based framework CGM for real-world SE tasks. Before CGM starts its work, we construct a repository-level code graph to better represent the repository context and its structure by Code Graph Generator. Inspired by the Retrieval-Augmented Generation (RAG) approach, CGM framework is designed as a chain structure consisting of four atomic nodes, termed as R4 (Rewriter, Retriever, Reranker, and Reader) chain for this scenario. Given an issue, the initial input to the CGM framework includes the issue description and the corresponding code graph. Rewriter will first rewrite the original issue by extracting keywords and generating relevant queries for code graph. Then a heuristic code subgraph is retrieved through Retriever based on the matching anchor nodes from rewriter output. Given that the resulting subgraph provides a relatively broad context necessary for reference, we need a Reranker to identify the files most likely to be modified as a further hint. Subsequently, both the retrieved subgraph and the identified files are input into a trainable, graph-based Reader to generate the corresponding code patch.
38 |
39 | ### Framework
40 |
41 | 
42 |
43 | ### Highlights
44 | :white_check_mark: **Code Graph**: Train models on multiple tasks while maintaining a balance between them. The models can even generalize to new, previously unseen tasks.
45 |
46 | :white_check_mark: **Multi-framework**: It provides support for both Accelerate (with Deepspeed and FSDP)
47 |
48 | :white_check_mark: **Efficient fine-tuning**: It supports LoRA, QLoRA as well as Full-parameters training, enabling fine-tuning of large models with minimal resources. The training speed meets the demands of almost all fine-tuning scenarios.
49 |
50 | ## Installation
51 | ### Prerequisites
52 | - Python 3.8+
53 | - pip
54 |
55 | ### Required Packages
56 |
57 | ```bash
58 | transformers==4.46.1
59 | tokenizers==0.20.0
60 | accelerate==1.0.1
61 | peft==0.13.2
62 | jinja2==2.11.3
63 | fuzzywuzzy==0.18.0
64 | python-Levenshtein==0.25.1
65 | networkx==3.0
66 | ```
67 |
68 | ## Examples
69 |
70 | The following chart illustrates the whole processing pipeline of R3.
71 | 
72 |
73 |
74 | ### Pre-process for Retriever: Generate Node Embedding
75 |
76 | Before Retriever, we need to embed
77 | - all the nodes in Code Graph into embeddings
78 | - Queries generated by Rewriter, into embeddings
79 |
80 | by [CGE-large](https://huggingface.co/codefuse-ai/CodeFuse-CGE-Large)
81 |
82 | ```bash
83 | python generate_code_content.py # this step will preprocess each code graph and extract content from each node (save a .json file for each repo)
84 | python generate_code_embedding.py # this step will generate embedding for each node by CGE-large (save a .pkl file for each repo)
85 | python generate_rewriter_embedding.py # this step will generate embedding for each Query enerated by Rewriter (save a .pkl file)
86 | ```
87 |
88 | Requirements for CGE-large
89 | ```bash
90 | torch==2.1.0
91 | transformers==4.39.2
92 | tokenizers==0.15.2
93 | accelerate==0.28.0
94 | ```
95 |
96 | ### Rewriter
97 |
98 | Given issues meta data, execute the following scripts to generate Rewriter results (both Inferer and Extractor).
99 |
100 | ```bash
101 | python generate_rewriter_prompt.py # this step will generate a json file "test_rewriter_prompt.json" containing prompts used in Rewriter
102 |
103 | python inference_rewriter.py --prompt_path test_rewriter_prompt.json # this step will load Qwen Model to execuate inference and generate Rewriter's output "test_rewriter_output.json"
104 |
105 | python rewriter_output_post_processing.py # this step will load Rewriter's output "rewriter_output.json" and generate post processed output "test_rewriter_output.json
106 | ```
107 |
108 | Use function ```generate_prompt_for_extractor``` and ```generate_prompt_for_inferer``` in ```rewriter/prompt.py```
109 | ```python
110 | def generate_prompt_for_extractor(problem_statement, repo_name):
111 | prompt = """
112 |
113 | {}
114 |
115 | This is an issue related to repository '{}'.
116 | Instructions:
117 | 1. Analysis:
118 | ○ Analyze the provided issue description. Identify the relevant File, Class, or Function involved.
119 | ○ Determine the specific problem or error encountered and note any clues that may assist in locating the relevant or problematic area.
120 | 2. Extraction:
121 | ○ After the analysis, extract ALL the mentioned code entities (File, Class, or Function), especially Files.
122 | ○ Then extract three potential and meaningful keywords, responding in the following format:
123 |
124 | [start_of_analysis]
125 |
126 | [end_of_analysis]
127 |
128 | [start_of_related_code_entities]
129 |
130 | [end_of_related_code_entities]
131 |
132 | [start_of_related_keywords]
133 |
134 | [end_of_related_keywords]
135 |
136 | Notes:
137 | - Pay attention to the information in the error logs (if exists).
138 | - The buggy code exists solely in the project described in the issue (e.g., django, sklearn). Buggy location is usually not in the tests files or external packages.
139 | - Your extracted entities should be CONCISE, ACCURATE and INFORMATIVE.
140 | - Provide the relative path for code entities if specified (e.g., package/foo.py). Relative path is relative to the repository itself, do not include suffix like '/home/username/', '/etc/service/' or '/tree/master'.
141 | - Do not include any additional information such as line numbers or explanations in your extraction result.
142 |
143 | Preferred extraction Examples of Code Entities:
144 | - repo/cart.py
145 | - Class User()
146 | - def getData()
147 | Preferred extraction Examples of Keywords:
148 | - train_loop
149 | - hooks
150 | - docker
151 |
152 | Unpreferred extraction Examples of keywords:
153 | - something wrong
154 | - input validation
155 | - TypeError
156 | """.format(problem_statement, repo_name)
157 |
158 | return prompt
159 |
160 | def generate_prompt_for_inferer(problem_statement, repo_name):
161 | prompt = """
162 |
163 | {}
164 |
165 | This is an issue related to repository '{}'.
166 | Task:
167 | Based on the issue description provided, identify the characteristics of code entities (files, functions, class) that might need to be modified.
168 | For each characteristic, generate a search query that could help locate relevant code entities in a codebase.
169 | Instructions:
170 | First, analyze the issue description and identify keywords, features, and functionalities that are likely relevant to the modification of code entities.
171 | Then, create queries that capture these characteristics, focusing on:
172 | ● File names that may implement relevant functionalities.
173 | ● Functions or methods that are related to the features described in the issue.
174 | ● Any patterns or structures that might be relevant to the functionalities mentioned.
175 | For example:
176 | ● File related to the initialization of a neural network.
177 | ● Function related to the training process.
178 | ● Code used to configure the service.
179 | Please answer in the following format:
180 |
181 | [start_of_analysis]
182 |
183 | [end_of_analysis]
184 |
185 | [start_of_related_queries]
186 | query 1:
187 | query 2:
188 | ...
189 | [end_of_related_queries]
190 |
191 | Notes:
192 | - Your queries should be DETAILED, ACCURATE and INFORMATIVE.
193 | - Your queries should be a complete sentences and do not include additional explanation.
194 | - The number of queries is up to five, so be focus on the important characteristics.
195 | - Your queries should focus on the repository code itself, rather than other information like commit history.
196 | - Pay attention to the information in the error logs (if exists).
197 |
198 | Preferred Query Examples:
199 | - Look for references to "tqdm" or "progress_bar" within the training loop files to find where progress bars are currently updated.
200 | - Code snippets where 'gethostbyname' function from 'socket' module is called.
201 | - File name containing 'mysql.py' AND functions related to 'MySQLStatementSamples' initialization.
202 | - Functions or methods handling hostname resolution or encoding within 'datadog_checks' directory.
203 | - Find all occurrences of "early_stopping" within files that also mention "Trainer" to identify where early stopping logic is implemented and potentially needs adjustment for non-default 'val_check_interval'.
204 | """.format(problem_statement, repo_name)
205 |
206 | return prompt
207 | ```
208 | You can use the rewriter prompt by
209 | ```python
210 | from rewriter.prompt import generate_prompt_for_extractor, generate_prompt_for_inferer
211 |
212 | # Generate extraction prompt
213 | extraction_prompt = generate_prompt_for_extractor(problem_statement, repo_name)
214 |
215 | # Generate inference prompt
216 | inference_prompt = generate_prompt_for_inferer(problem_statement, repo_name)
217 | ```
218 |
219 | ### Retriever
220 | Now, we have
221 | - Original CodeGraph: `codegraph/`
222 | - Node Embedding of CodeGraph: `node_embedding/`
223 | - Query Embedding of Rewriter's Inferer: `rewriter_embedding.pkl`
224 | - Output of Rewriter's Extractor: `rewriter_output.json`
225 |
226 | Then we can execute Retriever:
227 |
228 | ```bash
229 | python locate_anchor_node.py # this step will used the above input and generate ancode nodes "anchor_node.json" for all samples
230 | python subgraph.py # this step will then expand the anchor nodes to a connected subgraph (saving as a set of node_id)
231 | python serialize_subgraph.py # based on the above subgraph node ids, this step will serialize the subgraph into json format (which is the final output of Retriever)
232 | ```
233 |
234 | Requirements for Retriever
235 | ```bash
236 | RapidFuzz==1.5.0
237 | faiss-cpu
238 | ```
239 |
240 | ### Reranker
241 |
242 | Reranker is used to determine the most relevant files from the subgraph generated by Retriever. The input is the subgraph json file which is the output of Retriever
243 |
244 | ```bash
245 | python reranker.py --stage_1_k 10 --stage_2_k 5 # this step will load the subgraph json file and generate the output of Reranker.
246 | ```
247 |
248 | Requirements for Reranker
249 | ```bash
250 | vllm>=0.8.5
251 | ```
252 |
253 | Use function ```generate_prompt_for_reranker_stage_1``` and ```generate_prompt_for_reranker_stage_2``` in ```reranker/prompt.py```
254 | ```python
255 | """
256 | Prompt Template for Reranker
257 | """
258 |
259 | reranker_stage_1_system_prompt = """
260 | You are an experienced software developer who specializes in extracting the most relevant files for solving issues from many reference files.
261 |
262 | Task:
263 | Based on the information received about the issue from a repository, find the most likely few files from among those that may be able to resolve the issue.
264 |
265 | Instructions:
266 | 1. Analysis:
267 | - Analyze the provided issue description and files, and pay attention to the relevance of the provided files with the given issue, especially those might be modified during fixing the issue.
268 | - Determine the specific problem or error mentioned in the issue and note any clues that could help your judgment.
269 | 2. Extraction:
270 | - Based on your analysis, choose the Top **1** relevant files which might be used in fixing the issue.
271 | - You should choose files from the provided files, and should not modify their name in any way.
272 |
273 | Respond in the following format:
274 | [start_of_analysis]
275 |
276 | [end_of_analysis]
277 |
278 | [start_of_relevant_files]
279 | 1.
280 | 2.
281 | 3. ...
282 | [end_of_relevant_files]
283 |
284 | Notes:
285 | - You can refer to to the information in the error logs (if exists).
286 | - The relevant file usually exists in the project described in the issue (e.g., django, sklearn). File need modification is usually not in the tests files or external packages.
287 | - The file you choose should be contained in the provided files.
288 | - Provide the file path with files. Do not include redundant suffix like '/home/username/', '/etc/service/' or '/tree/master'.
289 | - Do not include any additional information such as line numbers or explanations in your extraction result.
290 | - Files for initialization and configuration might be modified during changing the code.
291 |
292 | Preferred extraction Examples of Related Files:
293 | 1. src/utils/file_handler.py
294 | 2. core/services/service_manager.py
295 | 3. ...
296 | """.strip()
297 |
298 | reranker_stage_1_user_prompt_template = """
299 |
300 | {}
301 |
302 |
303 |
304 | {}
305 |
306 |
307 |
308 | {}
309 |
310 |
311 |
312 | {}
313 |
314 | """
315 |
316 | reranker_stage_2_system_prompt = """
317 | You are an experienced software developer who specializes in assessing the relevance of the file for solving the issue in software repositories.
318 |
319 | Task:
320 | For a file provided, evaluate the likelihood that modifying this file would resolve the given issue, and assign a score based on specific criteria.
321 |
322 | Instructions:
323 | 1. Analysis:
324 | - Analyze the provided issue description and the content of the single relevant file, pay attention to any keywords, error messages, or specific functionalities mentioned that relate to the file.
325 | - Determine how closely the contents and functionality of the file are tied to the problem or error described in the issue.
326 | - Consider the role of the file in the overall project structure (e.g., configuration files, core logic files versus test files, or utility scripts).
327 | 2. Scoring:
328 | - Based on your analysis, assign a score from 1 to 5 that represents the relevance of modifying the given file in order to solve the issue.
329 |
330 | Score Specifications:
331 | 1. **Score 1**: The file is almost certainly unrelated to the issue, with no apparent connection to the functionality or error described in the issue.
332 | 2. **Score 2**: The file may be tangentially related, but modifying it is unlikely to resolve the issue directly; possible in rare edge cases.
333 | 3. **Score 3**: The file has some relevance to the issue; it might interact with the affected functionality indirectly and tweaking it could be part of a broader fix.
334 | 4. **Score 4**: The file is likely related to the issue; it includes code that interacts directly with the functionality in question and could plausibly contain bugs that lead to the issue.
335 | 5. **Score 5**: The file is very likely the root cause or heavily involved in the issue and modifying it should directly address the error or problem mentioned.
336 |
337 | Respond in the following format:
338 | [start_of_analysis]
339 |
340 | [end_of_analysis]
341 |
342 | [start_of_score]
343 | Score
344 | [end_of_score]
345 |
346 | Notes:
347 | - The content of the file shows only the structure of this file, including the names of the classes and functions defined in this file.
348 | - You can refer to to the information in the error logs (if exists).
349 | """.strip()
350 |
351 | reranker_stage_2_user_prompt_template = """
352 |
353 | {}
354 |
355 |
356 |
357 | {}
358 |
359 |
360 |
361 | {}
362 |
363 |
364 |
365 | {}
366 |
367 | """
368 |
369 | def generate_prompt_for_reranker_stage_1(problem_statement, repo_name, py_file, other_file):
370 | """
371 | problem_statement: issue
372 | repo_name: repo
373 | py_file: py file list
374 | other_file: related file list
375 | """
376 | return reranker_stage_1_system_prompt, reranker_stage_1_user_prompt_template.format(repo_name, problem_statement, py_file, other_file)
377 |
378 | def generate_prompt_for_reranker_stage_2(problem_statement, repo_name, file_name, file_content):
379 | """
380 | problem_statement: issue
381 | repo_name: repo
382 | file_name: file
383 | file_content: file content(class xxx和def xxx)
384 | """
385 | return reranker_stage_2_system_prompt, reranker_stage_2_user_prompt_template.format(repo_name, problem_statement, file_name, file_content)
386 | ```
387 | You can use the reranker prompt by
388 | ```python
389 | from reranker.prompt import generate_prompt_for_reranker_stage_1, generate_prompt_for_reranker_stage_2
390 |
391 | # Stage 1: Identify relevant files
392 | system_prompt, user_prompt = generate_prompt_for_reranker_stage_1(
393 | problem_statement,
394 | repo_name,
395 | py_file_list,
396 | other_file_list
397 | )
398 |
399 | # Stage 2: Score file relevance
400 | system_prompt, user_prompt = generate_prompt_for_reranker_stage_2(
401 | problem_statement,
402 | repo_name,
403 | target_file,
404 | file_content
405 | )
406 | ```
407 |
408 | ### Reader
409 | Execute the Reader module with DeepSpeed configurations:
410 | ```bash
411 | # Zero-2 Configuration
412 | EXPORT N_NODE={YOUR_MACHINE_NUM} && \
413 | EXPORT N_GPU_PER_NODE={YOUR_GPU_NUM} && \
414 | EXPORT TRAIN_CONFIG={TRAIN_CONFIG}.json && \
415 | bash launch/zero2.sh
416 |
417 | # Zero-3 Configuration
418 | EXPORT N_NODE={YOUR_MACHINE_NUM} && \
419 | EXPORT N_GPU_PER_NODE={YOUR_GPU_NUM} && \
420 | EXPORT TRAIN_CONFIG={TRAIN_CONFIG}.json && \
421 | bash launch/zero3.sh
422 | ```
423 |
424 | ## Contributing
425 | Contributions are welcome! If you have any suggestions, ideas, bug reports, or new model/feature supported, please open an issue or submit a pull request.
426 |
427 | We welcome contributions from the community! Please follow these guidelines:
428 |
429 | 1. Fork the repository
430 |
431 | 2. Create your feature branch
432 |
433 | 3. Commit your changes
434 |
435 | 4. Push to the branch
436 |
437 | 5. Open a Pull Request
438 |
439 | For major changes, please open an issue first to discuss the proposed changes.
440 |
441 |
442 | ## Citation
443 | If you find our work useful or helpful for your R&D works, please feel free to cite our paper as below.
444 | ```bibtex
445 | @misc{tao2025codegraphmodelcgm,
446 | title={Code Graph Model (CGM): A Graph-Integrated Large Language Model for Repository-Level Software Engineering Tasks},
447 | author={Hongyuan Tao and Ying Zhang and Zhenhao Tang and Hongen Peng and Xukun Zhu and Bingchang Liu and Yingguang Yang and Ziyin Zhang and Zhaogui Xu and Haipeng Zhang and Linchao Zhu and Rui Wang and Hang Yu and Jianguo Li and Peng Di},
448 | year={2025},
449 | eprint={2505.16901},
450 | archivePrefix={arXiv},
451 | primaryClass={cs.SE},
452 | url={https://arxiv.org/abs/2505.16901},
453 | }
454 | ```
455 | ## Join-US
456 |
457 | We are the AI Native team within the Platform Technology Business Group at Ant Group, dedicated to the intelligentization of Ant Group's platform engineering. Established for over three years, our team has played a pivotal role in supporting the intelligent operation and maintenance of Ant Group's cloud computing infrastructure. Our mission is to build algorithm services and platforms with a wide user base through world-class technological innovation and impact, supporting the implementation of internal and external products and businesses.
458 | Embracing an innovation-driven ethos, our team not only supports business implementation but also propels technological influence. Over the past three years, we have published more than 20 papers at top conferences like ICLR, NeurIPS, KDD, and ACL. Our innovative business outcomes have earned us two Ant Technology's highest T-Star awards and one SuperMA award from Ant Group. Our open-source project CodeFuse has received 4K stars as of February 2024, and our models have been downloaded over 1.5 million times on Huggingface and Modelscope.
459 |
460 | We are on the lookout for top talents to join our vibrant team! If you're eager to develop your career in an environment filled with energy, innovation, and a culture of excellence, we welcome you to explore our career opportunities for both campus and experienced hires. Join us and be a part of creating the next milestone in the industry.
461 |
462 | **Contact**: hyu.hugo@antgroup.com
463 |
--------------------------------------------------------------------------------
/cgm/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.functional import one_hot
7 | from tqdm.auto import tqdm
8 | import time
9 | import datetime
10 | from collections import OrderedDict
11 |
12 | sys.path.append("..")
13 | from utils.common_utils import touch_print, print_rank_0
14 | from utils.loss import loss_CGM
15 | from torch.utils.tensorboard import SummaryWriter
16 | from accelerate.logging import get_logger
17 | from torch.cuda.amp import autocast
18 |
19 | logger = get_logger(__name__)
20 |
21 | task_ids = {
22 | (0, 'graph_query'),
23 | (1, 'api'),
24 | (2, 'issue_fix'),
25 | (3, 'unit_test'),
26 | (4, 'readme_summary'),
27 | }
28 |
29 | task_to_id = {task: idx for idx, task in task_ids}
30 | id_to_task = {idx: task for idx, task in task_ids}
31 |
32 | def check_weight_dtype(model):
33 | for name, param in model.named_parameters():
34 | print_rank_0(f"Layer {name}: {param.dtype}")
35 |
36 | def write_tensorboard(summary_writer: SummaryWriter, log_dict: dict, completed_steps):
37 | for key, value in log_dict.items():
38 | summary_writer.add_scalar(f'{key}', value, completed_steps)
39 |
40 | def accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir: str, completed_steps: int, args):
41 | accelerator.wait_for_everyone()
42 |
43 | accelerator.print(f"[CHECKPOINT] Saving checkpoint")
44 | unwrapped_model = accelerator.unwrap_model(model)
45 |
46 | save_encoder = False
47 | save_adapter = False
48 | save_lm = False
49 | if 'e' in args.mode:
50 | save_encoder = True
51 | if 'a' in args.mode:
52 | save_adapter = True
53 | if 'l' in args.mode:
54 | save_lm = True
55 |
56 | if accelerator.is_main_process:
57 | if not os.path.exists(output_dir):
58 | os.makedirs(output_dir)
59 | tokenizer.save_pretrained(output_dir)
60 |
61 | if save_adapter:
62 | torch.save(accelerator.get_state_dict(model.adapter), f"{output_dir}/adapter.pth")
63 |
64 |
65 | if save_encoder:
66 | unwrapped_model.encoder.save_pretrained(
67 | f"{output_dir}/encoder",
68 | is_main_process=accelerator.is_main_process,
69 | save_function=accelerator.save,
70 | state_dict=accelerator.get_state_dict(model.encoder)
71 | )
72 |
73 | if save_lm:
74 | unwrapped_model.lm.save_pretrained(
75 | output_dir,
76 | is_main_process=accelerator.is_main_process,
77 | save_function=accelerator.save,
78 | state_dict=accelerator.get_state_dict(model.lm)
79 | )
80 |
81 | accelerator.print(
82 | f"[CHECKPOINT][complete_steps={completed_steps}], checkpoint {output_dir} saved"
83 | )
84 |
85 | accelerator.wait_for_everyone()
86 |
87 | def accelerate_evaluate_CGM(accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, min_eval_loss, stall_num,
88 | best_step, summary_writer):
89 |
90 | losses = []
91 | end_eval = False
92 | eval_step = 0
93 | for batch in valid_dataloader:
94 | with torch.no_grad():
95 | if args.task == 'mft':
96 | task = batch['task']
97 |
98 | x = batch['x']
99 | y = batch['y']
100 | loss_mask = batch['loss_mask']
101 | qa_mask = batch['qa_mask']
102 | embeddings = batch['embeddings']
103 |
104 | len_y = y.shape[1]
105 |
106 | outputs = model(
107 | graph_embeddings = embeddings,
108 | qa_embeddings = x,
109 | qa_mask = qa_mask
110 | )
111 | output_logits = outputs['logits'][:,-len_y:,:]
112 |
113 | if args.task == 'mft':
114 | loss_dict = {task_id: torch.tensor(0.0, device=output_logits.device) for task_id, _ in task_ids}
115 |
116 | task_id = task.item()
117 | loss_dict[task_id] += loss_CGM(
118 | output_logits = output_logits,
119 | labels = y,
120 | loss_mask = loss_mask,
121 | )
122 |
123 | loss = sum(loss_dict.values())
124 | else:
125 | loss = loss_CGM(
126 | output_logits = output_logits,
127 | labels = y,
128 | loss_mask = loss_mask,
129 | )
130 |
131 | eval_step += 1
132 |
133 | losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
134 |
135 | accelerator.wait_for_everyone()
136 | valid_batch_num = len(losses)
137 | gathered_size = losses[0].shape
138 | losses = torch.cat(losses)
139 |
140 | try:
141 | eval_loss = torch.mean(losses)
142 | if eval_loss <= min_eval_loss:
143 | min_eval_loss = eval_loss
144 | stall_num = 0
145 | best_step = completed_steps
146 | else:
147 | stall_num += 1
148 | perplexity = math.exp(eval_loss)
149 | except OverflowError:
150 | perplexity = float("inf")
151 |
152 | logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]"
153 | f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]"
154 | f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}]")
155 | eval_log_dict = {
156 | "valid/valid_loss": eval_loss.float(),
157 | "valid/perplexity": perplexity
158 | }
159 |
160 | if accelerator.is_main_process:
161 | write_tensorboard(summary_writer, eval_log_dict, completed_steps)
162 |
163 | return eval_loss, min_eval_loss, stall_num, best_step
164 |
165 | def accelerate_evaluate_CGM_mft(accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step, min_eval_loss, stall_num,
166 | best_step, summary_writer):
167 |
168 | losses = []
169 | task_eval_counts = {}
170 | task_losses = {}
171 | eval_step = 0
172 |
173 | for batch in valid_dataloader:
174 | with torch.no_grad():
175 | if args.task == 'mft':
176 | task = batch['task']
177 | task_id = task.item()
178 | if task_id not in task_eval_counts:
179 | task_eval_counts[task_id] = 0
180 | task_losses[task_id] = []
181 | if task_eval_counts[task_id] >= 50:
182 | continue
183 |
184 | x = batch['x']
185 | y = batch['y']
186 | loss_mask = batch['loss_mask']
187 | qa_mask = batch['qa_mask']
188 | embeddings = batch['embeddings']
189 |
190 | len_y = y.shape[1]
191 |
192 | outputs = model(
193 | graph_embeddings = embeddings,
194 | qa_embeddings = x,
195 | qa_mask = qa_mask
196 | )
197 | output_logits = outputs['logits'][:,-len_y:,:]
198 |
199 | if args.task == 'mft':
200 | loss = loss_CGM(
201 | output_logits = output_logits,
202 | labels = y,
203 | loss_mask = loss_mask,
204 | )
205 | task_losses[task_id].append(loss.item())
206 | else:
207 | loss = loss_CGM(
208 | output_logits = output_logits,
209 | labels = y,
210 | loss_mask = loss_mask,
211 | )
212 |
213 | eval_step += 1
214 | losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
215 | task_eval_counts[task_id] += 1
216 |
217 | accelerator.wait_for_everyone()
218 | valid_batch_num = len(losses)
219 | gathered_size = losses[0].shape
220 | losses = torch.cat(losses)
221 |
222 | try:
223 | eval_loss = torch.mean(losses)
224 | if eval_loss <= min_eval_loss:
225 | min_eval_loss = eval_loss
226 | stall_num = 0
227 | best_step = completed_steps
228 | else:
229 | stall_num += 1
230 | perplexity = math.exp(eval_loss)
231 | except OverflowError:
232 | perplexity = float("inf")
233 |
234 | logger.info(f"[EVAL][global_steps={step + 1}][completed_steps={completed_steps}]"
235 | f"[valid_batch_num={valid_batch_num}], [gather_size={gathered_size}]"
236 | f"[perplexity={perplexity:.4f}][eval_loss={eval_loss:.6f}]")
237 |
238 | for task_id, task_loss_list in task_losses.items():
239 | task_eval_loss = sum(task_loss_list) / len(task_loss_list) if task_loss_list else 0.0
240 | logger.info(f"[EVAL][task_id={task_id}][task_loss={task_eval_loss:.6f}]")
241 | eval_log_dict = {
242 | "valid/valid_loss": eval_loss.float(),
243 | "valid/perplexity": perplexity,
244 | f"valid/{id_to_task[task_id]}": task_eval_loss
245 | }
246 |
247 | if accelerator.is_main_process:
248 | write_tensorboard(summary_writer, eval_log_dict, completed_steps)
249 |
250 | return eval_loss, min_eval_loss, stall_num, best_step
251 |
252 | def accelerate_monitor_CGM_mft(accelerator, reduce_loss_dict, args, completed_steps,
253 | lr_scheduler, optimizer, summary_writer):
254 |
255 | """
256 | gather reduce_loss from all N devices.
257 | train logging and tensorboarding.
258 | """
259 | # gathered_loss_dict = {task_id: accelerator.gather(reduce_loss) for task_id, reduce_loss in reduce_loss_dict.items()}
260 | gathered_loss_dict = {task_id: reduce_loss for task_id, reduce_loss in reduce_loss_dict.items()}
261 | print_rank_0(f"*******************gathered_loss_dict*******************")
262 | print_rank_0(gathered_loss_dict)
263 |
264 | train_log_dict = {
265 | f"train/{id_to_task[task_id]}": torch.mean(gathered_loss) / max(reduce_loss_count_dict[task_id], 1)
266 | for task_id, gathered_loss in gathered_loss_dict.items()
267 | }
268 | train_log_dict["train/lr"] = optimizer.param_groups[0]['lr']
269 |
270 | logger.info(
271 | f"[TRAIN][completed_steps={completed_steps}]"
272 | f"[lr={optimizer.param_groups[0]['lr']:.4e}]",
273 | )
274 | for task_id, train_loss in train_log_dict.items():
275 | if task_id != "train/lr":
276 | logger.info(f"{task_id}={train_loss:.6f}")
277 |
278 | if accelerator.is_main_process:
279 | write_tensorboard(summary_writer, train_log_dict, completed_steps)
280 |
281 |
282 | def accelerate_monitor_CGM(accelerator, reduce_loss, args, completed_steps,
283 | lr_scheduler, optimizer, summary_writer):
284 |
285 | """
286 | gather reduce_loss from all N devices.
287 | train logging and tensorboarding.
288 | """
289 | reduce_losses = accelerator.gather(reduce_loss)
290 | # reduce_losses = reduce_loss
291 |
292 | train_loss = torch.mean(reduce_losses) / (args.log_interval * args.gradient_accumulation_steps)
293 |
294 |
295 | logger.info(
296 | f"[TRAIN][complete_steps={completed_steps}][train_loss={train_loss:.6f}]"
297 | f"[gather shape={reduce_losses.shape}][lr={optimizer.param_groups[0]['lr']:.4e}]",
298 | )
299 |
300 | train_log_dict = {
301 | "train/train_loss": train_loss,
302 | "train/lr": optimizer.param_groups[0]['lr']
303 | }
304 |
305 | if accelerator.is_main_process:
306 | write_tensorboard(summary_writer, train_log_dict, completed_steps)
307 |
308 |
309 | def accelerate_train_CGM(accelerator, model, train_dataloader, valid_dataloader, optimizer, lr_scheduler, tokenizer,
310 | total_train_dataset_size, args):
311 |
312 | summary_writer = SummaryWriter(log_dir=args.tb_dir, filename_suffix=args.tb_dir.split('/')[-1]) if accelerator.is_main_process else None
313 | # Train!
314 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
315 | logger.info("**************************************** Running training ****************************************")
316 | logger.info(f" Num examples = {total_train_dataset_size}")
317 | logger.info(f" Num Epochs = {args.num_train_epochs}")
318 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
319 | logger.info(f" Total global train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
320 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
321 | logger.info(f" Total optimization(update/completed) steps = {args.max_train_steps}")
322 | logger.info(f" Complete/Optimization steps per Epoch = {args.max_train_steps // args.num_train_epochs}")
323 | logger.info("***************************************************************************************************")
324 |
325 | # Only show the progress bar once on each machine.
326 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
327 | # check_weight_dtype(model.lm)
328 | # exit()
329 | # set starting_epoch, completed_steps and resume_step of train_dataloader
330 | completed_steps = 0
331 | starting_epoch = 0
332 |
333 | # monitor minimum eval_loss, stalling num, and best_step
334 | min_eval_loss = float('inf')
335 | eval_loss = 100.0
336 | checkpoint_eval_loss = float('inf')
337 | checkpoint_stall_num = 0
338 | stall_num = 0
339 | best_step = None
340 |
341 | reduce_loss = 0
342 | reduce_loss_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids)
343 | reduce_loss_count_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids)
344 | for epoch in range(starting_epoch, args.num_train_epochs):
345 | model.train()
346 |
347 | for step, batch in enumerate(train_dataloader):
348 |
349 | with accelerator.accumulate(model):
350 |
351 | graph = batch['graph']
352 | x = batch['x']
353 | y = batch['y'].unsqueeze(0)
354 | loss_mask = batch['loss_mask'].unsqueeze(0)
355 | qa_mask = batch['qa_mask'].unsqueeze(0)
356 |
357 | if args.task == 'mft':
358 | task = batch['task']
359 |
360 | outputs = model(
361 | graph = graph,
362 | qa_ids = x,
363 | qa_mask = qa_mask,
364 | )
365 |
366 | len_y = y.shape[1]
367 | output_logits = outputs['logits'][:,-len_y:,:]
368 |
369 | if args.task == 'mft':
370 | loss_dict = {task_id: torch.tensor(0.0, device=output_logits.device) for task_id, _ in task_ids}
371 |
372 | task_id = task.item()
373 | loss_dict[task_id] += loss_CGM(
374 | output_logits = output_logits,
375 | labels = y,
376 | loss_mask = loss_mask,
377 | )
378 |
379 | loss = sum(loss_dict.values())
380 | else:
381 | loss = loss_CGM(
382 | output_logits = output_logits,
383 | labels = y,
384 | loss_mask = loss_mask,
385 | )
386 |
387 | accelerator.backward(loss)
388 |
389 | optimizer.step()
390 | if not args.lr_scheduler_type == "reduce_lr_on_plateau":
391 | lr_scheduler.step()
392 | optimizer.zero_grad()
393 |
394 | if optimizer.param_groups[0]['lr'] <= args.min_lr:
395 | optimizer.param_groups[0]['lr'] = args.min_lr
396 |
397 | if args.task == 'mft':
398 | for task_id in reduce_loss_dict.keys():
399 | if not torch.isnan(loss_dict[task_id]):
400 | reduce_loss_dict[task_id] += loss_dict[task_id].detach().float()
401 | reduce_loss_count_dict[task_id] += 1
402 | else:
403 |
404 | if not torch.isnan(loss):
405 | reduce_loss += loss.detach().float()
406 | else:
407 | logger.info("loss nan")
408 |
409 | if accelerator.sync_gradients:
410 | completed_steps += 1
411 | if args.task == 'mft':
412 | reduce_loss = sum(reduce_loss_dict.values())
413 | logger.info(f"accelerator step (accumulate) {completed_steps}, loss: {reduce_loss}")
414 |
415 | if completed_steps % args.log_interval == 0:
416 | if args.task == 'mft':
417 | progress_bar.update(args.log_interval)
418 | accelerate_monitor_CGM_mft(
419 | accelerator, reduce_loss_dict, reduce_loss_count_dict, args, completed_steps,
420 | lr_scheduler, optimizer, summary_writer
421 | )
422 | reduce_loss_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids)
423 | reduce_loss_count_dict = OrderedDict((task_id, 0) for task_id, _ in task_ids)
424 | else:
425 | if isinstance(reduce_loss, torch.Tensor):
426 | progress_bar.update(args.log_interval)
427 | accelerate_monitor_CGM(
428 | accelerator, reduce_loss, args, completed_steps,
429 | lr_scheduler, optimizer, summary_writer
430 | )
431 | reduce_loss = 0
432 |
433 | # steps checkpointing
434 | if args.step_checkpointing and completed_steps % args.checkpointing_steps == 0:
435 | output_dir = f"step_{completed_steps}"
436 | if args.output_dir is not None:
437 | output_dir = os.path.join(args.output_dir, output_dir)
438 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args)
439 |
440 | if args.step_evaluation and completed_steps % args.evaluation_steps == 0:
441 | logger.info(f"start evaluation...")
442 | model.eval()
443 | model.lm.gradient_checkpointing_disable()
444 | model.lm.config.use_cache = True
445 | if args.task == 'mft':
446 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM_mft(
447 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step,
448 | min_eval_loss, stall_num, best_step, summary_writer
449 | )
450 | else:
451 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM(
452 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step,
453 | min_eval_loss, stall_num, best_step, summary_writer
454 | )
455 | model.train()
456 | model.lm.gradient_checkpointing_enable()
457 | model.lm.config.use_cache = False
458 |
459 | if args.lr_scheduler_type == "reduce_lr_on_plateau":
460 | lr_scheduler.step(eval_loss)
461 |
462 | if eval_loss < checkpoint_eval_loss:
463 | checkpoint_eval_loss = eval_loss
464 | output_dir = f"step_{completed_steps}_stall_{checkpoint_stall_num}"
465 | if args.output_dir is not None:
466 | output_dir = os.path.join(args.output_dir, output_dir)
467 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args)
468 | checkpoint_stall_num = 0
469 | else:
470 | if checkpoint_stall_num < 2:
471 | output_dir = f"step_{completed_steps}_stall_{checkpoint_stall_num}"
472 | if args.output_dir is not None:
473 | output_dir = os.path.join(args.output_dir, output_dir)
474 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args)
475 | checkpoint_stall_num += 1
476 |
477 | if args.lr_scheduler_type == "reduce_lr_on_plateau":
478 | pass
479 | elif args.lr_scheduler_type == 'cosine':
480 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33
481 | lr_scheduler.base_lrs = optimizer.param_groups[0]['lr']
482 | lr_scheduler.step()
483 | else:
484 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33
485 | lr_scheduler.step()
486 |
487 | # adapter warmup
488 | if args.adapter_warmup and completed_steps >= args.adapter_warmup_steps:
489 | if 'l' in args.mode:
490 | for param in model.lm.parameters():
491 | param.requires_grad = True
492 | args.adapter_warmup = False
493 |
494 | accelerator.wait_for_everyone()
495 |
496 | load_t = time.time()
497 |
498 | if args.epoch_evaluation:
499 | model.eval()
500 | model.lm.gradient_checkpointing_disable()
501 | model.lm.config.use_cache = True
502 | if args.task == 'mft':
503 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM_mft(
504 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step,
505 | min_eval_loss, stall_num, best_step, summary_writer
506 | )
507 | else:
508 | eval_loss, min_eval_loss, stall_num, best_step = accelerate_evaluate_CGM(
509 | accelerator, model, tokenizer, valid_dataloader, args, completed_steps, step,
510 | min_eval_loss, stall_num, best_step, summary_writer
511 | )
512 | model.train()
513 | model.lm.gradient_checkpointing_enable()
514 | model.lm.config.use_cache = False
515 |
516 | if args.lr_scheduler_type == "reduce_lr_on_plateau":
517 | lr_scheduler.step(eval_loss)
518 |
519 | if eval_loss < checkpoint_eval_loss:
520 | checkpoint_eval_loss = eval_loss
521 | output_dir = f"epoch_{epoch}"
522 | ckpt_tag = output_dir
523 | if args.output_dir is not None:
524 | output_dir = os.path.join(args.output_dir, output_dir)
525 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args)
526 | else:
527 | if args.lr_scheduler_type == "reduce_lr_on_plateau":
528 | pass
529 | # lr_scheduler.step(eval_loss)
530 | elif args.lr_scheduler_type == 'cosine':
531 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33
532 | lr_scheduler.base_lrs = optimizer.param_groups[0]['lr']
533 | lr_scheduler.step()
534 | else:
535 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.33
536 | lr_scheduler.step()
537 |
538 | # epoch checkpointing
539 | if args.epoch_checkpointing:
540 | output_dir = f"epoch_{epoch}"
541 | if args.output_dir is not None:
542 | output_dir = os.path.join(args.output_dir, output_dir)
543 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args)
544 |
545 | if summary_writer:
546 | summary_writer.close()
547 |
548 | output_dir = f"final_step_{completed_steps}"
549 | if args.output_dir is not None:
550 | output_dir = os.path.join(args.output_dir, output_dir)
551 | accelerate_saving_checkpoint_CGM(accelerator, model, tokenizer, output_dir, completed_steps, args)
552 |
--------------------------------------------------------------------------------
/cgm/models/qwen2/_4_46_1/modeling_attn_mask_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import List, Optional, Tuple, Union
16 |
17 | import torch
18 |
19 | from transformers.utils.import_utils import is_torchdynamo_compiling
20 |
21 | # TODO: Copyied Version: transformers == 4.46.1
22 |
23 | @dataclass
24 | class AttentionMaskConverter:
25 | """
26 | A utility attention mask class that allows one to:
27 | - Create a causal 4d mask
28 | - Create a causal 4d mask with slided window
29 | - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
30 | key_value_length) that can be multiplied with attention scores
31 |
32 | Examples:
33 |
34 | ```python
35 | >>> import torch
36 | >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
37 |
38 | >>> converter = AttentionMaskConverter(True)
39 | >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
40 | tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
41 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
42 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
43 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
44 | [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
45 | ```
46 |
47 | Parameters:
48 | is_causal (`bool`):
49 | Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
50 |
51 | sliding_window (`int`, *optional*):
52 | Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
53 | """
54 |
55 | is_causal: bool
56 | sliding_window: int
57 |
58 | def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
59 | self.is_causal = is_causal
60 | self.sliding_window = sliding_window
61 |
62 | if self.sliding_window is not None and self.sliding_window <= 0:
63 | raise ValueError(
64 | f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
65 | )
66 |
67 | def to_causal_4d(
68 | self,
69 | batch_size: int,
70 | query_length: int,
71 | key_value_length: int,
72 | dtype: torch.dtype,
73 | device: Union[torch.device, "str"] = "cpu",
74 | ) -> Optional[torch.Tensor]:
75 | """
76 | Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
77 | bias to upper right hand triangular matrix (causal mask).
78 | """
79 | if not self.is_causal:
80 | raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
81 |
82 | # If shape is not cached, create a new causal mask and cache it
83 | input_shape = (batch_size, query_length)
84 | past_key_values_length = key_value_length - query_length
85 |
86 | # create causal mask
87 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
88 | causal_4d_mask = None
89 | if input_shape[-1] > 1 or self.sliding_window is not None:
90 | causal_4d_mask = self._make_causal_mask(
91 | input_shape,
92 | dtype,
93 | device=device,
94 | past_key_values_length=past_key_values_length,
95 | sliding_window=self.sliding_window,
96 | )
97 |
98 | return causal_4d_mask
99 |
100 | def to_4d(
101 | self,
102 | attention_mask_2d: torch.Tensor,
103 | query_length: int,
104 | dtype: torch.dtype,
105 | key_value_length: Optional[int] = None,
106 | ) -> torch.Tensor:
107 | """
108 | Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
109 | key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
110 | causal, a causal mask will be added.
111 | """
112 | input_shape = (attention_mask_2d.shape[0], query_length)
113 |
114 | # create causal mask
115 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
116 | causal_4d_mask = None
117 | if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
118 | if key_value_length is None:
119 | raise ValueError(
120 | "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
121 | )
122 |
123 | past_key_values_length = key_value_length - query_length
124 | causal_4d_mask = self._make_causal_mask(
125 | input_shape,
126 | dtype,
127 | device=attention_mask_2d.device,
128 | past_key_values_length=past_key_values_length,
129 | sliding_window=self.sliding_window,
130 | )
131 | elif self.sliding_window is not None:
132 | raise NotImplementedError("Sliding window is currently only implemented for causal masking")
133 |
134 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
135 | expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
136 | attention_mask_2d.device
137 | )
138 |
139 | if causal_4d_mask is not None:
140 | expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
141 |
142 | # expanded_attn_mask + causal_4d_mask can cause some overflow
143 | expanded_4d_mask = expanded_attn_mask
144 |
145 | return expanded_4d_mask
146 |
147 | def _3d_to_4d(
148 | self,
149 | attention_mask_3d: torch.Tensor,
150 | query_length: int,
151 | dtype: torch.dtype,
152 | key_value_length: Optional[int] = None,
153 | ) -> torch.Tensor:
154 | """
155 | Converts 3D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
156 | key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
157 | causal, a causal mask will be added.
158 | """
159 | input_shape = (attention_mask_3d.shape[0], query_length, key_value_length)
160 |
161 | # create causal mask
162 | # [bsz, tgt_seq_len, src_seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
163 | causal_4d_mask = None
164 | if (input_shape[-2] > 1 or self.sliding_window is not None) and self.is_causal:
165 | if key_value_length is None:
166 | raise ValueError(
167 | "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
168 | )
169 |
170 | past_key_values_length = key_value_length - query_length
171 | causal_4d_mask = self._make_causal_mask(
172 | input_shape,
173 | dtype,
174 | device=attention_mask_3d.device,
175 | past_key_values_length=past_key_values_length,
176 | sliding_window=self.sliding_window,
177 | )
178 | elif self.sliding_window is not None:
179 | raise NotImplementedError("Sliding window is currently only implemented for causal masking")
180 |
181 | # [bsz, tgt_seq_len, src_seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
182 | expanded_attn_mask = self._expand_mask_3d(attention_mask_3d, dtype, tgt_len=input_shape[-2]).to(
183 | attention_mask_3d.device
184 | )
185 |
186 | if causal_4d_mask is not None:
187 | expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
188 |
189 | # expanded_attn_mask + causal_4d_mask can cause some overflow
190 | expanded_4d_mask = expanded_attn_mask
191 |
192 | return expanded_4d_mask
193 |
194 | @staticmethod
195 | def _make_causal_mask(
196 | input_ids_shape: torch.Size,
197 | dtype: torch.dtype,
198 | device: torch.device,
199 | past_key_values_length: int = 0,
200 | sliding_window: Optional[int] = None,
201 | ):
202 | """
203 | Make causal mask used for bi-directional self-attention.
204 | """
205 | bsz, tgt_len = input_ids_shape
206 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
207 | mask_cond = torch.arange(mask.size(-1), device=device)
208 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
209 |
210 | mask = mask.to(dtype)
211 |
212 | if past_key_values_length > 0:
213 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
214 |
215 | # add lower triangular sliding window mask if necessary
216 | if sliding_window is not None:
217 | diagonal = past_key_values_length - sliding_window - 1
218 |
219 | context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
220 | mask.masked_fill_(context_mask, torch.finfo(dtype).min)
221 |
222 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
223 |
224 | @staticmethod
225 | def _make_causal_mask_3d(
226 | input_ids_shape: torch.Size,
227 | dtype: torch.dtype,
228 | device: torch.device,
229 | past_key_values_length: int = 0,
230 | sliding_window: Optional[int] = None,
231 | ):
232 | """
233 | Make causal mask used for bi-directional self-attention.
234 | """
235 | bsz, tgt_len, src_len = input_ids_shape
236 | mask = torch.full((tgt_len, src_len), torch.finfo(dtype).min, device=device)
237 | mask_cond = torch.arange(mask.size(-1), device=device)
238 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
239 |
240 | mask = mask.to(dtype)
241 |
242 | if past_key_values_length > 0:
243 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
244 |
245 | # add lower triangular sliding window mask if necessary
246 | if sliding_window is not None:
247 | diagonal = past_key_values_length - sliding_window - 1
248 |
249 | context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
250 | mask.masked_fill_(context_mask, torch.finfo(dtype).min)
251 |
252 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, src_len + past_key_values_length)
253 |
254 | @staticmethod
255 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
256 | """
257 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
258 | """
259 | bsz, src_len = mask.size()
260 | tgt_len = tgt_len if tgt_len is not None else src_len
261 |
262 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
263 |
264 | inverted_mask = 1.0 - expanded_mask
265 |
266 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
267 |
268 | @staticmethod
269 | def _expand_mask_3d(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
270 | """
271 | Expands attention_mask from `[bsz, tgt_seq_len, src_seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
272 | """
273 | bsz, tgt_seq_len, src_seq_len = mask.size()
274 | tgt_len = tgt_len if tgt_len is not None else tgt_seq_len
275 |
276 | expanded_mask = mask[:, None, :, :].expand(bsz, 1, tgt_len, src_seq_len).to(dtype)
277 |
278 | inverted_mask = 1.0 - expanded_mask
279 |
280 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
281 |
282 | @staticmethod
283 | def _unmask_unattended(
284 | expanded_mask: torch.FloatTensor,
285 | min_dtype: float,
286 | ):
287 | # fmt: off
288 | """
289 | Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
290 | using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
291 | Details: https://github.com/pytorch/pytorch/issues/110213
292 |
293 | `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
294 | `attention_mask` is [bsz, src_seq_len].
295 |
296 | The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
297 |
298 | For example, if `expanded_mask` is (e.g. here left-padding case)
299 | ```
300 | [[[[0, 0, 0],
301 | [0, 0, 0],
302 | [0, 0, 1]]],
303 | [[[1, 0, 0],
304 | [1, 1, 0],
305 | [1, 1, 1]]],
306 | [[[0, 0, 0],
307 | [0, 1, 0],
308 | [0, 1, 1]]]]
309 | ```
310 | then the modified `expanded_mask` will be
311 | ```
312 | [[[[1, 1, 1], <-- modified
313 | [1, 1, 1], <-- modified
314 | [0, 0, 1]]],
315 | [[[1, 0, 0],
316 | [1, 1, 0],
317 | [1, 1, 1]]],
318 | [[[1, 1, 1], <-- modified
319 | [0, 1, 0],
320 | [0, 1, 1]]]]
321 | ```
322 | """
323 | # fmt: on
324 | if expanded_mask.dtype == torch.bool:
325 | raise ValueError(
326 | "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
327 | )
328 |
329 | return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
330 |
331 | @staticmethod
332 | def _ignore_causal_mask_sdpa(
333 | attention_mask: Optional[torch.Tensor],
334 | inputs_embeds: torch.Tensor,
335 | past_key_values_length: int,
336 | sliding_window: Optional[int] = None,
337 | is_training: bool = False,
338 | ) -> bool:
339 | """
340 | Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
341 | ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
342 |
343 | In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
344 | `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
345 | allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
346 | passed).
347 | """
348 |
349 | _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
350 | key_value_length = query_length + past_key_values_length
351 |
352 | is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
353 |
354 | ignore_causal_mask = False
355 |
356 | if attention_mask is None:
357 | # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
358 | # shape, thus SDPA's `is_causal` argument is rightfully updated
359 | # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
360 | # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
361 | # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
362 | # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
363 | # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
364 | #
365 | # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
366 | # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
367 | if (
368 | (is_training or not is_tracing)
369 | and (query_length == 1 or key_value_length == query_length)
370 | and (sliding_window is None or key_value_length < sliding_window)
371 | ):
372 | ignore_causal_mask = True
373 | elif sliding_window is None or key_value_length < sliding_window:
374 | if len(attention_mask.shape) == 4:
375 | return False
376 | elif not is_tracing and torch.all(attention_mask == 1):
377 | if query_length == 1 or key_value_length == query_length:
378 | # For query_length == 1, causal attention and bi-directional attention are the same.
379 | ignore_causal_mask = True
380 |
381 | # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
382 | # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
383 | # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
384 | # Reference: https://github.com/pytorch/pytorch/issues/108108
385 | # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
386 |
387 | return ignore_causal_mask
388 |
389 |
390 | def _prepare_4d_causal_attention_mask(
391 | attention_mask: Optional[torch.Tensor],
392 | input_shape: Union[torch.Size, Tuple, List],
393 | inputs_embeds: torch.Tensor,
394 | past_key_values_length: int,
395 | sliding_window: Optional[int] = None,
396 | ):
397 | """
398 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
399 | `(batch_size, key_value_length)`
400 |
401 | Args:
402 | attention_mask (`torch.Tensor` or `None`):
403 | A 2D attention mask of shape `(batch_size, key_value_length)`
404 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
405 | The input shape should be a tuple that defines `(batch_size, query_length)`.
406 | inputs_embeds (`torch.Tensor`):
407 | The embedded inputs as a torch Tensor.
408 | past_key_values_length (`int`):
409 | The length of the key value cache.
410 | sliding_window (`int`, *optional*):
411 | If the model uses windowed attention, a sliding window should be passed.
412 | """
413 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
414 |
415 | key_value_length = input_shape[-1] + past_key_values_length
416 |
417 | # 4d mask is passed through the layers
418 | if attention_mask is not None and len(attention_mask.shape) == 2:
419 | attention_mask = attn_mask_converter.to_4d(
420 | attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
421 | )
422 | elif attention_mask is not None and len(attention_mask.shape) == 4:
423 | expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
424 | if tuple(attention_mask.shape) != expected_shape:
425 | raise ValueError(
426 | f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
427 | )
428 | else:
429 | # if the 4D mask has correct shape - invert it and fill with negative infinity
430 | inverted_mask = 1.0 - attention_mask
431 | attention_mask = inverted_mask.masked_fill(
432 | inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
433 | )
434 | else:
435 | attention_mask = attn_mask_converter.to_causal_4d(
436 | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
437 | )
438 |
439 | return attention_mask
440 |
441 |
442 | # Adapted from _prepare_4d_causal_attention_mask
443 | def _prepare_4d_causal_attention_mask_for_sdpa(
444 | attention_mask: Optional[torch.Tensor],
445 | input_shape: Union[torch.Size, Tuple, List],
446 | inputs_embeds: torch.Tensor,
447 | past_key_values_length: int,
448 | sliding_window: Optional[int] = None,
449 | ):
450 | """
451 | Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
452 |
453 | In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
454 | `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
455 | allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
456 | """
457 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
458 |
459 | key_value_length = input_shape[-1] + past_key_values_length
460 |
461 | # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
462 | # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
463 | # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
464 | is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
465 |
466 | ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
467 | attention_mask=attention_mask,
468 | inputs_embeds=inputs_embeds,
469 | past_key_values_length=past_key_values_length,
470 | sliding_window=sliding_window,
471 | )
472 |
473 | if ignore_causal_mask:
474 | expanded_4d_mask = None
475 | elif attention_mask is None:
476 | expanded_4d_mask = attn_mask_converter.to_causal_4d(
477 | input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
478 | )
479 | else:
480 | if attention_mask.dim() == 4:
481 | expanded_4d_mask = attention_mask
482 | else:
483 | expanded_4d_mask = attn_mask_converter.to_4d(
484 | attention_mask,
485 | input_shape[-1],
486 | dtype=inputs_embeds.dtype,
487 | key_value_length=key_value_length,
488 | )
489 |
490 | # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
491 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
492 | # Details: https://github.com/pytorch/pytorch/issues/110213
493 | if not is_tracing and expanded_4d_mask.device.type == "cuda":
494 | expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
495 | expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
496 | )
497 |
498 | return expanded_4d_mask
499 |
500 |
501 | def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
502 | """
503 | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
504 | `(batch_size, key_value_length)`
505 |
506 | Args:
507 | mask (`torch.Tensor`):
508 | A 2D attention mask of shape `(batch_size, key_value_length)`
509 | dtype (`torch.dtype`):
510 | The torch dtype the created mask shall have.
511 | tgt_len (`int`):
512 | The target length or query length the created mask shall have.
513 | """
514 | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
515 |
516 |
517 | def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
518 | """
519 | Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
520 | `(batch_size, key_value_length)`
521 |
522 | Args:
523 | mask (`torch.Tensor`):
524 | A 2D attention mask of shape `(batch_size, key_value_length)`
525 | dtype (`torch.dtype`):
526 | The torch dtype the created mask shall have.
527 | tgt_len (`int`):
528 | The target length or query length the created mask shall have.
529 | """
530 | _, key_value_length = mask.shape
531 | tgt_len = tgt_len if tgt_len is not None else key_value_length
532 |
533 | is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
534 |
535 | # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
536 | if not is_tracing and torch.all(mask == 1):
537 | return None
538 | else:
539 | return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
540 |
541 |
542 | def _create_4d_causal_attention_mask(
543 | input_shape: Union[torch.Size, Tuple, List],
544 | dtype: torch.dtype,
545 | device: torch.device,
546 | past_key_values_length: int = 0,
547 | sliding_window: Optional[int] = None,
548 | ) -> Optional[torch.Tensor]:
549 | """
550 | Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
551 |
552 | Args:
553 | input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
554 | The input shape should be a tuple that defines `(batch_size, query_length)`.
555 | dtype (`torch.dtype`):
556 | The torch dtype the created mask shall have.
557 | device (`int`):
558 | The torch device the created mask shall have.
559 | sliding_window (`int`, *optional*):
560 | If the model uses windowed attention, a sliding window should be passed.
561 | """
562 | attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
563 |
564 | key_value_length = past_key_values_length + input_shape[-1]
565 | attention_mask = attn_mask_converter.to_causal_4d(
566 | input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
567 | )
568 |
569 | return attention_mask
--------------------------------------------------------------------------------