├── .gitignore ├── README.md ├── example.py ├── leetcode_dataset ├── build.py ├── build.sh └── lib │ ├── add_test_cases.py │ ├── clean_dataset.py │ ├── fetch_dataset.py │ ├── format_dataset.py │ └── utils │ ├── extract_tests.yaml │ ├── llm.py │ └── utils.py ├── leetcode_env ├── __init__.py ├── environment.py ├── types.py └── utils │ ├── formatting.py │ └── leetcode.py ├── pyproject.toml ├── requirements.txt └── tests └── utils ├── test_python_formatter.py └── test_rust_formatter.py /.gitignore: -------------------------------------------------------------------------------- 1 | env/ 2 | .env 3 | __pycache__/ 4 | SubmissionTest.ipynb 5 | .vscode/ 6 | data/ 7 | __init__.py 8 | .pytest_cache/ 9 | build/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Leetcode-Hard Gym 2 | RL environment interface to LeetCode's submission server for evaluating codegen agents. Built on top of OpenAI's [gym](https://github.com/openai/gym). 3 | 4 | Supports: 5 | - `c` 6 | - `c#` 7 | - `java` 8 | - `python` 9 | - `javascript` 10 | - `ruby` 11 | - `swift` 12 | - `go` 13 | - `scala` 14 | - `kotlin` 15 | - `rust` 16 | - `php` 17 | - `typescript` 18 | - `racket` 19 | - `erlang` 20 | - `elixir` 21 | - `dart` 22 | - `mysql` 23 | 24 | ### Leaderboard for Leetcode Hard (Python): Pass@1 25 | - OpenAI's GPT-4: `10.7` ([source](https://arxiv.org/pdf/2303.12712.pdf)) 26 | - OpenAI's Codex: `3.6` ([source](https://arxiv.org/pdf/2303.12712.pdf)) 27 | - OpenAI's GPT-3.5: `0.0` ([source](https://arxiv.org/pdf/2303.12712.pdf)) 28 | - Reflexion + GPT-4: `15.0` ([source](https://arxiv.org/abs/2303.11366)) 29 | 30 | ### Setup: 31 | 1. Clone the repository: 32 | ```bash 33 | git clone https://github.com/GammaTauAI/leetcode-hard-gym.git && cd leetcode-hard-gym 34 | ``` 35 | 36 | 2. Create a virtual environment and install the `leetcode_env` module and its dependencies: 37 | ```bash 38 | python -m venv venv 39 | source venv/bin/activate 40 | python -m pip install -e . 41 | ``` 42 | 43 | 3. Set the environment variables `LEETCODE_SESSION` to the cookie `LEETCODE_SESSION` and `LEETCODE_CSRF_TOKEN` to the cookie `csrftoken` from a signed-in Leetcode session. This cookie can be found by using browser DevTools or by using a browser extension like [EditThisCookie](https://www.editthiscookie.com/). 44 | ```bash 45 | export LEETCODE_SESSION=... 46 | export LEETCODE_CSRF_TOKEN=... 47 | ``` 48 | 49 | ### Example usage: 50 | First we write some code: 51 | 52 | ```python 53 | code = """ 54 | class Solution: 55 | def twoSum(self, nums, target): 56 | l = len(nums) 57 | for i in range(l - 1): 58 | for j in range(i + 1, l): 59 | if nums[i] + nums[j] == target: 60 | return [i, j] 61 | """ 62 | ``` 63 | 64 | Then we can build a submission ... 65 | 66 | ```python 67 | from leetcode_env.types import LeetCodeSubmission, ProgrammingLanguage 68 | sub = LeetCodeSubmission(code=code, 69 | lang=ProgrammingLanguage.PYTHON3, 70 | question_slug='two-sum', 71 | timeout=5) 72 | ``` 73 | 74 | ... and instantiate a submission environment ... 75 | 76 | ```python 77 | from leetcode_env.environment import LeetCodeEnv 78 | env = LeetCodeEnv() 79 | ``` 80 | 81 | Finally, we can step through the environment with the submission: 82 | 83 | ```python 84 | status, reward, done, submission_result = env.step(sub) 85 | print(status, reward, done, submission_result) 86 | # Wrong Answer 87 | # False 88 | # False 89 | # {'status_code': 11, 'lang': 'python3', 'run_success': True, 'status_runtime': 'N/A', 'memory': 14160000, 'question_id': '4', 'elapsed_time': 105, 'compare_result': '00010000000...00000000001000', 'code_output': '1.00000', 'std_output': '', 'last_testcase': '[1,3]\n[2]', 'expected_output': '2.00000', 'task_finish_time': 1680132323596, 'total_correct': 6, 'total_testcases': 2094, 'runtime_percentile': None, 'status_memory': 'N/A', 'memory_percentile': None, 'pretty_lang': 'Python3', 'submission_id': '924506780', 'input_formatted': '[1,3], [2]', 'input': '[1,3]\n[2]', 'status_msg': 'Wrong Answer', 'state': 'SUCCESS'} 90 | ``` 91 | 92 | Note: compare result was shortened here, it contains a sequence of booleans indicating if a test was passed 93 | 94 | ## LeetcodeHardGym Dataset 95 | 96 | A script is provided to build an uncontaminated set of free Leetcode Hard problems in a format similar to HumanEval. It fetches the dataset, filters out class-dependent, void, and class implementation problems, and formats the problems for the specified programming languages. Optionally, it can extract test cases from examples in problem descriptions using GPT, or remove these examples from generated docstrings. 97 | 98 | ### Usage 99 | 100 | To build the dataset, `leetcode_env` must be installed in the current environment. Then, we can run the following command from the `leetcode_dataset/` directory of this repository: 101 | ```bash 102 | python build.py --langs python3 rust --log_level INFO --output_dir ./build 103 | ``` 104 | 105 | ### Arguments 106 | 107 | - `--langs`: List of languages. Current options are: rust, python3. 108 | - `--log_level`: Logging level. Options: DEBUG, INFO, WARNING, ERROR, CRITICAL. Default is INFO. 109 | - `--output_dir`: Directory to save the built dataset. Default is ./build. 110 | - `--extract_test_cases`: If set, test cases will be extracted from problem descriptions using GPT. 111 | - `--remove_examples`: If set, examples will be removed. Cannot be used with --extract_test_cases. 112 | 113 | ### Environment Variables 114 | 115 | - `LEETCODE_SESSION`: This environment variable must be set for the script to run. Please refer to the Setup section for instructions on how to obtain your session cookie. 116 | - `LEETCODE_CSRF_TOKEN`: This environment variable must be set for the script to run. Please refer to the Setup section for instructions on how to obtain your csrf token. 117 | - `OPENAI_API_KEY`: This environment variable is required if the `--extract_test_cases` option is used. Please refer to the OpenAI API documentation for instructions on how to obtain your API key. 118 | 119 | ### Dependencies 120 | 121 | If the `--extract_test_cases` option is used, the `openai` and `langchain` libraries are required. These can be installed with: 122 | ```python 123 | pip3 install openai langchain termcolor 124 | ``` 125 | 126 | ### Output 127 | 128 | The script will output a .jsonl file for each specified language in the output directory. The filename will be in the format `leetcode-hard-uncontaminated-{lang}.jsonl`. 129 | 130 | ### Cite 131 | 132 | This benchmark was introduced in the following paper: 133 | 134 | ```bibtex 135 | @misc{shinn2023reflexion, 136 | title={Reflexion: Language Agents with Verbal Reinforcement Learning}, 137 | author={Noah Shinn and Federico Cassano and Edward Berman and Ashwin Gopinath and Karthik Narasimhan and Shunyu Yao}, 138 | year={2023}, 139 | eprint={2303.11366}, 140 | archivePrefix={arXiv}, 141 | primaryClass={cs.AI} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from leetcode_env.environment import LeetCodeEnv 2 | from leetcode_env.types import LeetCodeSubmission, ProgrammingLanguage 3 | 4 | code = """ 5 | class Solution: 6 | def twoSum(self, nums, target): 7 | l = len(nums) 8 | for i in range(l - 1): 9 | for j in range(i + 1, l): 10 | if nums[i] + nums[j] == target: 11 | return [i, j] 12 | """ 13 | 14 | sub = LeetCodeSubmission(code=code, 15 | lang=ProgrammingLanguage.PYTHON3, 16 | question_slug='two-sum') 17 | 18 | env = LeetCodeEnv() 19 | 20 | status, reward, done, submission_result = env.step(sub) 21 | 22 | print(status, reward, done, submission_result) 23 | 24 | -------------------------------------------------------------------------------- /leetcode_dataset/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | from lib.fetch_dataset import fetch_dataset, fetch_solutions 5 | from lib.utils.utils import get_api_instance 6 | from lib.clean_dataset import remove_class_dependent, remove_void, remove_class_impls, remove_examples 7 | from lib.format_dataset import format_problems, to_jsonl 8 | 9 | parser = argparse.ArgumentParser(description="Configuration for building uncontaminated Leetcode Hard dataset") 10 | parser.add_argument('--langs', nargs='+', default=['python3'], help="List of languages. Possible options are: rust, python3") 11 | parser.add_argument('--log_level', type=str, default='INFO', help="Logging level. Options: DEBUG, INFO, WARNING, ERROR, CRITICAL.") 12 | parser.add_argument('--output_dir', type=str, default="./build", help="Directory to save the built dataset.") 13 | parser.add_argument('--extract_test_cases', action='store_true', help="If set, test cases will be extracted from problem descriptions using GPT.") 14 | parser.add_argument('--remove_examples', action='store_true', help="If set, examples will be removed. Cannot be used with --extract_test_cases.") 15 | parser.add_argument('--fetch_solutions', action='store_true', help="If set, solutions to problems will be fetched. Currently only supports lang=python3.") 16 | 17 | args = parser.parse_args() 18 | 19 | langs = args.langs 20 | log_level = getattr(logging, args.log_level.upper()) 21 | output_dir = args.output_dir 22 | extract_test_cases_ = args.extract_test_cases 23 | remove_examples_ = args.remove_examples 24 | fetch_solutions_ = args.fetch_solutions 25 | 26 | try: 27 | os.environ["LEETCODE_SESSION"] 28 | except: 29 | print("Environment variable LEETCODE_SESSION is not set. Please refer to README") 30 | exit(1) 31 | 32 | if extract_test_cases_: 33 | try: 34 | os.environ["OPENAI_API_KEY"] 35 | import openai 36 | import langchain 37 | except: 38 | print("Extra dependencies and setup are required for test case extraction. Please refer to README") 39 | exit(1) 40 | if remove_examples_: 41 | print("Cannot use --remove_examples with --extract_test_cases") 42 | exit(1) 43 | 44 | 45 | logging.basicConfig(level=log_level) 46 | os.makedirs(output_dir, exist_ok=True) 47 | 48 | api_instance = get_api_instance() 49 | dataset = fetch_dataset(api_instance) 50 | 51 | filtered_dataset = \ 52 | remove_class_impls( 53 | remove_void( 54 | remove_class_dependent(dataset))).reset_index(drop=True) 55 | 56 | if remove_examples_: 57 | filtered_dataset = remove_examples(filtered_dataset) 58 | 59 | logging.info(f"Filtered out {len(dataset) - len(filtered_dataset)} problem(s)") 60 | 61 | for lang in langs: 62 | logging.info(f"Formatting dataset for {lang}") 63 | formatted_dataset = format_problems(filtered_dataset, lang) 64 | if extract_test_cases_: 65 | logging.info(f"Extracting test cases for {lang}") 66 | from lib.add_test_cases import extract_test_cases 67 | formatted_dataset = extract_test_cases(formatted_dataset, lang) 68 | if fetch_solutions_: 69 | logging.info(f"Fetching solutions for {lang}") 70 | formatted_dataset = fetch_solutions(formatted_dataset, lang) 71 | to_jsonl(formatted_dataset, os.path.join(output_dir, f'leetcode-hard-uncontaminated-{lang}.jsonl')) 72 | -------------------------------------------------------------------------------- /leetcode_dataset/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run_build.sh 4 | 5 | python3 build.py \ 6 | --langs python3 rust \ 7 | --log_level INFO \ 8 | --output_dir ./build \ 9 | # --extract_test_cases \ 10 | # --remove_examples \ 11 | -------------------------------------------------------------------------------- /leetcode_dataset/lib/add_test_cases.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .utils.llm import LanguageFunction 3 | import os 4 | import inspect 5 | import logging 6 | from langchain.callbacks import get_openai_callback 7 | 8 | UTILS_DIR = os.path.join( 9 | os.path.dirname(inspect.getabsfile(inspect.currentframe())), 10 | "utils", 11 | ) 12 | 13 | def extract_test_cases(dataset: pd.DataFrame, lang: str) -> pd.DataFrame: 14 | """ 15 | Add test cases to the dataset 16 | Adds columns: 'test_cases' (List[str]) 17 | """ 18 | dataset = dataset.copy() 19 | dataset.reset_index(inplace=True, drop=True) 20 | dataset['test_cases'] = None 21 | with get_openai_callback() as callback: 22 | for ind, row in dataset.iterrows(): 23 | logging.info(f"Extracting test cases for problem {ind+1}/{len(dataset)}") 24 | examples = extract_examples(row['description']) 25 | function_signature = row['signature'] 26 | test_cases = examples_to_test_cases(examples, function_signature, lang) 27 | dataset.at[ind, 'test_cases'] = test_cases 28 | return dataset 29 | 30 | 31 | def extract_examples(description): 32 | """ 33 | Extract a natural language representation of the examples from the description 34 | """ 35 | inputs = [l for l in description.split('\n') if l.strip().startswith('Input')] 36 | outputs = [l.strip('Output: ') for l in description.split('\n') if l.strip().startswith('Output')] 37 | 38 | examples = [] 39 | 40 | for i, (input_str, output_str) in enumerate(zip(inputs, outputs)): 41 | example_str = f"Example {i+1}:\n{input_str}\nOutput: {output_str}" 42 | examples.append(example_str) 43 | return '\n\n'.join(examples) 44 | 45 | def examples_to_test_cases(examples: str, function_signature: str, language: str) -> str: 46 | """ 47 | Extract test cases from a natural language representation of the examples 48 | """ 49 | lang_function = LanguageFunction.from_yaml(os.path.join(UTILS_DIR, 'extract_tests.yaml')) 50 | response = lang_function(function_signature = function_signature, examples = examples, language = language, callback=True) 51 | test_cases = response['response'].split('\n') 52 | return test_cases 53 | 54 | 55 | -------------------------------------------------------------------------------- /leetcode_dataset/lib/clean_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | 4 | def remove_class_dependent(dataset: pd.DataFrame) -> pd.DataFrame: 5 | """ 6 | Remove problems that depend on class definitions 7 | """ 8 | dataset = dataset.copy() 9 | no_defs_inds = [ind for ind, row in dataset.iterrows() if row['cpp_snippet'].split(' ')[0] == 'class'] 10 | no_defs = dataset.iloc[no_defs_inds] 11 | return no_defs 12 | 13 | def remove_void(dataset: pd.DataFrame) -> pd.DataFrame: 14 | """ 15 | Remove problems that request a void implementation 16 | """ 17 | dataset = dataset.copy() 18 | ret_inds = [ind for ind, row in dataset.iterrows() if '\"\"\"' in row['python3_snippet'].split('\n')[2]] 19 | ret = dataset.drop(ret_inds) 20 | return ret 21 | 22 | def remove_class_impls(dataset: pd.DataFrame) -> pd.DataFrame: 23 | """ 24 | Remove problems that request a class implementation 25 | """ 26 | dataset = dataset.copy() 27 | function_name_regex = r"(?<=def\s)\w+" 28 | impl_inds = [ind for ind, row in dataset.iterrows() 29 | if re.search(function_name_regex, row['python3_snippet']).group(0) == '__init__'] 30 | no_impl = dataset.drop(impl_inds) 31 | return no_impl 32 | 33 | def remove_examples(dataset: pd.DataFrame) -> pd.DataFrame: 34 | """ 35 | Return a copy of the dataset without examples in the descriptions 36 | """ 37 | dataset = dataset.copy() 38 | for ind, row in dataset.iterrows(): 39 | res = docstring_remove_empty(docstring_remove_examples(row['description'])) 40 | dataset.at[ind, 'description'] = res 41 | 42 | return dataset 43 | 44 | def docstring_remove_examples(docstring: str): 45 | """ 46 | Remove the examples from the docstring 47 | """ 48 | lines = [l.strip() for l in docstring.split('\n')] 49 | for i, line in enumerate(lines): 50 | if 'Example' in line: 51 | return '\n'.join(lines[:i]) 52 | return docstring 53 | 54 | def docstring_remove_empty(desc: str): 55 | """ 56 | Remove empty lines from the docstring 57 | """ 58 | return '\n'.join(line for line in desc.split('\n') if line.strip()) 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /leetcode_dataset/lib/fetch_dataset.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import pandas as pd 3 | import ast 4 | from bs4 import BeautifulSoup 5 | import time 6 | import leetcode 7 | import logging 8 | import html2text 9 | import re 10 | import urllib.parse 11 | from typing import Dict 12 | 13 | import dotenv 14 | import html2text 15 | import leetcode 16 | import pandas as pd 17 | import requests 18 | from bs4 import BeautifulSoup 19 | 20 | from leetcode_env.utils.formatting import (PythonSubmissionFormatter, 21 | RustSubmissionFormatter, 22 | SubmissionFormatter) 23 | 24 | from .utils.utils import format_integer 25 | 26 | h = html2text.HTML2Text() 27 | h.ignore_links = True 28 | h.ignore_images = True 29 | h.ignore_emphasis = True 30 | 31 | dotenv.load_dotenv() 32 | 33 | def get_info(question_slug: str, api_instance): 34 | """ 35 | Retrieves the metadata of the question with the given slug 36 | """ 37 | graphql_request = leetcode.GraphqlQuery( 38 | query=""" 39 | query getQuestionDetail($titleSlug: String!) { 40 | question(titleSlug: $titleSlug) { 41 | codeSnippets { 42 | lang 43 | langSlug 44 | code 45 | __typename 46 | } 47 | content 48 | title 49 | } 50 | } 51 | """, 52 | variables={"titleSlug": question_slug}, 53 | operation_name="getQuestionDetail", 54 | ) 55 | response = ast.literal_eval(str(api_instance.graphql_post(body=graphql_request))) 56 | data = response['data']['question'] 57 | return data 58 | 59 | def fetch_solutions(dataset: pd.DataFrame, lang: str) -> pd.DataFrame: 60 | """ 61 | Fetch the solutions for the given lang 62 | """ 63 | dataset = dataset.copy() 64 | for ind, row in dataset.iterrows(): 65 | logging.info(f"Fetching solution for problem {ind+1}/{len(dataset)}") 66 | solution = fetch_solution(row['frontend_question_id'], row['question_title'], lang) 67 | dataset.at[ind, 'solution'] = solution if solution is not None else "" 68 | return dataset 69 | 70 | def fetch_solution(frontend_question_id: int, question_title: str, lang: str = "python3"): 71 | """Get the solution of the question from the LeetCode github repository.""" 72 | LANG_EXT_MAP = { 73 | "python3": "py", 74 | "java": "java", 75 | "cpp": "cpp", 76 | } 77 | 78 | if lang not in LANG_EXT_MAP: 79 | raise ValueError(f"Solutions not supported for Language {lang}") 80 | 81 | FORMATTER_MAP: Dict[str, SubmissionFormatter] = { 82 | "python3": PythonSubmissionFormatter, 83 | "rust": RustSubmissionFormatter, 84 | } 85 | question_id = format_integer(int(frontend_question_id)) 86 | 87 | url = f"https://raw.githubusercontent.com/walkccc/LeetCode/main/solutions/{question_id}. {question_title}/{question_id}.{LANG_EXT_MAP[lang]}" 88 | encoded_url = urllib.parse.quote(url, safe=":/") 89 | response = requests.get(encoded_url) 90 | if response.status_code == 404: 91 | return None 92 | return FORMATTER_MAP[lang].to_humaneval(response.text) 93 | 94 | def fetch_dataset(api_instance): 95 | """ 96 | Get the hard, free, uncontaminated, hard questions from the algorithms topic 97 | """ 98 | question_infos = api_instance.api_problems_topic_get(topic="algorithms") 99 | logging.info(f"Fetched question infos") 100 | 101 | hard = [q for q in question_infos.stat_status_pairs 102 | if q.difficulty.level == 3 103 | and q.paid_only == False] 104 | 105 | hard_dicts = [q.to_dict() for q in hard] 106 | 107 | slug = 'paths-in-matrix-whose-sum-is-divisible-by-k' # This is the first uncontaminated problem 108 | index = next((i for i, q in enumerate(hard_dicts) if q['stat']['question__title_slug'] == slug), None) 109 | uncontaminated = hard[:index + 1] 110 | uncontaminated = uncontaminated[-41:] # Need to get the oldest 41 problems so that the benchmark is consistent 111 | 112 | df = pd.DataFrame() 113 | for ind, question in enumerate(uncontaminated): 114 | logging.info(f"Fetching code snippets for problem {ind + 1}/{len(uncontaminated)}") 115 | question_slug = question.stat.question__title_slug 116 | info = get_info(question_slug, api_instance) 117 | snippets = info['code_snippets'] 118 | content = BeautifulSoup(info['content'], features='html.parser') 119 | text_content = h.handle(str(content)) 120 | text_content = "\n".join(line.lstrip() for line in text_content.split("\n")) 121 | text_content = re.sub('\n\n+', '\n\n', text_content) 122 | text_content = text_content.strip().strip('\n') 123 | 124 | df.at[ind, "question_slug"] = question.stat.question__title_slug 125 | df.at[ind, "question_title"] = question.stat.question__title 126 | df.at[ind, "frontend_question_id"] = int(question.stat.frontend_question_id) 127 | df.at[ind, "question_id"] = int(question.stat.question_id) 128 | df.at[ind, "description"] = text_content 129 | 130 | for snippet in snippets: 131 | df.at[ind, snippet['lang_slug'] + '_snippet'] = snippet['code'] 132 | 133 | return df 134 | 135 | -------------------------------------------------------------------------------- /leetcode_dataset/lib/format_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import logging 3 | from .utils.utils import lines_to_jsonl 4 | from leetcode_env.utils.formatting import ( 5 | PythonSubmissionFormatter, 6 | RustSubmissionFormatter, 7 | SubmissionFormatter, 8 | ) 9 | 10 | FORMATTERS = { 11 | "python3": PythonSubmissionFormatter, 12 | "rust": RustSubmissionFormatter, 13 | } 14 | 15 | 16 | def format_problems(dataset: pd.DataFrame, lang: str): 17 | """ 18 | Convert problems to functions with their descriptsions as docstrings 19 | Adds columns: 'signature', 'prompt' 20 | """ 21 | formatter: SubmissionFormatter = FORMATTERS.get(lang) 22 | dataset = dataset.copy() 23 | for ind, row in dataset.iterrows(): 24 | formatted_problem = formatter.to_humaneval(row[f"{lang}_snippet"]) 25 | prompt = formatter.add_docstring(formatted_problem, row["description"]) 26 | signature = formatter.extract_signature(formatted_problem) 27 | dataset.at[ind, "signature"] = signature 28 | dataset.at[ind, "prompt"] = prompt 29 | return dataset 30 | 31 | def to_jsonl(dataset: pd.DataFrame, path: str): 32 | """ 33 | Save the dataset to a jsonl file 34 | """ 35 | logging.info(f"Writing dataset to {path}") 36 | lines = [] 37 | for ind, row in dataset.iterrows(): 38 | task_id = row["question_slug"] 39 | test_cases = '\n'.join(row.get("test_cases", [])) 40 | solution = row.get("solution", "") 41 | prompt = row["prompt"] 42 | signature = row["signature"] 43 | docstring = row["description"] 44 | 45 | line = { 46 | "task_id": task_id, 47 | "prompt": prompt, 48 | "canonical_solution": solution, 49 | "test": test_cases, 50 | "signature": signature, 51 | "docstring": docstring, 52 | } 53 | 54 | lines.append(line) 55 | 56 | lines_to_jsonl(lines, path) 57 | 58 | -------------------------------------------------------------------------------- /leetcode_dataset/lib/utils/extract_tests.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model: "gpt-3.5-turbo" 3 | temperature: 0 4 | function: 5 | system_message: | 6 | You are a language-agnostic test writer. 7 | You will be given: 8 | - A function signature 9 | - Pseudocode tests in the form of example input output pairs 10 | - The programming language of the FUNCTION SIGNATURE 11 | 12 | Respond with: 13 | - New line seperated assert statements in the same language as the FUNCTION SIGNATURE. 14 | 15 | Rules: 16 | - Each test may only take up one line 17 | 18 | user_message_template: | 19 | FUNCTION SIGNATURE: 20 | {function_signature} 21 | 22 | PSEUDOCODE TESTS 23 | {examples} 24 | 25 | LANGUAGE: {language} 26 | 27 | few_shot_prompt: 28 | - role: "user" 29 | content: | 30 | FUNCTION SIGNATURE: 31 | def minReverseOperations(n: int, p: int, banned: List[int], k: int) -> List[int]: 32 | 33 | PSEUDOCODE TESTS 34 | Example 1: 35 | Input: n = 4, p = 0, banned = [1,2], k = 4 36 | Output: [0,-1,-1,1] 37 | 38 | Example 2: 39 | Input: n = 5, p = 0, banned = [2,4], k = 3 40 | Output: [0,-1,-1,-1,-1] 41 | 42 | Example 3: 43 | Input: n = 4, p = 2, banned = [0,1,3], k = 1 44 | Output: [-1,-1,0,-1] 45 | 46 | LANGUAGE: python 47 | 48 | - role: "assistant" 49 | content: | 50 | assert minReverseOperations(4, 0, [1,2], 4) == [0,-1,-1,1] 51 | assert minReverseOperations(5, 0, [2,4], 3) == [0,-1,-1,-1,-1] 52 | assert minReverseOperations(4, 2, [0,1,3], 1) == [-1,-1,0,-1] 53 | 54 | - role: "user" 55 | content: | 56 | FUNCTION SIGNATURE: 57 | fn collect_the_coins(coins: Vec, edges: Vec>) -> i32 { 58 | } 59 | 60 | PSEUDOCODE TESTS: 61 | Example 1: 62 | Input: coins = [1,0,0,0,0,1], edges = [[0,1],[1,2],[2,3],[3,4],[4,5]] 63 | Output: 2 64 | 65 | Example 2: 66 | Input: coins = [0,0,0,1,1,0,0,1], edges = [[0,1],[0,2],[1,3],[1,4],[2,5],[5,6],[5,7]] 67 | Output: 2 68 | 69 | LANGUAGE: Rust 70 | 71 | - role: "assistant" 72 | content: 73 | assert_eq!(collect_the_coins(vec![1,0,0,0,0,1], vec![vec![0,1],vec![1,2],vec![2,3],vec![3,4],vec![4,5]]), 2); 74 | assert_eq!(collect_the_coins(vec![0,0,0,1,1,0,0,1], vec![vec![0,1],vec![0,2],vec![1,3],vec![1,4],vec![2,5],vec![5,6],vec![5,7]]), 2); -------------------------------------------------------------------------------- /leetcode_dataset/lib/utils/llm.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | from typing import Dict, List, Sequence 5 | from langchain.chat_models.base import BaseChatModel 6 | from langchain.chat_models import ChatOpenAI 7 | import json 8 | import yaml 9 | from langchain.callbacks import get_openai_callback 10 | 11 | import logging 12 | from langchain.schema import ( 13 | HumanMessage, 14 | AIMessage, 15 | SystemMessage, 16 | BaseMessage, 17 | FunctionMessage, 18 | ChatMessage 19 | ) 20 | from termcolor import colored 21 | 22 | 23 | class LanguageFunction: 24 | """ 25 | Single turn natural language function which expects a text query in a structured format and returns a text response. 26 | """ 27 | def __init__(self, config: Dict, **model_kwargs) -> None: 28 | """ 29 | Initialize the function 30 | Args: 31 | config (Dict): The containing the function's configuration. 32 | """ 33 | function = dict(config["function"]) 34 | model = dict(config["model"]) 35 | model_kwargs.update(model) 36 | self.model_kwargs = model_kwargs 37 | self.system_message = SystemMessage(content = function.get("system_message", "")) 38 | self.few_shot_prompt = parse_conversation(function.get("few_shot_prompt", [])) 39 | self.user_message_template: str = function["user_message_template"] 40 | 41 | self.reset_messages() 42 | 43 | self.chat_model: BaseChatModel = ChatOpenAI(**self.model_kwargs) 44 | 45 | def reset_messages(self) -> None: 46 | """ 47 | Reset the agent's conversation. 48 | """ 49 | self.messages: List[BaseMessage] = [] 50 | 51 | def __call__(self, callback = False, **kwargs) -> Dict: 52 | """ 53 | Call the Agent Function with the given arguments. 54 | 55 | Args: 56 | callback (bool): Whether to use the OpenAI callback for logging 57 | kwargs (Dict): The arguments to the function. 58 | """ 59 | message = HumanMessage(content = self.user_message_template.format(**kwargs)) 60 | self.messages.append(message) 61 | return self._call_model(callback) 62 | 63 | def _call_model(self, callback_on: bool) -> Dict: 64 | if callback_on: 65 | with get_openai_callback() as callback: 66 | response = self.chat_model([self.system_message, *self.few_shot_prompt, *self.messages]) 67 | else: 68 | response = self.chat_model([self.system_message, *self.few_shot_prompt, *self.messages]) 69 | self.messages.append(response) 70 | chat_string = get_buffer_string(self.messages) 71 | logging.debug(f'Language Function thread:\n{chat_string}') 72 | if callback_on: 73 | logging.debug(f'Total prompt tokens: {callback.prompt_tokens}') 74 | logging.debug(f'Total completion tokens: {callback.completion_tokens}') 75 | logging.debug(f'Total Cost: ${callback.total_cost:.7f}') 76 | self.reset_messages() 77 | try: 78 | response_dict = json.loads(response.content) 79 | except json.decoder.JSONDecodeError: 80 | response_dict = {"response": response.content} 81 | return response_dict 82 | 83 | @classmethod 84 | def from_yaml(cls, filepath: str) -> LanguageFunction: 85 | """ 86 | Load an agent from a YAML file. 87 | 88 | Args: 89 | filepath (str): The path to the YAML file. 90 | 91 | Returns: 92 | Agent: The agent. 93 | """ 94 | yaml_obj = load_yaml_file(filepath) 95 | return cls(yaml_obj) 96 | 97 | def load_yaml_file(filepath: str) -> List[Dict]: 98 | """ 99 | Load a YAML file and return its contents as a dictionary. 100 | 101 | Args: 102 | filepath (str): The path to the YAML file. 103 | 104 | Returns: 105 | Dict: The contents of the YAML file as a dictionary. 106 | """ 107 | with open(filepath, "r", encoding="utf-8") as file: 108 | yaml_obj = yaml.safe_load(file) 109 | return yaml_obj 110 | 111 | def parse_conversation(raw_messages: List[Dict]) -> List[BaseMessage]: 112 | """ 113 | Parse a chat thread JSON object into a list of Messages. 114 | """ 115 | message_roles = { 116 | "user": HumanMessage, 117 | "assistant": AIMessage, 118 | "system": SystemMessage, 119 | } 120 | 121 | messages = [] 122 | for message in raw_messages: 123 | message_type = message_roles[message["role"]] 124 | messages.append(message_type(content=message["content"])) 125 | 126 | return messages 127 | 128 | def get_buffer_string( 129 | messages: Sequence[BaseMessage], human_prefix: str = "Input", ai_prefix: str = "Output" 130 | ) -> str: 131 | """Convert sequence of Messages to strings and concatenate them into one string. 132 | 133 | Args: 134 | messages: Messages to be converted to strings. 135 | human_prefix: The prefix to prepend to contents of HumanMessages. 136 | ai_prefix: THe prefix to prepend to contents of AIMessages. 137 | 138 | Returns: 139 | A single string concatenation of all input messages. 140 | 141 | Example: 142 | .. code-block:: python 143 | 144 | from langchain.schema import AIMessage, HumanMessage 145 | 146 | messages = [ 147 | HumanMessage(content="Hi, how are you?"), 148 | AIMessage(content="Good, how are you?"), 149 | ] 150 | get_buffer_string(messages) 151 | # -> "Human: Hi, how are you?\nAI: Good, how are you?" 152 | """ 153 | role_to_color = { 154 | "System": "red", 155 | human_prefix: "green", 156 | ai_prefix: "blue", 157 | "Function": "magenta", 158 | } 159 | formatted_messages = [] 160 | for m in messages: 161 | if isinstance(m, HumanMessage): 162 | role = human_prefix 163 | elif isinstance(m, AIMessage): 164 | role = ai_prefix 165 | elif isinstance(m, SystemMessage): 166 | role = "System" 167 | elif isinstance(m, FunctionMessage): 168 | role = "Function" 169 | elif isinstance(m, ChatMessage): 170 | role = m.role 171 | else: 172 | raise ValueError(f"Got unsupported message type: {m}") 173 | prefix_len = len(f'{role}: ') 174 | message_content = m.content 175 | message_lines = message_content.split("\n") 176 | if len(message_lines) > 1: # To align indent 177 | message_content = "\n".join( 178 | [message_lines[0]] 179 | + [" " * prefix_len + line for line in message_lines[1:]] 180 | ) 181 | message = f"{role}: {message_content}" 182 | if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: 183 | message += f"{m.additional_kwargs['function_call']}" 184 | 185 | formatted_messages.append(colored(message, role_to_color[role])) 186 | return "\n".join(formatted_messages) 187 | 188 | 189 | -------------------------------------------------------------------------------- /leetcode_dataset/lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import requests 4 | from bs4 import BeautifulSoup 5 | import os 6 | from typing import List 7 | import leetcode 8 | import leetcode.auth 9 | from typing import Dict 10 | import string 11 | 12 | def lines_to_jsonl(lines: List[Dict], file_path: str): 13 | """ 14 | Convert a list of dicts to a jsonl file 15 | """ 16 | # Empty the current file 17 | open(file_path, 'w').close() 18 | 19 | with open(file_path, 'a') as file: 20 | for dict_data in lines: 21 | json_line = json.dumps(dict_data) 22 | file.write(json_line + os.linesep) 23 | 24 | def get_api_instance(): 25 | """ 26 | Get the leetcode api instance 27 | """ 28 | configuration = leetcode.Configuration() 29 | 30 | # From Dev Tools/Application/Cookies/LEETCODE_SESSION 31 | leetcode_session = os.environ["LEETCODE_SESSION"] 32 | csrf_token = leetcode.auth.get_csrf_cookie(leetcode_session) 33 | 34 | configuration.api_key["x-csrftoken"] = csrf_token 35 | configuration.api_key["csrftoken"] = csrf_token 36 | configuration.api_key["LEETCODE_SESSION"] = leetcode_session 37 | configuration.api_key["Referer"] = "https://leetcode.com" 38 | configuration.debug = False 39 | 40 | api_instance = leetcode.DefaultApi(leetcode.ApiClient(configuration)) 41 | 42 | return api_instance 43 | 44 | def get_question(url): 45 | """ 46 | Get the question page 47 | """ 48 | while True: 49 | res = requests.get(url) # type: ignore 50 | status = res.status_code 51 | if status == 200: 52 | return res 53 | elif status == 404: 54 | return None 55 | else: 56 | print(status) 57 | time.sleep(300) 58 | 59 | def title_slug(title): 60 | """ 61 | Format the title into a title slug 62 | """ 63 | return '-'.join(title.lower().split()) 64 | 65 | def slug_to_title(question_slug: str) -> str: 66 | """Format a Leetcode question's slug as a title""" 67 | return string.capwords(question_slug.replace("-", " ")).strip() 68 | 69 | def format_integer(n): 70 | """Format the integer to have a length of 4 by padding with zeroes.""" 71 | return str(n).zfill(4)[:4] 72 | 73 | def get_code_snippets(url): 74 | """ 75 | Gets the code snippets for the given question url 76 | """ 77 | res = get_question(url) 78 | if res is None: 79 | return None 80 | soup = BeautifulSoup(res.content, "html.parser") 81 | script_tag = soup.find('script', {'type': 'application/json'}) 82 | data = dict(json.loads(script_tag.string)) 83 | queries = data['props']['pageProps']['dehydratedState']['queries'] 84 | query = [i for i in queries if 'question' in i['state']['data'] and 'codeSnippets' in i['state']['data']['question']][0] 85 | code_snippets = query["state"]["data"]["question"]["codeSnippets"] 86 | return code_snippets 87 | 88 | url = "https://leetcode.com/graphql/" 89 | 90 | payload = lambda slug: json.dumps({ 91 | "query": "\n query consolePanelConfig($titleSlug: String!) {\n question(titleSlug: $titleSlug) {\n exampleTestcaseList\n }\n}\n ", 92 | "variables": { 93 | "titleSlug": slug 94 | }, 95 | "operationName": "consolePanelConfig" 96 | }) 97 | 98 | headers = { 99 | 'authority': 'leetcode.com', 100 | 'accept': '*/*', 101 | 'accept-language': 'en-US,en;q=0.9', 102 | 'authorization': '', 103 | 'baggage': 'sentry-environment=production,sentry-release=8f466f72,sentry-transaction=%2Fproblems%2F%5Bslug%5D%2F%5B%5B...tab%5D%5D,sentry-public_key=2a051f9838e2450fbdd5a77eb62cc83c,sentry-trace_id=897972800d1c46e5a5d499f12244a91b,sentry-sample_rate=0.004', 104 | 'content-type': 'application/json', 105 | 'cookie': 'gr_user_id=35b498db-f28f-485f-8b44-417f8fba15ed; __stripe_mid=04d7a882-553c-499c-8866-bcf56aac8ef6ed918f; __atuvc=1%7C5; NEW_PROBLEMLIST_PAGE=1; csrftoken=9BiGVDJiJS7iFJKVYZ1CNMNulRAvYUdlezUlp1oYOrsR2zVsk9mZh1MD6C2d6twV; messages="9b526d67f2587ca52e83b4431db91f6bd6abdac1$[[\\"__json_message\\"\\0540\\05425\\054\\"You have signed out.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"Successfully signed in as beckles168.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"You have signed out.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"Successfully signed in as leetcodeexecutor.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"You have signed out.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"Successfully signed in as beckles168.\\"]]"; 87b5a3c3f1a55520_gr_last_sent_cs1=beckles168; _gid=GA1.2.2067840721.1681477917; LEETCODE_SESSION=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJfYXV0aF91c2VyX2lkIjoiOTIwNjcxMyIsIl9hdXRoX3VzZXJfYmFja2VuZCI6ImFsbGF1dGguYWNjb3VudC5hdXRoX2JhY2tlbmRzLkF1dGhlbnRpY2F0aW9uQmFja2VuZCIsIl9hdXRoX3VzZXJfaGFzaCI6IjU2MGIwZGIzMjVjOTcwNTk3OGFkZDI4MjY0MzM5NjU0NzVjZDhmMjYiLCJpZCI6OTIwNjcxMywiZW1haWwiOiJiZWNrbGVzMTY4QGdtYWlsLmNvbSIsInVzZXJuYW1lIjoiYmVja2xlczE2OCIsInVzZXJfc2x1ZyI6ImJlY2tsZXMxNjgiLCJhdmF0YXIiOiJodHRwczovL2Fzc2V0cy5sZWV0Y29kZS5jb20vdXNlcnMvYXZhdGFycy9hdmF0YXJfMTY4MDY1MjE2OC5wbmciLCJyZWZyZXNoZWRfYXQiOjE2ODE2NjIzMDAsImlwIjoiNzIuMTk1LjEzNC4zMSIsImlkZW50aXR5IjoiNzIzYzUxMjYzYzgwZjZiZTc5ZmEyMTE5MWVlMGIzODciLCJzZXNzaW9uX2lkIjozNzg5MzIzNH0.BJV_u27JVniHZ73kI76oTTkFGK4OHNJPpv-F58pZBUc; 87b5a3c3f1a55520_gr_session_id=73357b2f-2c35-49f4-8256-556aa503d604; 87b5a3c3f1a55520_gr_last_sent_sid_with_cs1=73357b2f-2c35-49f4-8256-556aa503d604; 87b5a3c3f1a55520_gr_session_id_73357b2f-2c35-49f4-8256-556aa503d604=true; _gat=1; 87b5a3c3f1a55520_gr_cs1=beckles168; _ga=GA1.1.1043183799.1675086637; __stripe_sid=d8eb8303-f932-4cfd-92ef-9ed80b781cae827bea; _ga_CDRWKZTDEX=GS1.1.1681662302.39.1.1681665678.0.0.0; _dd_s=rum=0&expire=1681666578197; LEETCODE_SESSION=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJfYXV0aF91c2VyX2lkIjoiOTIwNjcxMyIsIl9hdXRoX3VzZXJfYmFja2VuZCI6ImFsbGF1dGguYWNjb3VudC5hdXRoX2JhY2tlbmRzLkF1dGhlbnRpY2F0aW9uQmFja2VuZCIsIl9hdXRoX3VzZXJfaGFzaCI6IjU2MGIwZGIzMjVjOTcwNTk3OGFkZDI4MjY0MzM5NjU0NzVjZDhmMjYiLCJpZCI6OTIwNjcxMywiZW1haWwiOiJiZWNrbGVzMTY4QGdtYWlsLmNvbSIsInVzZXJuYW1lIjoiYmVja2xlczE2OCIsInVzZXJfc2x1ZyI6ImJlY2tsZXMxNjgiLCJhdmF0YXIiOiJodHRwczovL2Fzc2V0cy5sZWV0Y29kZS5jb20vdXNlcnMvYXZhdGFycy9hdmF0YXJfMTY4MDY1MjE2OC5wbmciLCJyZWZyZXNoZWRfYXQiOjE2ODE2NjIzMDAsImlwIjoiNTQuODYuNTAuMTM5IiwiaWRlbnRpdHkiOiI3MjNjNTEyNjNjODBmNmJlNzlmYTIxMTkxZWUwYjM4NyIsInNlc3Npb25faWQiOjM3ODkzMjM0fQ.DtQ8KCL7Qsua4Bp-vOMJfg4VJUjX4NSxhdNXs756x4M; csrftoken=9BiGVDJiJS7iFJKVYZ1CNMNulRAvYUdlezUlp1oYOrsR2zVsk9mZh1MD6C2d6twV', 106 | 'origin': 'https://leetcode.com', 107 | 'random-uuid': '4922dfe3-8c3c-1d65-9b7a-ca84bfe9f756', 108 | 'referer': 'https://leetcode.com/problems/two-sum/', 109 | 'sec-ch-ua': '"Chromium";v="112", "Google Chrome";v="112", "Not:A-Brand";v="99"', 110 | 'sec-ch-ua-mobile': '?0', 111 | 'sec-ch-ua-platform': '"macOS"', 112 | 'sec-fetch-dest': 'empty', 113 | 'sec-fetch-mode': 'cors', 114 | 'sec-fetch-site': 'same-origin', 115 | 'sentry-trace': '897972800d1c46e5a5d499f12244a91b-a37933a4a1d212e3-0', 116 | 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36', 117 | 'x-csrftoken': '9BiGVDJiJS7iFJKVYZ1CNMNulRAvYUdlezUlp1oYOrsR2zVsk9mZh1MD6C2d6twV' 118 | } 119 | 120 | def test_cases_from_slug(slug: str) -> List[str]: 121 | response = requests.post(url, headers=headers, data=payload(slug)) 122 | return dict(response.json())['data']['question']['exampleTestcaseList'] 123 | 124 | -------------------------------------------------------------------------------- /leetcode_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GammaTauAI/leetcode-hard-gym/b21c1b2227acafe5cc92684c55920cccf8e17699/leetcode_env/__init__.py -------------------------------------------------------------------------------- /leetcode_env/environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from datetime import datetime 4 | 5 | import dotenv 6 | import gym 7 | import leetcode 8 | import leetcode.auth 9 | 10 | from .types import LeetCodeSubmission 11 | from .utils.leetcode import id_from_slug 12 | 13 | dotenv.load_dotenv() 14 | 15 | 16 | class LeetCodeEnv(gym.Env): 17 | """ 18 | Gym environment for LeetCode submissions 19 | """ 20 | 21 | metadata = {"render.modes": ["human"]} 22 | 23 | def __init__(self, cooldown=0): 24 | super(LeetCodeEnv, self).__init__() 25 | self.__configure_leetcode() 26 | self.reward = False 27 | self.last_run = None 28 | self.cooldown = cooldown # To avoid rate limit 29 | 30 | def __configure_leetcode(self): 31 | configuration = leetcode.Configuration() 32 | 33 | # From Dev Tools/Application/Cookies/LEETCODE_SESSION 34 | leetcode_session = os.environ["LEETCODE_SESSION"] 35 | if "LEETCODE_CSRF_TOKEN" in os.environ: 36 | csrf_token = os.environ["LEETCODE_CSRF_TOKEN"] 37 | else: 38 | try: 39 | csrf_token = leetcode.auth.get_csrf_cookie(leetcode_session) 40 | except KeyError as e: 41 | raise KeyError( 42 | "Could not find CSRF token in cookies. Set the token manually in an LEETCODE_CSRF_TOKEN environment variable." 43 | ) from e 44 | 45 | configuration.api_key["x-csrftoken"] = csrf_token 46 | configuration.api_key["csrftoken"] = csrf_token 47 | configuration.api_key["LEETCODE_SESSION"] = leetcode_session 48 | configuration.api_key["Referer"] = "https://leetcode.com" 49 | configuration.debug = False 50 | 51 | self.api_instance = leetcode.DefaultApi(leetcode.ApiClient(configuration)) 52 | 53 | def step(self, action: LeetCodeSubmission): 54 | """ 55 | Sends a submission to LeetCode and returns the result 56 | 57 | Args: 58 | action (LeetCodeSubmission): LeetCodeSubmission object 59 | 60 | Returns: 61 | status (str): 'Accepted' | 'Runtime Error'| 'Wrong Answer' | 'Submission Timed-Out' | 'Unknown' 62 | reward (bool): True if status is 'Accepted', False otherwise 63 | done (bool): True if status is 'Accepted', False otherwise 64 | submission_result (dict): LeetCode API response 65 | """ 66 | submission_result = self.__send_submission(action) 67 | 68 | reward, status = self.__calculate_reward(submission_result) 69 | 70 | self.reward = reward 71 | 72 | done = self.is_done() 73 | 74 | return status, reward, done, submission_result 75 | 76 | def reset(self): 77 | self.reward = False 78 | 79 | def __send_submission(self, sub: LeetCodeSubmission): 80 | self.__wait_for_cooldown() 81 | 82 | if sub.question_id is None: 83 | sub.question_id = id_from_slug(sub.question_slug, self.api_instance) 84 | 85 | submission = leetcode.Submission( 86 | judge_type="large", 87 | typed_code=sub.code, 88 | question_id=sub.question_id, 89 | test_mode=False, 90 | lang=sub.lang.value, 91 | ) 92 | 93 | submission_id = self.api_instance.problems_problem_submit_post( 94 | problem=sub.question_slug, body=submission 95 | ) 96 | 97 | time.sleep(sub.timeout) 98 | 99 | submission_result = self.api_instance.submissions_detail_id_check_get( 100 | id=submission_id.submission_id 101 | ) 102 | 103 | return submission_result 104 | 105 | def __calculate_reward(self, submission_result): 106 | if submission_result == {"state": "STARTED"}: 107 | status_msg = "Submission Timed-Out" 108 | 109 | elif ( 110 | "status" in submission_result.keys() 111 | and submission_result["status"] == "PENDING" 112 | ): 113 | status_msg = "Submission Timed-Out" 114 | 115 | elif "status_msg" in submission_result.keys(): 116 | status_msg = submission_result[ 117 | "status_msg" 118 | ] # 'Accepted' | 'Runtime Error'| 'Wrong Answer' 119 | 120 | else: 121 | status_msg = "Unknown" 122 | 123 | return status_msg == "Accepted", status_msg 124 | 125 | def __wait_for_cooldown(self): 126 | if self.last_run == None: 127 | self.last_run = datetime.now() 128 | else: 129 | while (datetime.now() - self.last_run).total_seconds() < self.cooldown: 130 | time.sleep(0.1) 131 | self.last_run = datetime.now() 132 | 133 | def is_done(self): 134 | return self.reward 135 | -------------------------------------------------------------------------------- /leetcode_env/types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pydantic import BaseModel 3 | from enum import Enum 4 | 5 | 6 | class ProgrammingLanguage(Enum): 7 | """ 8 | Enum for valid LeetCodeSubmission programming languages 9 | """ 10 | 11 | CPP = "c++" 12 | JAVA = "java" 13 | PYTHON = "python" 14 | PYTHON3 = "python3" 15 | C = "c" 16 | C_SHARP = "c#" 17 | JAVASCRIPT = "javascript" 18 | RUBY = "ruby" 19 | SWIFT = "swift" 20 | GO = "go" 21 | SCALA = "scala" 22 | KOTLIN = "kotlin" 23 | RUST = "rust" 24 | PHP = "php" 25 | TYPESCRIPT = "typescript" 26 | RACKET = "racket" 27 | ERLANG = "erlang" 28 | ELIXIR = "elixir" 29 | DART = "dart" 30 | MYSQL = "mysql" 31 | MS_SQL_SERVER = "ms sql server" 32 | ORACLE = "oracle" 33 | 34 | 35 | class LeetCodeSubmission(BaseModel): 36 | """ 37 | Model for a Leetcode Code Submission 38 | """ 39 | 40 | code: str 41 | lang: ProgrammingLanguage 42 | question_id: str 43 | question_slug: str 44 | question_id: Optional[str] = None 45 | timeout: int = 5 46 | -------------------------------------------------------------------------------- /leetcode_env/utils/formatting.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import re 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | import astunparse 7 | 8 | 9 | class SubmissionFormatter(ABC): 10 | """ 11 | Class that converts between HumanEval and Leetcode submission formats. 12 | """ 13 | 14 | @staticmethod 15 | @abstractmethod 16 | def to_leetcode(humaneval_snippet: str): 17 | """ 18 | Convert the string to leetcode format 19 | """ 20 | 21 | @staticmethod 22 | @abstractmethod 23 | def to_humaneval(leetcode_snippet: str): 24 | """ 25 | Convert the string to humaneval format 26 | """ 27 | 28 | @staticmethod 29 | @abstractmethod 30 | def add_docstring(snippet: str, description: str): 31 | """ 32 | Add a docstring to the snippet 33 | """ 34 | 35 | @staticmethod 36 | @abstractmethod 37 | def extract_signature(source: str) -> str: 38 | """ 39 | Extract the signature from the function 40 | """ 41 | 42 | 43 | 44 | class PythonSubmissionFormatter: 45 | @staticmethod 46 | def add_docstring(snippet: str, description: str): 47 | snippet = snippet.strip("\n") 48 | # Add 4 spaces to the beginning of every line 49 | description = "\n".join([" " * 4 + line for line in description.splitlines()]) 50 | docstring = f''' """ 51 | {description} 52 | """''' 53 | return f"{snippet}\n{docstring}\n" 54 | 55 | @staticmethod 56 | def to_humaneval(leetcode_snippet: str) -> str: 57 | try: 58 | tree = ast.parse(leetcode_snippet) 59 | except IndentationError: 60 | class_source = leetcode_snippet.strip() + "\n pass" 61 | tree = ast.parse(class_source) 62 | func_node = tree.body[0].body[0] 63 | func_node.args.args.pop(0) # Remove 'self' argument 64 | 65 | if isinstance(func_node.body[-1], ast.Pass): 66 | func_node.body.pop() 67 | 68 | new_tree = ast.Module(body=[func_node], type_ignores=[]) 69 | return f"{astunparse.unparse(new_tree).strip()}\n" 70 | 71 | @staticmethod 72 | def to_leetcode(humaneval_snippet: str, class_name: str = "Solution") -> str: 73 | # Get imports 74 | imports = "\n".join( 75 | PythonSubmissionFormatter.extract_imports(humaneval_snippet) 76 | ) 77 | # Remove imports 78 | # humaneval_snippet = re.sub(r"^from\s+\S+\s+import.*|^import.*", "", humaneval_snippet, flags=re.MULTILINE) 79 | try: 80 | tree = ast.parse(humaneval_snippet) 81 | except IndentationError: 82 | function_source = humaneval_snippet.strip() + "\n pass" 83 | tree = ast.parse(function_source) 84 | 85 | func_node = None 86 | for child in ast.iter_child_nodes(tree): 87 | if isinstance(child, ast.FunctionDef): 88 | func_node = child 89 | break 90 | 91 | docstring = ast.get_docstring(func_node) 92 | if docstring is not None: 93 | func_node.body.pop(0) 94 | 95 | if func_node.body and isinstance(func_node.body[-1], ast.Pass): 96 | func_node.body.pop() 97 | 98 | # Add 'self' argument back to the function 99 | self_arg = ast.arg(arg="self", annotation=None) 100 | func_node.args.args.insert(0, self_arg) 101 | class_node = ast.ClassDef( 102 | name=class_name, 103 | bases=[], 104 | keywords=[], 105 | body=[func_node], 106 | decorator_list=[], 107 | ) 108 | new_tree = ast.Module(body=[class_node], type_ignores=[]) 109 | return f"{imports}\n{astunparse.unparse(new_tree).strip()}\n" 110 | 111 | @staticmethod 112 | def extract_imports(source: str) -> List[str]: 113 | """ 114 | Extract top level imports 115 | """ 116 | standard_import = re.compile(r"^import (\w+(?:, \w+)*)") 117 | from_import = re.compile(r"^from (\w+) import (\w+(?:, \w+)*)") 118 | 119 | imports = [] 120 | 121 | for line in source.splitlines(): 122 | std_match = standard_import.match(line) 123 | from_match = from_import.match(line) 124 | 125 | if std_match: 126 | imports.append(std_match.group(0)) 127 | 128 | if from_match: 129 | imports.append(from_match.group(0)) 130 | 131 | return imports 132 | 133 | @staticmethod 134 | def extract_signature(source: str) -> str: 135 | return source.replace('def ', '', 1)[:-1] 136 | 137 | 138 | class RustSubmissionFormatter: 139 | @staticmethod 140 | def add_docstring(snippet: str, description: str): 141 | # Formatting the docstring in Rust style using /* */ 142 | rust_docstring = f"/*\n{description}\n*/" 143 | 144 | # Combining the docstring and the signature 145 | result = f"{rust_docstring}\n{snippet}" 146 | return result 147 | 148 | @staticmethod 149 | def extract_imports(source: str) -> List[str]: 150 | rust_import = re.compile(r"^use ([\w::]+(?:\s+as\s+\w+)?)(?:;\s*)?$") 151 | 152 | imports = [] 153 | 154 | for line in source.splitlines(): 155 | rust_match = rust_import.match(line) 156 | 157 | if rust_match: 158 | imports.append(rust_match.group(0).strip()) 159 | 160 | return imports 161 | 162 | @staticmethod 163 | def remove_imports(source: str) -> str: 164 | rust_import = re.compile(r"^use ([\w::]+(?:\s+as\s+\w+)?)(?:;\s*)?$") 165 | 166 | lines = source.splitlines() 167 | new_lines = [] 168 | for line in lines: 169 | if rust_import.match(line): 170 | print(f"Removing import: {line}") 171 | else: 172 | new_lines.append(line) 173 | 174 | return "\n".join(new_lines) 175 | 176 | @staticmethod 177 | def to_humaneval(leetcode_snippet: str) -> str: 178 | # Remove comments 179 | function_source = re.sub(r"//.*", "", leetcode_snippet) 180 | # Using the re.DOTALL flag to match across multiple lines 181 | function_source = re.sub(r"/\*.*?\*/", "", function_source, flags=re.DOTALL) 182 | 183 | # Remove solution class def 184 | function_source = re.sub(r"impl Solution \{\n", "", function_source) 185 | reversed_source = function_source[::-1] 186 | reversed_substituted = re.sub(r"\}", "", reversed_source, count=1) 187 | function_source = reversed_substituted[::-1] 188 | 189 | # Remove pub from function 190 | function_source = re.sub(r"pub ", "", function_source) 191 | 192 | # Unindent function 193 | whitespace = leading_whitespace_count(function_source) 194 | function_source = "\n".join( 195 | [line[whitespace:] for line in function_source.splitlines()] 196 | ) 197 | function_source = function_source.strip() 198 | 199 | # Remove whitespace from every line in the function 200 | return f"{function_source}\n" 201 | 202 | @staticmethod 203 | def to_leetcode(humaneval_snippet: str, struct_name: str = "Solution") -> str: 204 | imports = "\n".join(RustSubmissionFormatter.extract_imports(humaneval_snippet)) 205 | function_source = RustSubmissionFormatter.remove_imports(humaneval_snippet) 206 | 207 | function_source = re.sub(r"//.*", "", function_source) # Remove comments 208 | function_source = re.sub(r"/\*.*?\*/", "", function_source, flags=re.DOTALL) 209 | function_source = function_source.strip() 210 | function_source = re.sub( 211 | r"fn ", "pub fn ", function_source, count=1 212 | ) # Add pub to root function 213 | return f"{imports}\nimpl {struct_name} {{\n{function_source}\n}}\n" # Add impl struct_name { } around function 214 | 215 | @staticmethod 216 | def extract_signature(source: str) -> str: 217 | return source.strip('fn ').replace('{', '').replace('}', '').strip().strip('\n') 218 | 219 | 220 | def leading_whitespace_count(s): 221 | # Split the string into lines and get the first line 222 | first_line = [l for l in s.splitlines() if l][0] if s else "" 223 | 224 | # Find the index of the first non-whitespace character 225 | non_whitespace_index = next( 226 | (i for i, char in enumerate(first_line) if not char.isspace()), None 227 | ) 228 | 229 | # If the entire line consists of whitespaces (or is empty), then return its length 230 | if non_whitespace_index is None: 231 | return len(first_line) 232 | 233 | return non_whitespace_index -------------------------------------------------------------------------------- /leetcode_env/utils/leetcode.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import leetcode 3 | 4 | def id_from_slug(slug: str, api_instance) -> str: 5 | """ 6 | Retrieves the id of the question with the given slug 7 | """ 8 | graphql_request = leetcode.GraphqlQuery( 9 | query=""" 10 | query getQuestionDetail($titleSlug: String!) { 11 | question(titleSlug: $titleSlug) { 12 | questionId 13 | } 14 | } 15 | """, 16 | variables={"titleSlug": slug}, 17 | operation_name="getQuestionDetail", 18 | ) 19 | response = ast.literal_eval(str(api_instance.graphql_post(body=graphql_request))) 20 | frontend_id = response['data']['question']['question_id'] 21 | return frontend_id 22 | 23 | def metadata_from_slug(slug: str, api_instance) -> str: 24 | """ 25 | Retrieves the metadata of the question with the given slug 26 | """ 27 | graphql_request = leetcode.GraphqlQuery( 28 | query=""" 29 | query getQuestionDetail($titleSlug: String!) { 30 | question(titleSlug: $titleSlug) { 31 | metaData 32 | } 33 | } 34 | """, 35 | variables={"titleSlug": slug}, 36 | operation_name="getQuestionDetail", 37 | ) 38 | response = ast.literal_eval(str(api_instance.graphql_post(body=graphql_request))) 39 | metadata = response['data']['question'] 40 | return metadata 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "leetcode_env" 3 | version = "0.2" 4 | description = "Leetcode evaluation environment for code generation agents" 5 | authors = ["Beck LaBash "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/GammaTauAI/leetcode-hard-gym" 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.9" 12 | beautifulsoup4 = "4.12.0" 13 | gym = "0.26.2" 14 | pydantic = "1.10.7" 15 | python-dotenv = "1.0.0" 16 | requests = "2.28.2" 17 | python-leetcode = "1.2.1" 18 | astunparse = "1.6.3" 19 | pandas = "1.5.3" 20 | pytest = "7.4.0" 21 | ipykernel = "6.25.1" 22 | html2text = "2020.1.16" 23 | langchain = { version = "0.0.268", optional = true } 24 | openai = { version = "0.27.8", optional = true } 25 | termcolor = { version = "2.3.0", optional = true } 26 | 27 | [tool.poetry.extras] 28 | llms = ["langchain", "openai", "termcolor"] 29 | 30 | [build-system] 31 | requires = ["poetry-core>=1.0.8"] 32 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.12.0 2 | gym==0.26.2 3 | pydantic==1.10.7 4 | python-dotenv==1.0.0 5 | requests==2.28.2 6 | python-leetcode==1.2.1 7 | astunparse==1.6.3 8 | pandas==1.5.3 9 | pytest==7.4.0 10 | 11 | # LLM 12 | langchain==0.0.268 13 | termcolor==2.3.0 14 | openai==0.27.8 15 | -------------------------------------------------------------------------------- /tests/utils/test_python_formatter.py: -------------------------------------------------------------------------------- 1 | from leetcode_env.utils.formatting import PythonSubmissionFormatter 2 | import pytest 3 | 4 | to_humaneval_samples = [ 5 | (""" 6 | class Solution: 7 | def some_function(self, x: int, y: int) -> int: 8 | """, 9 | """ 10 | def some_function(x: int, y: int) -> int: 11 | """.strip()), 12 | 13 | (""" 14 | class Solution: 15 | def some_function(self, x: int, y: int) -> int: 16 | x = 1 17 | return (x + y) 18 | """, 19 | """ 20 | def some_function(x: int, y: int) -> int: 21 | x = 1 22 | return (x + y) 23 | """.strip()), 24 | ] 25 | 26 | to_leetcode_samples = [ 27 | (""" 28 | def some_function(x: int, y: int) -> int: 29 | """.strip(), 30 | """ 31 | class Solution(): 32 | 33 | def some_function(self, x: int, y: int) -> int: 34 | """.strip()), 35 | 36 | (""" 37 | def some_function(x: int, y: int) -> int: 38 | x = 1 39 | return (x + y) 40 | """, 41 | """ 42 | class Solution(): 43 | 44 | def some_function(self, x: int, y: int) -> int: 45 | x = 1 46 | return (x + y) 47 | """.strip()), 48 | 49 | (""" 50 | from collections import Counter 51 | def some_function(x: int, y: int) -> int: 52 | import string 53 | x = 1 54 | return (x + y) 55 | """, 56 | """ 57 | from collections import Counter 58 | class Solution(): 59 | 60 | def some_function(self, x: int, y: int) -> int: 61 | import string 62 | x = 1 63 | return (x + y) 64 | """.strip()), 65 | ] 66 | 67 | @pytest.mark.parametrize("leetcode_snippet, expected", to_humaneval_samples) 68 | def test_to_humaneval(leetcode_snippet, expected): 69 | result = PythonSubmissionFormatter.to_humaneval(leetcode_snippet).strip() 70 | assert result == expected 71 | 72 | # try: 73 | # assert expected == result 74 | # except AssertionError as e: 75 | # print(f"Expected:\n{expected}") 76 | # print('-'*20) 77 | # print(f"Result:\n{result}") 78 | # raise 79 | 80 | 81 | @pytest.mark.parametrize("humaneval_snippet, expected", to_leetcode_samples) 82 | def test_to_leetcode(humaneval_snippet, expected): 83 | result = PythonSubmissionFormatter.to_leetcode(humaneval_snippet).strip() 84 | assert result == expected 85 | -------------------------------------------------------------------------------- /tests/utils/test_rust_formatter.py: -------------------------------------------------------------------------------- 1 | from leetcode_env.utils.formatting import RustSubmissionFormatter 2 | import pytest 3 | 4 | to_humaneval_samples = [ 5 | (""" 6 | impl Solution { 7 | pub fn some_function(x: i32, y: i32) -> i32 { 8 | } 9 | } 10 | """, 11 | """ 12 | fn some_function(x: i32, y: i32) -> i32 { 13 | } 14 | """.strip()), 15 | 16 | (""" 17 | impl Solution { 18 | pub fn some_function(x: i32, y: i32) -> i32 { 19 | x = 1; 20 | return x + y; 21 | } 22 | } 23 | """, 24 | """ 25 | fn some_function(x: i32, y: i32) -> i32 { 26 | x = 1; 27 | return x + y; 28 | } 29 | """.strip()), 30 | ] 31 | 32 | to_leetcode_samples = [ 33 | (""" 34 | fn some_function(x: i32, y: i32) -> i32 { 35 | } 36 | """.strip(), 37 | """ 38 | impl Solution { 39 | pub fn some_function(x: i32, y: i32) -> i32 { 40 | } 41 | } 42 | """.strip()), 43 | 44 | (""" 45 | fn some_function(x: i32, y: i32) -> i32 { 46 | x = 1; 47 | return x + y; 48 | } 49 | """, 50 | """ 51 | impl Solution { 52 | pub fn some_function(x: i32, y: i32) -> i32 { 53 | x = 1; 54 | return x + y; 55 | } 56 | } 57 | """.strip()), 58 | 59 | (""" 60 | use std::collections::HashMap; 61 | fn some_function(x: i32, y: i32) -> i32 { 62 | let z = x + y; 63 | return z; 64 | } 65 | """, 66 | 67 | """ 68 | use std::collections::HashMap; 69 | impl Solution { 70 | pub fn some_function(x: i32, y: i32) -> i32 { 71 | let z = x + y; 72 | return z; 73 | } 74 | } 75 | """.strip()), 76 | ] 77 | 78 | @pytest.mark.parametrize("leetcode_snippet, expected", to_humaneval_samples) 79 | def test_to_humaneval(leetcode_snippet, expected): 80 | result = RustSubmissionFormatter.to_humaneval(leetcode_snippet).strip() 81 | print(f'Input:\n {leetcode_snippet}') 82 | print('-'*20) 83 | print(f"Expected:\n {expected}") 84 | print('-'*20) 85 | print(f"Result:\n {result}") 86 | assert result == expected 87 | 88 | @pytest.mark.parametrize("humaneval_snippet, expected", to_leetcode_samples) 89 | def test_to_leetcode(humaneval_snippet, expected): 90 | result = RustSubmissionFormatter.to_leetcode(humaneval_snippet).strip() 91 | print(f'Input:\n {humaneval_snippet}') 92 | print('-'*20) 93 | print(f"Expected:\n {expected}") 94 | print('-'*20) 95 | print(f"Result:\n {result}") 96 | assert result == expected 97 | --------------------------------------------------------------------------------