"
11 | middle_tokens: ""
12 | suffix_tokens: ""
13 | # context truncation length
14 | max_context_length: 15500
15 | eos_sequences: ["\\sclass\\s", "\\sdef\\s", "^def\\s", "^class\\s", "@", ""]
16 | tokenizer_fix: 1
17 |
18 | model_kwargs:
19 | use_flash_attention_2: True
20 |
--------------------------------------------------------------------------------
/config/model/codeparrot.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # number of parameters in the model. For codellama it is 7b, 13b or 34b
4 | size: "110m"
5 | # name of the model on HuggingFace
6 | model_path: 'codeparrot/codeparrot-small'
7 | model_short_name: "codeparrot-small"
8 | # codellama special tokens
9 | lm_prefix_tokens: ""
10 | prefix_tokens: ""
11 | middle_tokens: ""
12 | suffix_tokens: ""
13 | # context truncation length
14 | max_context_length: 512
15 | # model_kwargs:
16 | # use_flash_attention_2: True
17 |
--------------------------------------------------------------------------------
/config/model/deepseek.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | size: "1.3b"
4 | model_path: 'deepseek-ai/deepseek-coder-${size}-base'
5 | model_short_name: "deepseek-coder-1.3b"
6 | lm_prefix_tokens: ""
7 | lm_suffix_tokens: ""
8 | prefix_tokens: "<|fim▁begin|>"
9 | middle_tokens: "<|fim▁hole|>"
10 | suffix_tokens: "<|fim▁end|>"
11 | max_context_length: 7500
12 | model_kwargs:
13 | use_flash_attention_2: True
14 |
--------------------------------------------------------------------------------
/config/model/local.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | size: ""
4 | model_path: '${model_base_path}/${model_short_name}'
5 | model_base_path: ""
6 | model_short_name: ""
7 | lm_prefix_tokens: ""
8 | lm_suffix_tokens: ""
9 | prefix_tokens: ""
10 | middle_tokens: ""
11 | suffix_tokens: ""
12 | max_context_length: 1024
13 |
--------------------------------------------------------------------------------
/config/model/mistral.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | size: "7B"
4 | model_path: 'mistralai/Mistral-${size}-v0.1'
5 | model_short_name: "Mistral-${size}"
6 | lm_prefix_tokens: ""
7 | lm_suffix_tokens: ""
8 | prefix_tokens: ""
9 | middle_tokens: ""
10 | suffix_tokens: ""
11 | max_context_length: 7500
12 | model_kwargs:
13 | use_flash_attention_2: True
14 |
--------------------------------------------------------------------------------
/config/model/phi1.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | model_path: 'microsoft/phi-1'
4 | model_short_name: "phi1"
5 | lm_prefix_tokens: ""
6 | prefix_tokens: "0"
7 | middle_tokens: "0"
8 | suffix_tokens: "0"
9 | eos_sequences: ["\\sclass\\s", "\\sdef\\s", "^def\\s", "^class\\s", "@", "from", "import", "<|endoftext|>"]
10 | max_context_length: 7500
11 |
--------------------------------------------------------------------------------
/config/model/qwen25coder.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | size: "1.5B"
4 | model_path: Qwen/${model_short_name}
5 | model_short_name: Qwen2.5-Coder-${size}
6 | lm_prefix_tokens: ""
7 | lm_suffix_tokens: ""
8 | prefix_tokens: "<|fim_prefix|>"
9 | middle_tokens: "<|fim_suffix|>"
10 | suffix_tokens: "<|fim_middle|>"
11 | max_context_length: 7500
12 | model_kwargs:
13 | attn_implementation: flash_attention_2
14 |
--------------------------------------------------------------------------------
/config/model/starcoder.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | size: "1b"
4 | model_path: 'bigcode/starcoderbase-${size}'
5 | model_short_name: "starcoderbase-${size}"
6 | lm_prefix_tokens: ""
7 | lm_suffix_tokens: "\\n\\n" # weirdly, this works better simple lm
8 | prefix_tokens: ""
9 | middle_tokens: ""
10 | suffix_tokens: ""
11 | max_context_length: 7500
12 |
--------------------------------------------------------------------------------
/config/model/starcoder15b.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | variant: "rbase"
4 | model_path: 'bigcode/starcode${variant}'
5 | model_short_name: "starcode${variant}"
6 | prefix_tokens: ""
7 | middle_tokens: ""
8 | suffix_tokens: ""
9 | max_context_length: 7500
10 |
--------------------------------------------------------------------------------
/config/task/FG.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | dataset_meta_file: 'realcode_v3_FG.json'
--------------------------------------------------------------------------------
/config/task/SG.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | dataset_meta_file: 'realcode_v3_SG.json'
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/data/generations/FG/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/data/generations/SG/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/lm_eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Core-Team/RealCode_eval/f70984fb83022eb191ba94bcf55729c2fc64aa80/lm_eval/__init__.py
--------------------------------------------------------------------------------
/lm_eval/context_parser.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import typing as tp
3 | from collections import deque, namedtuple
4 | import re
5 | import ast
6 | from pathlib import Path
7 |
8 | from .datatypes import Task
9 | from transformers import AutoTokenizer
10 |
11 |
12 | Import = namedtuple("Import", ["module", "name", "alias"])
13 |
14 | """
15 | >>> Imports
16 | import math
17 | <<< imports
18 |
19 | >>> file scope
20 | def get_c():
21 | return 1
22 |
23 | <<< file scope
24 | >>> outer scope
25 | class Foo:
26 | def __init__(self, a):
27 | self.a = a
28 | <<< outer scope
29 | >>> inner scope
30 | @staticmethod
31 | def bar():
32 | '''
33 | Turn Foo into bar
34 | '''
35 | <<< inner scope
36 | >>> body (unavailable for model)
37 | bar = 'B'
38 | self.a = bar
39 | return self
40 | <<< body (unavailable for model)
41 | >>> outer scope
42 | def bar2():
43 | self.a = 'C'
44 | return self
45 | <<< outer scope
46 | >>> file scope
47 | class Foo2:
48 | ...
49 |
50 | <<< file scope
51 | """
52 |
53 | @dataclass(frozen=False)
54 | class ParsedContext:
55 | imports = ''
56 | file = ''
57 | outer = ''
58 | inner = ''
59 |
60 | def __setitem__(self, key, value):
61 | setattr(self, key, value)
62 |
63 | def __getitem__(self, key):
64 | return getattr(self, key)
65 |
66 | def __str__(self):
67 | return (
68 |
69 | '\n----- imports -----\n' +
70 | self.imports +
71 | '\n----- end imports -----\n' +
72 | '\n----- file -----\n' +
73 | (
74 | ('\n'.join(self.file.split('\n')[:10]) + '\n...\n' + '\n'.join(self.file.split('\n')[-15:])) if len(self.file.split('\n')) > 20 else self.file
75 | ) +
76 | '\n----- end file -----\n' +
77 | '\n----- outer -----\n' +
78 | (
79 | ('\n'.join(self.outer.split('\n')[:15]) + '\n...\n' + '\n'.join(self.outer.split('\n')[-15:])) if len(self.outer.split('\n')) > 20 else self.outer
80 | ) +
81 | '\n----- end outer -----\n' +
82 | '\n----- inner -----\n' +
83 | '\n'.join(self.inner.split('\n')) +
84 | '\n----- end inner -----\n'
85 |
86 | )
87 |
88 |
89 | def get_indent(code):
90 | line = code.split('\n')[0]
91 | return len(line) - len(line.strip())
92 |
93 |
94 | def parse_context(context: str, indent: int, side: tp.Literal['left', 'right']) -> ParsedContext:
95 | res = ParsedContext()
96 | if side == 'left':
97 | cur_scope = deque()
98 | state = 'inner'
99 |
100 | for line in reversed(context.split('\n')):
101 | if line.startswith('import') or (line.startswith('from') and ' import ' in line):
102 | res['imports'] += line + '\n'
103 | continue
104 |
105 | if state == 'inner_wait@':
106 | if not line.lstrip().startswith('@'):
107 | res['inner'] = "\n".join(cur_scope)
108 | cur_scope = deque()
109 | if indent > 0:
110 | state = 'outer'
111 | else:
112 | state = 'file'
113 |
114 | cur_scope.appendleft(line)
115 | if state == 'inner':
116 | if line.strip().startswith('def '):
117 | state = 'inner_wait@'
118 | if state == 'outer':
119 | if line.startswith('class'):
120 | res['outer'] = "\n".join(cur_scope)
121 | state = 'file'
122 | cur_scope = deque()
123 | if state == 'inner_wait@':
124 | state = 'inner'
125 | res[state] = "\n".join(cur_scope)
126 | elif side == 'right':
127 | cur_scope = deque()
128 | state = 'outer'
129 |
130 | for line in context.split('\n'):
131 | if state == 'outer':
132 | if (
133 | line.strip()
134 | and not line.startswith(' ')
135 | ):
136 | res['outer'] = "\n".join(cur_scope)
137 | state = 'file'
138 | cur_scope = deque()
139 | cur_scope.append(line)
140 | res[state] = "\n".join(cur_scope)
141 | return res
142 |
143 |
144 | class BaseParser:
145 | def get_left_and_right_context(self, task: Task) -> tp.Tuple[str, str]:
146 | """
147 | main method, that returns tuple (left_context, right_context) for the task
148 | """
149 | raise NotImplementedError()
150 |
151 |
152 | class TrivialContextParser(BaseParser):
153 | def get_left_and_right_context(self, task: Task) -> tp.Tuple[str, str]:
154 | """
155 | returns left and right context without processing
156 | """
157 | return task.left_context, task.right_context
158 |
159 |
160 | class SmartContextParser(BaseParser):
161 | def __init__(self,
162 | left_config = ['imports', 'file', 'outer', 'inner'],
163 | right_config = ['outer', 'file']
164 | ):
165 | self.left_config = left_config
166 | self.right_config = right_config
167 |
168 | def get_left_and_right_context(self, task: Task) -> tp.Tuple[str, str]:
169 | """
170 |
171 | """
172 | indent = (len(task.gt) - len(task.gt.lstrip()))
173 | left_context_parsed = parse_context(task.left_context, indent, 'left')
174 | left_context = "\n".join([left_context_parsed[k] for k in self.left_config])
175 | right_context_parsed = parse_context(task.right_context, indent, 'right')
176 | right_context = "\n".join([right_context_parsed[k] for k in self.right_config])
177 | return left_context, right_context
178 |
179 | class ImportResolutionParser(BaseParser):
180 | def __init__(self,
181 | data_root: str,
182 | left_config = ['imports', 'file', 'outer', 'inner'],
183 | right_config = ['outer', 'file']
184 | ):
185 | """
186 |
187 | """
188 | self.data_root = data_root
189 | self.left_config = left_config
190 | self.right_config = right_config
191 |
192 | def _desc_func(self, functionNode, lines):
193 | return " ".join([t.strip() for t in lines[functionNode.lineno-1: functionNode.body[0].lineno - 1]])
194 |
195 | def _parse_file(self, filename, func_names):
196 | ans = []
197 | with open(filename, 'r', encoding='UTF-8') as f:
198 | text = f.read()
199 | lines = text.split('\n')
200 | node = ast.parse(text)
201 | if func_names:
202 | functions = [n for n in node.body if isinstance(n, ast.FunctionDef) and n.name in func_names]
203 | classes = [n for n in node.body if isinstance(n, ast.ClassDef) and n.name in func_names]
204 | else:
205 | functions = [n for n in node.body if isinstance(n, ast.FunctionDef)]
206 | classes = [n for n in node.body if isinstance(n, ast.ClassDef)]
207 |
208 | for function in functions:
209 | s = self._desc_func(function, lines)
210 | ans.append('' + s)
211 |
212 | for class_ in classes:
213 | ans.append("class " + class_.name)
214 | methods = [n for n in class_.body if isinstance(n, ast.FunctionDef)]
215 | for method in methods:
216 | s = self._desc_func(method, lines)
217 | ans.append(' ' + s)
218 | return "\n".join(ans)
219 |
220 | def _get_imports(self, code):
221 | root = ast.parse(code)
222 |
223 | for node in ast.iter_child_nodes(root):
224 | if isinstance(node, ast.Import):
225 | module = [t.name for t in node.names]
226 | yield (
227 | Import(module, [], []),
228 | " ".join(code.split('\n')[node.lineno-1: node.end_lineno])
229 | )
230 | elif isinstance(node, ast.ImportFrom):
231 | module = node.module.split('.')
232 | yield (
233 | Import(module, [n.name for n in node.names], [n.name for n in node.names]),
234 | " ".join(code.split('\n')[node.lineno-1: node.end_lineno])
235 | )
236 | else:
237 | continue
238 |
239 | def _resolve_imports(self, task: Task) -> str:
240 | repo = (Path(self.data_root) / task.repo).resolve()
241 | ans = []
242 | for imp, line in self._get_imports(task.left_context):
243 | pth = repo / ("/".join(imp.module) + '.py')
244 | if imp.module and pth.exists():
245 | ans.append(line)
246 | ans.append(self._parse_file(pth, imp.name))
247 | else:
248 | ans.append(line)
249 | return '\n'.join(ans)
250 |
251 | def get_left_and_right_context(self, task: Task) -> tp.Tuple[str, str]:
252 | indent = (len(task.gt) - len(task.gt.lstrip()))
253 | left_context_parsed = parse_context(task.left_context, indent, 'left')
254 | left_context = "\n".join([
255 | left_context_parsed[k] if k != 'imports' else self._resolve_imports(task) + '\n'
256 | for k in self.left_config
257 | ])
258 | right_context_parsed = parse_context(task.right_context, indent, 'right')
259 | right_context = "\n".join([right_context_parsed[k] for k in self.right_config])
260 | return left_context, right_context
261 |
262 |
263 | class ImportCopyParser(ImportResolutionParser):
264 | def _parse_file(self, filename, func_names):
265 | ans = []
266 | with open(filename, 'r', encoding='UTF-8') as f:
267 | text = f.read()
268 | lines = text.split('\n')
269 | node = ast.parse(text)
270 | if func_names:
271 | functions = [n for n in node.body if isinstance(n, ast.FunctionDef) and n.name not in func_names and n.col_offset == 0]
272 | classes = [n for n in node.body if isinstance(n, ast.ClassDef) and n.name not in func_names and n.col_offset == 0]
273 | skip_intervals = [(t.lineno-1, t.end_lineno-1) for t in functions + classes]
274 | skip_intervals.sort()
275 | else:
276 | functions = [n for n in node.body if isinstance(n, ast.FunctionDef)]
277 | classes = [n for n in node.body if isinstance(n, ast.ClassDef)]
278 | skip_intervals = []
279 | interval_id = 0
280 | i = 0
281 | while i < len(lines):
282 | if interval_id < len(skip_intervals) and i >= skip_intervals[interval_id][0]:
283 | i = skip_intervals[interval_id][1]
284 | interval_id += 1
285 | else:
286 | ans.append(lines[i])
287 | i += 1
288 | return "\n".join(ans)
289 |
290 | def _resolve_imports(self, task: Task) -> str:
291 | repo = (Path(self.data_root) / task.repo).resolve()
292 | ans = []
293 | for imp, line in self._get_imports(task.left_context):
294 | module_pth = ("/".join(imp.module) + '.py')
295 | pth = repo / module_pth
296 | if imp.module and pth.exists():
297 | ans.append('#' + module_pth)
298 | ans.append(self._parse_file(pth, imp.name))
299 |
300 | cur_module = task.path_from_root.replace('/', '.').replace('.py', '')
301 | for file in [
302 | f for f in repo.rglob('*.py')
303 | if {"venv_bench", '.ipynb_checkpoints'}.isdisjoint(set([str(p) for p in f.parts]))
304 | ]:
305 | file = file.absolute()
306 | with open(file, 'r', encoding='UTF-8') as f:
307 | text = f.read()
308 | if cur_module in text:
309 | ans.append('#' + str(file.relative_to(repo)))
310 | ans.append(text)
311 | ans.append('#' + task.path_from_root)
312 | for imp, line in self._get_imports(task.left_context):
313 | ans.append(line)
314 | return '\n'.join(ans)
315 |
--------------------------------------------------------------------------------
/lm_eval/datatypes.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import typing as tp
3 |
4 | @dataclass(frozen=True)
5 | class Task:
6 | repo: str
7 | repo_n: int
8 | path_from_root: str
9 | left_context: str
10 | right_context: str
11 | gt: str
12 | total_tests: int
13 | doc: str = ''
14 |
15 |
16 |
--------------------------------------------------------------------------------
/lm_eval/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import typing as tp
3 | import math
4 | from collections import defaultdict
5 | import json
6 | import re
7 | from statistics import mean
8 | from dataclasses import asdict
9 | from multiprocessing import Pool, Manager
10 |
11 | from .utils import evaluate_override, evaluate_override_wrapped
12 | from .datatypes import Task
13 |
14 | import logging
15 | logger = logging.getLogger("RealCode")
16 |
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 |
19 | def get_num_lines_bin(t: Task):
20 | lines = t.gt.strip().count('\n') + 1
21 | if 1 <= lines <= 2:
22 | return '1-2'
23 | elif 3 <= lines <= 5:
24 | return '3-5'
25 | elif 6 <= lines <= 10:
26 | return '6-10'
27 | elif lines > 10:
28 | return '10+'
29 |
30 |
31 | METRIC_AGGREGATIONS = {
32 | 'total': lambda t: 1,
33 | 'repo': lambda t: t.repo,
34 | 'nlines_bin': get_num_lines_bin,
35 | # 'detailed': lambda t: t,
36 | }
37 |
38 | class PassK:
39 | def __init__(self, k: int, n: int):
40 | self.k = k
41 | self.n = n
42 |
43 | def __call__(self, correct: int):
44 | return (1 - (math.comb(self.n - correct, self.k) / math.comb(self.n, self.k)))
45 |
46 | def name(self):
47 | return f"Pass@{self.k}"
48 |
49 |
50 | class Evaluator:
51 | def __init__(self,
52 | dataset_root: os.PathLike,
53 | num_samples: int,
54 | pass_k_list: tp.List[int] = [1],
55 | njobs: int = 1,
56 | working_dir: tp.Optional[os.PathLike] = None,
57 | metric_aggregations: tp.Dict[str, tp.Callable[[Task], int]] = METRIC_AGGREGATIONS
58 | ):
59 | self.metrics = []
60 | for pass_k in pass_k_list:
61 | if num_samples < pass_k:
62 | raise ValueError(f"num_samples {num_samples} must be greater than or equal to PassK={pass_k}")
63 | self.metrics.append(PassK(pass_k, num_samples))
64 | self.dataset_root = dataset_root
65 | self.num_samples = num_samples
66 | self.njobs = njobs
67 | self.working_dir = working_dir
68 | self.metric_aggregations = metric_aggregations
69 |
70 | def evaluate(self,
71 | tasks: tp.List[Task],
72 | generations: tp.List[tp.List[str]],
73 | ) -> tp.Dict[tp.Literal["aggregated", "detailed"], tp.Any]:
74 | logger.info(f"Evaluating {len(tasks)} tasks with {self.num_samples} samples on {self.njobs} CPUs")
75 | # Run test evaluation
76 | if self.njobs == 1:
77 | results = [
78 | [evaluate_override( self.dataset_root, task, gen, os.path.join(self.working_dir) ) for gen in generations[i]]
79 | for i, task in enumerate(tasks)
80 | ]
81 | else:
82 | with Manager() as manager:
83 | cache = manager.dict()
84 | with manager.Pool(processes=self.njobs) as pool:
85 | results = [[None for _2 in range(self.num_samples)] for _ in tasks]
86 | async_result = pool.starmap_async(
87 | evaluate_override_wrapped, [
88 | ( self.dataset_root, task, gen, os.path.join(self.working_dir, f"{j}_{i}"), j, i, cache )
89 | for j, task in enumerate(tasks) for i, gen in enumerate(generations[j])
90 | ]
91 | )
92 | res = async_result.get()
93 | for task_n, gen_n, result in res:
94 | results[task_n][gen_n] = result
95 | if task_n % 25 == 0 and gen_n == 0:
96 | logger.debug(result['output'])
97 |
98 | # Calculate metrics per task
99 | all_metric_names = ['compilation_error_rate', 'exact_match'] + [t.name() for t in self.metrics]
100 | metrics = []
101 | agg_metrics = {level: {metric_name: defaultdict(list) for metric_name in all_metric_names} for level in self.metric_aggregations}
102 | for task, task_results, task_generations in zip(tasks, results, generations):
103 | if len(task_results) != self.num_samples:
104 | raise ValueError(f"Task {task} has {len(task_results)} samples, expected {self.num_samples}")
105 | correct = sum([int(t['passed'] == task.total_tests) for t in task_results])
106 | not_compiles = mean([int(t['passed'] + t['failed'] == 0) for t in task_results])
107 | exact_match = mean([int(re.sub(r'\W+', '', task.gt) == re.sub(r'\W+', '', gen)) for gen in task_generations])
108 | task_metrics = {'compilation_error_rate': not_compiles, 'exact_match': exact_match}
109 | for metric in self.metrics:
110 | # If generated exact repository code, Pass@1 is 1
111 | if exact_match > 1 - 1e-3:
112 | task_metrics[metric.name()] = 1.0
113 | else:
114 | task_metrics[metric.name()] = metric(correct)
115 | task_metrics['evaluations'] = [t['output'] for t in task_results]
116 | metrics.append(task_metrics)
117 | for level, level_func in self.metric_aggregations.items():
118 | for metric in all_metric_names:
119 | agg_metrics[level][metric][level_func(task)].append(task_metrics[metric])
120 |
121 | for level in self.metric_aggregations:
122 | for metric_name in all_metric_names:
123 | means = {val: mean(agg_metrics[level][metric_name][val]) for val in agg_metrics[level][metric_name]}
124 | agg_metrics[level][metric_name] = means
125 |
126 | # Save metics
127 | metrics = agg_metrics | {
128 | "detailed": [asdict(task) | task_metric for task, task_metric in zip(tasks, metrics)]
129 | }
130 | return metrics
131 |
--------------------------------------------------------------------------------
/lm_eval/generators.py:
--------------------------------------------------------------------------------
1 | import os
2 | import typing as tp
3 | import json
4 | from pathlib import Path
5 | from dataclasses import asdict, fields
6 | import re
7 |
8 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
9 | import torch
10 | from tqdm import tqdm
11 |
12 | from .datatypes import Task
13 | from .context_parser import BaseParser, TrivialContextParser
14 | import logging
15 | logger = logging.getLogger("RealCode")
16 |
17 |
18 | def get_indent(code):
19 | line = [t for t in code.split('\n') if t.strip()][0]
20 | return len(line) - len(line.strip())
21 |
22 |
23 | class InfillGenerator:
24 | def __init__(self,
25 | accelerator,
26 | model_path: str,
27 | num_samples: int,
28 | prefix_tokens: tp.Union[str, tp.List[int]] = [],
29 | middle_tokens: tp.Union[str, tp.List[int]] = [],
30 | suffix_tokens: tp.Union[str, tp.List[int]] = [],
31 | max_context_length: int = None,
32 | left_context_ratio: int = 1,
33 | dtype = torch.bfloat16,
34 | model_kwargs: tp.Dict = {},
35 | generation_params: tp.Dict[str, tp.Any] = {},
36 | context_parser: BaseParser = TrivialContextParser(),
37 | ):
38 | """
39 | Class to generate code in fill-in-the-middle mode
40 | params:
41 | model_path: str - which model to use for generation, anything that can be passed to AutoModelForCausalLM.from_pretrained
42 | num_samples: int - number of samples to generate per task, values > 1 should be paired with generation_params
43 | prefix_tokens: tp.Union[str, tp.List[int]] = [] - tokens to insert before the left context. Can be either str or list of int tokens
44 | middle_tokens: tp.Union[str, tp.List[int]] = [] - tokens to insert before the right context (see Fill-In-the-Middle). Can be either str or list of int tokens
45 | suffix_tokens: tp.Union[str, tp.List[int]] = [] - tokens to insert after the right context (see Fill-In-the-Middle). Can be either str or list of int tokens
46 | max_context_length: int = None - truncation length for prompt, measured in tokens (len(left_context) + len(right_context) < max_context_length)
47 | left_context_ratio: int = 1 - proportion of max_context_length given to left_context. 1 means 1:1 split between left and right, 3 means 3:1 split in favor of left context
48 | dtype=torch.bfloat16 - torch dtype to use for inference
49 | eos_sequences: tp.List[str] = ["\sclass\s", "\sdef\s", "\s@", "<|endoftext|>", ""] - regular expressions that determine end of geneartion
50 | model_kwargs: tp.Dict = {} - kwargs to be passed to AutoModelForCausalLM.from_pretrained
51 | generation_params: tp.Dict[str, tp.Any] = {} - kwargs to be passed to AutoModelForCausalLM.generate
52 | context_parser: BaseParser = TrivialContextParser() - parser for left and right contexts
53 | add_extra_spaces_to_generation=0 - number of added extra spaces add the begining of generation to fix indentation. May be required due to bugs in some tokenizers (e.g. Codellama)
54 | """
55 | logger.info(f"Loading model from {model_path} with kwargs f{model_kwargs}")
56 | self.device = accelerator.device
57 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
58 | model = AutoModelForCausalLM.from_pretrained(model_path,
59 | torch_dtype=dtype, trust_remote_code=True, **model_kwargs
60 | )
61 | self.model = model.to(self.device).eval()
62 | logger.info(f"Loaded model from {model_path} with kwargs f{model_kwargs}")
63 | logger.info(f"{self.model}")
64 |
65 | self.num_samples = num_samples
66 |
67 | self.prefix_tokens = self.tokenize_special_tokens(prefix_tokens)
68 | self.middle_tokens = self.tokenize_special_tokens(middle_tokens)
69 | self.suffix_tokens = self.tokenize_special_tokens(suffix_tokens)
70 |
71 | logger.debug(f"prefix_tokens: {self.prefix_tokens}, middle_tokens: {self.middle_tokens}, suffix_tokens: {self.suffix_tokens}")
72 |
73 | #context truncation parameters
74 | self.max_context_length = max_context_length
75 | self.left_context_truncate_at = left_context_ratio / (left_context_ratio + 1)
76 | self.right_context_truncate_at = 1 / (left_context_ratio + 1)
77 |
78 | self.generation_params = generation_params
79 | self.generation_params['num_return_sequences'] = self.num_samples
80 |
81 | self.context_parser = context_parser
82 |
83 | def tokenize_special_tokens(self, str_or_list: tp.Union[str, tp.List[int]]) -> torch.Tensor:
84 | if type(str_or_list) == str:
85 | return self.tokenizer.encode(str_or_list, return_tensors="pt", add_special_tokens=False) # ['input_ids']
86 | else:
87 | return torch.as_tensor(str_or_list).unsqueeze(0)
88 |
89 | def _prepare_tokens(self, task: Task) -> torch.Tensor:
90 | left_context_str, right_context_str = self.context_parser.get_left_and_right_context(task)
91 | logger.info("Task\n" + "\n".join(left_context_str.split('\n')[-20:]))
92 | left_tokens = self.tokenizer.encode(
93 | left_context_str, return_tensors="pt", add_special_tokens=False, max_length=self.max_context_length)# ['input_ids']
94 | right_tokens = self.tokenizer.encode(
95 | right_context_str, return_tensors="pt", add_special_tokens=False) # ['input_ids']
96 | if self.max_context_length and left_tokens.shape[1] + right_tokens.shape[1] > self.max_context_length:
97 | logger.debug("Truncating context")
98 |
99 | left_tokens = left_tokens[:, -min(int(self.max_context_length * self.left_context_truncate_at), left_tokens.shape[1]) + 1:]
100 | right_tokens = right_tokens[:, :min(int(self.max_context_length * self.right_context_truncate_at), right_tokens.shape[1]) - 1]
101 | tokens = torch.cat([self.prefix_tokens, left_tokens, self.middle_tokens, right_tokens, self.suffix_tokens], dim=-1).type(torch.long)
102 | return tokens
103 |
104 | def _postprocess(self, generation: str, indent: int):
105 | new_gen = []
106 | for i, line in enumerate(generation.split('\n')):
107 | line = line.replace("<|fim_pad|>", "")
108 | if i == 0:
109 | print("/".join(line))
110 | print(len(line) - len(line.lstrip()))
111 | if i == 0 and (len(line) - len(line.lstrip())) % 4 == 3:
112 | line = " " + line
113 | if line.strip() != '' and get_indent(line) < indent:
114 | break
115 | new_gen.append(line)
116 | return "\n".join(new_gen).rstrip() + '\n\n'
117 |
118 | @torch.no_grad()
119 | def generate(self, tasks: tp.List[Task]) -> tp.List[tp.List[str]]:
120 | res = []
121 | for i, task in tqdm(enumerate(tasks), desc='Generating (main process)', total=len(tasks)):
122 | tokens = self._prepare_tokens(task).to(self.device)
123 | if i == 0:
124 | logger.debug(f"\nTokens: {tokens[:, :5]} ... {tokens[:, -5:]}\n")
125 | generated_tokens = self.model.generate(tokens, **self.generation_params)
126 | generations = self.tokenizer.batch_decode(generated_tokens[:, tokens.shape[1]:], skip_special_tokens=True)
127 | gt_indent = get_indent(task.gt)
128 | if i % 1 == 0:
129 | logger.info(f"Raw Generation for task {i}:\n{generations[0]}")
130 | logger.info(f"Generation for task {i}:\n{self._postprocess(generations[0], gt_indent)}")
131 | res.append([self._postprocess(t, gt_indent) for t in generations])
132 | return res
133 |
134 |
135 | class LMGenerator(InfillGenerator):
136 | def __init__(self,
137 | lm_prefix_tokens: tp.Union[str, tp.List[int]] = [],
138 | lm_suffix_tokens: tp.Union[str, tp.List[int]] = [],
139 | **kwargs
140 | ):
141 | """
142 | Class to generate code in causal LM mode, uses only left context
143 | params:
144 | lm_prefix_tokens: tp.Union[str, tp.List[int]] = [] - tokens to insert before the context. Can be either str or list of int tokens
145 | lm_suffix_tokens: tp.Union[str, tp.List[int]] = [] - tokens to insert after the context. Can be either str or list of int tokens
146 | """
147 | super().__init__(**kwargs)
148 | self.lm_prefix_tokens = super().tokenize_special_tokens(lm_prefix_tokens)
149 | self.lm_suffix_tokens = super().tokenize_special_tokens(lm_suffix_tokens)
150 | logger.debug(f"lm_prefix_tokens: {self.lm_prefix_tokens}, lm_suffix_tokens: {self.lm_suffix_tokens}")
151 |
152 | def _prepare_tokens(self, task: Task) -> torch.Tensor:
153 | left_context_str, _ = self.context_parser.get_left_and_right_context(task)
154 | logger.info("\n" + "\n".join(left_context_str.split('\n')[-20:]))
155 | left_tokens = self.tokenizer.encode(
156 | left_context_str, return_tensors="pt", add_special_tokens=False) # ['input_ids']
157 | if self.max_context_length and left_tokens.shape[1] > self.max_context_length:
158 | left_tokens = left_tokens[:, -self.max_context_length:]
159 | tokens = torch.cat([self.lm_prefix_tokens, left_tokens, self.lm_suffix_tokens], dim=-1).type(torch.long)
160 | return tokens
161 |
162 |
163 |
164 |
--------------------------------------------------------------------------------
/lm_eval/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Dict, Any, Tuple
3 | from shutil import copytree
4 | from pathlib import Path
5 | import json
6 | from subprocess import Popen, TimeoutExpired, PIPE, run
7 | import os
8 | import re
9 | import shutil
10 |
11 | from .datatypes import Task
12 | CONDA_BIN = '/home/user/conda/bin/conda'
13 |
14 |
15 | TIMEOUT = 30
16 |
17 | def get_indent(code):
18 | line = code.split('\n')[0]
19 | return len(line) - len(line.strip())
20 |
21 | def run_wrapper(cmd, cwd):
22 | my_env = os.environ.copy()
23 | my_env['PATH'] = f"{cwd}:" + my_env['PATH']
24 | my_env['PYTHONPATH'] = f"{cwd}"
25 | res = run([cmd.replace('\n', ' ')], shell=True, capture_output=True, check=False, env=my_env, timeout=TIMEOUT)
26 | return res.stdout.decode("utf-8") + res.stderr.decode("utf-8")
27 |
28 |
29 | def run_tests(bin: os.PathLike, repo: os.PathLike) -> Dict[str, int]:
30 | """
31 | Execute all tests in the given path using pytest from bin
32 | """
33 | try:
34 | cmd = run_wrapper(f"cd {str(repo)} && conda run -p {str(bin)} pytest tests --color=no -p no:cacheprovider", cwd=str(repo))
35 | except TimeoutExpired:
36 | print('TIMEOUT CAUGHT')
37 | return {'passed': 0, 'failed': 0, 'output': 'TIMEOUT'}
38 | passed = re.findall(r" \d+ passed", cmd)
39 | if passed:
40 | passed = int(passed[0][1:-7])
41 | else:
42 | passed = 0
43 | failed = re.findall(r" \d+ failed", cmd)
44 | if failed:
45 | failed = int(failed[0][1:-7])
46 | else:
47 | failed = 0
48 | if cmd.find("short test summary info") != -1:
49 | out = '\n'.join(cmd.split('\n')[-50:]) # cmd[cmd.find("short test summary info"):]
50 | else:
51 | out = '\n'.join(cmd.split('\n')[:])
52 | return {'passed': passed, 'failed': failed, 'output': out}
53 |
54 | def evaluate_override(
55 | root_path: os.PathLike, task: Task, generation: str, workdir: os.PathLike
56 | ) -> Dict[str, Any]:
57 | root_path = Path(root_path)
58 | workdir = Path(workdir).absolute()
59 | if os.path.exists(workdir):
60 | try:
61 | shutil.rmtree(workdir)
62 | except FileNotFoundError as e:
63 | print(f"Caught file not found at rmtree {workdir}")
64 | workdir.mkdir(parents=True, exist_ok=True)
65 |
66 | copytree(root_path / task.repo, workdir, dirs_exist_ok=True, # we do not want to copy venv, it is very slow
67 | ignore=shutil.ignore_patterns(
68 | 'venv_bench', '.github', '.git', '.pytest_cache', '*.egg-info', '__pycache__', 'testtemp'
69 | )
70 | )
71 | new_content = task.left_context + generation + task.right_context
72 | with open(workdir / task.path_from_root, 'w', encoding='utf-8') as f:
73 | f.write(new_content)
74 |
75 | metrics = run_tests(root_path / task.repo / "venv_bench", workdir)
76 |
77 | try:
78 | shutil.rmtree(workdir)
79 | except FileNotFoundError as e:
80 | print(f"Caught file not found at rmtree {workdir}")
81 | except OSError as e:
82 | print(f"OSError {e} while rm {workdir}")
83 | return metrics
84 |
85 | def evaluate_override_wrapped(
86 | root_path: os.PathLike, task: Task, generation: str, workdir: os.PathLike, task_n: int, gen_n: int, cache: dict
87 | ) -> Tuple[int, int, Dict[str, Any]]:
88 | cache_key = task.left_context + generation + task.right_context
89 | if cache_key in cache:
90 | return (task_n, gen_n, cache[cache_key])
91 | else:
92 | res = evaluate_override(root_path, task, generation, workdir)
93 | cache[cache_key] = res
94 | return (task_n, gen_n, res)
95 |
96 |
97 | def load_dataset(root_path: os.PathLike, meta_file: str = 'dataset.json', limit: int = 10_000) -> List[Task]:
98 | with open(Path(root_path) / meta_file, 'r') as f:
99 | dataset = [Task(**t) for t in json.load(f)][:limit]
100 | return dataset
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 |
2 | import hydra
3 | import torch
4 | import numpy as np
5 | import random
6 | import json
7 | import os
8 |
9 | from lm_eval.generators import InfillGenerator, LMGenerator
10 | from lm_eval.evaluator import Evaluator
11 | from lm_eval.context_parser import TrivialContextParser
12 | from lm_eval.utils import load_dataset
13 |
14 | from omegaconf import DictConfig, OmegaConf
15 | from accelerate import Accelerator
16 | from accelerate.utils import gather_object
17 |
18 |
19 |
20 | import logging
21 | logger = logging.getLogger("RealCode")
22 | logger.setLevel(logging.INFO)
23 |
24 | def seed_all(seed):
25 | random.seed(seed)
26 | np.random.seed(seed)
27 | torch.manual_seed(seed)
28 | torch.cuda.manual_seed(seed)
29 |
30 | @hydra.main(config_path="config", config_name="config", version_base="1.3")
31 | def main(cfg: DictConfig) -> None:
32 | seed_all(cfg.seed)
33 | print(cfg)
34 | accelerator = Accelerator()
35 | dataset = load_dataset(cfg.dataset_root, cfg.dataset_meta_file, cfg.limit)
36 | logger.info(f"loaded {cfg.dataset_root} {cfg.dataset_meta_file}")
37 | if cfg.do_generation:
38 | if 'context_parser' in cfg:
39 | parser = hydra.utils.instantiate(cfg.context_parser)
40 | else:
41 | parser = TrivialContextParser()
42 |
43 | dtype_map = {'fp16': torch.float16, 'fp32': torch.float, 'bf16': torch.bfloat16}
44 | if cfg.generator_mode == 'infill':
45 | generator = InfillGenerator(
46 | accelerator=accelerator,
47 | model_path=cfg.model_path,
48 | dtype=dtype_map[cfg.dtype],
49 | num_samples=cfg.num_samples,
50 | prefix_tokens=cfg.prefix_tokens,
51 | middle_tokens=cfg.middle_tokens,
52 | suffix_tokens=cfg.suffix_tokens,
53 | max_context_length=cfg.max_context_length,
54 | generation_params=dict(cfg.generation_params),
55 | model_kwargs=cfg.model_kwargs if 'model_kwargs' in cfg else {},
56 | context_parser=parser,
57 | left_context_ratio=cfg.left_context_ratio,
58 | )
59 | elif cfg.generator_mode == 'lm':
60 | generator = LMGenerator(
61 | accelerator=accelerator,
62 | model_path=cfg.model_path,
63 | dtype=dtype_map[cfg.dtype],
64 | num_samples=cfg.num_samples,
65 | lm_prefix_tokens=cfg.lm_prefix_tokens if 'lm_prefix_tokens' in cfg else [],
66 | lm_suffix_tokens=cfg.lm_suffix_tokens if 'lm_suffix_tokens' in cfg else [],
67 | max_context_length=cfg.max_context_length,
68 | generation_params=dict(cfg.generation_params),
69 | model_kwargs=cfg.model_kwargs if 'model_kwargs' in cfg else {},
70 | context_parser=parser,
71 | )
72 | else:
73 | raise ValueError(f"generator_mode can be either 'lm' or 'infill', found {cfg.generator_mode}")
74 |
75 |
76 |
77 | logger.info(f"Starting generation")
78 | with accelerator.split_between_processes(dataset) as part:
79 | part_generations = generator.generate(part)
80 | generations = gather_object(part_generations)
81 | if accelerator.is_main_process:
82 | with open(cfg.generations_save_path, "w") as f:
83 | json.dump(generations, f)
84 | del generator.model
85 | else:
86 | with open(cfg.generations_save_path, "r") as f:
87 | generations = json.load(f)
88 |
89 | if cfg.do_eval and accelerator.is_main_process:
90 | evaluator = Evaluator(
91 | dataset_root=cfg.dataset_root,
92 | num_samples=cfg.num_samples,
93 | pass_k_list=cfg.pass_k_list,
94 | njobs=cfg.njobs,
95 | working_dir=cfg.working_dir,
96 | )
97 | logger.info(f"Starting evaluation")
98 | metrics = evaluator.evaluate(dataset, generations)
99 | logger.info(json.dumps(metrics['total'], indent=4))
100 | if cfg.metrics_save_path:
101 | try:
102 | with open(cfg.metrics_save_path, "w") as f:
103 | json.dump(metrics, f)
104 | except FileNotFoundError:
105 | logger.warn("Found slashes in your cli args, metrics will not be saved")
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
111 |
--------------------------------------------------------------------------------
/prepare_data/run.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from git import Repo
3 | import os
4 | import shutil
5 | import subprocess
6 | from pathlib import Path
7 | from tqdm import tqdm
8 | from joblib import Parallel, delayed
9 |
10 | os.environ['GIT_PYTHON_TRACE'] = 'full'
11 |
12 |
13 | def run(cmd, check=False):
14 | return subprocess.run(
15 | [cmd.replace('\n', ' ')],
16 | shell=True, capture_output=True, check=check,
17 | ).stdout.decode("utf-8")
18 |
19 | def delete_and_report(path):
20 | if os.path.exists(path):
21 | shutil.rmtree(path, ignore_errors=False)
22 | if os.path.exists(path):
23 | raise ValueError(f"Unable to delete {path}, please delete it manually")
24 |
25 | def setup(repo):
26 | """
27 | builds conda environments in repo, that will be used to run tests
28 | """
29 | base_path = str(repo.resolve().absolute())
30 | delete_and_report(f"{base_path}/venv_bench")
31 | delete_and_report(f"{base_path}/build")
32 | delete_and_report(f"{base_path}/*.egg-info")
33 | try:
34 | d = run(f"conda create -p {base_path}/venv_bench --copy -y python=3.11 poetry", check=True)
35 | except subprocess.CalledProcessError as e:
36 | print(repo, 'create')
37 | print(e.stdout)
38 | print(e.stderr)
39 | raise e
40 | if os.path.exists(f"{base_path}/poetry.lock"):
41 | run(f"rm {base_path}/reqs_p.txt")
42 | out = run(f"cd {base_path} && conda run -p {base_path}/venv_bench poetry export -o reqs_p.txt --without-hashes")
43 | out = run(f"cd {base_path} && conda run -p {base_path}/venv_bench poetry export --with dev -o reqs_p.txt --without-hashes")
44 | out = run(f"cd {base_path} && conda run -p {base_path}/venv_bench poetry export --with test -o reqs_p.txt --without-hashes")
45 |
46 | for req_filename in ["reqs_p.txt", "requirements.txt", "linux_requirements.txt",
47 | "requirements-ci.txt","requirements_ci.txt", "dev-requirements.txt",
48 | 'requirements_dev.txt', "requirements-dev.txt"]:
49 | if os.path.exists(f"{base_path}/{req_filename}"):
50 | out = run(f"conda run -p {base_path}/venv_bench python -m pip install -r {base_path}/{req_filename}", check=True)
51 | skip_install = False
52 | try:
53 | if not skip_install and (os.path.exists(f"{base_path}/setup.py") or os.path.exists(f"{base_path}/pyproject.toml")):
54 | out = run(f"conda run -p {base_path}/venv_bench python -m pip install {base_path}", check=True)
55 | except subprocess.CalledProcessError as e:
56 | print('='*40)
57 | print(repo, 'pip install warn')
58 | print('='*40)
59 | for toml_option in ["[test]", "[dev]", "[all]"]:
60 | out = run(f"conda run -p {base_path}/venv_bench python -m pip install {base_path}.{toml_option}")
61 | out = run(f"conda run -p {base_path}/venv_bench pip install pytest")
62 | if not os.path.exists(f"{repo}/venv_bench/bin/python"):
63 | raise ValueError(f"{repo}/venv_bench/bin/python not found")
64 | print(repo, "done")
65 | return base_path
66 |
67 |
68 | def build_envs(source_dir):
69 | repos_parent = Path(source_dir)
70 | Parallel(n_jobs=8)(
71 | delayed(setup)(path)
72 | for path in tqdm([t for t in repos_parent.iterdir() if os.path.isdir(t)], desc='building_envs')
73 | )
74 |
75 |
76 |
77 | if __name__ == '__main__':
78 | dataset_dir = '../data/realcode_v3'
79 | build_envs(dataset_dir)
80 | print('Done')
81 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core==1.3.2
2 | hydra-joblib-launcher==1.2.0
3 | pandas
4 | tqdm
5 | pytest
6 | transformers==4.48.0
7 |
--------------------------------------------------------------------------------
/results/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Core-Team/RealCode_eval/f70984fb83022eb191ba94bcf55729c2fc64aa80/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_evaluator.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pathlib import Path
3 | import os
4 | import json
5 | import joblib
6 |
7 | import lm_eval.utils
8 | import lm_eval.evaluator
9 |
10 | @pytest.fixture
11 | def dataset_path():
12 | return str(Path('./data/realcode_v3').resolve())
13 |
14 | def get_indent(code):
15 | line = code.split('\n')[0]
16 | return len(line) - len(line.strip())
17 |
18 |
19 | @pytest.mark.parametrize('dataset_file', ['realcode_v3_SG.json', 'realcode_v3_FG.json'])
20 | def test_perfect_preds(dataset_path, dataset_file, workdir='./workdir'):
21 | print("Testing where Pass@1 should be 1")
22 | root = Path(dataset_path)
23 | print(f"Dataset is at ", root, dataset_file)
24 | NJOBS = 8
25 |
26 | dataset = lm_eval.utils.load_dataset(root, dataset_file, limit=10_000)
27 | empty_ans = [[t.gt] for t in dataset]
28 | evaluator = lm_eval.evaluator.Evaluator(
29 | root,
30 | num_samples=1,
31 | pass_k_list=[1],
32 | njobs=NJOBS,
33 | working_dir=workdir
34 | )
35 | metrics = evaluator.evaluate(dataset, empty_ans)
36 | wrong = []
37 | for metric in metrics['detailed']:
38 | if metric['Pass@1'] < 1 - 1e-3:
39 | wrong.append(metric)
40 | print(metric['Pass@1'], metric['repo'], metric['repo_n'], metric['path_from_root'], metric['evaluations'][0])
41 | with open('test_perfect_preds_fails.json', 'w') as f:
42 | json.dump(wrong, f)
43 | for x in wrong:
44 | print(x['repo'], x['path_from_root'], x['repo_n'])
45 | assert len(wrong) == 0
46 |
47 |
48 | @pytest.mark.parametrize('dataset_file', ['realcode_v3_SG.json', 'realcode_v3_FG.json'])
49 | def test_incorrect_answers(dataset_path, dataset_file):
50 | print("Testing where Pass@1 should be 0")
51 | root = Path(dataset_path)
52 | print(f"Dataset is at ", root, dataset_file)
53 | NJOBS = 8
54 |
55 | dataset = lm_eval.utils.load_dataset(root, dataset_file, limit=10_000)
56 | empty_ans = [[" "*get_indent(t.gt) + 'pass\n'] for t in dataset]
57 | evaluator = lm_eval.evaluator.Evaluator(
58 | root,
59 | 1,
60 | [1],
61 | njobs=NJOBS,
62 | working_dir='./workdir'
63 | )
64 | metrics = evaluator.evaluate(dataset, empty_ans)
65 | wrong = []
66 | for metric in metrics['detailed']:
67 | if metric['Pass@1'] > 1e-3:
68 | wrong.append(metric)
69 | print(metric['Pass@1'], metric['repo'], metric['repo_n'], metric['path_from_root'], metric['evaluations'][0])
70 | print('\n' * 10)
71 | with open('test_incorrest_answers_fails.json', 'w') as f:
72 | json.dump(wrong, f)
73 | for x in wrong:
74 | print(x['repo'], x['path_from_root'], x['repo_n'])
75 | assert len(wrong) == 0
76 |
77 |
78 |
79 | @pytest.mark.parametrize('dataset_file', ['realcode_v3_SG.json', 'realcode_v3_FG.json'])
80 | def test_perfect_preds_parallel(dataset_path, dataset_file):
81 | """
82 | Like test_perfect_preds but with parallel evaluation
83 | """
84 | joblib.Parallel(n_jobs=8)(joblib.delayed(test_perfect_preds)(dataset_path, dataset_file, workdir=str(i)) for i in range(2))
85 |
--------------------------------------------------------------------------------
/workdir/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
--------------------------------------------------------------------------------