├── .gitignore ├── README.md ├── data └── rawcode_1k.jsonl ├── environment.yaml ├── figures ├── ioexample.png └── overview.png ├── processed_data ├── codeio_1k_gens.jsonl ├── codeio_1k_gens_rev.jsonl ├── codeio_1k_gens_rev_verified.jsonl ├── codeio_1k_gens_verified.jsonl ├── codeio_1k_msg.jsonl ├── codeio_1k_msg_rev.jsonl ├── codeio_demo_final.jsonl ├── rawcode_1k.jsonl ├── rawcode_1k_msg.jsonl ├── rawcode_1k_parsed.jsonl └── rawcode_1k_unified.jsonl ├── requirements.txt ├── scripts └── pipeline_check.sh └── src ├── assemble_codeio_demo.py ├── batched_api_inference.py ├── build_codeio_msg.py ├── build_codeio_rev_msg.py ├── build_transform_msg.py ├── check_io_pred_acc_mp.py ├── check_io_pred_acc_mp_inplace.py ├── codeio_utils.py ├── parse_gen_ios.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | temp/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # CodeI/O: Condensing Reasoning Patterns via Code Input-Output Prediction 4 | 5 |

6 | 📑 Paper    |    🌐 Project Page    |    🤗 Released Resources    |    💾 Dataset    |    📦 Repo 7 |
8 | 9 |

10 | 11 |

12 | 13 | ## Table of contents 14 | 15 | - [Introduction](#Introduction) 16 | - [Released Resources](#Released-Resources) 17 | - [Dataset](#Dataset) 18 | - [Models](#Models) 19 | - [Get Started](#Get-Started) 20 | - [Setup](#Setup) 21 | - [Data Processing](#Data-Processing) 22 | - [Training](#Training) 23 | - [Citation](#Citation) 24 | - [Acknowledgement](#Acknowledgement) 25 | 26 | ## Introduction 27 | CodeI/O is a novel approach that transforms code-based reasoning patterns into natural language formats to enhance Large Language Models' reasoning capabilities. Unlike traditional methods focusing on specific skills, our approach systematically extracts universal reasoning primitives while maintaining procedural rigor, enabling better performance across various reasoning tasks. 28 | 29 | **Key Features & Contributions** 30 | - 🔄 Universal Transformation: Converts diverse code patterns into natural language Chain-of-Thought rationales 31 | - 🧠 Syntax-Decoupled: Decouples reasoning from code syntax while preserving logical structure 32 | - 📊 Multi-Task Enhancement: Improves performance across symbolic, scientific, logic, mathematical, commonsense and code reasoning 33 | - ✨ Fully-Verifiable: Supports precise prediction verification through cached ground-truth matching or code re-execution 34 | - 🚀 Advanced Iteration: Enhanced version (CodeI/O++) with multi-turn revision for better accuracy 35 | 36 | ## Released Resources 37 | 38 | #### Dataset 39 | 40 | |Dataset|Link| 41 | |-|-| 42 | |CodeI/O-PythonEdu-Reasoning|[🤗](https://huggingface.co/datasets/hkust-nlp/CodeIO-Pyedu-Reasoning)| 43 | |CodeI/O-PythonEdu-Raw|[🤗](https://huggingface.co/datasets/hkust-nlp/CodeIO-PyEdu-Reasoning-Raw)| 44 | |LeetCode-O Benchmark|[🤗](https://huggingface.co/datasets/hkust-nlp/LeetCode-O)| 45 | 46 | Due to our collaborators' compliance requirements, we only release the PythonEdu-Reasoning subset of the CodeI/O(++) dataset. 47 | 48 | 49 | 50 | #### Models 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 |
Base Model / TrainingCodeI/OCodeI/O++
Stage 1Stage 2Stage 1Stage 2
Qwen 2.5 7B Coder🤗🤗🤗🤗
LLaMA 3.1 8B🤗🤗🤗🤗
DeepSeek v2 Lite Coder🤗🤗🤗🤗
86 | 87 | 88 | ## Get Started 89 | 90 | ### Setup 91 | 92 | We provide both the `requirements.txt` and `environment.yaml`. You can choose either way to setup the environment. 93 | ``` 94 | conda create -n codeio_exec python 3.11 95 | conda activate codeio_exec 96 | pip install -r requirements.txt 97 | ``` 98 | or 99 | ``` 100 | conda env create -f environment.yaml --name codeio_exec 101 | conda activate codeio_exec 102 | ``` 103 | Please note that our setup does not guarantee the execution of all types of Python code; you may need to update the environment to meet your personal requirements when processing different code files. 104 | 105 | ### Data Processing 106 | 107 | We provide a complete guide for you to build data for CodeI/O on a toy dataset. After all these steps you can get a dataset with the same format as in our [huggingface dataset](https://huggingface.co/datasets/hkust-nlp/CodeIO-Pyedu-Reasoning). 108 | 109 | All intermediate results will be stored under `./data`, but we have also provided a set of pre-processed files under `./processed_data`. 110 | 111 | #### Step 1: Convert raw code files into the unified format. 112 | 113 | ##### Step 1.1: Build Messages 114 | ``` 115 | python ./src/build_transform_msg.py \ 116 | --raw_code_file data/rawcode_1k.jsonl \ 117 | --raw_code_msg_file data/rawcode_1k_msg.jsonl 118 | ``` 119 | ##### Step 1.2: Inference 120 | ``` 121 | python ./src/batched_api_inference.py \ 122 | --input data/rawcode_1k_msg.jsonl \ 123 | --output data/rawcode_1k_unified.jsonl \ 124 | --model deepseek-chat \ 125 | --num_process 10 \ 126 | --num_thread 10 \ 127 | --key \ 128 | --temperature 0.7 \ 129 | --max_tokens 4096 130 | ``` 131 | You can also use GPT series models to do this transformation step, since recently the DeepSeek API is under heavy pressure. For example, set `--model` as `gpt-4o-mini-2024-07-18​` and change `--key` accordingly. 132 | You may find some the requests failed, it's OK and we just skip them. 133 | 134 | *Note that we only provide the code to inference with OpenAI-style APIs. However, it is also 100\% feasible to deploy other open-source models and inference locally via frameworks like [vllm](https://github.com/vllm-project/vllm) or [sglang](https://github.com/sgl-project/sglang). Please refer to their official websites for more details. 135 | #### Step 2: Parse & Generate I/O Pairs 136 | ``` 137 | python ./src/parse_gen_ios.py \ 138 | --input_file data/rawcode_1k_unified.jsonl \ 139 | --output_file data/rawcode_1k_parsed.jsonl \ 140 | --python_path "python" \ 141 | --run_path "./temp/temp/temp" 142 | ``` 143 | The `--python_path` is the python path you will use to run the I/O pair generation code, which can be different from what you use in the main workflow, e.g., installed with some specific packages. The `--run_path` is the path where the I/O pair generation code will be executed, since sometimes it will store some temp files in the file systems, so we explicitly assign a place for it to save them. 144 | 145 | #### Step 3: Build Input-Output Prediction Instances 146 | We only pick 3 input prediction and 3 output prediction instances for each sample. 147 | ``` 148 | python ./src/build_codeio_msg.py \ 149 | --input_file data/rawcode_1k_parsed.jsonl \ 150 | --output_file data/codeio_1k_msg.jsonl 151 | ``` 152 | 153 | #### Step 4: Inference on CodeI/O data 154 | ``` 155 | python ./src/batched_api_inference.py \ 156 | --input data/codeio_1k_msg.jsonl \ 157 | --output data/codeio_1k_gens.jsonl \ 158 | --model deepseek-chat \ 159 | --num_process 10 \ 160 | --num_thread 10 \ 161 | --key \ 162 | --temperature 0.7 \ 163 | --max_tokens 4096 164 | ``` 165 | #### Step 5: Verification 166 | ``` 167 | bash ./scripts/pipeline_check.sh \ 168 | data/rawcode_1k_parsed.jsonl \ 169 | data/codeio_1k_gens.jsonl \ 170 | data/codeio_1k_gens_verified.jsonl \ 171 | python \ 172 | ./temp/temp/temp 173 | ``` 174 | In the bash script we run the verification for several times to try our best avoid the runtime effect brought by multi-processing execution (e.g. timeout). This is helpful for large scale verification. You can change the number of process to match your machine (e.g. more if you have a large number of CPUs and a large memory). 175 | 176 | #### Step 6: Second Turn - Revision and Re-verification 177 | ##### Step 6.1: Build Multi-turn Messages 178 | ``` 179 | python ./src/build_codeio_rev_msg.py \ 180 | --input_file data/codeio_1k_gens_verified.jsonl \ 181 | --output_file data/codeio_1k_msg_rev.jsonl 182 | ``` 183 | ##### Step 6.2: Re-generate 184 | ``` 185 | python ./src/batched_api_inference.py \ 186 | --input data/codeio_1k_msg_rev.jsonl \ 187 | --output data/codeio_1k_gens_rev.jsonl \ 188 | --model deepseek-chat \ 189 | --num_process 10 \ 190 | --num_thread 10 \ 191 | --key \ 192 | --temperature 0.7 \ 193 | --max_tokens 4096 194 | ``` 195 | ##### Step 6.3: Re-verification 196 | ``` 197 | bash ./scripts/pipeline_check.sh \ 198 | data/rawcode_1k_parsed.jsonl \ 199 | data/codeio_1k_gens_rev.jsonl \ 200 | data/codeio_1k_gens_rev_verified.jsonl \ 201 | python \ 202 | ./temp/temp/temp 203 | ``` 204 | ##### Step 6.4: Final Data 205 | ``` 206 | python ./src/assemble_codeio_demo.py \ 207 | --result_file_turn1 data/codeio_1k_gens_verified.jsonl \ 208 | --result_file_turn2 data/codeio_1k_gens_rev_verified.jsonl \ 209 | --output_file codeio_demo_final.jsonl 210 | ``` 211 | By doing so, you can get data `data/codeio_demo_final.jsonl` with the same format as in our [huggingface dataset](https://huggingface.co/datasets/hkust-nlp/CodeIO-Pyedu-Reasoning). 212 | 213 | ### Training 214 | You can use any popular training framework to train your model like [llama-factory](https://github.com/hiyouga/LLaMA-Factory). 215 | 216 | ## Citation 217 | If you find this work helpful, please kindly cite as: 218 | ``` 219 | @article{li2025codeio, 220 | title={CodeI/O: Condensing Reasoning Patterns via Code Input-Output Prediction}, 221 | author={Li, Junlong and Guo, Daya and Yang, Dejian and Xu, Runxin and Wu, Yu and He, Junxian}, 222 | journal={arXiv preprint arXiv:2502.07316}, 223 | year={2025} 224 | } 225 | ``` 226 | 227 | ## Acknowledgement 228 | We thank Fan Zhou, Wei Liu and Yiheng Xu for their valuable feedback and suggestions! 🤗🤗🤗 229 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | dependencies: 4 | - _libgcc_mutex=0.1=main 5 | - _openmp_mutex=5.1=1_gnu 6 | - bzip2=1.0.8=h5eee18b_6 7 | - ca-certificates=2024.7.2=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_1 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - libuuid=1.41.5=h5eee18b_0 14 | - ncurses=6.4=h6a678d5_0 15 | - openssl=3.0.15=h5eee18b_0 16 | - python=3.11.9=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - sqlite=3.45.3=h5eee18b_0 19 | - tk=8.6.14=h39e8969_0 20 | - xz=5.4.6=h5eee18b_1 21 | - zlib=1.2.13=h5eee18b_1 22 | - pip: 23 | - absl-py==2.1.0 24 | - accelerate==1.1.1 25 | - addict==2.4.0 26 | - aiohappyeyeballs==2.4.4 27 | - aiohttp==3.11.9 28 | - aiosignal==1.3.1 29 | - annotated-types==0.7.0 30 | - anyio==4.6.2.post1 31 | - appdirs==1.4.4 32 | - argparse==1.4.0 33 | - arviz==0.20.0 34 | - astropy==6.1.4 35 | - astropy-iers-data==0.2024.11.4.0.33.34 36 | - attrs==24.2.0 37 | - beautifulsoup4==4.12.3 38 | - bibtexparser==1.4.3 39 | - biopython==1.84 40 | - blis==0.7.11 41 | - cachetools==5.5.0 42 | - catalogue==1.0.2 43 | - certifi==2024.8.30 44 | - cffi==1.17.1 45 | - charset-normalizer==3.3.2 46 | - chess==1.11.1 47 | - cirq==1.4.1 48 | - cirq-aqt==1.4.1 49 | - cirq-core==1.4.1 50 | - cirq-google==1.4.1 51 | - cirq-ionq==1.4.1 52 | - cirq-pasqal==1.4.1 53 | - cirq-rigetti==1.4.1 54 | - cirq-web==1.4.1 55 | - clarabel==0.9.0 56 | - click==8.1.7 57 | - cloudpathlib==0.19.0 58 | - cloudpickle==3.1.0 59 | - colorama==0.4.6 60 | - compressed-tensors==0.8.0 61 | - contourpy==1.3.0 62 | - cpm-kernels==1.0.11 63 | - cpmpy==0.9.23 64 | - cramjam==2.8.3 65 | - cryptography==43.0.3 66 | - cvxpy==1.5.3 67 | - cycler==0.12.1 68 | - cymem==2.0.8 69 | - datasets==3.1.0 70 | - deprecated==1.2.14 71 | - dill==0.3.8 72 | - diskcache==5.6.3 73 | - distro==1.9.0 74 | - docker-pycreds==0.4.0 75 | - docopt==0.6.2 76 | - duet==0.2.9 77 | - ecos==2.0.14 78 | - einops==0.5.0 79 | - en-core-web-sm==2.3.1 80 | - et-xmlfile==2.0.0 81 | - evalplus==0.1.0.dev598 82 | - evaluate==0.4.3 83 | - fastapi==0.115.5 84 | - fastparquet==2024.5.0 85 | - fastprogress==1.0.3 86 | - filelock==3.16.1 87 | - fire==0.7.0 88 | - fonttools==4.54.1 89 | - frozendict==2.4.6 90 | - frozenlist==1.5.0 91 | - fsspec==2024.9.0 92 | - func-timeout==4.3.5 93 | - fuzzywuzzy==0.18.0 94 | - gguf==0.10.0 95 | - gitdb==4.0.11 96 | - gitpython==3.1.43 97 | - gmpy2==2.2.1 98 | - google-api-core==2.22.0 99 | - google-auth==2.35.0 100 | - googleapis-common-protos==1.65.0 101 | - greenlet==3.1.1 102 | - grilops==0.10.3 103 | - grpc-interceptor==0.15.4 104 | - grpcio==1.67.1 105 | - grpcio-status==1.62.3 106 | - h11==0.14.0 107 | - h5netcdf==1.4.0 108 | - h5py==3.12.1 109 | - html5lib==1.1 110 | - httpcore==1.0.6 111 | - httptools==0.6.4 112 | - httpx==0.27.2 113 | - huggingface-hub==0.24.7 114 | - idna==3.10 115 | - immutabledict==4.2.0 116 | - importlib-metadata==8.5.0 117 | - interegular==0.3.3 118 | - jieba==0.42.1 119 | - jinja2==3.1.4 120 | - jiter==0.8.0 121 | - joblib==1.4.2 122 | - json5==0.10.0 123 | - jsonschema==4.23.0 124 | - jsonschema-specifications==2024.10.1 125 | - kiwisolver==1.4.7 126 | - langcodes==3.4.0 127 | - language-data==1.2.0 128 | - lark==1.2.2 129 | - levenshtein==0.26.1 130 | - llvmlite==0.43.0 131 | - lm-format-enforcer==0.10.9 132 | - lxml==5.3.0 133 | - mando==0.7.1 134 | - marisa-trie==1.2.0 135 | - markdown==3.7 136 | - markdown-it-py==3.0.0 137 | - markupsafe==2.1.5 138 | - matplotlib==3.9.2 139 | - matplotlib-inline==0.1.7 140 | - mdurl==0.1.2 141 | - mistral-common==1.5.1 142 | - mistune==3.0.2 143 | - mmengine-lite==0.10.5 144 | - mosestokenizer==1.0.0 145 | - mpmath==1.3.0 146 | - msgpack==1.1.0 147 | - msgspec==0.18.6 148 | - multidict==6.1.0 149 | - multipledispatch==1.0.0 150 | - multiprocess==0.70.16 151 | - multitasking==0.0.11 152 | - murmurhash==1.0.10 153 | - nest-asyncio==1.6.0 154 | - networkx==3.4.1 155 | - nltk==3.8 156 | - numba==0.60.0 157 | - numpoly==1.2.14 158 | - numpy==1.26.4 159 | - openai==1.55.3 160 | - opencc==1.1.9 161 | - opencv-python-headless==4.10.0.84 162 | - openfile==0.0.7 163 | - openpyxl==3.1.5 164 | - ortools==9.9.3963 165 | - osqp==0.6.7.post3 166 | - outlines==0.0.46 167 | - packaging==23.2 168 | - pandas>=2.0.0 169 | - partial-json-parser==0.2.1.1.post4 170 | - patsy==0.5.6 171 | - peewee==3.17.7 172 | - pillow==10.4.0 173 | - pip==24.2 174 | - plac==1.1.3 175 | - platformdirs==4.3.6 176 | - playwright==1.49.1 177 | - poetry-core==1.9.0 178 | - portalocker==3.0.0 179 | - preshed==3.0.9 180 | - prettytable==3.12.0 181 | - prometheus-client==0.21.0 182 | - prometheus-fastapi-instrumentator==7.0.0 183 | - propcache==0.2.1 184 | - proto-plus==1.25.0 185 | - protobuf==4.25.5 186 | - psutil==6.1.0 187 | - pulp==2.9.0 188 | - py-cpuinfo==9.0.0 189 | - pyairports==2.1.1 190 | - pyarrow==17.0.0 191 | - pyasn1==0.6.1 192 | - pyasn1-modules==0.4.1 193 | - pycosat==0.6.6 194 | - pycountry==24.6.1 195 | - pycparser==2.22 196 | - pycryptodome==3.21.0 197 | - pydantic==2.9.2 198 | - pydantic-core==2.23.4 199 | - pyee==12.0.0 200 | - pyenchant==3.2.2 201 | - pyerfa==2.0.1.4 202 | - pyext==0.5 203 | - pygments==2.18.0 204 | - pymc3==3.11.4 205 | - pympler==1.1 206 | - pyparsing==3.2.0 207 | - pyperclip==1.9.0 208 | - pyquil==4.14.3 209 | - python-dateutil==2.9.0.post0 210 | - python-dotenv==1.0.1 211 | - python-levenshtein==0.26.1 212 | - python-rapidjson==1.20 213 | - pytz==2024.2 214 | - pywavelets==1.7.0 215 | - pyyaml==6.0.2 216 | - pyzmq==26.2.0 217 | - qcs-api-client-common==0.10.0 218 | - qcs-sdk-python==0.20.1 219 | - qdldl==0.1.7.post4 220 | - quil==0.13.1 221 | - radon==6.0.1 222 | - rank-bm25==0.2.2 223 | - rapidfuzz==3.10.1 224 | - ray==2.39.0 225 | - referencing==0.35.1 226 | - regex==2024.9.11 227 | - requests==2.32.3 228 | - retrying==1.3.4 229 | - rich==13.8.1 230 | - rouge==1.0.1 231 | - rouge-chinese==1.0.3 232 | - rouge-score==0.1.2 233 | - rpcq==3.11.0 234 | - rpds-py==0.21.0 235 | - rsa==4.9 236 | - ruamel-yaml==0.18.6 237 | - ruamel-yaml-clib==0.2.12 238 | - sacrebleu==2.4.3 239 | - safetensors==0.4.5 240 | - scikit-learn==1.5.0 241 | - scipy==1.14.1 242 | - scs==3.2.7 243 | - seaborn==0.13.2 244 | - semver==3.0.2 245 | - sentence-transformers==2.2.2 246 | - sentencepiece==0.2.0 247 | - sentry-sdk==2.18.0 248 | - setproctitle==1.3.3 249 | - setuptools==72.1.0 250 | - shellingham==1.5.4 251 | - shortuuid==1.0.13 252 | - six==1.16.0 253 | - smart-open==7.0.4 254 | - smmap==5.0.1 255 | - sniffio==1.3.1 256 | - sortedcontainers==2.4.0 257 | - soupsieve==2.6 258 | - spacy==2.3.9 259 | - spacy-legacy==3.0.12 260 | - spacy-loggers==1.0.5 261 | - srsly==1.0.7 262 | - starlette==0.41.3 263 | - strip-markdown==1.3 264 | - sympy==1.13.1 265 | - tabulate==0.9.0 266 | - tempdir==0.7.1 267 | - termcolor==2.5.0 268 | - theano-pymc==1.1.2 269 | - thinc==7.4.6 270 | - threadpoolctl==3.5.0 271 | - tiktoken==0.7.0 272 | - timeout-decorator==0.5.0 273 | - tokenizers==0.20.3 274 | - toolwrapper==2.1.0 275 | - torch==2.5.1 276 | - torchvision==0.20.1 277 | - tqdm==4.66.5 278 | - traitlets==5.14.3 279 | - transformers==4.46.2 280 | - triton==3.1.0 281 | - typer==0.12.5 282 | - types-deprecated==1.2.9.20240311 283 | - typing-extensions==4.12.2 284 | - tzdata==2024.1 285 | - urllib3==2.2.3 286 | - uvicorn==0.32.1 287 | - uvloop==0.21.0 288 | - wandb==0.18.7 289 | - wasabi==0.10.1 290 | - watchfiles==1.0.0 291 | - wcwidth==0.2.13 292 | - webencodings==0.5.1 293 | - websockets==14.1 294 | - wget==3.2 295 | - wheel==0.43.0 296 | - wrapt==1.16.0 297 | - xarray==2024.10.0 298 | - xarray-einstats==0.8.0 299 | - xformers==0.0.28.post3 300 | - xxhash==3.5.0 301 | - yapf==0.43.0 302 | - yarl==1.18.3 303 | - yfinance==0.2.48 304 | - z3-solver==4.13.3.0 305 | - zipp==3.20.2 306 | -------------------------------------------------------------------------------- /figures/ioexample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/CodeIO/1d3541cc928e9f76da9c80e95778635a11e0583c/figures/ioexample.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkust-nlp/CodeIO/1d3541cc928e9f76da9c80e95778635a11e0583c/figures/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==1.1.1 3 | addict==2.4.0 4 | aiohappyeyeballs==2.4.4 5 | aiohttp==3.11.9 6 | aiosignal==1.3.1 7 | annotated-types==0.7.0 8 | anyio==4.6.2.post1 9 | appdirs==1.4.4 10 | arviz==0.20.0 11 | astropy==6.1.4 12 | astropy-iers-data==0.2024.11.4.0.33.34 13 | attrs==24.2.0 14 | beautifulsoup4==4.12.3 15 | bibtexparser==1.4.3 16 | biopython==1.84 17 | blis==0.7.11 18 | cachetools==5.5.0 19 | catalogue==1.0.2 20 | certifi==2024.8.30 21 | cffi==1.17.1 22 | charset-normalizer==3.3.2 23 | chess==1.11.1 24 | cirq==1.4.1 25 | cirq-aqt==1.4.1 26 | cirq-core==1.4.1 27 | cirq-google==1.4.1 28 | cirq-ionq==1.4.1 29 | cirq-pasqal==1.4.1 30 | cirq-rigetti==1.4.1 31 | cirq-web==1.4.1 32 | clarabel==0.9.0 33 | click==8.1.7 34 | cloudpathlib==0.19.0 35 | cloudpickle==3.1.0 36 | colorama==0.4.6 37 | compressed-tensors==0.8.0 38 | contourpy==1.3.0 39 | cpm-kernels==1.0.11 40 | cpmpy==0.9.23 41 | cramjam==2.8.3 42 | cryptography==43.0.3 43 | cvxpy==1.5.3 44 | cycler==0.12.1 45 | cymem==2.0.8 46 | datasets==3.1.0 47 | Deprecated==1.2.14 48 | dill==0.3.8 49 | diskcache==5.6.3 50 | distro==1.9.0 51 | docker-pycreds==0.4.0 52 | docopt==0.6.2 53 | duet==0.2.9 54 | ecos==2.0.14 55 | einops==0.5.0 56 | et_xmlfile==2.0.0 57 | evaluate==0.4.3 58 | fastapi==0.115.5 59 | fastparquet==2024.5.0 60 | fastprogress==1.0.3 61 | filelock==3.16.1 62 | fire==0.7.0 63 | fonttools==4.54.1 64 | frozendict==2.4.6 65 | frozenlist==1.5.0 66 | fsspec==2024.9.0 67 | func_timeout==4.3.5 68 | fuzzywuzzy==0.18.0 69 | gguf==0.10.0 70 | gitdb==4.0.11 71 | GitPython==3.1.43 72 | gmpy2==2.2.1 73 | google-api-core==2.22.0 74 | google-auth==2.35.0 75 | googleapis-common-protos==1.65.0 76 | greenlet==3.1.1 77 | grilops==0.10.3 78 | grpc-interceptor==0.15.4 79 | grpcio==1.67.1 80 | grpcio-status==1.62.3 81 | h11==0.14.0 82 | h5netcdf==1.4.0 83 | h5py==3.12.1 84 | html5lib==1.1 85 | httpcore==1.0.6 86 | httptools==0.6.4 87 | httpx==0.27.2 88 | huggingface-hub==0.24.7 89 | idna==3.10 90 | immutabledict==4.2.0 91 | importlib_metadata==8.5.0 92 | interegular==0.3.3 93 | jieba==0.42.1 94 | Jinja2==3.1.4 95 | jiter==0.8.0 96 | joblib==1.4.2 97 | json5==0.10.0 98 | jsonschema==4.23.0 99 | jsonschema-specifications==2024.10.1 100 | kiwisolver==1.4.7 101 | langcodes==3.4.0 102 | language_data==1.2.0 103 | lark==1.2.2 104 | Levenshtein==0.26.1 105 | llvmlite==0.43.0 106 | lm-format-enforcer==0.10.9 107 | lxml==5.3.0 108 | mando==0.7.1 109 | marisa-trie==1.2.0 110 | Markdown==3.7 111 | markdown-it-py==3.0.0 112 | MarkupSafe==2.1.5 113 | matplotlib==3.9.2 114 | matplotlib-inline==0.1.7 115 | mdurl==0.1.2 116 | mistral_common==1.5.1 117 | mistune==3.0.2 118 | mmengine-lite==0.10.5 119 | mosestokenizer==1.0.0 120 | mpmath==1.3.0 121 | msgpack==1.1.0 122 | msgspec==0.18.6 123 | multidict==6.1.0 124 | multipledispatch==1.0.0 125 | multiprocess==0.70.16 126 | multitasking==0.0.11 127 | murmurhash==1.0.10 128 | nest-asyncio==1.6.0 129 | networkx==3.4.1 130 | nltk==3.8 131 | numba==0.60.0 132 | numpoly==1.2.14 133 | numpy==1.26.4 134 | openai==1.55.3 135 | OpenCC==1.1.9 136 | opencv-python-headless==4.10.0.84 137 | openfile==0.0.7 138 | openpyxl==3.1.5 139 | ortools==9.9.3963 140 | osqp==0.6.7.post3 141 | outlines==0.0.46 142 | packaging==23.2 143 | pandas>=2.0.0 144 | partial-json-parser==0.2.1.1.post4 145 | patsy==0.5.6 146 | peewee==3.17.7 147 | pillow==10.4.0 148 | plac==1.1.3 149 | platformdirs==4.3.6 150 | playwright==1.49.1 151 | poetry-core==1.9.0 152 | portalocker==3.0.0 153 | preshed==3.0.9 154 | prettytable==3.12.0 155 | prometheus-fastapi-instrumentator==7.0.0 156 | prometheus_client==0.21.0 157 | propcache==0.2.1 158 | proto-plus==1.25.0 159 | protobuf==4.25.5 160 | psutil==6.1.0 161 | PuLP==2.9.0 162 | py-cpuinfo==9.0.0 163 | pyairports==2.1.1 164 | pyarrow==17.0.0 165 | pyasn1==0.6.1 166 | pyasn1_modules==0.4.1 167 | pycosat==0.6.6 168 | pycountry==24.6.1 169 | pycparser==2.22 170 | pycryptodome==3.21.0 171 | pydantic==2.9.2 172 | pydantic_core==2.23.4 173 | pyee==12.0.0 174 | pyenchant==3.2.2 175 | pyerfa==2.0.1.4 176 | pyext==0.5 177 | Pygments==2.18.0 178 | pymc3==3.11.4 179 | Pympler==1.1 180 | pyparsing==3.2.0 181 | pyperclip==1.9.0 182 | pyquil==4.14.3 183 | python-dateutil==2.9.0.post0 184 | python-dotenv==1.0.1 185 | python-Levenshtein==0.26.1 186 | python-rapidjson==1.20 187 | pytz==2024.2 188 | PyWavelets==1.7.0 189 | PyYAML==6.0.2 190 | pyzmq==26.2.0 191 | qcs-api-client-common==0.10.0 192 | qcs-sdk-python==0.20.1 193 | qdldl==0.1.7.post4 194 | quil==0.13.1 195 | radon==6.0.1 196 | rank-bm25==0.2.2 197 | RapidFuzz==3.10.1 198 | ray==2.39.0 199 | referencing==0.35.1 200 | regex==2024.9.11 201 | requests==2.32.3 202 | retrying==1.3.4 203 | rich==13.8.1 204 | rouge==1.0.1 205 | rouge-chinese==1.0.3 206 | rouge_score==0.1.2 207 | rpcq==3.11.0 208 | rpds-py==0.21.0 209 | rsa==4.9 210 | ruamel.yaml==0.18.6 211 | ruamel.yaml.clib==0.2.12 212 | sacrebleu==2.4.3 213 | safetensors==0.4.5 214 | scikit-learn==1.5.0 215 | scipy==1.14.1 216 | scs==3.2.7 217 | seaborn==0.13.2 218 | semver==3.0.2 219 | sentence-transformers==2.2.2 220 | sentencepiece==0.2.0 221 | sentry-sdk==2.18.0 222 | setproctitle==1.3.3 223 | shellingham==1.5.4 224 | shortuuid==1.0.13 225 | six==1.16.0 226 | smart-open==7.0.4 227 | smmap==5.0.1 228 | sniffio==1.3.1 229 | sortedcontainers==2.4.0 230 | soupsieve==2.6 231 | spacy==2.3.9 232 | spacy-legacy==3.0.12 233 | spacy-loggers==1.0.5 234 | srsly==1.0.7 235 | starlette==0.41.3 236 | strip-markdown==1.3 237 | sympy==1.13.1 238 | tabulate==0.9.0 239 | tempdir==0.7.1 240 | termcolor==2.5.0 241 | Theano-PyMC==1.1.2 242 | thinc==7.4.6 243 | threadpoolctl==3.5.0 244 | tiktoken==0.7.0 245 | timeout-decorator==0.5.0 246 | tokenizers==0.20.3 247 | toolwrapper==2.1.0 248 | torch==2.5.1 249 | torchvision==0.20.1 250 | tqdm==4.66.5 251 | traitlets==5.14.3 252 | transformers==4.46.2 253 | triton==3.1.0 254 | typer==0.12.5 255 | types-Deprecated==1.2.9.20240311 256 | typing_extensions==4.12.2 257 | tzdata==2024.1 258 | urllib3==2.2.3 259 | uvicorn==0.32.1 260 | uvloop==0.21.0 261 | wandb==0.18.7 262 | wasabi==0.10.1 263 | watchfiles==1.0.0 264 | wcwidth==0.2.13 265 | webencodings==0.5.1 266 | websockets==14.1 267 | wget==3.2 268 | wrapt==1.16.0 269 | xarray==2024.10.0 270 | xarray-einstats==0.8.0 271 | xformers==0.0.28.post3 272 | xxhash==3.5.0 273 | yapf==0.43.0 274 | yarl==1.18.3 275 | yfinance==0.2.48 276 | z3-solver==4.13.3.0 277 | zipp==3.20.2 278 | -------------------------------------------------------------------------------- /scripts/pipeline_check.sh: -------------------------------------------------------------------------------- 1 | pfn=${1:-data/rawcode_1k_parsed.jsonl} 2 | ifn=${2:-data/codeio_1k_gens.jsonl} 3 | ofn=${3:-data/codeio_1k_gens_verified.jsonl} 4 | pythonpath=${4:-"python"} 5 | runpath=${5:-"./temp/temp/temp"} 6 | 7 | python ./src/check_io_pred_acc_mp.py \ 8 | --parsed_file_name $pfn \ 9 | --pred_file_name $ifn \ 10 | --res_file_name $ofn \ 11 | --batchsize 1024 \ 12 | --num_processes 24 \ 13 | --python_path $pythonpath \ 14 | --run_path $runpath 15 | 16 | for i in {1..10} 17 | do 18 | echo "trial $i" 19 | 20 | if [ $i -eq 1 ]; then 21 | numprocess=24 22 | elif [ $i -eq 2 ]; then 23 | numprocess=16 24 | elif [ $i -eq 8 ] || [ $i -eq 9 ] || [ $i -eq 10 ]; then 25 | numprocess=4 26 | else 27 | numprocess=8 28 | fi 29 | 30 | echo "numprocess: $numprocess" 31 | 32 | python ./src/check_io_pred_acc_mp_inplace.py \ 33 | --parsed_file_name $pfn \ 34 | --pred_file_name $ofn \ 35 | --batchsize 1024 \ 36 | --write_batchsize 16 \ 37 | --num_processes $numprocess \ 38 | --python_path $pythonpath \ 39 | --run_path $runpath 40 | done -------------------------------------------------------------------------------- /src/assemble_codeio_demo.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import * 3 | from tqdm import tqdm 4 | 5 | if __name__=="__main__": 6 | import argparse 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--result_file_turn1", type=str, default=None) 9 | parser.add_argument("--result_file_turn2", type=str, default=None) 10 | parser.add_argument("--output_file", type=str, default=None) 11 | args = parser.parse_args() 12 | 13 | fn1 = args.result_file_turn1 14 | fn2 = args.result_file_turn2 15 | ofn = args.output_file 16 | dt1 = load_jsonl_yield(fn1) 17 | dt2 = load_jsonl_yield(fn2) 18 | ndt = [] 19 | for item in tqdm(dt1): 20 | status = item['res']['status'] 21 | if status == 'success': 22 | sample = {"prompt":item['messages'][0]['content'], 23 | "turn_1":item['output'], 24 | "feedback_1":item['res']['message'], 25 | "turn_2":None, 26 | "feedback_2":None} 27 | ndt.append(sample) 28 | if len(ndt)==1000: 29 | write_jsonl(ndt,ofn,"a") 30 | ndt = [] 31 | for item in tqdm(dt2): 32 | # elegant_show(item,full=True) 33 | # raise ValueError 34 | sample = {"prompt":item['messages'][0]['content'], 35 | "turn_1":item['messages'][1]['content'], 36 | "feedback_1":item['messages'][2]['content'], 37 | "turn_2":item['output'], 38 | "feedback_2":item['res']['message']} 39 | ndt.append(sample) 40 | if len(ndt)==1000: 41 | write_jsonl(ndt,ofn,"a") 42 | ndt = [] 43 | 44 | write_jsonl(ndt,ofn,"a") -------------------------------------------------------------------------------- /src/batched_api_inference.py: -------------------------------------------------------------------------------- 1 | try: 2 | from openai import OpenAI 3 | except: 4 | pass 5 | import datetime 6 | import json 7 | import multiprocessing 8 | from argparse import ArgumentParser 9 | import os 10 | import time 11 | from tqdm import tqdm 12 | 13 | from multiprocessing import Process, Queue, Lock, Value 14 | import concurrent 15 | from concurrent.futures import ThreadPoolExecutor 16 | 17 | ############################################### 18 | max_try_one_call = 3 19 | SYSTEM = None 20 | ############################################### 21 | 22 | def get_client(): 23 | assert model.startswith("gpt") or model.startswith("deepseek") 24 | params = { 25 | "api_key": key, 26 | "timeout":10000.0 27 | } 28 | if model.startswith("deepseek"): 29 | params["base_url"] = "https://api.deepseek.com" 30 | 31 | client = OpenAI( 32 | **params 33 | ) 34 | return client 35 | 36 | def timer(func): 37 | def format_time(time_delta): 38 | hours, remainder = divmod(time_delta.total_seconds(), 3600) 39 | minutes, seconds = divmod(remainder, 60) 40 | return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}" 41 | def wrapper(*args, **kwargs): 42 | start_time = datetime.datetime.now() 43 | print("开始时间:", start_time.strftime("%Y-%m-%d %H:%M:%S")) 44 | result = func(*args, **kwargs) 45 | end_time = datetime.datetime.now() 46 | print("结束时间:", end_time.strftime("%Y-%m-%d %H:%M:%S")) 47 | elapsed_time = end_time - start_time 48 | print("执行时间:", format_time(elapsed_time)) 49 | return result 50 | return wrapper 51 | 52 | def load_jsonl_yield(path): 53 | with open(path) as f: 54 | for row, line in enumerate(f): 55 | try: 56 | line = json.loads(line) 57 | yield line 58 | except: 59 | pass 60 | 61 | def check_exists(line): 62 | if "output" in line and line["output"] is not None: 63 | return True 64 | return False 65 | 66 | def process_line(js, good_cnt, bad_cnt, lock, output_path): 67 | messages = js['messages'] 68 | response = None 69 | finish_reason = None 70 | for i in range(max_try_one_call): # retry if failed 71 | try: 72 | client = get_client() 73 | chat_completion = client.chat.completions.create( 74 | model = model, 75 | messages = messages, 76 | max_tokens = max_tokens, 77 | temperature = temperature, 78 | timeout=10000.0 79 | ) 80 | if model == "deepseek-reasoner": 81 | reasoning = chat_completion.choices[0].message.reasoning_content 82 | else: 83 | reasoning = None 84 | response = chat_completion.choices[0].message.content 85 | finish_reason = chat_completion.choices[0].finish_reason 86 | break 87 | except Exception as e: 88 | if i>>>", problem_statement).replace("<<<>>>", io_req) 8 | tag = "<<<>>>" if io=="i" else "<<<>>>" 9 | inputxx = f"{inputx}" 10 | outputxx = f"{outputx}" 11 | prompt = prompt.replace(tag, outputxx if io=="i" else inputxx) 12 | if w_refcode: 13 | refcodepart = refcode_template.replace("<<<>>>", refcode) 14 | prompt+="\n\n"+refcodepart 15 | return prompt 16 | 17 | if __name__=="__main__": 18 | import argparse 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--input_file', type=str, default='data/rawcode_1k_parsed.jsonl') 21 | parser.add_argument('--output_file', type=str, default='data/codeio_1k_msg.jsonl') 22 | args = parser.parse_args() 23 | 24 | fn = args.input_file 25 | ofn = args.output_file 26 | adt = [] 27 | dt = load_jsonl_yield(fn) 28 | for iid,item in enumerate(tqdm(dt)): 29 | problem_description = item['problem_description'] 30 | io_req = item['io_requirements'] 31 | refcode = item['refcode'] 32 | for ioid,io in enumerate(item['ios']): 33 | uplimit = 3 34 | if ioid>=uplimit:break # we now first only use the first 3 io 35 | input_xx = io['input'] 36 | output_xx = io['output'] 37 | oprompt = build_io_pred(problem_description, io_req, refcode, input_xx, output_xx, io="o") 38 | iprompt = build_io_pred(problem_description, io_req, refcode, input_xx, output_xx, io="i") 39 | imsg = build_messages(iprompt) 40 | omsg = build_messages(oprompt) 41 | isample = {"messages":imsg, "itemid":iid,"ioid":ioid,"io_pred":"i"} 42 | osample = {"messages":omsg, "itemid":iid,"ioid":ioid,"io_pred":"o"} 43 | adt.append(isample) 44 | adt.append(osample) 45 | 46 | write_jsonl(adt,ofn) 47 | -------------------------------------------------------------------------------- /src/build_codeio_rev_msg.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from utils import * 4 | from tqdm import tqdm 5 | 6 | 7 | def wrap_feedback(errormessage,xtype=0): 8 | if xtype==0: 9 | return errormessage 10 | elif xtype==1: 11 | return errormessage+'\n\nPlease redo it, and your prediction should no longer be any of the wrong ones you have made before!' 12 | else: 13 | raise ValueError 14 | 15 | if __name__=="__main__": 16 | import argparse 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--input_file", type=str, default=None) 19 | parser.add_argument("--output_file", type=str, default=None) 20 | args = parser.parse_args() 21 | revision_type = 1 22 | 23 | ifn = args.input_file 24 | ofn = args.output_file 25 | 26 | dt = load_jsonl_yield(ifn) 27 | batch = [] 28 | for iid, item in enumerate(tqdm(dt)): 29 | if item['res']['status']=='success': 30 | continue 31 | item['messages'].append({"role":"assistant","content": item['output']} ) 32 | item.pop('output') 33 | item['messages'].append({"role":"user","content":wrap_feedback(item['res']['message'],revision_type)}) 34 | item['retry_count'] = item.get('retry_count',0)+1 35 | item['history_errors'] = item.get('history_errors',[]) 36 | item['history_errors'].append(item['res']) 37 | item.pop('res') 38 | 39 | item['original_id'] = item.get('original_id',iid) 40 | 41 | batch.append(item) 42 | if len(batch)==1000: 43 | write_jsonl(batch,ofn,"a") 44 | batch = [] 45 | write_jsonl(batch,ofn,"a") -------------------------------------------------------------------------------- /src/build_transform_msg.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import argparse 4 | from codeio_utils import build_testcases_prompt_advanced 5 | from utils import * 6 | 7 | if __name__=="__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--raw_code_file', type=str, default='data/rawcode_1k.jsonl') 10 | parser.add_argument('--raw_code_msg_file', type=str, default='data/rawcode_1k_msg.jsonl') 11 | args = parser.parse_args() 12 | dt = read_jsonl(args.raw_code_file) 13 | messages = [] 14 | for item in dt: 15 | messages.append( 16 | { 17 | "messages":[ 18 | { 19 | "role":"user", 20 | "content":build_testcases_prompt_advanced.replace("<<<>>>", item["content"]) 21 | } 22 | ] 23 | } 24 | ) 25 | write_jsonl(messages, args.raw_code_msg_file) 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /src/check_io_pred_acc_mp.py: -------------------------------------------------------------------------------- 1 | from os import write 2 | from codeio_utils import * 3 | from utils import * 4 | import copy 5 | import multiprocessing as mp 6 | from tqdm import tqdm 7 | from itertools import islice 8 | import json 9 | from multiprocessing.pool import ThreadPool 10 | from collections import defaultdict 11 | 12 | # global vars 13 | python_path = "python" 14 | run_path = "./temp/temp/temp" 15 | 16 | # Your check_io_pred_acc function uses source_2_data 17 | def check_io_pred_acc(item): 18 | output = item["output"] 19 | last_json = extract_last_complete_json(output) 20 | if last_json is None: 21 | return {"status":"no answer","message":"Fail to extract a complete and valid json from the output!"} 22 | needed_oriitem = parsed_ios[item["itemid"]] 23 | needed_oriio = needed_oriitem["ios"][item["ioid"]] 24 | if item['io_pred'] == "o": 25 | if not isinstance(last_json, dict): 26 | return {"status":"no answer","message":"The last json is not a dict!"} 27 | if "output" not in last_json: 28 | return {"status":"no answer","message":"No field 'output' in the last json!"} 29 | pred_output = last_json["output"] 30 | acc = is_close(pred_output, needed_oriio["output"]) 31 | if acc: 32 | return {"status":"success","message":"Correct output!"} 33 | else: 34 | return {"status":"failed","message":f"[Mismatch] Your output is not correct! Given the input {json.dumps(needed_oriio['input'])}, your predicted output is {json.dumps(pred_output)} which is wrong!"} 35 | elif item['io_pred'] == "i": 36 | if not isinstance(last_json, dict): 37 | return {"status":"no answer","message":"The last json is not a dict!"} 38 | if "input" not in last_json: 39 | return {"status":"no answer","message":"No field 'input' in the last json!"} 40 | pred_input = last_json["input"] 41 | candio = {'input': pred_input, 'output': needed_oriio['output']} 42 | res = check_input(needed_oriitem['refcode'], candio, needed_oriitem['funcname'], solution_prefix=solution_prefix, runtime_limit=5, used_python_path = python_path, run_path=run_path) 43 | if "exception_type" in res: 44 | res['messages'] = f"[{res['exception_type']}] {res['message']}" 45 | return res 46 | 47 | # Function to batch items from an iterator 48 | def batcher(iterable, batch_size): 49 | """Batch an iterator into lists of length batch_size""" 50 | it = iter(iterable) 51 | while True: 52 | chunk = list(islice(it, batch_size)) 53 | if not chunk: 54 | break 55 | yield chunk 56 | 57 | def get_total_items_with_wc(filename): 58 | result = subprocess.run(['wc', '-l', filename], stdout=subprocess.PIPE, text=True) 59 | total_lines = int(result.stdout.split()[0]) # wc输出的形式是: 行数 文件名, 所以只取第一部分 60 | return total_lines 61 | 62 | def main(): 63 | import argparse 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("--parsed_file_name",type=str,help="The path to the parsed io file.") 66 | parser.add_argument("--pred_file_name", type=str, help="The path to the prediction file.") 67 | parser.add_argument("--res_file_name", type=str, help="The path to the result file.") 68 | parser.add_argument("--batchsize", type=int, help="The batch size.") 69 | parser.add_argument("--num_processes", type=int, help="The number of processes.") 70 | parser.add_argument('--python_path', type=str, default="python") 71 | parser.add_argument('--run_path', type=str, default="./temp/temp/temp") 72 | args = parser.parse_args() 73 | 74 | pred_file_name = args.pred_file_name 75 | res_file_name = args.res_file_name 76 | global parsed_ios 77 | parsed_ios = read_jsonl(args.parsed_file_name) 78 | 79 | global python_path, run_path 80 | python_path = args.python_path 81 | run_path = args.run_path 82 | 83 | if not os.path.exists(run_path): 84 | os.makedirs(run_path,exist_ok=True) 85 | 86 | # Load dt as a generator 87 | dt = load_jsonl_yield(pred_file_name) # This is a generator object 88 | 89 | if os.path.exists(res_file_name): 90 | existing = get_total_items_with_wc(res_file_name) 91 | else: 92 | existing = 0 93 | 94 | print(f"Existing items: {existing}, skipping them.") 95 | dt = islice(dt, existing, None) 96 | 97 | total_num_items = get_total_items_with_wc(pred_file_name)-existing 98 | 99 | batchsize = args.batchsize 100 | 101 | # Use fork method to avoid duplicating source_2_data 102 | ctx = mp.get_context('fork') 103 | 104 | num_processes = args.num_processes 105 | 106 | with ctx.Pool(processes = num_processes) as pool: 107 | # Create batches of items 108 | batches = batcher(dt, batchsize) 109 | 110 | # Initialize the progress bar 111 | pbar = tqdm(total=total_num_items) 112 | 113 | for batch_idx, batch in enumerate(batches): 114 | 115 | batch_i = [item for item in batch if item['io_pred'] == 'i'] 116 | batch_o = [item for item in batch if item['io_pred'] == 'o'] 117 | assert len(batch_i) + len(batch_o) == len(batch) 118 | print("I:", len(batch_i), "O:", len(batch_o)) 119 | 120 | batchstat = defaultdict(int) 121 | # Process the batch in parallel 122 | 123 | results = pool.map(check_io_pred_acc, batch_i) 124 | # Combine items with their results 125 | for item, res in zip(batch_i, results): 126 | item['res'] = res 127 | batchstat[res['status']] += 1 128 | results = pool.map(check_io_pred_acc, batch_o) 129 | # Combine items with their results 130 | for item, res in zip(batch_o, results): 131 | item['res'] = res 132 | batchstat[res['status']] += 1 133 | 134 | # Write the batch to the output file 135 | write_jsonl(batch_i, res_file_name, mode='a') 136 | write_jsonl(batch_o, res_file_name, mode='a') 137 | 138 | print(f"Wrote a batch of {len(batch)} items.") 139 | print(f"Batch {batch_idx} status: {batchstat}") 140 | 141 | # Update the progress bar 142 | pbar.update(len(batch)) 143 | 144 | pbar.close() 145 | 146 | if __name__ == "__main__": 147 | import multiprocessing as mp 148 | 149 | # Set the start method to 'fork' to avoid duplicating source_2_data 150 | mp.set_start_method('fork') 151 | 152 | main() 153 | -------------------------------------------------------------------------------- /src/check_io_pred_acc_mp_inplace.py: -------------------------------------------------------------------------------- 1 | from os import write 2 | from codeio_utils import * 3 | from utils import * 4 | import copy 5 | import multiprocessing as mp 6 | from tqdm import tqdm 7 | from itertools import islice 8 | import json 9 | from multiprocessing.pool import ThreadPool 10 | from collections import defaultdict 11 | 12 | # global vars 13 | python_path = "python" 14 | run_path = "./temp/temp/temp" 15 | 16 | # Your check_io_pred_acc function uses source_2_data 17 | def check_io_pred_acc(item): 18 | output = item["output"] 19 | last_json = extract_last_complete_json(output) 20 | if last_json is None: 21 | return {"status":"no answer","message":"Fail to extract a complete and valid json from the output!"} 22 | needed_oriitem = parsed_ios[item["itemid"]] 23 | needed_oriio = needed_oriitem["ios"][item["ioid"]] 24 | if item['io_pred'] == "o": 25 | if "output" not in last_json: 26 | return {"status":"no answer","message":"No field 'output' in the last json!"} 27 | pred_output = last_json["output"] 28 | acc = is_close(pred_output, needed_oriio["output"]) 29 | if acc: 30 | return {"status":"success","message":"Correct output!"} 31 | else: 32 | return {"status":"failed","message":f"[Mismatch] Your output is not correct! Given the input {json.dumps(needed_oriio['input'])}, your predicted output is {json.dumps(pred_output)} which is wrong!"} 33 | elif item['io_pred'] == "i": 34 | if "input" not in last_json: 35 | return {"status":"no answer","message":"No field 'input' in the last json!"} 36 | pred_input = last_json["input"] 37 | candio = {'input': pred_input, 'output': needed_oriio['output']} 38 | res = check_input(needed_oriitem['refcode'], candio, needed_oriitem['funcname'], solution_prefix=solution_prefix, used_python_path = python_path, run_path=run_path) 39 | if "exception_type" in res: 40 | res['messages'] = f"[{res['exception_type']}] {res['message']}" 41 | return res 42 | 43 | # Function to batch items from an iterator 44 | def batcher(iterable, batch_size): 45 | """Batch an iterator into lists of length batch_size""" 46 | it = iter(iterable) 47 | while True: 48 | chunk = list(islice(it, batch_size)) 49 | if not chunk: 50 | break 51 | yield chunk 52 | 53 | def main(): 54 | import argparse 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--parsed_file_name",type=str,help="The path to the parsed io file.") 57 | parser.add_argument("--pred_file_name", type=str, help="The path to the prediction file.") 58 | parser.add_argument("--batchsize", type=int, help="The batch size.", default=1024) 59 | parser.add_argument("--write_batchsize", type=int, help="The batch size of batch size", default=8) 60 | parser.add_argument("--num_processes", type=int, help="The number of processes.", default=16) 61 | parser.add_argument('--python_path', type=str, default="python") 62 | parser.add_argument('--run_path', type=str, default="./temp/temp/temp") 63 | args = parser.parse_args() 64 | 65 | pred_file_name = args.pred_file_name 66 | batchsize = args.batchsize 67 | write_batchsize = args.write_batchsize 68 | global parsed_ios 69 | parsed_ios = read_jsonl(args.parsed_file_name) 70 | 71 | global python_path, run_path 72 | python_path = args.python_path 73 | run_path = args.run_path 74 | 75 | if not os.path.exists(run_path): 76 | os.makedirs(run_path,exist_ok=True) 77 | 78 | print(f"Reading data from {pred_file_name}...") 79 | dt = read_jsonl(pred_file_name) 80 | total_num_items = len(dt) 81 | 82 | # Get list of items to process: items without 'res' or items with 'res','status' == "timeout" 83 | items_to_process = [(idx, item) for idx, item in enumerate(dt) if 'res' not in item or item['res']['status'] == "timeout"] 84 | total_items_to_process = len(items_to_process) 85 | 86 | print(f"Total items: {total_num_items}") 87 | print(f"Items to process: {total_items_to_process}") 88 | 89 | # Use fork method to avoid duplicating source_2_data 90 | ctx = mp.get_context('fork') 91 | num_processes = args.num_processes 92 | 93 | with ctx.Pool(processes=num_processes) as pool: 94 | # Create batches of items to process 95 | batches = batcher(items_to_process, batchsize) 96 | 97 | # Initialize the progress bar 98 | pbar = tqdm(total=total_items_to_process) 99 | 100 | bid=0 101 | fullbatchstat = defaultdict(int) 102 | 103 | for batch in batches: 104 | # batch is a list of (idx, item) 105 | batch_indices = [idx for idx, item in batch] 106 | batch_items_o = [item for idx, item in batch if item['io_pred'] == 'o'] 107 | batch_items_i = [item for idx, item in batch if item['io_pred'] == 'i'] 108 | assert len(batch) == len(batch_items_i) + len(batch_items_o) 109 | print(f"I: {len(batch_items_i)}, O: {len(batch_items_o)}") 110 | 111 | batchstat = defaultdict(int) 112 | 113 | # Process the batch in parallel - o 114 | results = pool.map(check_io_pred_acc, batch_items_o) 115 | # Update items in dt 116 | for idx, res in zip(batch_indices, results): 117 | dt[idx]['res'] = res 118 | batchstat[res['status']] += 1 119 | fullbatchstat[res['status']] += 1 120 | 121 | # Process the batch in parallel - i 122 | results = pool.map(check_io_pred_acc, batch_items_i) 123 | for idx, res in zip(batch_indices, results): 124 | dt[idx]['res'] = res 125 | batchstat[res['status']] += 1 126 | fullbatchstat[res['status']] += 1 127 | 128 | print(batchstat) 129 | 130 | if (bid+1) % write_batchsize == 0: 131 | # Write the batch to the output file 132 | write_jsonl(dt, pred_file_name, mode='w') 133 | print(f"[bid {bid} finished] Wrote several batches of items.") 134 | 135 | # # Write the entire dt back to the same file 136 | # write_jsonl(dt, pred_file_name, mode='w') 137 | 138 | # Update the progress bar 139 | pbar.update(len(batch)) 140 | bid+=1 141 | 142 | # Write the remaining items to the output file 143 | write_jsonl(dt, pred_file_name, mode='w') 144 | 145 | print(fullbatchstat) 146 | 147 | pbar.close() 148 | 149 | if __name__ == "__main__": 150 | import multiprocessing as mp 151 | import time 152 | # 获取当前时间的结构化格式 153 | current_time = time.localtime() 154 | 155 | # 格式化输出(秒、分钟、小时、日、月、年) 156 | print("Current time: ", time.strftime("%H:%M:%S %d-%m-%Y", current_time)) 157 | 158 | # Set the start method to 'fork' to avoid duplicating source_2_data 159 | mp.set_start_method('fork') 160 | 161 | main() 162 | -------------------------------------------------------------------------------- /src/codeio_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import sys 4 | import re 5 | from pympler import asizeof 6 | from decimal import Decimal 7 | 8 | try: 9 | import shortuuid 10 | except: 11 | print("Please install shortuuid by running 'pip install shortuuid'") 12 | import os 13 | import signal 14 | 15 | 16 | 17 | refcode_template = """Tip: Here is a reference code snippet for this question. You can refer to this code to guide your reasoning but not copy spans of code directly. 18 | 19 | <<<>>>""" 20 | 21 | output_pred_template = """You are given a question that requires some input and output variables as follows: 22 | 23 | <<<>>> 24 | 25 | The input and output requirements are as follows: 26 | 27 | <<<>>> 28 | 29 | Given the following input: 30 | 31 | <<<>>> 32 | 33 | Can you predict the output without writing any code? Please reason and put your final answer in the following json format: {"output": }, where should strictly match the the output requirement as specified.""" 34 | 35 | input_pred_template = """You are given a question that requires some input and output variables as follows: 36 | 37 | <<<>>> 38 | 39 | The input and output requirements are as follows: 40 | 41 | <<<>>> 42 | 43 | Given the following output: 44 | 45 | <<<>>> 46 | 47 | Can you predict a feasible input without writing any code? Please reason and put your final answer in the following json format: {"input": }, where should be a dictionary, even if the there is only one input variable, with keys strictly match the input variables' names as specified.""" 48 | 49 | solution_prefix="""from itertools import accumulate, chain, combinations, count, permutations, product, groupby, islice, repeat 50 | from copy import deepcopy 51 | import signal 52 | from string import ascii_lowercase, ascii_uppercase 53 | from math import floor, log2, log10, sqrt, hypot, comb, gcd, ceil, inf, isqrt, lcm, factorial, dist 54 | from collections import defaultdict, deque, Counter 55 | from bisect import bisect, bisect_left, bisect_right, insort 56 | from heapq import heappush, heappop, heapify, merge, heapreplace 57 | from functools import reduce, lru_cache, cache, cmp_to_key 58 | from random import randrange, shuffle 59 | from operator import itemgetter, sub, or_, xor, and_ 60 | from re import search as re_search # Assuming 're' refers to a regex search 61 | from os.path import commonprefix 62 | from typing import List, Tuple, Dict, Set, Optional, Union, Any, Callable, Iterable, Iterator, Generator 63 | import copy 64 | import datetime 65 | import string 66 | import math 67 | from math import atan2, pi 68 | import collections 69 | import bisect 70 | import heapq 71 | from heapq import nlargest 72 | import functools 73 | import random 74 | from random import randint 75 | import itertools 76 | import operator 77 | import re 78 | import json 79 | import numpy as np 80 | from math import log, prod # 'log' and 'prod' are functions in the math module 81 | from collections import deque, defaultdict, Counter, OrderedDict 82 | from itertools import accumulate, permutations, combinations, product, groupby, islice, chain, repeat, zip_longest, cycle 83 | from functools import lru_cache, reduce, partial 84 | import sys 85 | from itertools import pairwise""" 86 | 87 | build_testcases_prompt_advanced="""Task Overview: 88 | 89 | Given a code file (not be not python), you need to organize a main function in python and generate a problem based on this function. The request includes the following components: 90 | 91 | 1. Main Function: 92 | - Refer to the provided code file to build your solution function to solve the problem; it should include the main logic of the code file. 93 | - If some self-defined modules are imported, please keep them as they are instead of write a placeholder for them. 94 | - The function should be named as `main_solution`; you can do this by either renaming some functions in the code file or calling multiple of them in the new main function. 95 | - The `main_solution` function must have JSON serializable input and output variables. 96 | If no input variables are needed in the reference code file, please adjust the code to make sure input variables always exist, like rewriting the code for one specific case to general cases. 97 | If some input variables in the original code file are not JSON serializable (like set, tuple, np.array, functions, self-defined objects, etc.), you need to convert the JSON serializable inputs of the `main_solution` function to the original input variables at the beginning of the function. 98 | If some output variables in the original code file are not JSON serializable (like set, tuple, np.array, functions, self-defined objects, etc.), you must convert them to JSON serializable outputs at the end of the `main_solution` function before returning. 99 | The size of both input and output variables should be reasonable, like less than 1KB. 100 | Try to avoid too complex input and output variables, like too many nested structures or extremely large numbers or floats with too many decimals. 101 | - The `main_solution` function should return the final output instead of printing it. 102 | - Please remove all plotting code and only keep the core solution code as we never want to return plots. Also remove all print statements, and writing to files. 103 | - Please always define the `main_solution` function at the end of this part, before this, you must prepare all the necessary code by referring to the code file to make sure the `main_solution` function can run correctly. 104 | 105 | 2. Input Output Description: 106 | - You need to provide clear descriptions of the input and output variables in the `main_solution` function. 107 | - In the descriptions, you should include the type of each variable and a brief explanation of its meaning. 108 | For example, if the variable is a dictionary, you should specify the key names and the object type and meaning corresponding to each key's value. In short, please make sure the input and output requirements are very clear and unambiguous. 109 | For example, if the variable is a string, you need to specify what format the string should be presented (like what is the separator to link multiple items in the string), open-ended string inputs and outputs are not allowed. 110 | 111 | 3. Input Generator: 112 | - You need to provide a function named `input_generator` that generates the input arguments for the `main_solution` function. 113 | - The `input_generator` function should not require any input arguments, and each time it is called, it should return a set of input arguments that meet the requirements of the `main_solution` function. 114 | - The output of `input_generator` should always be a dictionary because we always call by `**kwargs` in the `main_solution` function. 115 | - Add some randomness in the `input_generator` function to ensure the input arguments are different each time it is called. 116 | - Please try to make the generated input arguments as reasonable as possible, try to avoid generating too complex or too trivial input variables, also the size of the variables should be reasonable, like less than 1KB. 117 | 118 | 4. Problem Statement: 119 | - Based on the `main_solution` function, you need to create a problem that is related to the provided code. 120 | - Please avoid writing contents such as "implement a function", "write a function" or "implement a system" in the problem, but instead, describe the background and requirements to present a non-programming problem and you must have a wh-question in your problem. 121 | - You should clearly denote the input variable names (but not an exact or specific value) in your problem statement, and clearly ask for the returned value, to be consistent with the `main_solution` function. 122 | - You do not need to include again the input and output variable requirements or any examples in this part. 123 | 124 | --------- 125 | 126 | Your final output should be like this: 127 | ## Main Function 128 | ```python 129 | # import necessary packages 130 | import ... 131 | from ... 132 | 133 | # all class and function definitions in the code file, if any 134 | # they will be used in the main_solution function, you may need to modify them to meet the requirements of the main_solution function (optional) 135 | class ... 136 | def ... 137 | 138 | # main function 139 | def main_solution(arg1, arg2, ...): 140 | # all input arguments of the main_solution function should be json serializable (no self-defined objects, functions, np.array, set, tuple, etc.) 141 | # if you need to invoke functions that require non-json serializable inputs, like those defined in the reference code file, you need to convert them (optional) 142 | ... 143 | # return, the returned value must be json serializable (no self-defined objects, functions, np.array, set, tuple, etc.) 144 | return ... 145 | ``` 146 | ## Input Output Description 147 | Input: 148 | `arg1` (type): description 149 | `arg2` (type): description 150 | ... 151 | Output: 152 | `return` (type): description 153 | ## Input Generator 154 | ```python 155 | def input_generator(): 156 | # generate input arguments for the main_solution function 157 | ... 158 | return {"arg1": ..., "arg2": ..., ...} 159 | ``` 160 | ## Problem Statement 161 | ... (with a wh-question, the input variables names should be in the questions) ... 162 | 163 | --------- 164 | 165 | Here is the code file you need to process: 166 | [Code Start] 167 | <<<>>> 168 | [Code End]""" 169 | 170 | # 在函数外部,预编译正则表达式 171 | exception_pattern = re.compile( 172 | r"Traceback \(most recent call last\):\s*(?:.*\n)+([a-zA-Z_][a-zA-Z0-9_]*):\s*(.+)", 173 | re.MULTILINE 174 | ) 175 | 176 | template_check_input = """{solution_prefix} 177 | 178 | {refcode} 179 | 180 | def is_close(pred, target, tol=0.001): 181 | if isinstance(pred, dict) and isinstance(target, dict): 182 | if pred.keys() != target.keys(): 183 | return False 184 | return all(is_close(pred[k], target[k], tol) for k in pred) 185 | 186 | elif isinstance(pred, list) and isinstance(target, list): 187 | if len(pred) != len(target): 188 | return False 189 | return all(is_close(p, t, tol) for p, t in zip(pred, target)) 190 | 191 | elif isinstance(pred, (int, float)) and isinstance(target, (int, float)): 192 | if isinstance(pred, float) or isinstance(target, float): 193 | # if we have non number, like nan, inf, we should not compare them 194 | if math.isnan(pred) or math.isnan(target) or math.isinf(pred) or math.isinf(target): 195 | return False 196 | return (abs(pred - target) <= tol * abs(target)) and (int(pred) == int(target)) 197 | return pred == target 198 | 199 | else: 200 | return pred == target 201 | 202 | def diy_check_input_output(): 203 | iiiioooo = {io} 204 | 205 | input_xx = iiiioooo['input'] # should be a json object 206 | output_xx = iiiioooo['output'] # should be a json object 207 | 208 | warning_string = "[Mismatch] Your input is not feasible! Given the output <<<<3>>>>, your predicted input is <<<<1>>>>, which actually gets a wrong output as <<<<2>>>>" 209 | 210 | string_iii = json.dumps(input_xx) 211 | string_ooo = json.dumps(output_xx).strip() 212 | 213 | execed_output = None 214 | 215 | if not {bypass}: 216 | if isinstance(input_xx, dict): 217 | execed_output = {funcname}(**input_xx) 218 | else: 219 | execed_output = {funcname}(input_xx) 220 | else: 221 | execed_output = {funcname}(input_xx) 222 | 223 | string_eee = json.dumps(execed_output).strip() 224 | 225 | cond1 = string_ooo == string_eee 226 | cond2 = is_close(execed_output, output_xx) 227 | 228 | assert cond1 or cond2, warning_string.replace( 229 | "<<<<1>>>>", string_iii).replace("<<<<2>>>>", string_eee).replace("<<<<3>>>>", string_ooo) 230 | 231 | diy_check_input_output() 232 | """ 233 | 234 | def sub_extract_last_complete_json(s): 235 | if '```json' not in s: 236 | # Stack to keep track of opening and closing braces 237 | stack = [] 238 | last_json_start = None 239 | last_json_str = None 240 | 241 | for i, char in enumerate(s): 242 | if char == '{': 243 | stack.append(i) 244 | if last_json_start is None: 245 | last_json_start = i 246 | elif char == '}': 247 | if stack: 248 | start = stack.pop() 249 | if not stack: 250 | # Complete JSON object found 251 | last_json_str = s[last_json_start:i+1] 252 | last_json_start = None 253 | else: 254 | # find the last ```json 255 | last_json_start = s.rfind('```json') 256 | last_json_end = s.find('```', last_json_start+len('```json')) 257 | last_json_str = s[last_json_start+7:last_json_end].strip() 258 | 259 | # Load the last JSON object 260 | 261 | if last_json_str: 262 | try: 263 | return json.loads(last_json_str.replace("\n", "")) 264 | except json.JSONDecodeError: 265 | # replace 'False', 'True' to 'false', 'true' 266 | last_json_str = last_json_str.replace("False", "false").replace("True", "true").replace("None", "null") 267 | try: 268 | return json.loads(last_json_str.replace("\n", "")) 269 | except json.JSONDecodeError: 270 | pass 271 | return None 272 | 273 | def extract_last_complete_json(s): 274 | res = sub_extract_last_complete_json(s) 275 | if res is None: 276 | s = s.replace("\{","{").replace("\}","}").replace('(','[').replace(')',']') 277 | res = sub_extract_last_complete_json(s) 278 | if res is None and "\\boxed{" in s: 279 | boxstart = s.rfind("\\boxed{")+len("\\boxed{") 280 | boxend = s.rfind("}",boxstart) 281 | boxcontent = s[boxstart:boxend] 282 | processed_box_content = boxcontent.replace("\\\\","\\").replace("\\{","{").replace("\\}","}").replace('\\left','').replace('\\right','') 283 | res = sub_extract_last_complete_json(processed_box_content) 284 | return res 285 | 286 | def strict_check_size(obj): 287 | # Check if object size is less than 1024 bytes 288 | if asizeof.asizeof(obj) >= 1024: 289 | return False 290 | 291 | # Check for dict type 292 | if isinstance(obj, dict): 293 | if len(obj) >= 20: # Check dict has fewer than 20 key-value pairs 294 | return False 295 | # Recursively check keys and values 296 | for k, v in obj.items(): 297 | if not strict_check_size(k) or not strict_check_size(v): 298 | return False 299 | 300 | # Check for list, tuple, or set 301 | elif isinstance(obj, (list, tuple, set)): 302 | if len(obj) >= 20: # Check if the length is less than 20 303 | return False 304 | # Recursively check each element 305 | for item in obj: 306 | if not strict_check_size(item): 307 | return False 308 | 309 | # Check for string 310 | elif isinstance(obj, str): 311 | if len(obj) >= 100: # Check if string length is less than 100 characters 312 | return False 313 | 314 | # elif isinstance(obj, float): 315 | # d = Decimal(str(obj)) 316 | # if d.as_tuple().exponent < -3: 317 | # return False 318 | 319 | # Other objects - check size in bytes 320 | else: 321 | if asizeof.asizeof(obj) >= 128: # Check if object size is less than 128 bytes 322 | return False 323 | 324 | # If all checks are passed, return True 325 | return True 326 | 327 | def combine(mainbody,funcname, args, output_file="output.json"): 328 | return solution_prefix+'\n\n\n'+mainbody+'\n\n\n'+exec_part.replace('<>',funcname).replace('<>',args).replace('<>',"\""+output_file+"\"") 329 | 330 | def get_output(mainbody, funcname, args, debug=False): 331 | uid = shortuuid.uuid() 332 | pyfilename = "./temp/solutions/solution."+uid+".py" 333 | outputfilename = "./temp/solutions/output."+uid+".json" 334 | 335 | if not os.path.exists(pyfilename): 336 | os.makedirs(os.path.dirname(pyfilename), exist_ok=True) 337 | 338 | solution_py = combine(mainbody,funcname, args, outputfilename) 339 | 340 | if debug: 341 | print('=================') 342 | print(solution_py) 343 | print('=================') 344 | 345 | # start a commend to run the code 346 | # write it into a file 347 | 348 | with open(pyfilename, 'w') as f: 349 | f.write(solution_py) 350 | 351 | # run the code 352 | # if error, raise ValueError 353 | subprocess.run(["python3", pyfilename], check=True) 354 | 355 | # read the output.json 356 | with open(outputfilename, 'r') as f: 357 | output = f.read() 358 | return output 359 | 360 | def extract_last_python(text): 361 | posstart = text.rfind("```python") 362 | if posstart == -1: 363 | return None 364 | posstart+=len("```python") 365 | posend = text.find("```", posstart) 366 | if posend == -1: 367 | return None 368 | return text[posstart:posend].strip() 369 | 370 | def extract_input(ss): 371 | last_json = extract_last_complete_json(ss) 372 | if last_json is None: 373 | return None 374 | if isinstance(last_json,dict): 375 | inputx = last_json.get('input',None) 376 | else: 377 | inputx = last_json 378 | return inputx 379 | 380 | def check_input_legacy(refcode, io, funcname, 381 | solution_prefix="", 382 | used_python_path="x", 383 | run_path = "x", 384 | runtime_limit=5, 385 | bypass=False, 386 | ): 387 | strbypass = "True" if bypass else "False" 388 | 389 | runnablepycode = template_check_input.format( 390 | solution_prefix=solution_prefix, refcode=refcode, io=io, funcname=funcname, bypass=strbypass 391 | ) 392 | 393 | result_dict = {} 394 | 395 | try: 396 | # Run the code with a timeout of 5 seconds 397 | result = subprocess.run( 398 | [used_python_path, '-'], 399 | input=runnablepycode, 400 | stdout=subprocess.DEVNULL, # Discard standard output 401 | stderr=subprocess.PIPE, # Capture standard error 402 | text=True, 403 | timeout=runtime_limit, 404 | cwd = run_path 405 | ) 406 | if result.returncode == 0: 407 | # Success 408 | result_dict['status'] = 'success' 409 | result_dict['message'] = 'Feasible input!' 410 | else: 411 | # Error occurred 412 | stderr = result.stderr 413 | 414 | # Attempt to extract the specific exception type and message 415 | exception_type = None 416 | exception_message = None 417 | 418 | # Pattern to match Python traceback exceptions 419 | match = exception_pattern.search(stderr) 420 | 421 | if match: 422 | # Extract exception type and message 423 | exception_type = match.group(1) 424 | exception_message = match.group(2).strip() 425 | 426 | # Special handling for AssertionError 427 | if exception_type == 'AssertionError': 428 | result_dict['status'] = 'AssertionError' 429 | result_dict['message'] = exception_message or 'An assertion error occurred.' 430 | else: 431 | result_dict['status'] = 'exception' 432 | result_dict['exception_type'] = exception_type 433 | result_dict['message'] = exception_message 434 | else: 435 | # If pattern matching fails, return the last line as the error message 436 | lines = stderr.strip().splitlines() 437 | if lines: 438 | last_line = lines[-1] 439 | result_dict['status'] = 'exception' 440 | result_dict['message'] = last_line.strip() 441 | else: 442 | result_dict['status'] = 'error' 443 | result_dict['message'] = 'An unknown error occurred.' 444 | except subprocess.TimeoutExpired: 445 | # Timeout 446 | result_dict['status'] = 'timeout' 447 | result_dict['message'] = f'Code execution time exceeded the limit {runtime_limit} seconds, may encounter infinite loop.' 448 | except Exception as e: 449 | # Other exceptions 450 | result_dict['status'] = 'exception' 451 | result_dict['message'] = str(e) 452 | finally: 453 | pass 454 | 455 | return result_dict 456 | 457 | def check_input(refcode, io, funcname, 458 | solution_prefix="", 459 | used_python_path="x", 460 | run_path="x", 461 | runtime_limit=5, 462 | bypass=False, 463 | ): 464 | 465 | 466 | # Define the exception pattern if not already defined 467 | exception_pattern = re.compile(r'Traceback \(most recent call last\):.*\n([\w\.]+):\s+(.*)', re.DOTALL) 468 | 469 | strbypass = "True" if bypass else "False" 470 | 471 | runnablepycode = template_check_input.format( 472 | solution_prefix=solution_prefix, refcode=refcode, io=io, funcname=funcname, bypass=strbypass 473 | ) 474 | 475 | result_dict = {} 476 | 477 | # Cross-platform process creation and termination functions 478 | # if sys.platform == 'win32': 479 | # # Windows 480 | # def start_process(*args, **kwargs): 481 | # return subprocess.Popen( 482 | # *args, 483 | # **kwargs, 484 | # creationflags=subprocess.CREATE_NEW_PROCESS_GROUP 485 | # ) 486 | # def kill_process(process): 487 | # try: 488 | # process.send_signal(signal.CTRL_BREAK_EVENT) 489 | # except Exception: 490 | # process.kill() 491 | # else: 492 | 493 | # Unix/Linux 494 | def start_process(*args, **kwargs): 495 | return subprocess.Popen( 496 | *args, 497 | **kwargs, 498 | preexec_fn=os.setsid 499 | ) 500 | def kill_process(process): 501 | try: 502 | os.killpg(os.getpgid(process.pid), signal.SIGTERM) 503 | except Exception: 504 | process.kill() 505 | 506 | process = None 507 | 508 | try: 509 | # Start the process 510 | process = start_process( 511 | [used_python_path, '-'], 512 | stdin=subprocess.PIPE, 513 | stdout=subprocess.DEVNULL, 514 | stderr=subprocess.PIPE, 515 | text=True, 516 | cwd=run_path, 517 | ) 518 | 519 | try: 520 | # Communicate with the process 521 | stdout_data, stderr_data = process.communicate( 522 | input=runnablepycode, 523 | timeout=runtime_limit 524 | ) 525 | except subprocess.TimeoutExpired: 526 | # Timeout occurred, kill the process group 527 | kill_process(process) 528 | stdout_data, stderr_data = process.communicate() 529 | result_dict['status'] = 'timeout' 530 | result_dict['message'] = f'Code execution time exceeded the limit of {runtime_limit} seconds; may have encountered an infinite loop.' 531 | return result_dict 532 | except Exception as e: 533 | # Kill the process group 534 | kill_process(process) 535 | stdout_data, stderr_data = process.communicate() 536 | result_dict['status'] = 'exception' 537 | result_dict['message'] = str(e) 538 | return result_dict 539 | else: 540 | # Process completed 541 | returncode = process.returncode 542 | if returncode == 0: 543 | # Success 544 | result_dict['status'] = 'success' 545 | result_dict['message'] = 'Feasible input!' 546 | else: 547 | # Error occurred 548 | stderr = stderr_data 549 | 550 | # Attempt to extract the specific exception type and message 551 | exception_type = None 552 | exception_message = None 553 | 554 | # Pattern to match Python traceback exceptions 555 | match = exception_pattern.search(stderr) 556 | 557 | if match: 558 | # Extract exception type and message 559 | exception_type = match.group(1) 560 | exception_message = match.group(2).strip() 561 | 562 | # Special handling for AssertionError 563 | if exception_type == 'AssertionError': 564 | result_dict['status'] = 'AssertionError' 565 | result_dict['message'] = exception_message or 'An assertion error occurred.' 566 | else: 567 | result_dict['status'] = 'exception' 568 | result_dict['exception_type'] = exception_type 569 | result_dict['message'] = exception_message 570 | else: 571 | # If pattern matching fails, return the last line as the error message 572 | lines = stderr.strip().splitlines() 573 | if lines: 574 | last_line = lines[-1] 575 | result_dict['status'] = 'exception' 576 | result_dict['message'] = last_line.strip() 577 | else: 578 | result_dict['status'] = 'error' 579 | result_dict['message'] = 'An unknown error occurred.' 580 | finally: 581 | # Ensure that the process is terminated and resources are cleaned up 582 | if process is not None: 583 | try: 584 | kill_process(process) 585 | except Exception: 586 | pass 587 | # Wait for the process to terminate to prevent zombies 588 | process.wait() 589 | # Close any open file descriptors 590 | process.stdout.close() if process.stdout else None 591 | process.stderr.close() if process.stderr else None 592 | process.stdin.close() if process.stdin else None 593 | 594 | return result_dict 595 | 596 | def is_close(pred, target, tol=0.001): 597 | if isinstance(pred, dict) and isinstance(target, dict): 598 | if pred.keys() != target.keys(): 599 | return False 600 | return all(is_close(pred[k], target[k], tol) for k in pred) 601 | 602 | elif isinstance(pred, list) and isinstance(target, list): 603 | if len(pred) != len(target): 604 | return False 605 | return all(is_close(p, t, tol) for p, t in zip(pred, target)) 606 | 607 | elif isinstance(pred, (int, float)) and isinstance(target, (int, float)): 608 | try: 609 | if isinstance(pred, float) or isinstance(target, float): 610 | # if we have non number, like nan, inf, we should not compare them 611 | if math.isnan(pred) or math.isnan(target) or math.isinf(pred) or math.isinf(target): 612 | return False 613 | return (abs(pred - target) <= tol * abs(target)) and (int(pred) == int(target)) 614 | return pred == target 615 | except: 616 | return False 617 | else: 618 | return pred == target -------------------------------------------------------------------------------- /src/parse_gen_ios.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import * 3 | from codeio_utils import * 4 | import subprocess 5 | import tempfile 6 | import os 7 | import json 8 | import signal 9 | from time import sleep 10 | 11 | from multiprocessing import Pool 12 | from collections import defaultdict 13 | from tqdm import tqdm 14 | import gc 15 | import shutil 16 | 17 | def resolve_output(output): 18 | # extract_last_python 19 | pos1 = output.find("## Main Function") 20 | if pos1 == -1: 21 | return None 22 | pos11 = pos1+len("## Main Function") 23 | pos2 = output.find("## Input Output Description", pos11) 24 | if pos2 == -1: 25 | return None 26 | maincode = extract_last_python(output[pos11:pos2].strip()) 27 | if maincode is None: 28 | return None 29 | 30 | pos21 = pos2+len("## Input Output Description") 31 | pos3 = output.find("## Input Generator", pos21) 32 | if pos3 == -1: 33 | pos3 = output.find("## Input generator", pos21) 34 | if pos3 == -1: 35 | return None 36 | io_desc = output[pos21:pos3].strip() 37 | 38 | pos31 = pos3+len("## Input Generator") 39 | pos4 = output.find("## Problem Statement", pos31) 40 | if pos4 == -1: 41 | return None 42 | inputgencode = extract_last_python(output[pos31:pos4].strip()) 43 | if inputgencode is None: 44 | return None 45 | 46 | pos41 = pos4+len("## Problem Statement") 47 | pos5 = len(output) 48 | problem_statement = output[pos41:pos5].strip() 49 | 50 | if "main_solution" not in maincode: 51 | return None 52 | if "input_generator" not in inputgencode: 53 | return None 54 | 55 | return { 56 | "maincode": maincode, 57 | "io_desc": io_desc, 58 | "inputgencode": inputgencode, 59 | "problem_statement": problem_statement 60 | } 61 | 62 | templatexx = """ 63 | import json 64 | from pympler import asizeof 65 | 66 | def strict_check_size(obj): 67 | # Check if object size is less than 1024 bytes 68 | if asizeof.asizeof(obj) >= 1024: 69 | return False 70 | 71 | # Check for dict type 72 | if isinstance(obj, dict): 73 | if len(obj) >= 20: # Check dict has fewer than 20 key-value pairs 74 | return False 75 | # Recursively check keys and values 76 | for k, v in obj.items(): 77 | if not strict_check_size(k) or not strict_check_size(v): 78 | return False 79 | 80 | # Check for list, tuple, or set 81 | elif isinstance(obj, (list, tuple, set)): 82 | if len(obj) >= 20: # Check if the length is less than 20 83 | return False 84 | # Recursively check each element 85 | for item in obj: 86 | if not strict_check_size(item): 87 | return False 88 | 89 | # Check for string 90 | elif isinstance(obj, str): 91 | if len(obj) >= 100: # Check if string length is less than 100 characters 92 | return False 93 | 94 | # elif isinstance(obj, float): 95 | # d = Decimal(str(obj)) 96 | # if d.as_tuple().exponent < -3: 97 | # return False 98 | 99 | # Other objects - check size in bytes 100 | else: 101 | if asizeof.asizeof(obj) >= 128: # Check if object size is less than 128 bytes 102 | return False 103 | 104 | # If all checks are passed, return True 105 | return True 106 | 107 | <<<|!!|!!|maincode|!!|!!|>>> 108 | 109 | <<<|!!|!!|inputgencode|!!|!!|>>> 110 | 111 | diff_inputs = [] 112 | corr_outputs = [] 113 | for i in range(1000): 114 | cand_input = input_generator() 115 | if cand_input not in diff_inputs and strict_check_size(cand_input): 116 | cand_output = main_solution(**cand_input) 117 | if strict_check_size(cand_output) and cand_output is not None: 118 | diff_inputs.append(cand_input) 119 | corr_outputs.append(cand_output) 120 | if len(diff_inputs) >= 10: 121 | break 122 | 123 | assert len(diff_inputs) == len(corr_outputs) 124 | 125 | iolist = [{"input": diff_inputs[i], "output": corr_outputs[i]} for i in range(len(diff_inputs))] 126 | 127 | jsoniolist = json.dumps(iolist) 128 | 129 | print("[JSON IOS START]" + jsoniolist + "[JSON IOS END]") 130 | """ 131 | 132 | def process_one_item(res): 133 | 134 | maincode = res['maincode'] 135 | io_desc = res['io_desc'] 136 | inputgencode = res['inputgencode'] 137 | problem_statement = res['problem_statement'] 138 | pyruncode = "import json\nfrom pympler import asizeof\n\n"+templatexx.replace("<<<|!!|!!|maincode|!!|!!|>>>", maincode).replace("<<<|!!|!!|inputgencode|!!|!!|>>>", inputgencode) 139 | 140 | runtime_limit = 60 # seconds 141 | 142 | try: 143 | # Start the subprocess in a new session (process group) 144 | proc = subprocess.Popen( 145 | [used_python_path, '-'], 146 | stdin=subprocess.PIPE, 147 | stdout=subprocess.PIPE, 148 | stderr=subprocess.PIPE, 149 | cwd=run_path, 150 | text=True, 151 | start_new_session=True 152 | ) 153 | 154 | try: 155 | # Communicate with the subprocess 156 | stdout, stderr = proc.communicate(input=pyruncode, timeout=runtime_limit) 157 | except subprocess.TimeoutExpired: 158 | # Timeout expired; kill the process group 159 | os.killpg(proc.pid, signal.SIGTERM) # Send SIGTERM to the process group 160 | stdout, stderr = proc.communicate() 161 | return None 162 | except Exception as e: 163 | # Other exception occurred; kill the process group 164 | os.killpg(proc.pid, signal.SIGTERM) 165 | stdout, stderr = proc.communicate() 166 | return None 167 | finally: 168 | # Ensure the subprocess is terminated 169 | proc.kill() 170 | proc.wait() 171 | 172 | # Process stdout as before 173 | start_marker = "[JSON IOS START]" 174 | end_marker = "[JSON IOS END]" 175 | 176 | if start_marker in stdout and end_marker in stdout: 177 | start_index = stdout.index(start_marker) + len(start_marker) 178 | end_index = stdout.index(end_marker) 179 | json_str = stdout[start_index:end_index].strip() 180 | 181 | # Load the JSON string 182 | json_data = json.loads(json_str) 183 | return json_data 184 | else: 185 | # If markers are not found, return None 186 | return None 187 | 188 | except Exception as e: 189 | # Handle any exceptions that might occur while setting up the subprocess 190 | return None 191 | 192 | def process_item(item): 193 | res = resolve_output(item['output']) 194 | if res is None: 195 | return None 196 | ios = process_one_item(res) 197 | if ios is None or len(ios)<=1: 198 | return None 199 | sample = { 200 | "problem_description": res['problem_statement'], 201 | "io_requirements": res['io_desc'], 202 | "refcode": res['maincode'], 203 | "funcname": "main_solution", 204 | "ios": ios, # can be empty, to indicate all of them are too large 205 | # "source": "z3examples", 206 | "category": item.get('category',None), 207 | "meta": { 208 | # "repo_name": item['repo_name'], 209 | # "path": item['path'], 210 | # "original_sourece": item['source'], 211 | # "eid": item['eid'], 212 | "msgidx": item['index'], 213 | } 214 | } 215 | # Clean up 216 | del res 217 | del ios 218 | gc.collect() 219 | sleep(0.05) 220 | return sample 221 | 222 | if __name__ == "__main__": 223 | import argparse 224 | parser = argparse.ArgumentParser() 225 | parser.add_argument('--input_file', type=str, default="data/rawcode_1k_unified.jsonl") 226 | parser.add_argument('--output_file', type=str, default="data/rawcode_1k_parsed_2.jsonl") 227 | parser.add_argument('--python_path', type=str, default="python") 228 | parser.add_argument('--run_path', type=str, default="./temp/temp/temp") 229 | args = parser.parse_args() 230 | 231 | if not os.path.exists(args.run_path): 232 | os.makedirs(args.run_path,exist_ok=True) 233 | 234 | used_python_path = args.python_path 235 | run_path = args.run_path 236 | 237 | fn = args.input_file 238 | 239 | ofn = args.output_file 240 | dt = read_jsonl(fn) 241 | 242 | if os.path.exists(ofn): 243 | exdt = read_jsonl(ofn) 244 | exindex = set([x['meta']['msgidx'] for x in exdt]) 245 | else: 246 | exindex = set() 247 | 248 | dt = [x for x in dt if x['index'] not in exindex] 249 | 250 | print("Skip existing:", len(exindex)) 251 | 252 | adt = [] 253 | goodcount=0 254 | prevgoodcount=0 255 | totalcount=0 256 | prevtotalcount=0 257 | with Pool(processes=64, maxtasksperchild=10) as pool: 258 | for result in tqdm(pool.imap_unordered(process_item, dt), total=len(dt)): 259 | totalcount+=1 260 | if result is not None: 261 | adt.append(result) 262 | goodcount+=1 263 | if len(adt) >= 100: 264 | write_jsonl(adt, ofn,"a") 265 | adt = [] 266 | print(f"{goodcount}/{totalcount}") 267 | detlagoodcount = goodcount - prevgoodcount 268 | detlatotalcount = totalcount - prevtotalcount 269 | print(f"Delta: {detlagoodcount}/{detlatotalcount} = {detlagoodcount/detlatotalcount}") 270 | prevgoodcount = goodcount 271 | prevtotalcount = totalcount 272 | 273 | if len(adt) > 0: 274 | write_jsonl(adt, ofn,"a") 275 | print(f"Final - {goodcount}/{totalcount}") 276 | 277 | try: 278 | shutil.rmtree(args.run_path) 279 | except Exception as e: 280 | print(f"Error: {e}") -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | def load_jsonl_yield(path): 5 | with open(path) as f: 6 | for row, line in enumerate(f): 7 | try: 8 | line = json.loads(line) 9 | yield line 10 | except: 11 | pass 12 | 13 | def read_jsonl(jsonl_file_path): 14 | s = [] 15 | with open(jsonl_file_path, "r") as f: 16 | lines = f.readlines() 17 | for line in lines: 18 | linex = line.strip() 19 | if linex == "": 20 | continue 21 | s.append(json.loads(linex)) 22 | return s 23 | 24 | def write_jsonl(data, jsonl_file_path, mode="w"): 25 | # data is a list, each of the item is json-serilizable 26 | assert isinstance(data, list) 27 | if len(data) == 0: 28 | return 29 | if not os.path.exists(os.path.dirname(jsonl_file_path)): 30 | os.makedirs(os.path.dirname(jsonl_file_path)) 31 | with open(jsonl_file_path, mode) as f: 32 | for item in data: 33 | f.write(json.dumps(item) + "\n") 34 | 35 | def elegant_show(something, level=0, sid=0, full=False): 36 | # str,float,int 37 | # all print in this call should add level*4 spaces 38 | prefix = "\t" * level 39 | 40 | if isinstance(something, (str, float, int)) or something is None: 41 | if isinstance(something, str): 42 | # if '\n' in something: 43 | # something = '\n'+something 44 | # add prefix whenever go to a new line in this string 45 | something = something.replace("\n", f"\n{prefix}") 46 | print(prefix, f"\033[1;35mElement: \033[0m", something) 47 | elif isinstance(something, list) or isinstance(something, tuple): 48 | # take a random example, and length 49 | # sid = 0 50 | if len(something) == 0: 51 | print( 52 | prefix, 53 | f"\033[1;33mLen: \033[0m{len(something)} \t\033[1;33m& No elements! \033[0m", 54 | ) 55 | elif not full or len(something) == 1: 56 | print( 57 | prefix, 58 | f"\033[1;33mLen: \033[0m{len(something)} \t\033[1;33m& first element ...\033[0m", 59 | ) 60 | elegant_show(something[sid], level + 1, sid, full) 61 | else: 62 | print( 63 | prefix, 64 | f"\033[1;33mLen: \033[0m{len(something)} \t\033[1;33m& Elements ...\033[0m", 65 | ) 66 | for i in range(len(something) - 1): 67 | elegant_show(something[i], level + 1, sid, full) 68 | print( 69 | prefix + "\t", f"\033[1;33m-------------------------------\033[0m" 70 | ) 71 | elegant_show(something[-1], level + 1, sid, full) 72 | 73 | elif isinstance(something, dict): 74 | for k, v in something.items(): 75 | print(prefix, f"\033[1;34mKey: \033[0m{k} \033[1;34m...\033[0m") 76 | elegant_show(v, level + 1, sid, full) 77 | else: 78 | print(prefix, f"\033[1;31mError @ Type: \033[0m{type(something)}") 79 | # raise NotImplementedError 80 | 81 | def build_messages(prompt, response = None, system_message = None): 82 | messages = [] 83 | if system_message is not None: 84 | messages.append({"role":"system","content":system_message}) 85 | messages.append({"role":"user","content":prompt}) 86 | if response is not None: 87 | messages.append({"role":"assistant","content":response}) 88 | return messages --------------------------------------------------------------------------------