├── .gitignore ├── README.md ├── environment.sh ├── figures ├── case_study.png ├── logo.png └── pipeline.png ├── requirements.txt ├── run_dpo_train.sh ├── run_testing.sh ├── run_train_testing.sh ├── run_training.sh ├── test ├── PDEcontrol │ ├── .gitignore │ ├── cog.yaml │ └── evaluation │ │ ├── data_processing │ │ ├── __init__.py │ │ ├── answer_extraction.py │ │ └── process_utils.py │ │ ├── eval │ │ ├── __init__.py │ │ ├── eval_robustness_reasoning_wrapper.py │ │ ├── eval_robustness_wrapper.py │ │ ├── eval_script.py │ │ └── utils.py │ │ ├── infer │ │ └── simulate_gt.py │ │ └── scripts │ │ └── infer_pdecontrol.sh ├── README.md ├── requirements.txt └── scripts │ ├── read_result.py │ ├── simulate_gt.sh │ └── test_pdecontrol.sh ├── train ├── README.md ├── config │ ├── deepspeed.json │ └── deepspeed_dpo.json ├── scripts │ ├── group_text.py │ ├── group_text_dpo.py │ ├── merge_model.py │ ├── tokenize_data.py │ ├── tokenize_data_dpo.py │ ├── train.sh │ ├── train_dpo.sh │ └── utils │ │ ├── loader.py │ │ ├── trainer.py │ │ └── util.py ├── train.py ├── train_dpo.py ├── train_finetune.py ├── utils │ ├── loader.py │ ├── trainer.py │ └── util.py └── validate.py ├── train_test_combined.yml └── utils └── few_shot_prompts ├── __init__.py ├── cot_one_d_combined_fewshot.py ├── cot_one_d_heat_fewshot.py ├── cot_one_d_wave_fewshot.py ├── examples ├── DPO_one_d_combined │ ├── dataset_description.json │ └── examples.jsonl ├── DPO_one_d_heat │ ├── dataset_description.json │ └── examples.jsonl ├── DPO_one_d_wave │ ├── dataset_description.json │ └── examples.jsonl ├── one_d_combined │ ├── dataset_description.json │ └── examples.jsonl ├── one_d_heat │ ├── dataset_description.json │ └── examples.jsonl └── one_d_wave │ ├── dataset_description.json │ └── examples.jsonl ├── few_shot_prompting.py ├── few_shot_test.py ├── few_shot_train.py └── few_shot_train_dpo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | */__pycache__/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | # WandB 166 | wandb/ 167 | 168 | # extra outputs: 169 | train/scripts/file-names_to-be-tokenized.json 170 | train/scripts/file_names_to-be-trained.json 171 | 172 | # Femformal 173 | 174 | nusmv/ 175 | doc/**/*.pdf 176 | .DS_Store 177 | *.pyc 178 | *.lp 179 | .srclist 180 | .tmuxp.yaml 181 | tags 182 | *.egg-info 183 | out_vars.txt 184 | temp_plots/ 185 | doc/source/femformal/stubs 186 | doc/build 187 | out.ilp 188 | out.txt 189 | 190 | 191 | # robustness stuff from delta: 192 | *test/PDEcontrol/evaluation/eval/more_heat_robust_.npy 193 | *test/PDEcontrol/evaluation/eval/more_heat_time_.npy 194 | 195 | *more_heat_robust_.npy 196 | *more_heat_time_.npy 197 | 198 | *control/femformal/examples/heat_mix/*eval_llm*.py 199 | *control/femformal/examples/mech_mix/*eval_llm*.py 200 | *control/femformal/examples/heat_mix/*eval_llm*.pyc 201 | *control/femformal/examples/mech_mix/*eval_llm*.pyc 202 | 203 | 204 | # unprocessesed data because they may be too large 205 | data_processing/* 206 | 207 | datasets/ 208 | models/ 209 | outputs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PDE-Controller 2 | 3 | This repository contains the implementation for [PDE-Controller: LLMs for Autoformalization and Reasoning of PDEs](http://arxiv.org/abs/2502.00963) 4 | 5 |
6 |
7 | 8 |
9 | 10 |

11 | 🤗 Hugging Face Datasets   |    📑 Paper    |   📖 Project Page 12 |

13 | 14 | ## Dataset and Models 15 | Our datasets, are released at [Hugging-face](https://huggingface.co/datasets/delta-lab-ai/pde-controller). 16 | 17 | 19 | 20 | 21 | 22 | The PDE-Controller models are as follows: 23 | 24 | |Model Name|Huggingface Link| 25 | |:--|:--| 26 | |Translator|🤗 [link](https://huggingface.co/delta-lab-ai/translator)| 27 | |Coder|🤗 [link](https://huggingface.co/delta-lab-ai/coder)| 28 | |Controller|🤗 [link](https://huggingface.co/delta-lab-ai/controller)| 29 | |Fine-tuned Coder for Controller|🤗 [link](https://huggingface.co/delta-lab-ai/finetuned_coder)| 30 | 31 | 32 | ## Introduction 33 | 34 | We present [PDE-Controller](http://arxiv.org/abs/2502.00963), a framework that enables large language models (LLMs) to control systems governed by partial differential equations (PDEs). Traditional LLMs have excelled in commonsense reasoning but fall short in rigorous logical reasoning. While recent AI-for-math has made strides in pure mathematics, areas of applied mathematics, particularly PDEs, remain underexplored despite their significant real-world applications. Our approach enables LLMs to transform informal natural language instructions into formal specifications, and then execute reasoning and planning steps to improve the utility of PDE control. We build a holistic solution comprising datasets (both human-written cases and 2 million synthetic samples), math-reasoning models, and novel evaluation metrics, all of which require significant effort. Our PDE-Controller significantly outperforms the latest open-source and GPT models in reasoning, autoformalization, and program synthesis, achieving up to a 62% improvement in utility gain for PDE control. By bridging the gap between language generation and PDE systems, we demonstrate the potential of LLMs in addressing complex scientific and engineering challenges. 35 | 36 |
37 |
38 | 39 |
40 | 41 | ## Case Study 42 | 43 | An example of LLM reasoning for PDE control on heat (top) and wave (bottom) problems. 44 | 45 | 46 | 47 |
48 |
49 | 50 |
51 | 52 | ## Installation 53 | For training only the `trainenv` environment is required. For testing the `trainenv` environment will work to obtain model outputs. In order to simulate utilities with the Gurobi optimizer and the Femformal repository in Python 2, an additional environment, `pdecontrol`, is required. Only `trainenv` need to be activated to run the code, activating `pdecontrol` will be done by the [main test script](./test/PDEcontrol/evaluation/infer/run_1d_pdecontrol_eval_full.py) 54 | 55 | ## Training and Testing in python 3 56 | Create a conda environment: 57 | 58 | ```shell 59 | conda create -n trainenv python=3.10 60 | ``` 61 | 62 | Activate the environment: 63 | 64 | ```shell 65 | conda activate trainenv 66 | ``` 67 | 68 | Run the commands in `environment.sh` to install the required packages. 69 | 70 | 71 | 77 | 78 | 83 | 84 | ## Training 85 | 86 | The documentation for training is at: [training](train/README.md) 87 | 88 | ## Testing 89 | 90 | The documentation for testing is at: [evaluation](test/README.md). 91 | 92 | ## Citation 93 | 94 | If you find this repository helpful, please consider citing our paper: 95 | 96 | ``` 97 | @article{soroco2025pdecontroller, 98 | title={PDE-Controller: LLMs for Autoformalization and Reasoning of PDEs}, 99 | author={Soroco, Mauricio and Song, Jialin and Xia, Mengzhou and Emond, Kye and Sun, Weiran and Chen, Wuyang}, 100 | journal={arXiv preprint arXiv:2502.00963}, 101 | year={2025} 102 | } 103 | ``` -------------------------------------------------------------------------------- /environment.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | pip install torch==2.4.0 6 | pip install torchaudio==2.4.0 7 | 8 | pip install xformers==0.0.27.post2 9 | 10 | # conda env update -f train_test_combined.yml 11 | pip install -r requirements.txt 12 | 13 | pip install gradio==5.9.1 14 | 15 | pip install flash-attn --no-build-isolation 16 | -------------------------------------------------------------------------------- /figures/case_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/figures/case_study.png -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/figures/logo.png -------------------------------------------------------------------------------- /figures/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/figures/pipeline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | accelerate==1.2.0 3 | aiofiles==23.2.1 4 | aiohttp==3.9.1 5 | aiosignal==1.3.1 6 | airportsdata==20241001 7 | annotated-types==0.7.0 8 | anyio==4.7.0 9 | async-timeout==4.0.3 10 | attrs==23.1.0 11 | beautifulsoup4==4.12.3 12 | bitarray==3.0.0 13 | bs4==0.0.2 14 | cachetools==5.3.2 15 | certifi==2022.12.7 16 | charset-normalizer==2.1.1 17 | click==8.1.7 18 | cloudpickle==3.1.0 19 | compressed-tensors==0.6.0 20 | datasets==3.2.0 21 | deepspeed==0.15.4 22 | dill==0.3.7 23 | diskcache==5.6.3 24 | distro==1.9.0 25 | docker-pycreds==0.4.0 26 | docstring_parser==0.16 27 | editdistance==0.8.1 28 | einops==0.7.0 29 | exceptiongroup==1.2.2 30 | fastapi==0.115.6 31 | ffmpy==0.5.0 32 | filelock==3.16.1 33 | fire==0.5.0 34 | # flash-attn==2.7.2.post1 35 | frozenlist==1.4.0 36 | fsspec==2023.10.0 37 | gguf==0.10.0 38 | gitdb==4.0.11 39 | GitPython==3.1.43 40 | google-auth==2.25.2 41 | google-auth-oauthlib==1.0.0 42 | gradio==5.9.1 43 | gradio_client==1.5.2 44 | grpcio==1.60.0 45 | h11==0.14.0 46 | hjson==3.1.0 47 | httpcore==1.0.7 48 | httptools==0.6.4 49 | httpx==0.28.1 50 | huggingface-hub==0.26.5 51 | idna==3.4 52 | importlib_metadata==8.5.0 53 | interegular==0.3.3 54 | Jinja2==3.1.2 55 | jiter==0.8.2 56 | jsonlines==4.0.0 57 | jsonschema==4.23.0 58 | jsonschema-specifications==2024.10.1 59 | lark==1.2.2 60 | llvmlite==0.43.0 61 | lm-format-enforcer==0.10.6 62 | loader==2017.9.11 63 | loaderr==1.0.0 64 | Markdown==3.5.1 65 | markdown-it-py==3.0.0 66 | MarkupSafe==2.1.3 67 | mdurl==0.1.2 68 | mistral_common==1.4.4 69 | mpi4py==4.0.1 70 | mpmath==1.3.0 71 | msgpack==1.1.0 72 | msgspec==0.18.6 73 | multidict==6.0.4 74 | multiprocess==0.70.15 75 | nest-asyncio==1.6.0 76 | networkx==3.0 77 | ninja==1.11.1.1 78 | numba==0.60.0 79 | numpy==1.25.0 80 | nvidia-cublas-cu12==12.1.3.1 81 | nvidia-cuda-cupti-cu12==12.1.105 82 | nvidia-cuda-nvrtc-cu12==12.1.105 83 | nvidia-cuda-runtime-cu12==12.1.105 84 | nvidia-cudnn-cu12==9.1.0.70 85 | nvidia-cufft-cu12==11.0.2.54 86 | nvidia-curand-cu12==10.3.2.106 87 | nvidia-cusolver-cu12==11.4.5.107 88 | nvidia-cusparse-cu12==12.1.0.106 89 | nvidia-ml-py==12.560.30 90 | nvidia-nccl-cu12==2.20.5 91 | nvidia-nvjitlink-cu12==12.4.127 92 | nvidia-nvtx-cu12==12.1.105 93 | oauthlib==3.2.2 94 | openai==1.58.1 95 | opencv-python-headless==4.10.0.84 96 | orjson==3.10.12 97 | outlines==0.0.43 98 | outlines_core==0.1.26 99 | packaging==23.2 100 | pandas==2.1.4 101 | partial-json-parser==0.2.1.1.post4 102 | Pebble==5.1.0 103 | peft==0.14.0 104 | pillow==10.3.0 105 | platformdirs==4.2.2 106 | prometheus-fastapi-instrumentator==7.0.0 107 | prometheus_client==0.21.1 108 | protobuf==4.25.1 109 | psutil==5.9.6 110 | py-cpuinfo==9.0.0 111 | pyairports==2.1.1 112 | pyarrow==18.1.0 113 | pyarrow-hotfix==0.6 114 | pyasn1==0.5.1 115 | pyasn1-modules==0.3.0 116 | pycountry==24.6.1 117 | pydantic==2.10.3 118 | pydantic_core==2.27.1 119 | pydub==0.25.1 120 | Pygments==2.18.0 121 | python-dateutil==2.8.2 122 | python-dotenv==1.0.1 123 | python-multipart==0.0.20 124 | pytz==2023.3.post1 125 | PyYAML==6.0.1 126 | pyzmq==26.2.0 127 | ray==2.40.0 128 | referencing==0.35.1 129 | regex==2023.10.3 130 | requests==2.32.3 131 | requests-oauthlib==1.3.1 132 | rich==13.9.3 133 | rpds-py==0.22.3 134 | rsa==4.9 135 | ruff==0.8.4 136 | safehttpx==0.1.6 137 | safetensors==0.4.5 138 | semantic-version==2.10.0 139 | sentencepiece==0.2.0 140 | sentry-sdk==2.5.1 141 | setproctitle==1.3.3 142 | shellingham==1.5.4 143 | shtab==1.7.1 144 | six==1.16.0 145 | smmap==5.0.1 146 | sniffio==1.3.1 147 | soupsieve==2.5 148 | starlette==0.41.3 149 | sympy==1.13.1 150 | tensorboard==2.14.0 151 | tensorboard-data-server==0.7.2 152 | termcolor==2.4.0 153 | tiktoken==0.7.0 154 | tokenizers==0.20.3 155 | tomlkit==0.13.2 156 | torch==2.4.0 157 | torchaudio==2.4.0 158 | torchvision==0.19.0 159 | tqdm==4.67.1 160 | transformers==4.46.3 161 | triton==3.0.0 162 | trl==0.12.2 163 | typer==0.15.1 164 | typing_extensions==4.12.2 165 | tyro==0.8.14 166 | tzdata==2023.3 167 | urllib3==1.26.13 168 | uvicorn==0.32.1 169 | uvloop==0.21.0 170 | vllm==0.6.3.post1 171 | wandb==0.17.1 172 | watchfiles==1.0.0 173 | websockets==14.1 174 | Werkzeug==3.0.1 175 | # xformers==0.0.27.post2 176 | xxhash==3.4.1 177 | yarl==1.9.4 178 | zipp==3.21.0 179 | -------------------------------------------------------------------------------- /run_dpo_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # DeepSeekMath-7B mothcoder2 version 5 | MODEL_PATH=models/base_model/MathCoder2-DeepSeekMath-7B 6 | OUTPUT_DIR=outputs/dpo 7 | 8 | ##### model parameters ##### 9 | context_length=4096 10 | 11 | ##### Data parameters ##### 12 | input_file_paths=( 13 | datasets/unprocessed/dpo/dpo_d0_train 14 | datasets/unprocessed/dpo/dpo_d1_train 15 | datasets/unprocessed/dpo/dpo_d2_train 16 | ) 17 | out_file_path=datasets/dpo 18 | 19 | 20 | 21 | python train/scripts/tokenize_data_dpo.py --paths "${input_file_paths[@]}" --out_file_path ${out_file_path}/train --model_path $MODEL_PATH 22 | python train/scripts/group_text.py --max_len $context_length --out_file_path ${out_file_path}/train --model_path $MODEL_PATH --no_grouping --no_padding --balance 1 --dpo 23 | 24 | 25 | echo Starting training... 26 | 27 | bash train/scripts/train_dpo.sh -------------------------------------------------------------------------------- /run_testing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | eval "$(conda shell.bash hook)" 4 | 5 | PROJ_DIR=/localhome/mms43/scratch/mathcoder2/MathCoder2 6 | MODELS="/localhome/mms43/scratch/mathcoder2/outputs" 7 | BASE_MODEL_PATH=/localhome/mms43/scratch/mathcoder2/model_ckpts/MathCoder2-DeepSeekMath-7B 8 | OUTPUT_DIR=/localhome/mms43/scratch/mathcoder2/outputs_new 9 | 10 | ######## 11 | DIR_TRANSLATOR=$MODELS/30233_ds_to_STL_s16800/checkpoint-step-3000/merged_model 12 | # DIR_CODER=$MODELS/30236_ds_to_python_given_STL_s16800/checkpoint-step-3000/merged_model ### coder (model to evaluate) 13 | # DIR_CODER=$BASE_MODEL_PATH ### MathCoder2-DeepSeekMath-7B 14 | # DIR_CODER=$MODELS/30234_ds_to_python_no_STL_s16800/checkpoint-step-6000/merged_model ### trained baseline 15 | # DIR_CODER==$MODELS/30235_ds_to_python_GT_STL_s16800/checkpoint-step-3000/merged_model ### oracle 16 | DIR_CODER=$MODELS/10136_ds_to_python_misaligned_s16800/checkpoint-step-1500/merged_model ### finetuned coder (model to evaluate) 17 | DIR_CONTROLLER=$MODELS/10138_ds_DPO__s16800/checkpoint-step-16000/policy/merged_model ### controller 18 | 19 | 20 | 21 | 22 | 23 | gpus="1" 24 | 25 | 26 | use_openai= 27 | # use_openai='gpt-4o' 28 | # use_openai='o1-mini' 29 | 30 | 31 | cd $PROJ_DIR 32 | 33 | 34 | if [ -n "$use_openai" ]; then 35 | EVAL_DIR=$OUTPUT_DIR/$use_openai/eval 36 | else 37 | EVAL_DIR=${OUTPUT_DIR}/eval 38 | fi 39 | 40 | 41 | bash test/scripts/test_pdecontrol.sh $EVAL_DIR $gpus $DIR_TRANSLATOR $DIR_CODER $DIR_CONTROLLER $use_openai 42 | -------------------------------------------------------------------------------- /run_train_testing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset_list=$1 4 | prompt_format=$2 5 | few_shot_number=$3 6 | checkpoint_dir=$4 7 | out_file_path=$5 8 | max_samples=$6 9 | 10 | # Split the dataset_list into an array 11 | IFS=' ' read -r -a datasets <<< "$dataset_list" 12 | 13 | 14 | for dataset in "${datasets[@]}" 15 | do 16 | CMD="CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false python test/PDEcontrol/evaluation/infer/run_1d_pdecontrol_eval_train.py \ 17 | --data_dir $dataset \ 18 | --save_dir $out_file_path \ 19 | --use_vllm \ 20 | --model_name_or_path $checkpoint_dir \ 21 | --tokenizer_name_or_path $checkpoint_dir \ 22 | --eval_batch_size 1 \ 23 | --temperature 0.0 \ 24 | --prompt_format few_shot \ 25 | --few_shot_number $few_shot_number \ 26 | --few_shot_prompt $prompt_format" 27 | 28 | if [[ $dataset == *"heat"* ]]; then 29 | CMD="$CMD --prompt_dataset heat" 30 | elif [[ $dataset == *"wave"* ]]; then 31 | CMD="$CMD --prompt_dataset wave" 32 | fi 33 | 34 | if [ $max_samples -gt 0 ]; then 35 | CMD="$CMD --max_num_examples $max_samples" 36 | fi 37 | 38 | echo $CMD 39 | eval $CMD 40 | done 41 | 42 | 43 | -------------------------------------------------------------------------------- /run_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python --version 3 | 4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")") 5 | 6 | 7 | ### --------- Set Parameters --------- ### 8 | 9 | 10 | ##### model parameters ##### 11 | 12 | tokenizer_path=models/base_model/MathCoder2-DeepSeekMath-7B 13 | context_length=4096 14 | 15 | ##### Data parameters ##### 16 | input_file_paths=( 17 | datasets/unprocessed/sft/heat_nc1_train 18 | datasets/unprocessed/sft/heat_nc2_train 19 | datasets/unprocessed/sft/heat_nc3_train 20 | datasets/unprocessed/sft/wave_nc1_train 21 | datasets/unprocessed/sft/wave_nc2_train 22 | datasets/unprocessed/sft/wave_nc3_train 23 | ) 24 | 25 | # ["to_python_no_STL", "to_STL", "to_python_GT_STL", "to_python_given_STL"] 26 | prompt_format=to_STL 27 | 28 | out_file_path=datasets/sft/${prompt_format} 29 | 30 | ### To set for "to_python_given_STL" #### 31 | max_samples=-1 # -1 to test all datapoints. 32 | TRANSLATOR_DIR=models/translator 33 | 34 | 35 | ###################################### 36 | ### --------- Data processing --------- ### 37 | ###################################### 38 | 39 | 40 | if [ "$prompt_format" = "to_python_given_STL" ]; then 41 | ### --------- prompt the model for its predictions and add this to the training data. --------- ### 42 | 43 | 44 | ### --------- First step: train the model on the original data. --------- ### 45 | prompt_format=to_STL 46 | 47 | few_shot_number=2 48 | 49 | # join the input file paths into a single string 50 | input_file_paths_str=$(IFS=" "; echo "${input_file_paths[*]}") 51 | 52 | predictions_output_dir=$out_file_path/train_eval 53 | # obtain the model's predictions for the second step 54 | "$SCRIPT_DIR/run_train_testing.sh" "$input_file_paths_str" "$prompt_format" "$few_shot_number" $TRANSLATOR_DIR $predictions_output_dir $max_samples 55 | # the second step continues below 56 | prompt_format=to_python_given_STL 57 | 58 | input_file_paths=( 59 | $predictions_output_dir/to_STL/ 60 | ) 61 | ### --------- End of first step --------- ### 62 | fi 63 | 64 | 65 | python train/scripts/tokenize_data.py --paths "${input_file_paths[@]}" --out_file_path ${out_file_path} --model_path $tokenizer_path --sft --prompt_format $prompt_format 66 | python train/scripts/group_text.py --max_len $context_length --out_file_path ${out_file_path} --model_path $tokenizer_path --sft --no_grouping --no_padding --balance 0.05 0.22 0.23 0.05 0.22 0.23 --total 128000 67 | 68 | 69 | ###################################### 70 | ### --------- Training --------- ### 71 | ###################################### 72 | 73 | echo Starting training... 74 | 75 | bash train/scripts/train.sh 76 | -------------------------------------------------------------------------------- /test/PDEcontrol/.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | -------------------------------------------------------------------------------- /test/PDEcontrol/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | python_version: "3.11" 7 | python_packages: 8 | - torch==2.0.1 9 | - torchvision==0.15.2 10 | - transformers==4.37.2 11 | - accelerate==0.27.0 12 | - hf_transfer 13 | 14 | # predict.py defines how predictions are run on your model 15 | predict: "replicate/predict.py:Predictor" 16 | -------------------------------------------------------------------------------- /test/PDEcontrol/evaluation/data_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/test/PDEcontrol/evaluation/data_processing/__init__.py -------------------------------------------------------------------------------- /test/PDEcontrol/evaluation/data_processing/process_utils.py: -------------------------------------------------------------------------------- 1 | import regex 2 | 3 | from data_processing.answer_extraction import extract_math_answer, strip_string 4 | 5 | def process_gsm8k_test(item): 6 | sample = { 7 | 'dataset': 'gsm8k-cot', 8 | 'id': item['id'], 9 | 'messages': [ 10 | {'role': 'user', 'content': item['question']}, 11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."} 12 | ], 13 | 'answer': item['answer'].replace(',', '') 14 | } 15 | yield sample 16 | 17 | def process_math_test(item): 18 | question = item["problem"] 19 | try: 20 | answer = extract_math_answer(question, item['solution'], task="cot") 21 | except: 22 | return 23 | sample = { 24 | "dataset": "math-cot", 25 | "id": item['id'], 26 | "level": item["level"], 27 | "type": item["type"], 28 | "category": item["category"], 29 | "messages": [ 30 | {"role": "user", "content": question}, 31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))} 32 | ], 33 | "answer": answer 34 | } 35 | yield sample 36 | 37 | def process_math_sat(item): 38 | options = item['options'].strip() 39 | assert 'A' == options[0] 40 | options = '(' + options 41 | for ch in 'BCDEFG': 42 | if f' {ch}) ' in options: 43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options) 44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" 45 | messages = [ 46 | {'role': 'user', 'content': question}, 47 | {'role': 'assistant', 'content': item['Answer']} 48 | ] 49 | item = { 50 | 'dataset': 'math_sat', 51 | 'id': item['id'], 52 | 'language': 'en', 53 | 'messages': messages, 54 | 'answer': item['Answer'], 55 | } 56 | yield item 57 | 58 | def process_ocwcourses(item): 59 | messages = [ 60 | {'role': 'user', 'content': item['problem'].strip()}, 61 | {'role': 'assistant', 'content': item['solution'].strip()} 62 | ] 63 | item = { 64 | "dataset": "OCWCourses", 65 | "id": item['id'], 66 | "language": "en", 67 | "messages": messages, 68 | "answer": item['answer'] 69 | } 70 | yield item 71 | 72 | def process_mmlu_stem(item): 73 | options = item['options'] 74 | for i, (label, option) in enumerate(zip('ABCD', options)): 75 | options[i] = f"({label}) {str(option).strip()}" 76 | options = ", ".join(options) 77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" 78 | messages = [ 79 | {'role': 'user', 'content': question}, 80 | {'role': 'assistant', 'content': item['answer']} 81 | ] 82 | item = { 83 | "dataset": "MMLU-STEM", 84 | "id": item['id'], 85 | "language": "en", 86 | "messages": messages, 87 | "answer": item['answer'] 88 | } 89 | yield item 90 | 91 | def process_mgsm_zh(item): 92 | item['answer'] = item['answer'].replace(',', '') 93 | yield item 94 | 95 | def process_cmath(item): 96 | item = { 97 | 'dataset': 'cmath', 98 | 'id': item['id'], 99 | 'grade': item['grade'], 100 | 'reasoning_step': item['reasoning_step'], 101 | 'messages': [ 102 | {'role': 'user', 'content': item['question'].strip()}, 103 | {'role': 'assistant', 'content': ''} 104 | ], 105 | 'answer': item['golden'].strip().replace(",", "") 106 | } 107 | yield item 108 | 109 | def process_agieval_gaokao_math_cloze(item): 110 | item = { 111 | 'dataset': 'agieval-gaokao-math-cloze', 112 | 'id': item['id'], 113 | 'messages': [ 114 | {'role': 'user', 'content': item['question'].strip()}, 115 | {'role': 'assistant', 'content': ''} 116 | ], 117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")] 118 | } 119 | yield item 120 | 121 | def process_agieval_gaokao_mathqa(item): 122 | question = item['question'].strip() 123 | options = [] 124 | for option in item['options']: 125 | option = option.strip() 126 | assert option[0] == '(' 127 | assert option[2] == ')' 128 | assert option[1] in 'ABCD' 129 | option = f"{option[1]}: {option[3:].strip()}" 130 | options.append(option.strip()) 131 | question = f"{question}\n{options}" 132 | item = { 133 | 'dataset': 'agieval-gaokao-mathqa', 134 | 'id': item['id'], 135 | 'messages': [ 136 | {'role': 'user', 'content': question}, 137 | {'role': 'assistant', 'content': ''} 138 | ], 139 | "answer": item['label'] 140 | } 141 | yield item 142 | 143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item): 144 | question = item['question'].strip().rstrip('\\') 145 | options = " ".join([opt.strip() for opt in item['options']]) 146 | question = f"{question}\n从以下选项中选择: {options}" 147 | item = { 148 | 'dataset': 'agieval-gaokao-mathqa', 149 | 'id': item['id'], 150 | 'messages': [ 151 | {'role': 'user', 'content': question}, 152 | {'role': 'assistant', 'content': ''} 153 | ], 154 | "answer": item['label'] 155 | } 156 | yield item 157 | 158 | def process_minif2f_isabelle(item): 159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}" 160 | item = { 161 | 'dataset': 'minif2f-isabelle', 162 | 'id': item['id'], 163 | 'messages': [ 164 | {'role': 'user', 'content': question}, 165 | {'role': 'assistant', 'content': ''} 166 | ], 167 | "answer": "placeholder" 168 | } 169 | yield item 170 | -------------------------------------------------------------------------------- /test/PDEcontrol/evaluation/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from . import eval_script -------------------------------------------------------------------------------- /test/PDEcontrol/evaluation/eval/eval_robustness_reasoning_wrapper.py: -------------------------------------------------------------------------------- 1 | # eval_robustness_reasoning_wrapper.py 2 | import sys 3 | import json 4 | try: 5 | from control.femformal.eval_robustness_reasoning import eval_robustness as femformal_eval_robustness 6 | except ImportError: 7 | import sys 8 | import os 9 | sys.path.append(os.path.abspath('/localhome/mms43/scratch/mathcoder2/MathCoder2')) 10 | from control.femformal.eval_robustness_reasoning import eval_robustness as femformal_eval_robustness 11 | 12 | 13 | def main(): 14 | llm_anchor_output = sys.argv[1] 15 | llm_interm_output = sys.argv[2] 16 | robustness, runtime = femformal_eval_robustness(llm_anchor_output, llm_interm_output) 17 | # Note that this code should not print anything else to stdout or it will screw with the evaluation 18 | print(json.dumps({ 19 | "robustness": robustness, 20 | "runtime": runtime 21 | })) 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /test/PDEcontrol/evaluation/eval/eval_robustness_wrapper.py: -------------------------------------------------------------------------------- 1 | # eval_robustness_wrapper.py 2 | import sys 3 | import json 4 | try: 5 | from control.femformal.eval_robustness import eval_robustness as femformal_eval_robustness 6 | except ImportError: 7 | import sys 8 | import os 9 | sys.path.append(os.path.abspath('/localhome/mms43/scratch/mathcoder2/MathCoder2')) 10 | from control.femformal.eval_robustness import eval_robustness as femformal_eval_robustness 11 | 12 | 13 | def main(): 14 | nl_in_prompt = sys.argv[1] 15 | llm_output = sys.argv[2] 16 | robustness, runtime = femformal_eval_robustness(nl_in_prompt, llm_output) 17 | # Note that this code should not print anything else to stdout or it will screw with the evaluation 18 | print(json.dumps({ 19 | "robustness": robustness, 20 | "runtime": runtime 21 | })) 22 | 23 | if __name__ == "__main__": 24 | main() -------------------------------------------------------------------------------- /test/PDEcontrol/evaluation/infer/simulate_gt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import traceback 5 | current_dir = os.path.dirname(os.path.realpath(__file__)) 6 | from tqdm import tqdm 7 | import json 8 | from pebble import ProcessPool 9 | from concurrent.futures import TimeoutError 10 | import random 11 | 12 | 13 | 14 | utils_path = os.path.join(current_dir, '../../../../utils') 15 | sys.path.append(utils_path) 16 | 17 | from few_shot_prompts import CoTOneDHeat, CoTOneDWave 18 | 19 | 20 | evaluation_path = os.path.abspath(os.path.join(current_dir, '..')) 21 | sys.path.append(evaluation_path) 22 | 23 | eval_path = os.path.abspath(os.path.join(current_dir, '../eval')) 24 | sys.path.append(eval_path) 25 | 26 | data_processing_path = os.path.abspath(os.path.join(current_dir, '../data_processing')) 27 | sys.path.append(data_processing_path) 28 | 29 | from data_processing.answer_extraction import * 30 | from eval.eval_script import * 31 | 32 | 33 | 34 | 35 | # model_name_or_path cannot be both None or both not None. 36 | model = None 37 | tokenizer = None 38 | pool = None 39 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' 40 | 41 | def evaluate(eval_fn, tasks, _timeout=15): 42 | with ProcessPool() as pool: 43 | timeout_cnt = 0 44 | failure_cnt = 0 45 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result() 46 | labels = [] 47 | while True: 48 | try: 49 | labels.append((next(iterator))) 50 | except StopIteration: 51 | break 52 | except TimeoutError as error: 53 | labels.append(0) 54 | timeout_cnt += 1 55 | except Exception as error: 56 | print("An error occurred:", error, flush=True) 57 | traceback.print_exception(type(error), error, error.__traceback__, file=sys.stdout) 58 | failure_cnt += 1 59 | labels.append(-100) 60 | return labels, timeout_cnt, failure_cnt 61 | 62 | def evaluate_future(eval_fn, tasks, _timeout=300): 63 | # The runtime is doubled because the gt may have to be simulated also. 64 | num_cpus = os.cpu_count() 65 | max_workers = max(1, int(np.floor(num_cpus * 0.5))) 66 | print(f"Using {max_workers} out of {num_cpus} CPUs") 67 | with ProcessPool(max_workers=max_workers) as pool: 68 | timeout_cnt = 0 69 | futures = [(i, pool.schedule(eval_fn, args=(task,), timeout=_timeout)) for i, task in enumerate(tasks)] 70 | results = [None] * len(tasks) 71 | 72 | with tqdm(total=len(futures), desc="Evaluating robustness (completed)") as pbar: 73 | for i, future in futures: 74 | try: 75 | result = future.result() 76 | gt_robustness, gt_simtime = result 77 | results[i] = (gt_robustness, gt_simtime) 78 | 79 | except TimeoutError: 80 | results[i] = ("timeout", "timeout") 81 | timeout_cnt += 1 82 | except Exception as error: 83 | print("An error occurred:", error, flush=True) 84 | traceback.print_exception(type(error), error, error.__traceback__, file=sys.stdout) 85 | exit() 86 | finally: 87 | pbar.update(1) 88 | list_gt_robustness, list_gt_simtime = zip(*results) 89 | return list_gt_robustness, list_gt_simtime, timeout_cnt 90 | 91 | 92 | def main(args): 93 | if args.gpus is not None: 94 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 95 | random.seed(42) 96 | 97 | print("Loading data...") 98 | test_data = [] 99 | with open(os.path.join(args.data_dir, f"validation.jsonl" if args.infer_on_train_set else f"test.jsonl")) as fin: 100 | for line in fin: 101 | example = json.loads(line) 102 | python_code = example[args.python_key] 103 | sstl = example[args.stl_key] 104 | natural_language = example[args.nl_key] 105 | robustness = example.get(args.robustness_key, None) 106 | example['python'] = python_code.strip() 107 | example['sstl'] = sstl 108 | example['nl'] = natural_language 109 | if robustness is not None: 110 | example['robustness'] = robustness 111 | test_data.append(example) 112 | 113 | if args.max_num_examples and len(test_data) > args.max_num_examples: 114 | test_data = random.sample(test_data, args.max_num_examples) 115 | 116 | if not test_data: 117 | print("Ending. There was no data to test.") 118 | return 119 | 120 | args.save_dir = os.path.join(args.data_dir + "_" + str(args.max_num_examples)) 121 | os.makedirs(args.save_dir, exist_ok=True) 122 | 123 | 124 | if args.eval_robustness: 125 | eval_robustness_function = eval("eval_robustness_gt") 126 | sim_gt_robustnesses, sim_gt_simulation_times, eval_timeout_cnt_robustness = evaluate_future(eval_robustness_function, test_data) 127 | print("done evaluating robustness", flush=True) 128 | for item, gt_r, gt_time in zip(test_data, sim_gt_robustnesses, sim_gt_simulation_times): 129 | if 'robustness' not in item: 130 | item['robustness'] = gt_r 131 | item['time'] = gt_time 132 | 133 | print("Calculating accuracy...") 134 | ## track dataset stats: 135 | sum_gt_positive_robustness = 0 136 | sum_gt_negative_robustness = 0 137 | sum_gt_failed_robustness = 0 138 | for item in test_data: 139 | if args.eval_robustness: 140 | ## track dataset stats: 141 | gt_r = item['robustness'] 142 | if gt_r > 0: 143 | sum_gt_positive_robustness += 1 144 | elif gt_r < 0 and gt_r != -100: 145 | sum_gt_negative_robustness += 1 146 | elif gt_r == -100: 147 | sum_gt_failed_robustness += 1 148 | else: 149 | raise ValueError(f"gt_r = {gt_r}") 150 | ## end track dataset stats 151 | 152 | 153 | if args.eval_robustness: 154 | ## track dataset stats: 155 | gt_positive_robustness_rate = sum_gt_positive_robustness / len(test_data) 156 | gt_negative_robustness_rate = sum_gt_negative_robustness / len(test_data) 157 | gt_failed_robustness_rate = sum_gt_failed_robustness / len(test_data) 158 | print(f"gt positive robustness rate = {gt_positive_robustness_rate * 100}", flush=True) 159 | print(f"gt negative robustness rate = {gt_negative_robustness_rate * 100}", flush=True) 160 | print(f"gt failed robustness rate = {gt_failed_robustness_rate * 100}", flush=True) 161 | 162 | 163 | if args.eval_robustness: print(f"Timeout count >>> output eval robustness = {eval_timeout_cnt_robustness}", flush=True) 164 | 165 | with open(os.path.join(args.save_dir, f"validation.jsonl" if args.infer_on_train_set else f"test.jsonl"), "w") as fout: 166 | for item in test_data: 167 | if 'sstl' in item: 168 | sstl = item.pop('sstl') 169 | item['sstl'] = sstl 170 | if 'python' in item: 171 | python = item.pop('python') 172 | item['python'] = python 173 | if 'robustness' in item: 174 | robustness = item.pop('robustness') 175 | item['robustness'] = robustness 176 | if 'time' in item: 177 | time = item.pop('time') 178 | item['time'] = time 179 | json.dump(item, fout, ensure_ascii=True) 180 | fout.write("\n") 181 | 182 | metric_fname = "metrics.json" 183 | with open(os.path.join(args.save_dir, metric_fname), "w") as fout: 184 | metrics = { 185 | "n_samples": len(test_data), 186 | } 187 | if args.eval_robustness: 188 | metrics["gt positive robustness rate"] = gt_positive_robustness_rate 189 | metrics["gt negative robustness rate"] = gt_negative_robustness_rate 190 | metrics["gt failed robustness rate"] = gt_failed_robustness_rate 191 | json.dump(metrics, fout, indent=4) 192 | 193 | if __name__ == "__main__": 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument("--seed", type=int, default=None, help="random seed for the LLM generation only. This SHOULD NOT affect the random seed for data selection (I did not debug this so set it at your own peril).") 196 | parser.add_argument("--data_dir", type=str, default="data/mgsm") 197 | parser.add_argument("--python_key", type=str, default="python", help="Key for accessing the Python code in the example.") 198 | parser.add_argument("--stl_key", type=str, default="sstl", help="Key for accessing the STL in the example.") 199 | parser.add_argument("--nl_key", type=str, default="nl", help="Key for accessing the natural language in the example.") 200 | parser.add_argument("--robustness_key", type=str, default="robustness", help="Key for accessing the robustness in the example.") 201 | 202 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.") 203 | parser.add_argument("--save_dir", type=str, default="results/mgsm") 204 | 205 | parser.add_argument("--infer_on_train_set", action="store_true") 206 | 207 | parser.add_argument("--eval_perplexity", action='store_true', help="Whether to evaluate the perplexity of the model outputs.") 208 | parser.add_argument("--eval_robustness", action='store_true', help="Whether to evaluate the robustness of the model outputs by comparing on the FemFormal repo.") 209 | parser.add_argument("--eval_edit_distance", action='store_true', help="Whether to evaluate the edit distance of the model python with the ground truth python.") 210 | parser.add_argument("--eval_iou", action='store_true', help="Whether to evaluate the precision and recall of the model sstl latex generation with the ground truth sstl.") 211 | parser.add_argument("--load_from_file", action='store_true', help="Whether to load the model predictions from a file instead of regenerating them. If this flag is provided but no file exists, the model predictions will be generated and saved to a file.") 212 | parser.add_argument("--skip_existing_scores", action='store_true', help="Whether to skip computing and overwriting the evaluation of a datapoint (ex. --eval_perplexity) that exists already. Only relevant if `--load_from_file` is provided.") 213 | parser.add_argument("--gpus", type=str, default=None, help="Use to set the CUDA_VISIBLE_DEVICES environment variable.") 214 | args, unparsed_args = parser.parse_known_args() 215 | if args.gpus is not None: 216 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 217 | 218 | print("unparsed args:", flush=True) 219 | print(unparsed_args, flush=True) 220 | print("args:", flush=True) 221 | print(args, flush=True) 222 | 223 | if args.seed is not None: 224 | if args.seed < 0: 225 | args.seed = None 226 | 227 | if 'math6' in args.data_dir: 228 | args.multi_turn = True 229 | 230 | # the basename of the datadir is the dataset name 231 | args.dataset_name = args.data_dir.split("/")[-1] 232 | 233 | main(args) 234 | 235 | if pool is not None: 236 | pool.close() 237 | -------------------------------------------------------------------------------- /test/PDEcontrol/evaluation/scripts/infer_pdecontrol.sh: -------------------------------------------------------------------------------- 1 | eval "$(conda shell.bash hook)" 2 | 3 | set -e 4 | 5 | conda activate trainenv 6 | echo "activated conda env:" $CONDA_DEFAULT_ENV 7 | python --version 8 | 9 | 10 | 11 | dataset=$1 12 | out_dir=$2 13 | few_shot_number=$3 14 | prompt_format=$4 15 | max_samples=$5 16 | gpus=$6 17 | translator_path=${7} 18 | coder_path=${8} 19 | controller_path=${9} 20 | use_openai=${10} 21 | 22 | save_dir=$out_dir/${dataset}_shots=${few_shot_number} 23 | 24 | 25 | CMD="CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false python test/PDEcontrol/evaluation/infer/run_1d_pdecontrol_eval_full.py \ 26 | --data_dir \"$(dirname $0)/../../test_data/${dataset}\" \ 27 | --save_dir $save_dir \ 28 | --use_vllm \ 29 | --model_name_or_path_translator $translator_path \ 30 | --tokenizer_name_or_path_translator $translator_path \ 31 | --model_name_or_path_coder $coder_path \ 32 | --tokenizer_name_or_path_coder $coder_path \ 33 | --model_name_or_path_controller $controller_path \ 34 | --tokenizer_name_or_path_controller $controller_path \ 35 | --eval_batch_size 1 \ 36 | --temperature 0.2 \ 37 | --seed 0 38 | --n_repeat_sampling 3 39 | --prompt_format few_shot \ 40 | --prompt_dataset CoTOneDCombined \ 41 | --few_shot_number $few_shot_number \ 42 | --few_shot_prompt $prompt_format \ 43 | --eval_robustness 44 | --eval_iou 45 | --eval_edit_distance" 46 | 47 | if [ $max_samples -gt 0 ]; then 48 | CMD="$CMD --max_num_examples $max_samples" 49 | fi 50 | 51 | if [ "$gpus" != -1 ]; then 52 | CMD="$CMD --gpus $gpus" 53 | fi 54 | 55 | if [ -n "$use_openai" ]; then 56 | CMD="$CMD --use_openai $use_openai" 57 | else 58 | CMD="$CMD --eval_perplexity" 59 | fi 60 | 61 | echo $CMD 62 | eval $CMD 63 | 64 | exit 1 -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | # Testing 2 | 3 | ## Installation 4 | 5 | Install conda environments `trainenv` and `pdecontrol` as specified in the [main README](../README.md0) 6 | 7 | ## Evaluation 8 | 9 | 10 | We provide a test script for both zero-shot and few-shot evaluation on heat and wave PDEs used in our paper. Modify the settings in [`scripts/test_pdecontrol.sh`](./scripts/test_pdecontrol.sh). 11 | 12 | Then to evaluate the models, set the paths in the following script and run it: 13 | 14 | ```bash 15 | bash run_testing.sh` 16 | ``` 17 | 18 | Set `$MODELS` with the path to the directory where the model weights are stored, and set `$OUTPUT_DIR` with the path to the directory where you wish the inference results would be saved. This script would also create a `result` directory under `EVAL_DIR` where markdown files containing a table of all the results would be saved. 19 | 20 | 21 | ## Labels 22 | The following code can be used to simulate labels for the ground-truth natural language & python inputs. To do so, run: 23 | 24 | ```bash 25 | bash simulate_gt.sh 26 | ``` -------------------------------------------------------------------------------- /test/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | tensorflow 3 | transformers 4 | wandb 5 | beautifulsoup4 6 | torchvision 7 | tqdm 8 | tensorboard 9 | flash-attn 10 | trl 11 | vllm=0.2.0 12 | deepspeed 13 | tf_keras 14 | torchaudio 15 | peft 16 | Pebble 17 | -------------------------------------------------------------------------------- /test/scripts/read_result.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | import os 4 | from glob import glob 5 | 6 | import numpy as np 7 | 8 | def read_json(in_file): 9 | with open(in_file, "r", encoding="utf-8") as f: 10 | return json.load(f) 11 | 12 | 13 | def per_eval_method_read_result(metrics, datasets, in_dir, eval_method, subset_id=None, shots=3, seeds=[-1]): 14 | # Initialize the output text 15 | text = "" 16 | for metric in metrics: 17 | # Initialize a dictionary to store results 18 | max_shots = shots 19 | # results = {i: ["n/a"] * len(datasets) for i in range(max_shots + 1)} 20 | ## for each shot, there is a dataset. Each dataset contains a list to track the scores for each seed. 21 | results = {i: {dataset: [] for dataset in datasets} for i in range(max_shots + 1)} 22 | 23 | # Track dataset columns with valid entries 24 | valid_columns = {dataset: False for dataset in datasets} 25 | 26 | # Iterate over the directories again to populate results 27 | for dirname in os.listdir(in_dir): 28 | if "shots=" in dirname and "seed=" in dirname: 29 | # Extract shot number and seed number 30 | shot_num = int(dirname.split("shots=")[1].split("_")[0]) 31 | seed_num = int(dirname.split("seed=")[1].split("_")[0]) 32 | 33 | if seed_num not in seeds: 34 | continue 35 | 36 | # Determine the dataset and column based on the directory name 37 | dataset = dirname.split("_shots=")[0] 38 | 39 | # Read the metrics.json file 40 | in_file = os.path.join(in_dir, dirname, eval_method, "metrics.json") 41 | if subset_id is not None: 42 | in_file = os.path.join(in_dir, dirname, eval_method, f"metrics.{subset_id}.json") 43 | if os.path.exists(in_file): 44 | data = read_json(in_file) 45 | try: 46 | results[shot_num][dataset].append(data[metric]) 47 | valid_columns[dataset] = True 48 | except KeyError: 49 | pass 50 | 51 | # Ensure all datasets have the same number of scores for each shot number (excluding the empty) 52 | for shot_num in range(max_shots + 1): 53 | lengths = [len(results[shot_num][dataset]) for dataset in datasets if valid_columns[dataset]] 54 | if lengths and not all(length == lengths[0] for length in lengths): 55 | raise ValueError(f"Inconsistent number of scores for shot number {shot_num}: {lengths}") 56 | 57 | 58 | # Filter out columns that are entirely empty 59 | filtered_datasets = [] 60 | filtered_columns = [] 61 | for dataset in datasets: 62 | if valid_columns[dataset]: 63 | filtered_datasets.append(dataset) 64 | filtered_columns.append(dataset) 65 | 66 | # Initialize the table header 67 | header = f"## Metric: {metric}\n\n" 68 | header += "| shots | " + " | ".join([f"{dataset}" for dataset in filtered_datasets]) + " |\n" 69 | separator = "|-------|" + "------------|" * (len(filtered_datasets)) + "\n" 70 | text += header + separator 71 | 72 | 73 | # Construct the table rows 74 | for shot_num in range(max_shots + 1): 75 | row = [] 76 | for dataset in filtered_columns: 77 | scores = results[shot_num][dataset] 78 | if scores: 79 | mean_score = np.mean(scores) 80 | std_score = np.std(scores, ddof=1) 81 | row.append(f"{mean_score:.4f} ({std_score:.4f})") 82 | else: 83 | row.append("n/a") 84 | text += f"| {shot_num} | " + " | ".join(row) + " |\n" 85 | 86 | text += "\n\n" 87 | 88 | return text 89 | 90 | 91 | 92 | def read_result(in_dir, out_file, args): #, metrics=["perplexity"], subset_id=None): 93 | shots = args.shots 94 | metrics=args.metrics 95 | eval_methods=args.eval_methods 96 | seeds = args.seeds 97 | subset_id=args.subset_id 98 | if subset_id is not None: 99 | if subset_id < 0: 100 | subset_id = None 101 | 102 | # Initialize a set to store unique datasets 103 | datasets = set() 104 | 105 | # Iterate over the directories to identify datasets 106 | ## Directories take the form of `dataset_shots=3_seed=0` 107 | for dirname in os.listdir(in_dir): 108 | if "shots=" in dirname and "seed=" in dirname: 109 | dataset = dirname.split("_shots=")[0] 110 | datasets.add(dataset) 111 | 112 | # Sort datasets to maintain a consistent order 113 | datasets = sorted(datasets) 114 | 115 | for eval_method in eval_methods: 116 | text = per_eval_method_read_result(metrics, datasets, os.path.join(in_dir), eval_method, subset_id=subset_id, shots=shots, seeds=seeds) 117 | 118 | 119 | output_file = f"{out_file}-{eval_method}" + ".md" 120 | out_dir = os.path.dirname(output_file) 121 | 122 | if not os.path.exists(out_dir): 123 | os.makedirs(out_dir) 124 | 125 | print(eval_method) 126 | print(text) 127 | if text != "": 128 | with open(output_file, "w", encoding="utf-8") as f: 129 | f.write(text) 130 | 131 | 132 | 133 | def main(): 134 | parser = ArgumentParser() 135 | parser.add_argument("--in_dir", type=str) 136 | parser.add_argument("--subset_id", type=int, default=None, help="The dataset to test can be split into `n_subsets`(see run_1d_pdecontrol_eval) and if `subset_id != None` and non-negative, the `subset_id` will be the only subset to find metrics.") 137 | parser.add_argument("--metrics", type=str, nargs="+", default=[ 138 | "robustness accuracy", 139 | "robustness mre", 140 | "robustness failure rate", 141 | "robustness timeout rate", 142 | "simulation time mre", 143 | "edit distance", 144 | "iou", 145 | "iou failures", 146 | "iou timeout rate", 147 | "perplexity", 148 | "perplexity timeout rate", 149 | "gt positive robustness rate", 150 | "gt negative robustness rate", 151 | "gt failed robustness rate", 152 | "adjusted_failure_rate", 153 | ]) 154 | parser.add_argument("--shots", type=int, default=3) 155 | parser.add_argument( 156 | "--eval_methods", 157 | type=str, 158 | choices=["to_python_direct_with_sstl_cot", "to_python_no_STL", "to_python_two_step", "to_STL"], 159 | default=["to_python_direct_with_sstl_cot", "to_python_no_STL", "to_python_two_step", "to_STL"], 160 | nargs='+' 161 | ) 162 | parser.add_argument("--seeds", type=int, nargs="+", default=[-1], help="The seeds to to average results over. A directory may contain multiple seeds. This argument will specify which seeds to average over and the rest will be ignored.") 163 | 164 | 165 | 166 | args = parser.parse_args() 167 | in_dir = args.in_dir 168 | out_file = os.path.join(args.in_dir, "results", os.path.basename(in_dir)) 169 | read_result(in_dir, out_file, args=args) 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /test/scripts/simulate_gt.sh: -------------------------------------------------------------------------------- 1 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 2 | 3 | cd $DIR/../.. 4 | 5 | 6 | 7 | declare -a datasets=( 8 | "heat_nc1" 9 | "heat_nc2" 10 | "heat_nc3" 11 | "wave_nc1" 12 | "wave_nc2" 13 | "wave_nc3" 14 | ) 15 | 16 | 17 | ## set the max samples to be -1 to use all samples. 18 | max_samples=512 19 | skip_existing=False # True or False 20 | load_from_file=False # True or False 21 | gpus=1 22 | 23 | 24 | # Calculate the total number of iterations 25 | total_iterations=$(( ${#datasets[@]} )) 26 | current_iteration=0 27 | # Record the start time 28 | start_time=$(date +%s) 29 | 30 | for dataset in "${datasets[@]}" 31 | do 32 | echo " PROGRESS: $current_iteration/$total_iterations" 33 | echo " Evaluating $dataset" | tee -a $log_file 34 | 35 | if [[ $prompt_format == "to_STL" ]]; then 36 | eval_robustness=False 37 | fi 38 | 39 | CMD="CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false python test/PDEcontrol/evaluation/infer/simulate_gt.py \ 40 | --data_dir \"$DIR/../../test/PDEcontrol/test_data/${dataset}\" \ 41 | --eval_batch_size 1 \ 42 | --eval_robustness" 43 | 44 | if [ $max_samples -gt 0 ]; then 45 | CMD="$CMD --max_num_examples $max_samples" 46 | fi 47 | 48 | if [ "$load_from_file" = "True" ]; then 49 | CMD="$CMD --load_from_file" 50 | fi 51 | 52 | if [ -n "$gpus" ]; then 53 | CMD="$CMD --gpus $gpus" 54 | fi 55 | 56 | 57 | echo $CMD 58 | eval $CMD 59 | 60 | 61 | 62 | echo " Done evaluating $dataset with $few_shot few shot examples and format=$prompt_format" 63 | current_iteration=$((current_iteration + 1)) 64 | echo " PROGRESS: $current_iteration/$total_iterations" 65 | # Calculate elapsed time 66 | current_time=$(date +%s) 67 | elapsed_time=$((current_time - start_time)) 68 | # Estimate remaining time 69 | average_time_per_iteration=$((elapsed_time / current_iteration)) 70 | remaining_iterations=$((total_iterations - current_iteration)) 71 | estimated_remaining_time=$((remaining_iterations * average_time_per_iteration)) 72 | hours=$((estimated_remaining_time / 3600)) 73 | minutes=$(( (estimated_remaining_time % 3600) / 60 )) 74 | seconds=$((estimated_remaining_time % 60)) 75 | echo " Estimated remaining time: $hours hours, $minutes minutes, $seconds seconds" 76 | done 77 | -------------------------------------------------------------------------------- /test/scripts/test_pdecontrol.sh: -------------------------------------------------------------------------------- 1 | out_dir=${1} 2 | gpus=${2} 3 | translator_path=${3} 4 | coder_path=${4} 5 | controller_path=${5} 6 | use_openai=${6} 7 | 8 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 9 | 10 | declare -a datasets=( 11 | "heat_nc1_512" 12 | "heat_nc2_512" 13 | "heat_nc3_512" 14 | "wave_nc1_512" 15 | "wave_nc2_512" 16 | "wave_nc3_512" 17 | # "dpo_manual_test" 18 | ) 19 | 20 | declare -a prompt_formats=( 21 | # to_python_no_STL 22 | # to_python_direct_with_sstl_cot 23 | to_python_two_step 24 | # to_STL 25 | # full_pipeline 26 | # reasoning_only 27 | ) 28 | 29 | declare -a few_shot_numbers=( 30 | 0 31 | 2 32 | ) 33 | 34 | 35 | ## set the max samples to be -1 to use all samples. 36 | max_samples=4 37 | gpus=$gpus 38 | # gpus=0 39 | 40 | if [[ -n $use_openai ]]; then 41 | declare -a few_shot_numbers=( 42 | 2 43 | ) 44 | fi 45 | 46 | # appends to the log file 47 | log_file=${out_dir}/eval_output.log 48 | 49 | # Calculate the total number of iterations 50 | total_iterations=$(( ${#datasets[@]} * ${#prompt_formats[@]} * ${#few_shot_numbers[@]})) 51 | current_iteration=0 52 | # Record the start time 53 | start_time=$(date +%s) 54 | 55 | for dataset in "${datasets[@]}" 56 | do 57 | for prompt_format in "${prompt_formats[@]}" 58 | do 59 | for few_shot in "${few_shot_numbers[@]}" 60 | do 61 | echo " PROGRESS: $current_iteration/$total_iterations" | tee -a $log_file 62 | echo " Evaluating $dataset with $few_shot few shot examples and format=$prompt_format" | tee -a $log_file 63 | 64 | echo dataset: $dataset 65 | echo out_dir: $out_dir 66 | echo shots: $few_shot 67 | echo translator_path: $translator_path 68 | echo coder_path: $coder_path 69 | echo controller_path: $controller_path 70 | echo prompt_format: $prompt_format 71 | echo max_samples: $max_samples 72 | echo gpus: $gpus 73 | echo use_openai: $use_openai 74 | 75 | 76 | bash test/PDEcontrol/evaluation/scripts/infer_pdecontrol.sh $dataset $out_dir $few_shot $prompt_format $max_samples $gpus $translator_path $coder_path $controller_path $use_openai | tee -a $log_file 77 | 78 | echo " Done evaluating $dataset with $few_shot few shot examples and format=$prompt_format" | tee -a $log_file 79 | current_iteration=$((current_iteration + 1)) 80 | echo " PROGRESS: $current_iteration/$total_iterations" | tee -a $log_file 81 | # Calculate elapsed time 82 | current_time=$(date +%s) 83 | elapsed_time=$((current_time - start_time)) 84 | # Estimate remaining time 85 | average_time_per_iteration=$((elapsed_time / current_iteration)) 86 | remaining_iterations=$((total_iterations - current_iteration)) 87 | estimated_remaining_time=$((remaining_iterations * average_time_per_iteration)) 88 | # Convert estimated remaining time to human-readable format 89 | hours=$((estimated_remaining_time / 3600)) 90 | minutes=$(( (estimated_remaining_time % 3600) / 60 )) 91 | seconds=$((estimated_remaining_time % 60)) 92 | elapsed_hours=$((elapsed_time / 3600)) 93 | elapsed_minutes=$(( (elapsed_time % 3600) / 60 )) 94 | elapsed_seconds=$((elapsed_time % 60)) 95 | echo " Elapsed time: $elapsed_hours hours, $elapsed_minutes minutes, $elapsed_seconds seconds" | tee -a $log_file 96 | echo " Estimated remaining time: $hours hours, $minutes minutes, $seconds seconds" | tee -a $log_file 97 | done 98 | done 99 | done 100 | 101 | # # Initialize max_shots with the first element of the array 102 | # max_shots=${few_shot_numbers[0]} 103 | 104 | # # Iterate through the array to find the maximum value 105 | # for num in "${few_shot_numbers[@]}"; do 106 | # # Skip commented out values 107 | # if [[ $num =~ ^[0-9]+$ ]]; then 108 | # if (( num > max_shots )); then 109 | # max_shots=$num 110 | # fi 111 | # fi 112 | # done 113 | 114 | # echo "Maximum shots: $max_shots" 115 | 116 | # CMD="python /localhome/mms43/scratch/mathcoder2/MathCoder2/test/scripts/read_result.py --in_dir $out_dir --subset_id $max_samples --shots $max_shots --seeds ${seeds[@]}" 117 | 118 | # CMD="$CMD --eval_methods ${prompt_formats[@]}" 119 | # # if [[ $prompt_format == "to_STL" ]]; then 120 | # # CMD="$CMD --eval_methods to_STL" 121 | # # fi 122 | 123 | # eval $CMD -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | # Supervised Finetuning and Direct Preference Optimization 2 | 3 | This directory contains the code for SFT and DPO. 4 | 5 | ## Installation 6 | 7 | Follow the instructions on [the main readme](../README.md) 8 | 9 | ## Training 10 | 11 | ### Step 1: save the paths of the jsonl files to be tokenized 12 | 13 | Modify `input_file_paths` in `run_training.sh`. Exch element of `input_file_paths` should be the path to a directory that contains `.jsonl` files with the texts to be trainined under the key `"text"`. For example: 14 | ``` 15 | datasets/unprocessed/sft/one_d_heat_train 16 | datasets/unprocessed/sft/one_d_wave_train 17 | ``` 18 | 19 | You can modify `out_file_path` in `run_training.sh` to change the output directory for the parquet files. Then run: 20 | 21 | ```shell 22 | bash run_training.sh 23 | ``` 24 | #### Tokenization 25 | This script will automatically tokenize the file when calling `train/scripts/tokenize_data.py`. The paths for the output tokenized parquet files in automatically be written to the file `file_names_to-be-trained.json`. Each element should be a path to a directory that contains parquet files (created from tokenization). 26 | 27 | #### Grouping the texts into given context length 28 | 29 | You can modify `max_len` in `run_training.sh` to change the context length. This context length should equal the context length you wish to use in training. `train/scripts/group_text.py` will get called. This outputs a single parquet file containing token indexes grouped into the given context length. 30 | 31 | ### Step 2: train the model 32 | 33 | For example: 34 | 35 | You can modify the training configs in `scripts/train.py`. The example script runs on a single node with 4 A100 GPUs. You may need to modify the script so that it runs on your cluster. Run the following command (it is also included in the `run_training.sh` script) 36 | 37 | ```shell 38 | bash train/scripts/train.sh 39 | ``` 40 | 41 | This script can train the model with validation, or validation can be turned off by simply running: 42 | ```sh 43 | train 44 | exit_status=$? 45 | ``` -------------------------------------------------------------------------------- /train/config/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": 0.1 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "warmup_min_lr": "auto", 18 | "warmup_max_lr": "auto", 19 | "warmup_num_steps": "auto", 20 | "total_num_steps": "auto" 21 | } 22 | }, 23 | "flops_profiler": { 24 | "enabled": true, 25 | "profile_step": 25, 26 | "module_depth": -1, 27 | "top_modules": 1, 28 | "detailed": false, 29 | "output_file": null 30 | }, 31 | "zero_optimization": { 32 | "stage": 3, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": true 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "train_batch_size": "auto", 46 | "train_micro_batch_size_per_gpu": "auto", 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /train/config/deepspeed_dpo.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": 0.1 12 | } 13 | }, 14 | 15 | "flops_profiler": { 16 | "enabled": true, 17 | "profile_step": 25, 18 | "module_depth": -1, 19 | "top_modules": 1, 20 | "detailed": false, 21 | "output_file": null 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "overlap_comm": true, 26 | "contiguous_gradients": true, 27 | "sub_group_size": 1e9, 28 | "reduce_bucket_size": "auto", 29 | "stage3_prefetch_bucket_size": "auto", 30 | "stage3_param_persistence_threshold": "auto", 31 | "stage3_max_live_parameters": 1e9, 32 | "stage3_max_reuse_distance": 1e9, 33 | "stage3_gather_16bit_weights_on_model_save": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "train_batch_size": "auto", 38 | "train_micro_batch_size_per_gpu": "auto", 39 | "wall_clock_breakdown": false 40 | } -------------------------------------------------------------------------------- /train/scripts/group_text.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import json 5 | import random 6 | import numpy as np 7 | from utils.loader import Processor 8 | from datasets import load_dataset, concatenate_datasets 9 | import argparse 10 | from transformers import ( 11 | AutoTokenizer, 12 | ) 13 | 14 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 15 | 16 | def find_common_keywords_update_outfile(file_paths, outfile, ignore_keywords=None): 17 | if ignore_keywords is None: 18 | ignore_keywords = [] 19 | # extract the basenames from the filenames 20 | basenames = [os.path.basename(file_name) for file_name in file_paths] 21 | # split the basenames into words by underscores, ignoring the file extension 22 | words_list = [set(os.path.splitext(basename)[0].split('_')) for basename in basenames] 23 | # filter the keywords if any 24 | words_list = [words - set(ignore_keywords) for words in words_list] 25 | # common words in all basenames 26 | common_words = set.intersection(*words_list) 27 | # modify the out_file basename to include these common words 28 | out_file_basename = os.path.splitext(outfile)[0] # remove the extension from outfile 29 | if common_words: 30 | out_file_basename += "_" + "_".join(sorted(common_words)) 31 | out_file_basename += os.path.splitext(outfile)[1] # add the original extension back 32 | return out_file_basename 33 | 34 | def balance_datasets(datasets, balance, total=None, dataset_names=None): 35 | """Balance the datasets based on the balance values. Assumes each dataset was pre-shuffled. 36 | # Args: 37 | datasets: list of datasets to balance. 38 | balance: list(int). The percentage of data to keep from each dataset. The sum of the list must be 1. 39 | total: int. The total number of data to keep. 40 | dataset_names: list(str). The names of the datasets. 41 | --- 42 | All of these (except point 4) assume `total < sum([len(dataset) for dataset in datasets])`. 43 | 44 | 1. `Total = None, Balance = [1]` (**default**): Keep everything. 45 | 2. `Total = int, Balance = [1]`: Keep total number of data. Will first sample uniformly from all datasets, then sample from the remaining datasets based on the total. 46 | 3. `Total = None, len(Balance) > 1`: Keep all data for the smallest dataset and the rest will determined based of balance. 47 | 4. `Total = int, len(Balance) > 1`: Keep total number of data, and balance based on the balance values. This will may double-sample random datapoints from datasets that are too small. 48 | """ 49 | assert sum(balance) == 1, "The balance values must sum to 1." 50 | if len(balance) > 1: 51 | assert len(datasets) == len(balance), "The number of datasets and `balance` values must be the same." 52 | assert dataset_names is None or len(datasets) == len(dataset_names), "The number of datasets and `dataset_names` must be the same." 53 | 54 | balanced_datasets = [] 55 | if total is None: 56 | if balance == [1]: 57 | # 1. Keep everything 58 | return datasets 59 | else: 60 | # 3. Keep all data for the smallest dataset and balance the rest 61 | min_i, min_dataset = min(enumerate(datasets), key=lambda x: len(x[1])) 62 | total_size = len(min_dataset) / balance[min_i] 63 | 64 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)): 65 | num_to_keep = int(total_size * proportion) 66 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}" 67 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}") 68 | balanced_datasets.append(dataset.select(range(num_to_keep))) 69 | else: 70 | if balance == [1]: 71 | # 2. Keep total number of data. Since assumes data was pre-shuffled, we can just take the first `total` number of data. 72 | balanced_datasets = [dataset.select(range(total)) for dataset in datasets] 73 | else: 74 | # 4. Keep total number of data, balance based on balance values 75 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)): 76 | num_to_keep = int(total * proportion) 77 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}" 78 | temp_dataset = None 79 | if len(dataset) < num_to_keep: 80 | full_repeats = num_to_keep // len(dataset) 81 | remainder = num_to_keep % len(dataset) 82 | for _ in range(full_repeats): 83 | if temp_dataset is None: 84 | temp_dataset = dataset 85 | else: 86 | temp_dataset = concatenate_datasets([temp_dataset, dataset]) 87 | indices_to_keep = random.sample(range(len(dataset)), remainder) 88 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)} (with {full_repeats} replications for all points and random sampling for the remainder)") 89 | temp_dataset = concatenate_datasets([temp_dataset, dataset.select(indices_to_keep)]) 90 | balanced_datasets.append(temp_dataset) 91 | else: 92 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}") 93 | indices = random.sample(range(len(dataset)), num_to_keep) 94 | balanced_datasets.append(dataset.select(indices)) 95 | return balanced_datasets 96 | 97 | 98 | 99 | def group_text(tokenizer, out_file, max_len, sft=False, no_grouping=False, no_padding=False, args=None): 100 | seed = 3407 101 | if tokenizer.pad_token is None: 102 | tokenizer.pad_token = tokenizer.eos_token 103 | tokenizer.pad_token_id = tokenizer.eos_token_id 104 | 105 | train_file_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "file_names_to-be-trained.json") 106 | with open(train_file_config, "r") as f: 107 | train_files = json.load(f) 108 | 109 | out_file = find_common_keywords_update_outfile(train_files, out_file) 110 | train_sets = [] 111 | for file in train_files: 112 | print("loading file:", file) 113 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train') 114 | 115 | # shuffle the dataset for the balance_datasets function 116 | _dataset = _dataset.shuffle(seed=seed) 117 | train_sets.append(_dataset) 118 | 119 | train_sets = balance_datasets(train_sets, args.balance, args.total, train_files) 120 | 121 | lengths = np.array([_set.shape[0] for _set in train_sets]) 122 | print(f"Data Lengths: {lengths}") 123 | train_sets = concatenate_datasets(train_sets) 124 | print(f"Total Length: {train_sets.shape[0]}") 125 | process_batch_size = min(1000, len(train_sets)) 126 | 127 | processor = Processor() 128 | train_sets = train_sets.shuffle(seed=seed) 129 | column_names = list(train_sets.features) 130 | 131 | if no_grouping: 132 | if no_padding: 133 | # only truncate 134 | train_sets = train_sets.map( 135 | processor.truncate, 136 | fn_kwargs={ 137 | "max_len": max_len, 138 | "dpo": args.dpo 139 | }, 140 | batched=True, 141 | load_from_cache_file=False, 142 | batch_size=process_batch_size, 143 | num_proc=96, 144 | desc=f"Checking texts for chunks > {max_len}. Will truncate. No padding. No grouping.", 145 | ) 146 | else: 147 | # truncate and add padding 148 | train_sets = train_sets.map( 149 | processor.truncate_and_add_padding, 150 | fn_kwargs={ 151 | "tokenizer": tokenizer, 152 | "max_len": max_len 153 | }, 154 | batched=True, 155 | load_from_cache_file=False, 156 | batch_size=process_batch_size, 157 | num_proc=96, 158 | desc=f"Checking texts for chunks > {max_len}. Will pad and truncate. No grouping.", 159 | ) 160 | else: 161 | # pretraining 162 | train_sets = train_sets.map( 163 | processor.group_texts, 164 | fn_kwargs={ 165 | "tokenizer": tokenizer, 166 | "max_len": max_len 167 | }, 168 | batched=True, 169 | load_from_cache_file=False, 170 | remove_columns=column_names, 171 | batch_size=process_batch_size, 172 | num_proc=96, 173 | desc=f"Grouping texts in chunks of {max_len}", 174 | ) 175 | train_sets.to_parquet(out_file) 176 | print(f"Saved to {out_file}") 177 | 178 | if __name__ == "__main__": 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument("--sft", action="store_true", help="Run for supervised fine-tuning data") 181 | parser.add_argument("--max_len", type=int, default=4096, help="Max context length (or the length to use for truncation and padding if applicable)") 182 | parser.add_argument("--dpo", action="store_true", help="Run for DPO data") 183 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to output grouped parquet file") 184 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Tokenizer path") 185 | parser.add_argument("--no_grouping", action="store_true", help="Whether to not group smaller sequences and truncate longer ones when grouping.") 186 | parser.add_argument("--no_padding", action="store_true", help="Whether to not pad shorter sequences.") 187 | parser.add_argument("--balance", type=float, nargs="+", default=[1], help="The percentage of data to keep from each dataset. The sum of the list must be 1.") 188 | parser.add_argument("--total", type=int, default=None, help="The total number of data to keep.") 189 | 190 | args = parser.parse_args() 191 | 192 | sft = args.sft 193 | model_path = args.model_path 194 | max_len = args.max_len # context length 195 | out_file = f"{args.out_file_path}/{'not-' if args.no_grouping else ''}grouped_MaxContext{max_len}.parquet" 196 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 197 | group_text(tokenizer, out_file, max_len, sft, no_grouping=args.no_grouping, no_padding=args.no_padding, args=args) -------------------------------------------------------------------------------- /train/scripts/group_text_dpo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import json 5 | import random 6 | import numpy as np 7 | from utils.loader import Processor 8 | from datasets import load_dataset, concatenate_datasets 9 | import argparse 10 | from transformers import ( 11 | AutoTokenizer, 12 | ) 13 | 14 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 15 | 16 | def find_common_keywords_update_outfile(file_paths, outfile, ignore_keywords=None): 17 | if ignore_keywords is None: 18 | ignore_keywords = [] 19 | # extract the basenames from the filenames 20 | basenames = [os.path.basename(file_name) for file_name in file_paths] 21 | # split the basenames into words by underscores, ignoring the file extension 22 | words_list = [set(os.path.splitext(basename)[0].split('_')) for basename in basenames] 23 | # filter the keywords if any 24 | words_list = [words - set(ignore_keywords) for words in words_list] 25 | # common words in all basenames 26 | common_words = set.intersection(*words_list) 27 | # modify the out_file basename to include these common words 28 | out_file_basename = os.path.splitext(outfile)[0] # remove the extension from outfile 29 | if common_words: 30 | out_file_basename += "_" + "_".join(sorted(common_words)) 31 | out_file_basename += os.path.splitext(outfile)[1] # add the original extension back 32 | return out_file_basename 33 | 34 | def balance_datasets(datasets, balance, total=None, dataset_names=None): 35 | """Balance the datasets based on the balance values. Assumes each dataset was pre-shuffled. 36 | # Args: 37 | datasets: list of datasets to balance. 38 | balance: list(int). The percentage of data to keep from each dataset. The sum of the list must be 1. 39 | total: int. The total number of data to keep. 40 | dataset_names: list(str). The names of the datasets. 41 | --- 42 | All of these (except point 4) assume `total < sum([len(dataset) for dataset in datasets])`. 43 | 44 | 1. `Total = None, Balance = [1]` (**default**): Keep everything. 45 | 2. `Total = int, Balance = [1]`: Keep total number of data. Will first sample uniformly from all datasets, then sample from the remaining datasets based on the total. 46 | 3. `Total = None, len(Balance) > 1`: Keep all data for the smallest dataset and the rest will determined based of balance. 47 | 4. `Total = int, len(Balance) > 1`: Keep total number of data, and balance based on the balance values. This will may double-sample random datapoints from datasets that are too small. 48 | """ 49 | assert sum(balance) == 1, "The balance values must sum to 1." 50 | if len(balance) > 1: 51 | assert len(datasets) == len(balance), "The number of datasets and `balance` values must be the same." 52 | assert dataset_names is None or len(datasets) == len(dataset_names), "The number of datasets and `dataset_names` must be the same." 53 | 54 | balanced_datasets = [] 55 | if total is None: 56 | if balance == [1]: 57 | # 1. Keep everything 58 | return datasets 59 | else: 60 | # 3. Keep all data for the smallest dataset and balance the rest 61 | min_i, min_dataset = min(enumerate(datasets), key=lambda x: len(x[1])) 62 | total_size = len(min_dataset) / balance[min_i] 63 | 64 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)): 65 | num_to_keep = int(total_size * proportion) 66 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}" 67 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}") 68 | balanced_datasets.append(dataset.select(range(num_to_keep))) 69 | else: 70 | if balance == [1]: 71 | # 2. Keep total number of data. Since assumes data was pre-shuffled, we can just take the first `total` number of data. 72 | balanced_datasets = [dataset.select(range(total)) for dataset in datasets] 73 | else: 74 | # 4. Keep total number of data, balance based on balance values 75 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)): 76 | num_to_keep = int(total * proportion) 77 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}" 78 | temp_dataset = None 79 | if len(dataset) < num_to_keep: 80 | full_repeats = num_to_keep // len(dataset) 81 | remainder = num_to_keep % len(dataset) 82 | for _ in range(full_repeats): 83 | if temp_dataset is None: 84 | temp_dataset = dataset 85 | else: 86 | temp_dataset = concatenate_datasets([temp_dataset, dataset]) 87 | indices_to_keep = random.sample(range(len(dataset)), remainder) 88 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)} (with {full_repeats} replications for all points and random sampling for the remainder)") 89 | temp_dataset = concatenate_datasets([temp_dataset, dataset.select(indices_to_keep)]) 90 | balanced_datasets.append(temp_dataset) 91 | else: 92 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}") 93 | indices = random.sample(range(len(dataset)), num_to_keep) 94 | balanced_datasets.append(dataset.select(indices)) 95 | return balanced_datasets 96 | 97 | 98 | 99 | def group_text(tokenizer, out_file, max_len, sft=False, no_grouping=False, no_padding=False, args=None): 100 | seed = 3407 101 | if tokenizer.pad_token is None: 102 | tokenizer.pad_token = tokenizer.eos_token 103 | tokenizer.pad_token_id = tokenizer.eos_token_id 104 | 105 | train_file_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "file_names_to-be-trained.json") 106 | with open(train_file_config, "r") as f: 107 | train_files = json.load(f) 108 | 109 | out_file = find_common_keywords_update_outfile(train_files, out_file) 110 | train_sets = [] 111 | for file in train_files: 112 | print("loading file:", file) 113 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train') 114 | 115 | # shuffle the dataset. this is important for the balance_datasets function 116 | _dataset = _dataset.shuffle(seed=seed) 117 | train_sets.append(_dataset) 118 | 119 | train_sets = balance_datasets(train_sets, args.balance, args.total, train_files) 120 | 121 | lengths = np.array([_set.shape[0] for _set in train_sets]) 122 | print(f"Data Lengths: {lengths}") 123 | train_sets = concatenate_datasets(train_sets) 124 | print(f"Total Length: {train_sets.shape[0]}") 125 | process_batch_size = min(1000, len(train_sets)) 126 | 127 | processor = Processor() 128 | train_sets = train_sets.shuffle(seed=seed) 129 | column_names = list(train_sets.features) 130 | 131 | if no_grouping: 132 | if no_padding: 133 | # only truncate 134 | train_sets = train_sets.map( 135 | processor.truncate, 136 | fn_kwargs={ 137 | "max_len": max_len 138 | }, 139 | batched=True, 140 | load_from_cache_file=False, 141 | batch_size=process_batch_size, 142 | num_proc=96, 143 | desc=f"Checking texts for chunks > {max_len}. Will truncate. No padding. No grouping.", 144 | ) 145 | else: 146 | # truncate and add padding 147 | train_sets = train_sets.map( 148 | processor.truncate_and_add_padding, 149 | fn_kwargs={ 150 | "tokenizer": tokenizer, 151 | "max_len": max_len 152 | }, 153 | batched=True, 154 | load_from_cache_file=False, 155 | batch_size=process_batch_size, 156 | num_proc=96, 157 | desc=f"Checking texts for chunks > {max_len}. Will pad and truncate. No grouping.", 158 | ) 159 | else: 160 | if sft: 161 | # group texts for seq-to-seq 162 | raise NotImplementedError("Seq-to-seq grouping used to be implemented but this code needs to be checked before using again.") 163 | else: 164 | # pretraining: the original grouping function 165 | train_sets = train_sets.map( 166 | processor.group_texts, 167 | fn_kwargs={ 168 | "tokenizer": tokenizer, 169 | "max_len": max_len 170 | }, 171 | batched=True, 172 | load_from_cache_file=False, 173 | remove_columns=column_names, 174 | batch_size=process_batch_size, 175 | num_proc=96, 176 | desc=f"Grouping texts in chunks of {max_len}", 177 | ) 178 | train_sets.to_parquet(out_file) 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument("--sft", action="store_true", help="Run for supervised fine-tuning data") 183 | parser.add_argument("--max_len", type=int, default=4096, help="Max context length (or the length to use for truncation and padding if applicable)") 184 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to output grouped parquet file") 185 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Model path") 186 | parser.add_argument("--no_grouping", action="store_true", help="Whether to not group smaller sequences and truncate longer ones when grouping.") 187 | parser.add_argument("--no_padding", action="store_true", help="Whether to not pad shorter sequences.") 188 | parser.add_argument("--balance", type=float, nargs="+", default=[1], help="The percentage of data to keep from each dataset. The sum of the list must be 1.") 189 | parser.add_argument("--total", type=int, default=None, help="The total number of data to keep.") 190 | 191 | args = parser.parse_args() 192 | 193 | sft = args.sft 194 | model_path = args.model_path 195 | max_len = args.max_len # context length 196 | out_file = f"{args.out_file_path}/{'not-' if args.no_grouping else ''}grouped_MaxContext{max_len}.parquet" 197 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 198 | group_text(tokenizer, out_file, max_len, sft, no_grouping=args.no_grouping, no_padding=args.no_padding, args=args) -------------------------------------------------------------------------------- /train/scripts/merge_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import logging 5 | from transformers import AutoModelForCausalLM 6 | from peft import PeftModel 7 | import argparse 8 | from transformers import ( 9 | AutoTokenizer, 10 | AutoModelForCausalLM, 11 | ) 12 | 13 | def main(args): 14 | model_path = args.model_path 15 | output_dir = args.output_dir 16 | adapter_path = args.adapter_path 17 | 18 | if args.gpus is not None: 19 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | print(f"Using device: {device}") 23 | 24 | print(f'Loading base model: {model_path}') 25 | tokenizer = AutoTokenizer.from_pretrained(model_path) 26 | base_model = AutoModelForCausalLM.from_pretrained( 27 | model_path, 28 | use_safetensors=True, 29 | trust_remote_code=True, 30 | device_map="auto",) 31 | 32 | if hasattr(base_model.config, 'torch_dtype'): 33 | dtype = base_model.config.torch_dtype 34 | print("Will save merged model following the base model dtype: ", dtype) 35 | else: 36 | dtype = torch.float32 37 | print("Couldn't find the base model dtype. Will save merged model in:", dtype) 38 | 39 | print(f"loading adapter model: {adapter_path}") 40 | model = PeftModel.from_pretrained(base_model, adapter_path) 41 | model.to(device) 42 | model = model.to(dtype) 43 | print(model) 44 | model.eval() 45 | 46 | print('Saving merged model') 47 | merged_model = model.merge_and_unload() 48 | merged_model_dir = os.path.join(output_dir, "merged_model") 49 | merged_model.save_pretrained(merged_model_dir, safe_serialization=True) 50 | 51 | print('Saving tokenizer') 52 | tokenizer.save_pretrained(merged_model_dir) 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description='Merge Model') 57 | parser.add_argument('--model_path', type=str, help='Path to the model') 58 | parser.add_argument('--adapter_path', type=str, help='Path to the adapter') 59 | parser.add_argument('--output_dir', type=str, help='Output directory') 60 | parser.add_argument("--gpus", type=str, default=None, help="Use to set the CUDA_VISIBLE_DEVICES environment variable.") 61 | args = parser.parse_args() 62 | try: 63 | main(args) 64 | except Exception as e: 65 | logging.exception(e) 66 | exit(-1) 67 | -------------------------------------------------------------------------------- /train/scripts/tokenize_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | from transformers import AutoTokenizer 6 | from datasets import load_dataset 7 | from tqdm import tqdm 8 | from utils.loader import Processor 9 | import random 10 | 11 | 12 | 13 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 14 | 15 | 16 | def create_file_names(file_paths, out_file): 17 | with open(out_file, "w") as f: 18 | json.dump(file_paths, f) 19 | 20 | 21 | def load_jsonl(in_file): 22 | with open(in_file, "r", encoding="utf-8") as f: 23 | datas = [json.loads(line) for line in f] 24 | return datas 25 | 26 | 27 | def save_jsonl(datas, out_file): 28 | with open(out_file, "w", encoding="utf-8") as f: 29 | for data in datas: 30 | f.write(json.dumps(data, ensure_ascii=False) + "\n") 31 | 32 | def get_dataset_class(keywords): 33 | """Get the dataset class from the keywords in the provided dataset name. One of "heat", "wave".""" 34 | if "heat" in keywords: 35 | return "heat" 36 | elif "wave" in keywords: 37 | return "wave" 38 | else: 39 | raise ValueError(f"Dataset {keywords} does not contain one of 'heat' or 'wave'.") 40 | 41 | 42 | def name_output_file(in_file, out_dir): 43 | """Create a name for the output file using keywords in the base-file name and base-directory name. 44 | Does not include the .parquet file extension.""" 45 | parent_folder = os.path.basename(os.path.dirname(in_file)) 46 | basename_without_extension = os.path.splitext(os.path.basename(in_file))[0] 47 | # split into words 48 | parent_words = parent_folder.split('_') 49 | basename_words = basename_without_extension.split('_') 50 | # combine and remove duplicates 51 | unique_keywords = '_'.join(list(dict.fromkeys(parent_words + basename_words))) 52 | dataset_class = get_dataset_class(unique_keywords) 53 | out_file = os.path.join(out_dir, "tokenized_" + "".join(unique_keywords) + ".parquet") 54 | return str(out_file), dataset_class 55 | 56 | 57 | 58 | def prosess_tokenize(in_file, out_dir, tokenizer, prompt_format, sft=False, truncate=False, padding=False): 59 | """Tokenize the data in the input file, and save the tokenized data in the output directory as a parquet file. 60 | Will returned the name of the saved file.""" 61 | out_file, dataset_class = name_output_file(in_file, out_dir) 62 | if os.path.isfile(out_file): 63 | print(f"Warning: {out_file} already exists. Will replace") 64 | os.remove(out_file) 65 | 66 | processor = Processor() 67 | _dataset = load_dataset(in_file.split(".")[-1] if in_file.split(".")[-1] != "jsonl" else "json", data_files=in_file, split='train') 68 | process_batch_size = min(2, len(_dataset)) 69 | 70 | print(_dataset) 71 | _dataset = _dataset.map( 72 | processor.create_prompt, 73 | fn_kwargs={ 74 | "prompt_format": prompt_format, 75 | "dataset_class": dataset_class 76 | }, 77 | batched=True, 78 | load_from_cache_file=False, 79 | batch_size=process_batch_size, 80 | num_proc=64, 81 | desc=f"Creating prompt in format: {prompt_format}, in the standard data'{{'text', 'labels'}}'.", 82 | ) 83 | 84 | columns_to_remove = [feature for feature in _dataset.features if feature not in ["text", "labels"]] 85 | _dataset = _dataset.remove_columns(columns_to_remove) 86 | 87 | print("Example from the processed dataset:") 88 | random_indices = random.sample(range(len(_dataset)), 2) 89 | for i in random_indices: 90 | print("\nexample", i, ":") 91 | print(_dataset[i]) 92 | 93 | 94 | column_names = list(_dataset.features) 95 | if sft: 96 | _dataset = _dataset.map( 97 | processor.process_tokenize_sft, 98 | fn_kwargs={ 99 | "tokenizer": tokenizer, 100 | "truncate": truncate, 101 | "padding": padding 102 | }, 103 | batched=True, 104 | load_from_cache_file=False, 105 | remove_columns=column_names, 106 | batch_size=process_batch_size, 107 | num_proc=64, 108 | desc=f"Running tokenizer on SFT dataset. padding: {padding}, truncation: {truncate}.", 109 | ) 110 | else: 111 | _dataset = _dataset.map( 112 | processor.process_tokenize, 113 | fn_kwargs={ 114 | "tokenizer": tokenizer, 115 | "truncate": truncate, 116 | "padding": padding 117 | }, 118 | batched=True, 119 | load_from_cache_file=False, 120 | remove_columns=column_names, 121 | batch_size=process_batch_size, 122 | num_proc=64, 123 | desc="Running tokenizer on pretraining dataset", 124 | ) 125 | 126 | print(_dataset) 127 | _dataset.to_parquet(out_file) 128 | return str(out_file) 129 | 130 | def get_all_files_in_directories(directories): 131 | file_paths = [] 132 | for in_dir in directories: 133 | dir_list = os.listdir(in_dir) 134 | for file_name in dir_list: 135 | full_path = os.path.join(in_dir, file_name) 136 | # ignore markdown/jsonl description files of datasets 137 | if os.path.isfile(full_path) and full_path.split(".")[-1] != "md" and "description" not in file_name.split(".")[0]: 138 | file_paths.append(full_path) 139 | 140 | # send file_paths out for debugging 141 | current_dir_path = os.path.dirname(os.path.realpath(__file__)) 142 | out_file = os.path.join(current_dir_path, f"file-names_to-be-tokenized.json") 143 | with open(out_file, "w") as f: 144 | json.dump(file_paths, f) 145 | 146 | return file_paths 147 | 148 | def main(): 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--paths", nargs="+", help="List of paths directories containing json files of training data that will be preprocessed into individual parquet files for training.") 151 | parser.add_argument("--sft", action="store_true", help="""Tokenize as supervised fine-tuning for text-label data. For SFT, output data will contain: 152 | data["input_ids"], data["labels"], where data["labels"] = len(data["input_ids"]) * [-100] + `true labels`. [-100] is the default ignore token for cross-entropy loss. 153 | If not set, then the output data will use the pretraining objective: data["input_ids"], data["labels"], where data["labels"] = data["input_ids"].""") 154 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to directory where to store the output parquet file") 155 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Path to tokenizer.") 156 | parser.add_argument("--prompt_format", choices=["to_python_no_STL", "to_STL", "to_python_GT_STL", "to_python_given_STL", "to_python_misaligned"], default="to_python_no_STL", help="""Choose the prompt format, the Dataset requires the corresponding keys. 157 | `to_python_no_STL`: No STL is given in the prompt. The model is given natural language and should directly convert to code. Required keys: 'nl', 'python'. 158 | `to_STL`: No STL is given in the prompt. The model is given natural language and should directly convert to STL. Required keys: 'nl', 'sstl'. 159 | `to_python_GT_STL`: Ground truth STL is given in the prompt. The model is given natural language and STL and should directly convert to code. Required keys: 'nl', 'sstl', 'python'. 160 | `to_python_given_STL`: Must provide STL to be given in the prompt. The model is given natural language and STL and should directly convert to code. Required keys: 'nl', 'predicted_sstl', 'python'. 161 | `to_python_misaligned`: Must provide STL to be given in the prompt. The model is given natural language and STL which do not necessarily describe the same problem and should directly convert to code. Required keys: 'nl', 'predicted_sstl', 'python'. 162 | """) 163 | 164 | args = parser.parse_args() 165 | 166 | model_path = args.model_path 167 | tokenizer = AutoTokenizer.from_pretrained(model_path) 168 | 169 | file_paths = get_all_files_in_directories(args.paths) 170 | 171 | out_dir = args.out_file_path 172 | if not os.path.exists(out_dir): 173 | os.makedirs(out_dir) 174 | 175 | paths_to_be_trained = [] 176 | for in_file in file_paths: 177 | path = prosess_tokenize(in_file, out_dir, tokenizer, args.prompt_format, args.sft) 178 | if path in paths_to_be_trained: 179 | raise ValueError(f"{path} was made twice in this script. The fist has already been overwritten in the function `process_tokenize`. Check the dataset names and try again.") 180 | else: 181 | paths_to_be_trained.append(path) 182 | 183 | 184 | print(os.getcwd()) 185 | dir_path = os.path.dirname(os.path.realpath(__file__)) 186 | out_file = os.path.join(dir_path, "file_names_to-be-trained.json") 187 | create_file_names(paths_to_be_trained, out_file) 188 | 189 | 190 | if __name__ == "__main__": 191 | main() -------------------------------------------------------------------------------- /train/scripts/tokenize_data_dpo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | from transformers import AutoTokenizer 6 | from datasets import load_dataset 7 | from tqdm import tqdm 8 | from utils.loader import Processor 9 | import random 10 | 11 | 12 | 13 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) 14 | 15 | 16 | def create_file_names(file_paths, out_file): 17 | # Put into file_paths 18 | with open(out_file, "w") as f: 19 | json.dump(file_paths, f) 20 | 21 | 22 | def load_jsonl(in_file): 23 | with open(in_file, "r", encoding="utf-8") as f: 24 | datas = [json.loads(line) for line in f] 25 | return datas 26 | 27 | 28 | def save_jsonl(datas, out_file): 29 | with open(out_file, "w", encoding="utf-8") as f: 30 | for data in datas: 31 | f.write(json.dumps(data, ensure_ascii=False) + "\n") 32 | 33 | def get_dataset_class(keywords): 34 | """Get the dataset class from the keywords in the provided dataset name. One of "heat", "wave".""" 35 | if "heat" in keywords: 36 | return "heat" 37 | elif "wave" in keywords: 38 | return "wave" 39 | else: 40 | return None 41 | # raise ValueError(f"Dataset {keywords} does not contain one of 'heat' or 'wave'.") 42 | 43 | 44 | def name_output_file(in_file, out_dir): 45 | """Create a name for the output file using keywords in the base-file name and base-directory name. 46 | Does not include the .parquet file extension.""" 47 | parent_folder = os.path.basename(os.path.dirname(in_file)) 48 | basename_without_extension = os.path.splitext(os.path.basename(in_file))[0] 49 | # Split into words 50 | parent_words = parent_folder.split('_') 51 | basename_words = basename_without_extension.split('_') 52 | # Combine and remove duplicates 53 | unique_keywords = '_'.join(list(dict.fromkeys(parent_words + basename_words))) 54 | dataset_class = get_dataset_class(unique_keywords) 55 | out_file = os.path.join(out_dir, "tokenized_" + "".join(unique_keywords) + ".parquet") 56 | return str(out_file), dataset_class 57 | 58 | 59 | 60 | def prosess_tokenize(in_file, out_dir, tokenizer, prompt_format, truncate=False, padding=False): 61 | """Tokenize the data in the input file, and save the tokenized data in the output directory as a parquet file. 62 | Will returned the name of the saved file.""" 63 | out_file, dataset_class = name_output_file(in_file, out_dir) 64 | if os.path.isfile(out_file): 65 | print(f"Warning: {out_file} already exists. Will replace") 66 | os.remove(out_file) 67 | 68 | processor = Processor() 69 | _dataset = load_dataset(in_file.split(".")[-1] if in_file.split(".")[-1] != "jsonl" else "json", data_files=in_file, split='train') 70 | process_batch_size = min(2, len(_dataset)) 71 | 72 | print(_dataset) 73 | _dataset = _dataset.map( 74 | processor.create_prompt_dpo, 75 | fn_kwargs={ 76 | "prompt_format": prompt_format 77 | }, 78 | batched=True, 79 | load_from_cache_file=False, 80 | batch_size=process_batch_size, 81 | num_proc=64, 82 | desc=f"Creating prompt in format: {prompt_format}, in the standard data'{{'prompt', 'chosen', 'rejected'}}'.", 83 | ) 84 | 85 | columns_to_remove = [feature for feature in _dataset.features if feature not in ["prompt", "chosen", "rejected"]] 86 | _dataset = _dataset.remove_columns(columns_to_remove) 87 | 88 | print("Example from the processed dataset:") 89 | random_indices = random.sample(range(len(_dataset)), 2) 90 | for i in random_indices: 91 | print("\nexample", i, ":") 92 | print(_dataset[i]) 93 | 94 | 95 | column_names = list(_dataset.features) 96 | ## dpo 97 | # DPOTrainer will do the tokenization of the prompt, chosen, and rejected in the training loop. 98 | _dataset = _dataset.map( 99 | processor.process_tokenize_dpo, 100 | fn_kwargs={ 101 | "tokenizer": tokenizer, 102 | }, 103 | batched=True, 104 | load_from_cache_file=False, 105 | remove_columns=column_names, 106 | batch_size=process_batch_size, 107 | num_proc=64, 108 | desc="Running tokenizer on DPO dataset", 109 | ) 110 | 111 | print(_dataset) 112 | _dataset.to_parquet(out_file) 113 | return str(out_file) 114 | 115 | def get_all_files_in_directories(directories): 116 | file_paths = [] 117 | for in_dir in directories: 118 | dir_list = os.listdir(in_dir) 119 | for file_name in dir_list: 120 | full_path = os.path.join(in_dir, file_name) 121 | # ignore markdown/jsonl description files of datasets 122 | if os.path.isfile(full_path) and full_path.split(".")[-1] != "md" and "description" not in file_name.split(".")[0]: 123 | file_paths.append(full_path) 124 | 125 | # Put file_paths out for debugging 126 | current_dir_path = os.path.dirname(os.path.realpath(__file__)) 127 | out_file = os.path.join(current_dir_path, f"file-names_to-be-tokenized.json") 128 | with open(out_file, "w") as f: 129 | json.dump(file_paths, f) 130 | 131 | return file_paths 132 | 133 | def main(): 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument("--paths", nargs="+", help="List of paths directories containing json files of training data that will be preprocessed into individual parquet files for training.") 136 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to directory where to store the output parquet file") 137 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Path to the model that will be trained and whose tokenizer will be use here to preprocess the data.") 138 | parser.add_argument("--prompt_format", choices=["DPO"], default="DPO", help="""Choose the prompt format, the Dataset requires corresponding keys. 139 | `DPO`: No STL is given in the prompt. The model is given natural language and should directly convert to STL. Required keys: 'nl', 'sstl'. 140 | """) 141 | ## Padding and truncating is removed from this code because it should be done in the grouping stage (which can also combine other datasets so they all have the same context length). 142 | # parser.add_argument("--truncate", action="store_true", help="Whether to truncate the input and labels in tokenizer if either exceeds the context length of the model. If you plan to group the training data together then do not truncate here. Truncation will happen during grouping.") 143 | # parser.add_argument("--padding", action="store_true", help="Whether to pad the input and labels in tokenizer if either is shorter than the context length of the model. If you plan to group the training data together then do not pad here. Padding will happen during grouping. If you don't plan to group the data, padding here is equivalent to padding during grouping stage (just skipping the grouping part).") 144 | 145 | args = parser.parse_args() 146 | 147 | model_path = args.model_path 148 | tokenizer = AutoTokenizer.from_pretrained(model_path) 149 | 150 | file_paths = get_all_files_in_directories(args.paths) 151 | 152 | out_dir = args.out_file_path 153 | if not os.path.exists(out_dir): 154 | os.makedirs(out_dir) 155 | 156 | paths_to_be_trained = [] 157 | for in_file in file_paths: 158 | path = prosess_tokenize(in_file, out_dir, tokenizer, args.prompt_format) 159 | if path in paths_to_be_trained: 160 | raise ValueError(f"{path} was made twice in this script. The fist has already been overwritten in the function `process_tokenize`. Check the dataset names and try again.") 161 | else: 162 | paths_to_be_trained.append(path) 163 | 164 | 165 | print(os.getcwd()) 166 | dir_path = os.path.dirname(os.path.realpath(__file__)) 167 | out_file = os.path.join(dir_path, "file_names_to-be-trained.json") 168 | create_file_names(paths_to_be_trained, out_file) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /train/scripts/train.sh: -------------------------------------------------------------------------------- 1 | python --version 2 | 3 | export NCCL_DEBUG=WARN 4 | 5 | export NCCL_IB_TIMEOUT=22 6 | export NCCL_IB_RETRY_CNT=13 7 | export NCCL_IB_AR_THRESHOLD=0 8 | 9 | wandb login $WANDB_TOKEN 10 | 11 | # Constants for paths 12 | OUTS="outputs" 13 | MODELS="models" 14 | DATA="datasets/sft" 15 | 16 | # variables 17 | base_run_name=test 18 | max_steps=4 19 | per_gpu_batch_size=8 20 | accum_grad=8 21 | epoch_size=2 22 | 23 | save_meged_model=True 24 | 25 | # ["to_python_no_STL", "to_STL", "to_python_GT_STL", "to_python_given_STL", "to_python_misaligned""] 26 | prompt_format=to_STL 27 | 28 | 29 | ### specify path to base models for which an adapter will be attached. 30 | base_model="$MODELS/MathCoder2-DeepSeekMath-7B" 31 | 32 | 33 | ### Distributed settings: 34 | # export MASTER_ADDR=$(hostname -i) 35 | # export MASTER_PORT=29500 36 | export WORLD_SIZE=1 37 | export RANK=0 38 | export GPUPerNode=2 39 | export CUDA_VISIBLE_DEVICES=2,3 40 | 41 | 42 | 43 | run_name=${base_run_name}_${prompt_format}_s${max_steps} 44 | 45 | datafile_name=${prompt_format}/not-grouped_MaxContext4096_tokenized_train.parquet 46 | 47 | if [ "$prompt_format" = "to_python_given_STL" ]; then 48 | datafile_name=${prompt_format}/not-grouped_MaxContext4096_STL_predictions_to_tokenized_train.parquet 49 | elif [ "$prompt_format" = "to_python_misaligned" ]; then 50 | DATA="/localhome/mms43/scratch/mathcoder2/datasets/pdecontrol/dpo" 51 | datafile_name=${prompt_format}/not-grouped_MaxContext4096_misaligned_aligned_tokenized_train.parquet 52 | fi 53 | 54 | 55 | 56 | find_latest_checkpoint() { 57 | local checkpoint_dir=$1 58 | latest_checkpoint=$(ls -d ${checkpoint_dir}/checkpoint-* 2>/dev/null | sort -V | tail -n 1) 59 | echo $latest_checkpoint 60 | } 61 | 62 | 63 | train() { 64 | # directory where checkpoints are stored 65 | CHECKPOINT_DIR=$OUTS/$run_name 66 | 67 | export WANDB_DIR=$OUTS/$run_name 68 | export WANDB_RUN_ID=$run_name 69 | 70 | cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES OMP_NUM_THREADS=1 torchrun --nnodes $WORLD_SIZE --node_rank $RANK --nproc_per_node $GPUPerNode --rdzv-backend=c10d --rdzv-endpoint=localhost:0 train/train_finetune.py \ 71 | --ddp_timeout 360000 \ 72 | --train_parquet_file ${DATA}/${datafile_name} \ 73 | --run_name $run_name \ 74 | --output_dir $OUTS/${run_name} \ 75 | --no_timestamps \ 76 | --dataloader_num_workers 2 \ 77 | --max_len 4096 \ 78 | --max_steps $max_steps \ 79 | --num_train_epochs -1 \ 80 | --save_steps 5 \ 81 | --save_total_limit 2 \ 82 | --step_save_interval $epoch_size \ 83 | --warmup_steps 50 \ 84 | --logging_steps 10 \ 85 | --learning_rate 4e-5 \ 86 | --weight_decay 0.1 \ 87 | --lr_scheduler_type cosine \ 88 | --per_device_train_batch_size $per_gpu_batch_size \ 89 | --gradient_accumulation_steps $accum_grad \ 90 | --seed 3407 \ 91 | --deepspeed train/config/deepspeed.json \ 92 | --bf16 \ 93 | --stream \ 94 | --do_train \ 95 | --gradient_checkpointing \ 96 | --report_to wandb \ 97 | --lora_r 64 \ 98 | --lora_alpha 256 \ 99 | --lora_dropout 0.1 \ 100 | --model_cfg $base_model" 101 | 102 | if [ -n "$LATEST_CHECKPOINT" ]; then 103 | echo "Will try to load checkpoint from $LATEST_CHECKPOINT" 104 | export WANDB_RESUME="auto" 105 | cmd="$cmd --resume_from $LATEST_CHECKPOINT" 106 | fi 107 | 108 | if [ "$external_validation" == "True" ]; then 109 | cmd="$cmd --external_validation" 110 | fi 111 | 112 | eval $cmd 113 | } 114 | 115 | 116 | 117 | ############## Training loop with validation ################### 118 | ################################################################ 119 | 120 | validate() { 121 | local checkpoint_dir=$1 122 | python train/scripts/merge_model.py --adapter_path $checkpoint_dir --output_dir $checkpoint_dir --model_path $base_model --gpus $CUDA_VISIBLE_DEVICES 123 | echo "validating" 124 | python train/validate.py --checkpoint_dir $checkpoint_dir --base_model $base_model --cuda_visible_devices $CUDA_VISIBLE_DEVICES --validation_data_dir /localhome/mms43/scratch/mathcoder2/MathCoder2/test/PDEcontrol/validation_data --wandb_dir $OUTS/${run_name} 125 | } 126 | 127 | 128 | 129 | validate_interval=$epoch_size 130 | iterations=$((max_steps / validate_interval)) 131 | epoch_size=0 132 | external_validation=True 133 | for ((r=0; r/dev/null | sort -V | tail -n 1) 49 | echo $latest_checkpoint 50 | } 51 | 52 | 53 | train() { 54 | # directory where checkpoints are stored 55 | CHECKPOINT_DIR=$OUTS/$run_name 56 | 57 | export WANDB_DIR=$OUTS/$run_name 58 | export WANDB_RUN_ID=$run_name 59 | 60 | cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES OMP_NUM_THREADS=1 torchrun --nnodes $WORLD_SIZE --node_rank $RANK --nproc_per_node $GPUPerNode --rdzv-backend=c10d --rdzv-endpoint=localhost:0 train/train_dpo.py \ 61 | --ddp_timeout 360000 \ 62 | --train_parquet_file ${DATA}/${datafile_name} \ 63 | --run_name $run_name \ 64 | --output_dir $OUTS/${run_name} \ 65 | --no_timestamps \ 66 | --dataloader_num_workers 2 \ 67 | --max_len 4096 \ 68 | --max_steps $max_steps \ 69 | --num_train_epochs -1 \ 70 | --save_steps 20 \ 71 | --save_total_limit 2 \ 72 | --step_save_interval $epoch_size \ 73 | --warmup_steps 50 \ 74 | --logging_steps 10 \ 75 | --learning_rate 4e-5 \ 76 | --weight_decay 0.1 \ 77 | --lr_scheduler_type cosine \ 78 | --per_device_train_batch_size $per_gpu_batch_size \ 79 | --gradient_accumulation_steps $accum_grad \ 80 | --seed 3407 \ 81 | --bf16 \ 82 | --stream \ 83 | --do_train \ 84 | --report_to wandb \ 85 | --lora_r 64 \ 86 | --lora_alpha 256 \ 87 | --lora_dropout 0.1 \ 88 | --adapter_cfg $translator_adapter \ 89 | --model_cfg $base_model" 90 | 91 | if [ -n "$LATEST_CHECKPOINT" ]; then 92 | echo "Will try to load checkpoint from $LATEST_CHECKPOINT" 93 | export WANDB_RESUME="auto" 94 | cmd="$cmd --resume_from $LATEST_CHECKPOINT" 95 | fi 96 | 97 | if [ "$external_validation" == "True" ]; then 98 | cmd="$cmd --external_validation" 99 | fi 100 | 101 | eval $cmd 102 | } 103 | 104 | 105 | ############## Training loop with validation ################### 106 | ################################################################ 107 | validate() { 108 | local checkpoint_dir=$1 109 | python train/scripts/merge_model.py --adapter_path $checkpoint_dir --output_dir $checkpoint_dir --model_path $base_model --gpus $CUDA_VISIBLE_DEVICES 110 | echo "validating" 111 | python train/validate.py --checkpoint_dir $checkpoint_dir --base_model $base_model --cuda_visible_devices $CUDA_VISIBLE_DEVICES --validation_data_dir /localhome/mms43/scratch/mathcoder2/MathCoder2/test/PDEcontrol/validation_data --wandb_dir $OUTS/${run_name} 112 | } 113 | 114 | validate_interval=$epoch_size 115 | iterations=$((max_steps / validate_interval)) 116 | epoch_size=0 117 | external_validation=True 118 | for ((r=0; r %s", keystr, val) 33 | logger.info(f"******************* {name} *******************") 34 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import random 5 | import shutil 6 | import logging 7 | import transformers 8 | import torch 9 | import json 10 | 11 | import numpy as np 12 | import torch.distributed as dist 13 | 14 | from datetime import datetime 15 | from dataclasses import field, dataclass 16 | from utils.util import set_logger, print_args 17 | 18 | from utils.loader import Processor 19 | from utils.trainer import LoggerCallback, RemoveStateCallback 20 | 21 | from transformers.tokenization_utils import AddedToken 22 | from datasets import load_dataset, concatenate_datasets 23 | from transformers import ( 24 | Trainer, 25 | set_seed, 26 | AutoConfig, 27 | AutoTokenizer, 28 | HfArgumentParser, 29 | TrainingArguments, 30 | LlamaForCausalLM, 31 | AutoModelForCausalLM, 32 | default_data_collator 33 | ) 34 | 35 | logger = logging.getLogger() 36 | 37 | @dataclass 38 | class DataArguments: 39 | 40 | no_timestamps: bool = field(default=False) 41 | no_load_model_pararmeters: bool = field(default=False) 42 | 43 | resume_step: int = field(default=None) 44 | resume_batch_size: int = field(default=None) 45 | 46 | # data 47 | train_parquet_file: str = field(default=None) 48 | train_file_config: str = field(default=None) 49 | train_dataset: str = field(default=None) 50 | train_coef: str = field(default=None) 51 | delete_long_sample: bool = field(default=False) 52 | 53 | # process 54 | max_len: int = field(default=2048) 55 | preprocessing_num_workers: int = field(default=64) 56 | 57 | # model 58 | model_cfg: str = field(default="data/models/starcoder") 59 | flash_attention: bool = field(default=False) 60 | 61 | resume_from: str = field(default=None) 62 | 63 | # output 64 | stream: bool = field(default=False) 65 | 66 | 67 | def train(): 68 | parser = HfArgumentParser((DataArguments, TrainingArguments)) 69 | 70 | data_args, training_args = parser.parse_args_into_dataclasses() 71 | 72 | training_args._frozen = False 73 | 74 | if not data_args.no_timestamps: 75 | timestr = datetime.now().strftime("-%m%d%H%M") 76 | training_args.output_dir = training_args.output_dir + timestr 77 | 78 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logging') 79 | 80 | if os.path.exists(training_args.output_dir): 81 | if training_args.overwrite_output_dir: 82 | if training_args.process_index == 0: 83 | shutil.rmtree(training_args.output_dir) 84 | else: 85 | raise ValueError(f"Output directory ({training_args.output_dir}) already exists. Use --overwrite_output_dir to overcome.") 86 | 87 | if training_args.world_size > 1: 88 | dist.barrier() 89 | 90 | if training_args.process_index == 0: 91 | os.makedirs(training_args.output_dir) 92 | 93 | if training_args.world_size > 1: 94 | dist.barrier() 95 | 96 | set_seed(training_args.seed) 97 | 98 | node_rank = int(os.getenv('GROUP_RANK', '0')) 99 | 100 | for _logger in [logger, transformers.utils.logging.get_logger(), logging.getLogger('DeepSpeed')]: 101 | set_logger(_logger, training_args.local_rank, data_args.stream, os.path.join(training_args.output_dir, f'log-node-{node_rank}.log')) 102 | 103 | logger.warning("Device: %s, rank: %s, world size: %s", training_args.device, training_args.process_index, training_args.world_size) 104 | 105 | if training_args.world_size > 1: 106 | dist.barrier() 107 | 108 | print_args(data_args, 'Data Arguments') 109 | print_args(training_args, 'Training Arguments') 110 | 111 | processor = Processor() 112 | 113 | config = AutoConfig.from_pretrained(data_args.model_cfg, trust_remote_code=True) 114 | config._attn_implementation = "flash_attention_2" 115 | config.use_cache = False 116 | 117 | if data_args.no_load_model_pararmeters: 118 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) 119 | else: 120 | model = AutoModelForCausalLM.from_pretrained(data_args.model_cfg, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) 121 | # tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, legacy=False, use_fast=True) 122 | tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, trust_remote_code=True) 123 | 124 | if tokenizer.pad_token is None: 125 | tokenizer.pad_token = tokenizer.eos_token 126 | tokenizer.pad_token_id = tokenizer.eos_token_id 127 | 128 | if data_args.train_parquet_file is not None: 129 | train_sets = load_dataset("parquet", data_files=data_args.train_parquet_file, split='train') 130 | elif data_args.train_file_config is not None: 131 | with open(data_args.train_file_config, "r") as f: 132 | train_files = json.load(f) 133 | 134 | train_sets = [] 135 | for file in train_files: 136 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train') 137 | train_sets.append(_dataset) 138 | 139 | lengths = np.array([_set.shape[0] for _set in train_sets]) 140 | logger.info("Data Lengths: %s", lengths) 141 | 142 | for i in range(1, len(train_sets)): 143 | train_sets[i] = train_sets[i].cast(train_sets[0].features) 144 | 145 | train_sets = concatenate_datasets(train_sets) 146 | else: 147 | raise ValueError("Should provide either 'train_dataset' or 'train_file_config'") 148 | 149 | logger.info('Total %d case', len(train_sets)) 150 | 151 | process_batch_size = min(1000, len(train_sets)) 152 | 153 | with training_args.main_process_first(desc="Log a few random samples from the training set"): 154 | for index in random.sample(range(len(train_sets)), 3): 155 | logger.info( 156 | "Sample %d of the raw training set:\n\ninput_tokens: %s\n\n%s", 157 | index, 158 | train_sets[index]['input_ids'], 159 | tokenizer.convert_ids_to_tokens(train_sets[index]['input_ids']), 160 | ) 161 | 162 | train_sets = train_sets.shuffle(seed=training_args.seed) 163 | column_names = list(train_sets.features) 164 | if data_args.train_parquet_file is None: 165 | with training_args.main_process_first(desc="dataset map grouping"): 166 | train_sets = train_sets.map( 167 | processor.group_texts, 168 | fn_kwargs={ 169 | "tokenizer": tokenizer, 170 | "max_len": data_args.max_len 171 | }, 172 | batched=True, 173 | load_from_cache_file=False, 174 | remove_columns=column_names, 175 | batch_size=process_batch_size, 176 | num_proc=data_args.preprocessing_num_workers, 177 | desc=f"Grouping texts in chunks of {data_args.max_len}", 178 | ) 179 | 180 | with training_args.main_process_first(desc="Log a few random samples from the grouped training set"): 181 | for index in random.sample(range(len(train_sets)), 3): 182 | logger.info( 183 | "Sample %d of the merged training set:\n\n%s", 184 | index, tokenizer.decode(train_sets[index]['input_ids']) 185 | ) 186 | 187 | if data_args.resume_step is not None and data_args.resume_batch_size is not None: 188 | train_sets = train_sets[data_args.resume_step * data_args.resume_batch_size:] 189 | training_args.max_steps -= data_args.resume_step 190 | new_warmup_steps = max(0, training_args.warmup_steps - data_args.resume_step) 191 | new_learning_rate -= max(0, data_args.resume_step - training_args.warmup_steps) * (training_args.learning_rate / training_args.max_steps - training_args.warmup_steps) 192 | training_args.warmup_steps = new_warmup_steps 193 | training_args.learning_rate = new_learning_rate 194 | 195 | trainer = Trainer( 196 | args=training_args, 197 | model=model, 198 | tokenizer=tokenizer, 199 | train_dataset=train_sets, 200 | callbacks=[LoggerCallback, RemoveStateCallback], 201 | data_collator=default_data_collator, 202 | ) 203 | 204 | trainer.train(resume_from_checkpoint=data_args.resume_from) 205 | 206 | trainer.save_model(training_args.output_dir) 207 | 208 | if __name__ == "__main__": 209 | 210 | try: 211 | train() 212 | except Exception as e: 213 | logging.exception(e) 214 | exit(-1) 215 | -------------------------------------------------------------------------------- /train/train_dpo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import inspect 4 | import os 5 | import random 6 | import shutil 7 | import logging 8 | import transformers 9 | import torch 10 | 11 | import torch.distributed as dist 12 | 13 | from datetime import datetime 14 | from dataclasses import field, dataclass 15 | from utils.util import set_logger, print_args 16 | 17 | from utils.loader import Processor 18 | from utils.trainer import LoggerCallback, RemoveStateCallback, StepCheckpointCallback, DataCollatorForDPODataset 19 | 20 | from datasets import load_dataset 21 | from transformers import ( 22 | set_seed, 23 | AutoConfig, 24 | AutoTokenizer, 25 | HfArgumentParser, 26 | TrainingArguments, 27 | AutoModelForCausalLM, 28 | ) 29 | 30 | from peft import LoraConfig, PeftModel 31 | from trl import DPOTrainer, DPOConfig 32 | 33 | logger = logging.getLogger() 34 | 35 | @dataclass 36 | class DataArguments: 37 | 38 | no_timestamps: bool = field(default=False) 39 | 40 | 41 | # data 42 | train_parquet_file: str = field(default=None) 43 | train_file_config: str = field(default=None) 44 | train_dataset: str = field(default=None) 45 | train_coef: str = field(default=None) 46 | delete_long_sample: bool = field(default=False) 47 | 48 | # process 49 | max_len: int = field(default=4096) 50 | preprocessing_num_workers: int = field(default=64) 51 | 52 | # model 53 | model_cfg: str = field(default="data/models/starcoder") 54 | adapter_cfg: str = field(default="data/models/starcoder") 55 | flash_attention: bool = field(default=False) 56 | 57 | # LoRA model 58 | save_merged_model: bool = field(default=False) 59 | no_load_model_pararmeters: bool = field(default=False) 60 | resume_from: str = field(default=None) 61 | 62 | resume_step: int = field(default=None) 63 | resume_batch_size: int = field(default=None) 64 | 65 | # output 66 | stream: bool = field(default=False) 67 | 68 | step_save_interval: int = field(default=1000) # save model every n steps. These will persist. 69 | 70 | external_validation: bool = field(default=False) 71 | 72 | @dataclass 73 | class PeftArguments: 74 | ## See https://github.com/huggingface/peft/blob/f0b066eae888d5dea598e756b7e6d3401d0708e7/src/peft/tuners/lora/config.py#L72 75 | ## for the default values of the fields (or define more or new defaults here). 76 | # some_field: str = field(default="default_value") 77 | target_modules: str = field(default="k_proj,down_proj,q_proj,v_proj,gate_proj,o_proj,up_proj") 78 | task_type: str = field(default="CAUSAL_LM") 79 | 80 | lora_r: int = field(default=16) 81 | lora_alpha: int = field(default=64) 82 | lora_dropout: float = field(default=0.1) 83 | bias: str = field(default="none") 84 | 85 | def train(): 86 | dist.init_process_group(backend='nccl', init_method='env://') 87 | rank = dist.get_rank() 88 | torch.cuda.set_device(rank) 89 | print("Rank", rank, "Current device", torch.cuda.current_device()) 90 | 91 | parser = HfArgumentParser((DataArguments, TrainingArguments, PeftArguments)) 92 | 93 | data_args, training_args, peft_args = parser.parse_args_into_dataclasses() 94 | 95 | training_args._frozen = False 96 | 97 | if not data_args.no_timestamps: 98 | timestr = datetime.now().strftime("-%m%d%H%M") 99 | training_args.output_dir = training_args.output_dir + timestr 100 | 101 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logging') 102 | 103 | if os.path.exists(training_args.output_dir): 104 | if training_args.overwrite_output_dir: 105 | print(f"Output directory ({training_args.output_dir}) already exists. Overwriting output dir.") 106 | if training_args.process_index == 0: 107 | shutil.rmtree(training_args.output_dir) 108 | else: 109 | print(f"Output directory ({training_args.output_dir}) already exists. Use --overwrite_output_dir to overcome.") 110 | 111 | if training_args.world_size > 1: 112 | dist.barrier(device_ids=[rank]) 113 | 114 | if training_args.process_index == 0: 115 | if not os.path.exists(training_args.output_dir): 116 | os.makedirs(training_args.output_dir) 117 | 118 | if training_args.world_size > 1: 119 | dist.barrier(device_ids=[rank]) 120 | 121 | set_seed(training_args.seed) 122 | 123 | node_rank = int(os.getenv('GROUP_RANK', '0')) 124 | 125 | for _logger in [logger, transformers.utils.logging.get_logger(), logging.getLogger('DeepSpeed')]: 126 | set_logger(_logger, training_args.local_rank, data_args.stream, os.path.join(training_args.output_dir, f'log-node-{node_rank}.log')) 127 | 128 | logger.warning("Device: %s, rank: %s, world size: %s", training_args.device, training_args.process_index, training_args.world_size) 129 | 130 | if training_args.world_size > 1: 131 | dist.barrier(device_ids=[rank]) 132 | 133 | print_args(data_args, 'Data Arguments') 134 | print_args(training_args, 'Training Arguments') 135 | print_args(peft_args, 'LoRA Arguments') 136 | 137 | config = AutoConfig.from_pretrained(data_args.model_cfg, trust_remote_code=True) 138 | config._attn_implementation = "flash_attention_2" 139 | config.use_cache = False 140 | 141 | # load base model 142 | base_model = AutoModelForCausalLM.from_pretrained(data_args.model_cfg, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True,) 143 | logger.info(base_model) 144 | base_model.config.use_cache = False 145 | ## load the same adapter twice. once in training mode and once in eval mode. 146 | model = PeftModel.from_pretrained( 147 | base_model, 148 | data_args.adapter_cfg, 149 | is_trainable=True, 150 | adapter_name='policy', 151 | ) 152 | # load the adapter a second time, with a different name, which will be our reference model. 153 | model.load_adapter( 154 | data_args.adapter_cfg, 155 | adapter_name="reference", 156 | ) 157 | 158 | tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, trust_remote_code=True) 159 | tokenizer.padding_side = 'right' 160 | logger.info(f"padding side: {tokenizer.padding_side}") 161 | 162 | if tokenizer.pad_token is None: 163 | tokenizer.pad_token = tokenizer.eos_token 164 | tokenizer.pad_token_id = tokenizer.eos_token_id 165 | 166 | 167 | print("Loading training data", data_args.train_parquet_file) 168 | if data_args.train_parquet_file is not None: 169 | train_sets = load_dataset("parquet", data_files=data_args.train_parquet_file, split='train') 170 | else: 171 | raise ValueError("Should provide either 'train_dataset' or 'train_file_config'") 172 | 173 | logger.info('Total %d case', len(train_sets)) 174 | 175 | with training_args.main_process_first(desc="Log a few random samples from the training set"): 176 | for index in random.sample(range(len(train_sets)), 3): 177 | ### DPOTrainer implements the tokenizer itself. no way to turn it off short of editing the source code. 178 | logger.info( 179 | "Sample %d of the raw training set:\n\nprompt_input_ids: %s\n\nprompt_tokens: %s\n\nchosen_input_ids: %s\n\nchosen_tokens: %s\n\nrejected_input_ids: %s\n\nrejected_tokens: %s\n\n", 180 | index, 181 | train_sets[index]['prompt_input_ids'], 182 | tokenizer.convert_ids_to_tokens(train_sets[index]['prompt_input_ids']), 183 | train_sets[index]['chosen_input_ids'], 184 | tokenizer.convert_ids_to_tokens(train_sets[index]['chosen_input_ids']), 185 | train_sets[index]['rejected_input_ids'], 186 | tokenizer.convert_ids_to_tokens(train_sets[index]['rejected_input_ids']), 187 | ) 188 | 189 | train_sets = train_sets.shuffle(seed=training_args.seed) 190 | 191 | epoch_checkpoint_callback = StepCheckpointCallback(save_interval=data_args.step_save_interval, output_dir=training_args.output_dir, external_validation=data_args.external_validation) 192 | 193 | data_collator = DataCollatorForDPODataset(tokenizer=tokenizer) 194 | 195 | target_modules_str = peft_args.target_modules 196 | # convert the comma-separated string to a list of strings 197 | if target_modules_str: 198 | target_modules_list = [module.strip() for module in target_modules_str.split(',')] 199 | else: 200 | target_modules_list = None 201 | 202 | lora_config = LoraConfig( 203 | target_modules=target_modules_list, 204 | task_type=peft_args.task_type, 205 | r=peft_args.lora_r, 206 | lora_alpha=peft_args.lora_alpha, 207 | lora_dropout=peft_args.lora_dropout, 208 | bias=peft_args.bias, 209 | ) 210 | 211 | print(lora_config) 212 | 213 | training_args_main_output_dir = training_args.output_dir 214 | training_args.output_dir = os.path.join(training_args.output_dir, f"backups") 215 | 216 | 217 | ### DPO Config initialized this way since even the defaults will override the transformer.TrainingArguments 218 | # Get the list of parameters that DPOConfig accepts 219 | dpo_config_params = inspect.signature(DPOConfig).parameters 220 | # Filter training_args to include only those parameters 221 | relevant_args = {k: v for k, v in vars(training_args).items() if k in dpo_config_params} 222 | # Add any additional parameters that are not in training_args 223 | relevant_args.update({ 224 | 'ref_adapter_name': "reference", 225 | 'model_adapter_name': "policy", 226 | 'beta': 0.1, 227 | 'loss_type': "sigmoid", 228 | 'is_encoder_decoder': False, 229 | 'rpo_alpha': 1.0, 230 | 'sync_ref_model': False, 231 | 'ref_model_sync_steps': 4, # only used if the sync_ref_module is set to True 232 | 'force_use_ref_model': False, 233 | }) 234 | 235 | 236 | dpo_config = DPOConfig(**relevant_args) 237 | logger.info(dpo_config) 238 | 239 | dpo_trainer = DPOTrainer( 240 | model=model, 241 | args=dpo_config, 242 | train_dataset=train_sets, 243 | tokenizer=tokenizer, 244 | data_collator=data_collator, 245 | callbacks=[LoggerCallback, epoch_checkpoint_callback], 246 | ) 247 | 248 | epoch_checkpoint_callback.set_trainer(dpo_trainer) 249 | 250 | dpo_trainer.train(resume_from_checkpoint=data_args.resume_from) 251 | 252 | dpo_trainer.save_model(os.path.join(training_args_main_output_dir, "final")) 253 | 254 | if __name__ == "__main__": 255 | 256 | try: 257 | train() 258 | except Exception as e: 259 | logging.exception(e) 260 | exit(-1) 261 | -------------------------------------------------------------------------------- /train/train_finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import random 5 | import shutil 6 | import logging 7 | import transformers 8 | import torch 9 | import json 10 | 11 | import numpy as np 12 | import torch.distributed as dist 13 | 14 | from datetime import datetime 15 | from dataclasses import field, dataclass 16 | from utils.util import set_logger, print_args 17 | 18 | from utils.loader import Processor 19 | from utils.trainer import LoggerCallback, RemoveStateCallback, StepCheckpointCallback, DataCollatorForSupervisedDataset 20 | 21 | from datasets import load_dataset, concatenate_datasets 22 | from transformers import ( 23 | Trainer, 24 | set_seed, 25 | AutoConfig, 26 | AutoTokenizer, 27 | HfArgumentParser, 28 | TrainingArguments, 29 | AutoModelForCausalLM, 30 | ) 31 | 32 | from peft import LoraConfig 33 | from trl import SFTTrainer 34 | 35 | logger = logging.getLogger() 36 | 37 | @dataclass 38 | class DataArguments: 39 | 40 | no_timestamps: bool = field(default=False) 41 | 42 | 43 | # data 44 | train_parquet_file: str = field(default=None) 45 | train_file_config: str = field(default=None) 46 | train_dataset: str = field(default=None) 47 | train_coef: str = field(default=None) 48 | delete_long_sample: bool = field(default=False) 49 | 50 | # process 51 | max_len: int = field(default=4096) 52 | preprocessing_num_workers: int = field(default=64) 53 | 54 | # model 55 | model_cfg: str = field(default="data/models/starcoder") 56 | flash_attention: bool = field(default=False) 57 | 58 | # LoRA model 59 | save_merged_model: bool = field(default=False) 60 | no_load_model_pararmeters: bool = field(default=False) 61 | resume_from: str = field(default=None) 62 | 63 | resume_step: int = field(default=None) 64 | resume_batch_size: int = field(default=None) 65 | 66 | # output 67 | stream: bool = field(default=False) 68 | 69 | step_save_interval: int = field(default=1000) # save model every n steps. These will persist. 70 | 71 | external_validation: bool = field(default=False) 72 | 73 | @dataclass 74 | class PeftArguments: 75 | ## See https://github.com/huggingface/peft/blob/f0b066eae888d5dea598e756b7e6d3401d0708e7/src/peft/tuners/lora/config.py#L72 76 | ## for the default values of the fields (or define more or new defaults here). 77 | # some_field: str = field(default="default_value") 78 | target_modules: str = field(default="k_proj,down_proj,q_proj,v_proj,gate_proj,o_proj,up_proj") 79 | task_type: str = field(default="CAUSAL_LM") 80 | 81 | lora_r: int = field(default=16) 82 | lora_alpha: int = field(default=64) 83 | lora_dropout: float = field(default=0.1) 84 | bias: str = field(default="none") 85 | 86 | def train(): 87 | dist.init_process_group(backend='nccl', init_method='env://') 88 | rank = dist.get_rank() 89 | torch.cuda.set_device(rank) 90 | print("Rank", rank, "Current device", torch.cuda.current_device()) 91 | 92 | parser = HfArgumentParser((DataArguments, TrainingArguments, PeftArguments)) 93 | 94 | data_args, training_args, peft_args = parser.parse_args_into_dataclasses() 95 | 96 | training_args._frozen = False 97 | 98 | if not data_args.no_timestamps: 99 | timestr = datetime.now().strftime("-%m%d%H%M") 100 | training_args.output_dir = training_args.output_dir + timestr 101 | 102 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logging') 103 | 104 | if os.path.exists(training_args.output_dir): 105 | if training_args.overwrite_output_dir: 106 | print(f"Output directory ({training_args.output_dir}) already exists. Overwriting output dir.") 107 | if training_args.process_index == 0: 108 | shutil.rmtree(training_args.output_dir) 109 | else: 110 | print(f"Output directory ({training_args.output_dir}) already exists. Use --overwrite_output_dir to overcome.") 111 | 112 | if training_args.world_size > 1: 113 | dist.barrier(device_ids=[rank]) 114 | 115 | if training_args.process_index == 0: 116 | if not os.path.exists(training_args.output_dir): 117 | os.makedirs(training_args.output_dir) 118 | 119 | if training_args.world_size > 1: 120 | dist.barrier(device_ids=[rank]) 121 | 122 | set_seed(training_args.seed) 123 | 124 | node_rank = int(os.getenv('GROUP_RANK', '0')) 125 | 126 | for _logger in [logger, transformers.utils.logging.get_logger(), logging.getLogger('DeepSpeed')]: 127 | set_logger(_logger, training_args.local_rank, data_args.stream, os.path.join(training_args.output_dir, f'log-node-{node_rank}.log')) 128 | 129 | logger.warning("Device: %s, rank: %s, world size: %s", training_args.device, training_args.process_index, training_args.world_size) 130 | 131 | if training_args.world_size > 1: 132 | dist.barrier(device_ids=[rank]) 133 | 134 | print_args(data_args, 'Data Arguments') 135 | print_args(training_args, 'Training Arguments') 136 | print_args(peft_args, 'LoRA Arguments') 137 | 138 | processor = Processor() 139 | 140 | config = AutoConfig.from_pretrained(data_args.model_cfg, trust_remote_code=True) 141 | config._attn_implementation = "flash_attention_2" 142 | config.use_cache = False 143 | 144 | base_model = AutoModelForCausalLM.from_pretrained(data_args.model_cfg, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True) 145 | logger.info(base_model) 146 | 147 | tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, trust_remote_code=True) 148 | tokenizer.padding_side = 'right' 149 | logger.info(f"padding side: {tokenizer.padding_side}") 150 | 151 | if tokenizer.pad_token is None: 152 | tokenizer.pad_token = tokenizer.eos_token 153 | tokenizer.pad_token_id = tokenizer.eos_token_id 154 | 155 | if data_args.train_parquet_file is not None: 156 | train_sets = load_dataset("parquet", data_files=data_args.train_parquet_file, split='train') 157 | elif data_args.train_file_config is not None: 158 | with open(data_args.train_file_config, "r") as f: 159 | train_files = json.load(f) 160 | 161 | train_sets = [] 162 | for file in train_files: 163 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train') 164 | train_sets.append(_dataset) 165 | 166 | lengths = np.array([_set.shape[0] for _set in train_sets]) 167 | logger.info("Data Lengths: %s", lengths) 168 | 169 | for i in range(1, len(train_sets)): 170 | train_sets[i] = train_sets[i].cast(train_sets[0].features) 171 | 172 | train_sets = concatenate_datasets(train_sets) 173 | else: 174 | raise ValueError("Should provide either 'train_dataset' or 'train_file_config'") 175 | 176 | logger.info('Total %d case', len(train_sets)) 177 | 178 | process_batch_size = min(1000, len(train_sets)) 179 | 180 | with training_args.main_process_first(desc="Log a few random samples from the training set"): 181 | for index in random.sample(range(len(train_sets)), 3): 182 | logger.info( 183 | "Sample %d of the raw training set:\n\ninput_tokens: %s\n\n%s\n\n", 184 | index, 185 | train_sets[index]['input_ids'], 186 | tokenizer.convert_ids_to_tokens(train_sets[index]['input_ids']), 187 | ) 188 | 189 | train_sets = train_sets.shuffle(seed=training_args.seed) 190 | column_names = list(train_sets.features) 191 | if data_args.train_parquet_file is None: 192 | with training_args.main_process_first(desc="dataset map grouping"): 193 | train_sets = train_sets.map( 194 | processor.group_texts, 195 | fn_kwargs={ 196 | "tokenizer": tokenizer, 197 | "max_len": data_args.max_len 198 | }, 199 | batched=True, 200 | load_from_cache_file=False, 201 | remove_columns=column_names, 202 | batch_size=process_batch_size, 203 | num_proc=data_args.preprocessing_num_workers, 204 | desc=f"Grouping texts in chunks of {data_args.max_len}", 205 | ) 206 | 207 | with training_args.main_process_first(desc="Log a few random samples from the grouped training set"): 208 | for index in random.sample(range(len(train_sets)), 3): 209 | logger.info( 210 | "Sample %d of the merged training set:\n\n%s", 211 | index, tokenizer.decode(train_sets[index]['input_ids']) 212 | ) 213 | 214 | if data_args.resume_step is not None and data_args.resume_batch_size is not None: 215 | train_sets = train_sets[data_args.resume_step * data_args.resume_batch_size:] 216 | training_args.max_steps -= data_args.resume_step 217 | new_warmup_steps = max(0, training_args.warmup_steps - data_args.resume_step) 218 | new_learning_rate = training_args.learning_rate 219 | new_learning_rate -= max(0, data_args.resume_step - training_args.warmup_steps) * (training_args.learning_rate / training_args.max_steps - training_args.warmup_steps) 220 | training_args.warmup_steps = new_warmup_steps 221 | training_args.learning_rate = new_learning_rate 222 | 223 | epoch_checkpoint_callback = StepCheckpointCallback(save_interval=data_args.step_save_interval, output_dir=training_args.output_dir, external_validation=data_args.external_validation) 224 | 225 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 226 | 227 | target_modules_str = peft_args.target_modules 228 | # convert the comma-separated string to a list of strings 229 | if target_modules_str: 230 | target_modules_list = [module.strip() for module in target_modules_str.split(',')] 231 | else: 232 | target_modules_list = None 233 | 234 | lora_config = LoraConfig( 235 | target_modules=target_modules_list, 236 | task_type=peft_args.task_type, 237 | r=peft_args.lora_r, 238 | lora_alpha=peft_args.lora_alpha, 239 | lora_dropout=peft_args.lora_dropout, 240 | bias=peft_args.bias, 241 | ) 242 | 243 | training_args_main_output_dir = training_args.output_dir 244 | training_args.output_dir = os.path.join(training_args.output_dir, f"backups") 245 | 246 | 247 | trainer = SFTTrainer( 248 | model=base_model, 249 | train_dataset=train_sets, 250 | max_seq_length=data_args.max_len, 251 | peft_config=lora_config, 252 | tokenizer=tokenizer, 253 | args=training_args, 254 | data_collator=data_collator, 255 | callbacks=[LoggerCallback, epoch_checkpoint_callback], 256 | # RemoveStateCallback can be used to save disk space but you may not be able to resume runs from ckpt 257 | ) 258 | 259 | epoch_checkpoint_callback.set_trainer(trainer) 260 | print(lora_config) 261 | 262 | trainer.train(resume_from_checkpoint=data_args.resume_from) 263 | 264 | trainer.save_model(os.path.join(training_args_main_output_dir, "final")) 265 | 266 | if __name__ == "__main__": 267 | 268 | try: 269 | train() 270 | except Exception as e: 271 | logging.exception(e) 272 | exit(-1) 273 | -------------------------------------------------------------------------------- /train/utils/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import re 4 | import torch 5 | import logging 6 | 7 | logger = logging.getLogger() 8 | 9 | IGNORE_INDEX = -100 10 | 11 | class Processor: 12 | 13 | def group_texts(self, examples, tokenizer, max_len): 14 | input_ids, labels = [], [] 15 | final_input_ids, final_labels = [], [] 16 | 17 | for idx in range(len(examples['input_ids'])): 18 | _input_ids = examples['input_ids'][idx] 19 | _labels = examples['input_ids'][idx] 20 | examples['input_ids'][idx] = None 21 | if len(_input_ids) > max_len: 22 | # if single sample longer than max_len, break into several 23 | devided_input_ids, devided_labels = [], [] 24 | for i in range(0, len(_input_ids), max_len): 25 | devided_input_ids = _input_ids[i: i + max_len] 26 | devided_labels = _labels[i: i + max_len] 27 | if len(devided_input_ids) < max_len: 28 | devided_pad_num = max_len - len(devided_input_ids) 29 | devided_input_ids += [tokenizer.pad_token_id] * devided_pad_num 30 | devided_labels += [IGNORE_INDEX] * devided_pad_num 31 | final_input_ids.append(devided_input_ids) 32 | final_labels.append(devided_labels) 33 | continue 34 | 35 | # if single sample shorter than max_len, combine together 36 | if len(input_ids) + len(_input_ids) > max_len: 37 | pad_num = max_len - len(input_ids) 38 | final_input_ids.append(input_ids + [tokenizer.pad_token_id] * pad_num) 39 | final_labels.append(labels + [IGNORE_INDEX] * pad_num) 40 | 41 | input_ids, labels = [], [] 42 | 43 | input_ids.extend(_input_ids) 44 | labels.extend(_labels) 45 | 46 | if len(input_ids) > 0: 47 | pad_num = max_len - len(input_ids) 48 | final_input_ids.append(input_ids + [tokenizer.pad_token_id] * pad_num) 49 | final_labels.append(labels + [IGNORE_INDEX] * pad_num) 50 | 51 | return { 52 | "input_ids": torch.tensor(final_input_ids).long(), 53 | "labels": torch.tensor(final_labels).long() 54 | } 55 | 56 | def process_tokenize(self, exmaples, tokenizer): 57 | """ 58 | tokenize samples and add bos and eos tokens 59 | """ 60 | inputs = tokenizer(exmaples['text'], truncation=False, padding=False) 61 | 62 | input_ids, labels = [], [] 63 | for input_id in inputs['input_ids']: 64 | if tokenizer.bos_token_id is not None: 65 | input_ids.append([tokenizer.bos_token_id] + input_id + [tokenizer.eos_token_id]) 66 | else: 67 | input_ids.append(input_id + [tokenizer.eos_token_id]) 68 | 69 | return { 70 | "input_ids": input_ids, 71 | } 72 | -------------------------------------------------------------------------------- /train/utils/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import glob 4 | import logging 5 | import datetime 6 | 7 | from transformers import TrainerCallback 8 | from transformers import TrainingArguments, TrainerState, TrainerControl 9 | import os 10 | 11 | import torch 12 | 13 | from typing import Dict, Sequence 14 | from dataclasses import dataclass 15 | import transformers 16 | 17 | 18 | IGNORE_INDEX = -100 19 | 20 | logger = logging.getLogger() 21 | 22 | class LoggerCallback(TrainerCallback): 23 | 24 | def on_train_begin(self, args, state, control, **kwargs): 25 | 26 | self.start_time = datetime.datetime.now() 27 | 28 | def on_log(self, args, state, control, logs=None, **kwargs): 29 | if not state.is_local_process_zero: 30 | return 31 | 32 | if 'loss' not in logs: 33 | return 34 | 35 | loss_msg = ' '.join(["%s: %s" % (k, v) for k, v in logs.items() if 'loss' in k]) 36 | now = datetime.datetime.now() 37 | pass_time = now - self.start_time 38 | rest_time = pass_time * (state.max_steps - state.global_step) / state.global_step 39 | eta = now + rest_time 40 | 41 | pt_min = pass_time.seconds // 60 42 | pass_time = '%.2d:%.2d' % (pt_min // 60 + pass_time.days * 24, pt_min % 60) 43 | 44 | rt_min = rest_time.seconds // 60 45 | rest_time = '%.2d:%.2d' % (rt_min // 60 + rest_time.days * 24, rt_min % 60) 46 | 47 | logger.info( 48 | 'step: %d epoch: %.2f %s lr: %.4g passed time: %s rest time: %s eta: %s', 49 | state.global_step, state.epoch, loss_msg, logs.get('learning_rate', 0), 50 | pass_time, rest_time, eta.strftime('%m/%d %H:%M') 51 | ) 52 | 53 | class RemoveStateCallback(TrainerCallback): 54 | 55 | def remove_state(self, args, step): 56 | step = int(step) 57 | 58 | if step <= 0: 59 | return 60 | 61 | step_dir = os.path.join(args.output_dir, f'checkpoint-{step}') 62 | logger.info('Remove state in %s', step_dir) 63 | 64 | remove_paths = [ 65 | os.path.join(step_dir, 'latest'), # deepspeed state 66 | os.path.join(step_dir, f'global_step{step}'), # deepspeed state 67 | os.path.join(step_dir, 'optimizer.pt'), # optimizer state 68 | os.path.join(step_dir, 'scheduler.pt'), # scheduler state 69 | os.path.join(step_dir, 'generation_config.json'), # generation config 70 | os.path.join(step_dir, 'trainer_state.json'), # trainer state 71 | os.path.join(step_dir, 'training_args.bin'), # training args 72 | os.path.join(step_dir, 'zero_to_fp32.py') 73 | ] 74 | 75 | remove_paths.extend(glob.glob(os.path.join(step_dir, 'rng_state_*.pth'))) # numpy random state 76 | 77 | for path in remove_paths: 78 | if os.path.exists(path): 79 | os.system('rm -rf %s' % path) 80 | 81 | def on_save(self, args, state, control, **kwargs): 82 | 83 | if not state.is_world_process_zero: 84 | return 85 | 86 | self.remove_state(args, state.global_step - state.save_steps) 87 | 88 | def on_train_end(self, args, state, control, **kwargs): 89 | 90 | if not state.is_world_process_zero: 91 | return 92 | 93 | self.remove_state(args, state.global_step) 94 | 95 | 96 | class StepCheckpointCallback(TrainerCallback): 97 | def __init__(self, save_interval, output_dir, external_validation): 98 | self.save_interval = save_interval 99 | self.output_dir = output_dir 100 | self.trainer = None 101 | self.external_validation = external_validation 102 | 103 | def set_trainer(self, trainer): 104 | self.trainer = trainer 105 | 106 | def save_model(self, checkpoint_dir, state: TrainerState): 107 | self.trainer.save_model(checkpoint_dir) 108 | print(f"Global step: {state.global_step} (Epoch {int(state.epoch)}). Saved checkpoint at {checkpoint_dir}") 109 | 110 | 111 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 112 | if state.global_step % self.save_interval == 0: 113 | checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-step-{state.global_step}") 114 | self.save_model(checkpoint_dir, state) 115 | control.should_save = True # save to resume later 116 | if self.external_validation: 117 | control.should_training_stop = True # signal to stop training 118 | return control 119 | 120 | 121 | 122 | @dataclass 123 | class DataCollatorForSupervisedDataset(object): 124 | """Collate examples for supervised fine-tuning. 125 | 126 | Will do right padding for input_ids and labels up to the longest sequence in the batch.""" 127 | 128 | tokenizer: transformers.PreTrainedTokenizer 129 | 130 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 131 | input_ids, labels = tuple([torch.tensor(instance[key]) for instance in instances] for key in ("input_ids", "labels")) 132 | input_ids = torch.nn.utils.rnn.pad_sequence( 133 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 134 | ) 135 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 136 | return dict( 137 | input_ids=input_ids, 138 | labels=labels, 139 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 140 | ) 141 | 142 | 143 | 144 | 145 | @dataclass 146 | class DataCollatorForDPODataset(object): 147 | """Collate examples for DPO fine-tuning. 148 | 149 | Will do right padding for input_ids and labels up to the longest sequence in the batch.""" 150 | 151 | tokenizer: transformers.PreTrainedTokenizer 152 | 153 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 154 | input_ids, chosen_ids, rejected_ids = tuple([torch.tensor(instance[key]) for instance in instances] for key in ("prompt_input_ids", "chosen_input_ids", "rejected_input_ids")) 155 | input_ids = torch.nn.utils.rnn.pad_sequence( 156 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 157 | ) 158 | chosen_ids = torch.nn.utils.rnn.pad_sequence( 159 | chosen_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 160 | ) 161 | rejected_ids = torch.nn.utils.rnn.pad_sequence( 162 | rejected_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 163 | ) 164 | return dict( 165 | prompt_input_ids=input_ids, 166 | prompt_attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 167 | chosen_input_ids=chosen_ids, 168 | chosen_attention_mask=chosen_ids.ne(self.tokenizer.pad_token_id), 169 | rejected_input_ids=rejected_ids, 170 | rejected_attention_mask=rejected_ids.ne(self.tokenizer.pad_token_id), 171 | ) 172 | -------------------------------------------------------------------------------- /train/utils/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger() 4 | 5 | def set_logger(_logger, local_rank, stream=True, log_file=None): 6 | _logger.handlers.clear() 7 | 8 | if local_rank in [-1, 0]: 9 | _logger.setLevel(logging.INFO) 10 | else: 11 | _logger.setLevel(logging.WARN) 12 | 13 | log_format = '[%(asctime)s] [Rank {} - %(levelname)s] [%(filename)s - %(lineno)d] %(message)s'.format(local_rank) 14 | log_format = logging.Formatter(log_format, '%Y-%m-%d %H:%M:%S') 15 | 16 | if stream: 17 | console = logging.StreamHandler() 18 | console.setFormatter(log_format) 19 | _logger.addHandler(console) 20 | 21 | if log_file is not None: 22 | 23 | file = logging.FileHandler(log_file, mode='a') 24 | file.setFormatter(log_format) 25 | _logger.addHandler(file) 26 | 27 | def print_args(args, name): 28 | max_len = max([len(k) for k in vars(args).keys()]) + 4 29 | logger.info(f"******************* {name} *******************") 30 | for key, val in sorted(vars(args).items()): 31 | keystr = "{}".format(key) + (" " * (max_len - len(key))) 32 | logger.info("%s --> %s", keystr, val) 33 | logger.info(f"******************* {name} *******************") 34 | -------------------------------------------------------------------------------- /train/validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import wandb 6 | import concurrent.futures 7 | 8 | current_dir = os.path.dirname(os.path.realpath(__file__)) 9 | merge_path = os.path.join(current_dir, "scripts") 10 | sys.path.append(merge_path) 11 | import merge_model 12 | 13 | eval_path = os.path.join(current_dir, "../test/PDEcontrol/evaluation/infer") 14 | sys.path.append(eval_path) 15 | import run_1d_pdecontrol_eval_full as run_1d_pdecontrol_eval 16 | 17 | 18 | 19 | def create_merge_args(args): 20 | class merge_Args: 21 | model_path = args.base_model 22 | output_dir = f"{args.checkpoint_dir}" 23 | adapter_path = args.checkpoint_dir 24 | gpus = args.cuda_visible_devices 25 | return merge_Args() 26 | 27 | 28 | def create_eval_args(args, data_type, p_format, validation_dir, save_directory, shots=2): 29 | class eval_Args: 30 | data_dir = validation_dir 31 | python_key = "python" 32 | stl_key = "sstl" 33 | nl_key = "nl" 34 | robustness_key = "robustness" 35 | max_num_examples = args.valid_num_examples 36 | save_dir = f"{save_directory}/{shots}_shots" 37 | model_name_or_path_coder = f"{args.checkpoint_dir}/merged_model" 38 | tokenizer_name_or_path_coder = args.base_model 39 | eval_batch_size = len(args.cuda_visible_devices) 40 | load_in_8bits = False 41 | gptq= False 42 | use_vllm = True 43 | load_in_half = False 44 | infer_on_train_set=True 45 | n_subsets=1 46 | subset_id=0 47 | temperature = 0 48 | repeat_id_start = 0 49 | n_repeat_sampling = 1 50 | prompt_format = "few_shot" 51 | few_shot_prompt = p_format 52 | prompt_dataset = data_type 53 | few_shot_number = shots 54 | answer_extraction_fn=None 55 | eval_perplexity=True 56 | eval_robustness = True 57 | eval_edit_distance = True 58 | eval_iou = True 59 | load_from_file=False 60 | skip_existing_scores=False 61 | gpus = args.cuda_visible_devices 62 | seed = None 63 | use_openai = False 64 | return eval_Args() 65 | 66 | 67 | def delete_merged_model(checkpoint_dir): 68 | os.system(f"rm -rf {checkpoint_dir}/merged_model") 69 | 70 | def log_to_wandb(results_dir, args): 71 | for directory in os.listdir(results_dir): 72 | eval_shots_dir = os.path.join(results_dir, directory) 73 | if os.path.isdir(eval_shots_dir): 74 | for eval_method_subdir in os.listdir(eval_shots_dir): 75 | eval_method_subdir_path = os.path.join(eval_shots_dir, eval_method_subdir) 76 | if os.path.isdir(eval_method_subdir_path): 77 | metrics_file = os.path.join(eval_method_subdir_path, f"metrics.{args.valid_num_examples}.json") 78 | if os.path.isfile(metrics_file): 79 | with open(metrics_file, 'r') as f: 80 | metrics = json.load(f) 81 | 82 | few_shot_prompt = metrics.get("few_shot_prompt", "") 83 | few_shot_number = metrics.get("few_shot_number", None) 84 | 85 | for metric, value in metrics.items(): 86 | if metric not in ["few_shot_prompt", "few_shot_number"]: 87 | wandb.log({ 88 | f"validation_{few_shot_prompt}_{few_shot_number}shots/{metric}": value 89 | }) 90 | 91 | 92 | def validate_model(args): 93 | wandb.init(project=args.project_name, resume=True, dir=args.wandb_dir) 94 | 95 | merge_args = create_merge_args(args) 96 | merge_model.main(merge_args) 97 | 98 | for dataset in os.listdir(args.validation_data_dir): 99 | dataset_path = os.path.join(args.validation_data_dir, dataset) 100 | if os.path.isdir(dataset_path): 101 | for shots in [0, 2]: 102 | for p_format in ["to_python_no_STL"]: 103 | if 'heat' in dataset: 104 | data_type = "CoTOneDHeat" 105 | elif 'wave' in dataset: 106 | data_type = "CoTOneDWave" 107 | 108 | save_dir = f"{args.checkpoint_dir}/validation" 109 | eval_args = create_eval_args(args, data_type, p_format, dataset_path, save_dir, shots) 110 | run_1d_pdecontrol_eval.main(eval_args) 111 | 112 | with concurrent.futures.ThreadPoolExecutor() as executor: 113 | future = executor.submit(run_1d_pdecontrol_eval.run_main, eval_args) 114 | try: 115 | future.result(timeout=900) # Set timeout in seconds 116 | except concurrent.futures.TimeoutError: 117 | print(f"Timeout occurred for dataset: {dataset}, shots: {shots}, p_format: {p_format}") 118 | 119 | 120 | 121 | delete_merged_model(args.checkpoint_dir) 122 | log_to_wandb(save_dir, args) 123 | 124 | 125 | if __name__ == "__main__": 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('--checkpoint_dir', type=str, required=True) 128 | parser.add_argument('--base_model', type=str, required=True) 129 | parser.add_argument('--cuda_visible_devices', type=str, required=True) 130 | parser.add_argument('--valid_num_examples', type=int, default=8) 131 | parser.add_argument('--validation_data_dir', type=str, required=True) 132 | parser.add_argument('--project_name', type=str, default="huggingface") 133 | parser.add_argument('--wandb_dir', type=str, required=True) 134 | 135 | 136 | args = parser.parse_args() 137 | 138 | # vllm may freeze on multi-gpu 139 | args.cuda_visible_devices = args.cuda_visible_devices.split(',')[0] 140 | 141 | validate_model(args) -------------------------------------------------------------------------------- /train_test_combined.yml: -------------------------------------------------------------------------------- 1 | name: trainenv 2 | channels: 3 | - conda-forge 4 | - https://repo.anaconda.com/pkgs/main 5 | - https://repo.anaconda.com/pkgs/r 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - bzip2=1.0.8=h4bc722e_7 10 | - ca-certificates=2024.8.30=hbcca054_0 11 | - ld_impl_linux-64=2.43=h712a8e2_1 12 | - libffi=3.4.2=h7f98852_5 13 | - libgcc=14.1.0=h77fa898_1 14 | - libgcc-ng=14.1.0=h69a702a_1 15 | - libgomp=14.1.0=h77fa898_1 16 | - libnsl=2.0.1=hd590300_0 17 | - libsqlite=3.46.1=hadc24fc_0 18 | - libuuid=2.38.1=h0b41bf4_0 19 | - libxcrypt=4.4.36=hd590300_1 20 | - libzlib=1.3.1=hb9d3cd8_2 21 | - ncurses=6.5=he02047a_1 22 | - openssl=3.3.2=hb9d3cd8_0 23 | - pip=24.2=pyh8b19718_1 24 | - python=3.10.15=h4a871b0_1_cpython 25 | - readline=8.2=h8228510_1 26 | - setuptools=75.1.0=pyhd8ed1ab_0 27 | - tk=8.6.13=noxft_h4845f30_101 28 | - wheel=0.44.0=pyhd8ed1ab_0 29 | - xz=5.2.6=h166bdaf_0 30 | - pip: 31 | - absl-py==2.0.0 32 | - accelerate==1.2.0 33 | - aiofiles==23.2.1 34 | - aiohttp==3.9.1 35 | - aiosignal==1.3.1 36 | - airportsdata==20241001 37 | - annotated-types==0.7.0 38 | - anyio==4.7.0 39 | - async-timeout==4.0.3 40 | - attrs==23.1.0 41 | - beautifulsoup4==4.12.3 42 | - bitarray==3.0.0 43 | - bs4==0.0.2 44 | - cachetools==5.3.2 45 | - certifi==2022.12.7 46 | - charset-normalizer==2.1.1 47 | - click==8.1.7 48 | - cloudpickle==3.1.0 49 | - compressed-tensors==0.6.0 50 | - datasets==3.2.0 51 | - deepspeed==0.15.4 52 | - dill==0.3.7 53 | - diskcache==5.6.3 54 | - distro==1.9.0 55 | - docker-pycreds==0.4.0 56 | - docstring-parser==0.16 57 | - editdistance==0.8.1 58 | - einops==0.7.0 59 | - exceptiongroup==1.2.2 60 | - fastapi==0.115.6 61 | - ffmpy==0.5.0 62 | - filelock==3.16.1 63 | - fire==0.5.0 64 | # - flash-attn==2.7.2.post1 65 | - frozenlist==1.4.0 66 | - fsspec==2023.10.0 67 | - gguf==0.10.0 68 | - gitdb==4.0.11 69 | - gitpython==3.1.43 70 | - google-auth==2.25.2 71 | - google-auth-oauthlib==1.0.0 72 | - gradio==5.9.1 73 | - gradio-client==1.5.2 74 | - grpcio==1.60.0 75 | - h11==0.14.0 76 | - hjson==3.1.0 77 | - httpcore==1.0.7 78 | - httptools==0.6.4 79 | - httpx==0.28.1 80 | - huggingface-hub==0.26.5 81 | - idna==3.4 82 | - importlib-metadata==8.5.0 83 | - interegular==0.3.3 84 | - jinja2==3.1.2 85 | - jiter==0.8.2 86 | - jsonlines==4.0.0 87 | - jsonschema==4.23.0 88 | - jsonschema-specifications==2024.10.1 89 | - lark==1.2.2 90 | - llvmlite==0.43.0 91 | - lm-format-enforcer==0.10.6 92 | - loader==2017.9.11 93 | - loaderr==1.0.0 94 | - markdown==3.5.1 95 | - markdown-it-py==3.0.0 96 | - markupsafe==2.1.3 97 | - mdurl==0.1.2 98 | - mistral-common==1.4.4 99 | - mpi4py==4.0.1 100 | - mpmath==1.3.0 101 | - msgpack==1.1.0 102 | - msgspec==0.18.6 103 | - multidict==6.0.4 104 | - multiprocess==0.70.15 105 | - nest-asyncio==1.6.0 106 | - networkx==3.0 107 | - ninja==1.11.1.1 108 | - numba==0.60.0 109 | - numpy==1.25.0 110 | - nvidia-cublas-cu12==12.1.3.1 111 | - nvidia-cuda-cupti-cu12==12.1.105 112 | - nvidia-cuda-nvrtc-cu12==12.1.105 113 | - nvidia-cuda-runtime-cu12==12.1.105 114 | - nvidia-cudnn-cu12==9.1.0.70 115 | - nvidia-cufft-cu12==11.0.2.54 116 | - nvidia-curand-cu12==10.3.2.106 117 | - nvidia-cusolver-cu12==11.4.5.107 118 | - nvidia-cusparse-cu12==12.1.0.106 119 | - nvidia-ml-py==12.560.30 120 | - nvidia-nccl-cu12==2.20.5 121 | - nvidia-nvjitlink-cu12==12.4.127 122 | - nvidia-nvtx-cu12==12.1.105 123 | - oauthlib==3.2.2 124 | - openai==1.58.1 125 | - orjson==3.10.12 126 | - outlines==0.0.43 127 | - outlines-core==0.1.26 128 | - packaging==23.2 129 | - pandas==2.1.4 130 | - partial-json-parser==0.2.1.1.post4 131 | - pebble==5.1.0 132 | - peft==0.14.0 133 | - pillow==10.3.0 134 | - platformdirs==4.2.2 135 | - prometheus-client==0.21.1 136 | - prometheus-fastapi-instrumentator==7.0.0 137 | - protobuf==4.25.1 138 | - psutil==5.9.6 139 | - py-cpuinfo==9.0.0 140 | - pyairports==2.1.1 141 | - pyarrow==18.1.0 142 | - pyarrow-hotfix==0.6 143 | - pyasn1==0.5.1 144 | - pyasn1-modules==0.3.0 145 | - pycountry==24.6.1 146 | - pydantic==2.10.3 147 | - pydantic-core==2.27.1 148 | - pydub==0.25.1 149 | - pygments==2.18.0 150 | - python-dateutil==2.8.2 151 | - python-dotenv==1.0.1 152 | - python-multipart==0.0.20 153 | - pytz==2023.3.post1 154 | - pyyaml==6.0.1 155 | - pyzmq==26.2.0 156 | - ray==2.40.0 157 | - referencing==0.35.1 158 | - regex==2023.10.3 159 | - requests==2.32.3 160 | - requests-oauthlib==1.3.1 161 | - rich==13.9.3 162 | - rpds-py==0.22.3 163 | - rsa==4.9 164 | - ruff==0.8.4 165 | - safehttpx==0.1.6 166 | - safetensors==0.4.5 167 | - semantic-version==2.10.0 168 | - sentencepiece==0.2.0 169 | - sentry-sdk==2.5.1 170 | - setproctitle==1.3.3 171 | - shellingham==1.5.4 172 | - shtab==1.7.1 173 | - six==1.16.0 174 | - smmap==5.0.1 175 | - sniffio==1.3.1 176 | - soupsieve==2.5 177 | - starlette==0.41.3 178 | - sympy==1.13.1 179 | - tensorboard==2.14.0 180 | - tensorboard-data-server==0.7.2 181 | - termcolor==2.4.0 182 | - tiktoken==0.7.0 183 | - tokenizers==0.20.3 184 | - tomlkit==0.13.2 185 | - torch==2.4.0 186 | - torchaudio==2.4.0 187 | - torchvision==0.19.0 188 | - tqdm==4.67.1 189 | - transformers==4.46.3 190 | - triton==3.0.0 191 | - trl==0.12.2 192 | - typer==0.15.1 193 | - typing-extensions==4.12.2 194 | - tyro==0.8.14 195 | - tzdata==2023.3 196 | - urllib3==1.26.13 197 | - uvicorn==0.32.1 198 | - uvloop==0.21.0 199 | - vllm==0.6.3.post1 200 | - wandb==0.17.1 201 | - watchfiles==1.0.0 202 | - websockets==14.1 203 | - werkzeug==3.0.1 204 | # - xformers==0.0.27.post2 205 | - xxhash==3.4.1 206 | - yarl==1.9.4 207 | - zipp==3.21.0 208 | prefix: miniconda3/envs/trainenv 209 | -------------------------------------------------------------------------------- /utils/few_shot_prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .cot_one_d_heat_fewshot import CoTOneDHeat 2 | from .cot_one_d_wave_fewshot import CoTOneDWave 3 | from .cot_one_d_combined_fewshot import CoTOneDCombined 4 | from .few_shot_train import FewShotTrain 5 | from .few_shot_train_dpo import FewShotDPO -------------------------------------------------------------------------------- /utils/few_shot_prompts/cot_one_d_combined_fewshot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_test import FewShotTest 2 | import json 3 | import os 4 | 5 | class CoTOneDCombined(FewShotTest): 6 | def __init__(self, num_shots, format, dataset="combined"): 7 | assert dataset in ["combined", "heat", "wave"] 8 | if dataset== "combined" and num_shots not in [0, 2]: 9 | raise ValueError(f"Number of shots must be 0 or 2 for dataset {dataset}") 10 | super().__init__(num_shots) 11 | self.format=format 12 | current_dir = os.path.dirname(os.path.abspath(__file__)) 13 | jsonl_file_path = os.path.join(current_dir, "examples", f"one_d_{dataset}", "examples.jsonl") 14 | self.examples = self.load_examples(jsonl_file_path, format) 15 | 16 | def load_examples(self, jsonl_file_path, format): 17 | examples = [] 18 | with open(jsonl_file_path, 'r') as file: 19 | for line in file: 20 | data = json.loads(line) 21 | nl = data['nl'] 22 | python = data['python'] 23 | sstl = data['sstl'] 24 | example = super().format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip()) 25 | examples.append(example) 26 | return examples 27 | 28 | def format_prompt(self, nl="", sstl="", python=""): 29 | few_shots = self._get_few_shot_prompt() 30 | curr_prompt = super().format_prompt(self.format, nl, sstl, python) 31 | return f"{few_shots}{curr_prompt}" -------------------------------------------------------------------------------- /utils/few_shot_prompts/cot_one_d_heat_fewshot.py: -------------------------------------------------------------------------------- 1 | from .few_shot_test import FewShotTest 2 | import json 3 | import os 4 | 5 | class CoTOneDHeat(FewShotTest): 6 | def __init__(self, num_shots, format, dataset="heat"): 7 | if dataset != "heat" and num_shots != 0: 8 | raise ValueError(f"Dataset {dataset} not supported. Only 'heat' dataset is supported") 9 | super().__init__(num_shots) 10 | self.format=format 11 | current_dir = os.path.dirname(os.path.abspath(__file__)) 12 | jsonl_file_path = os.path.join(current_dir, "examples", "one_d_heat", "examples.jsonl") 13 | self.examples = self.load_examples(jsonl_file_path, format) 14 | 15 | def load_examples(self, jsonl_file_path, format): 16 | examples = [] 17 | with open(jsonl_file_path, 'r') as file: 18 | for line in file: 19 | data = json.loads(line) 20 | nl = data['nl'] 21 | python = data['python'] 22 | sstl = data['sstl'] 23 | example = super().format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip()) 24 | examples.append(example) 25 | return examples 26 | 27 | def format_prompt(self, nl="", sstl="", python=""): 28 | few_shots = self._get_few_shot_prompt() 29 | curr_prompt = super().format_prompt(self.format, nl, sstl, python) 30 | return f"{few_shots}{curr_prompt}" -------------------------------------------------------------------------------- /utils/few_shot_prompts/cot_one_d_wave_fewshot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from .few_shot_test import FewShotTest 4 | 5 | class CoTOneDWave(FewShotTest): 6 | def __init__(self, num_shots, format, dataset="wave"): 7 | if dataset != "wave" and num_shots != 0: 8 | raise ValueError(f"Dataset {dataset} not supported. Only 'wave' dataset is supported") 9 | super().__init__(num_shots) 10 | self.format = format 11 | current_dir = os.path.dirname(os.path.abspath(__file__)) 12 | jsonl_file_path = os.path.join(current_dir, "examples", "one_d_wave", "examples.jsonl") 13 | self.examples = self.load_examples(jsonl_file_path, format) 14 | 15 | def load_examples(self, jsonl_file_path, format): 16 | examples = [] 17 | with open(jsonl_file_path, 'r') as file: 18 | for line in file: 19 | data = json.loads(line) 20 | nl = data['nl'] 21 | python = data['python'] 22 | sstl = data['sstl'] 23 | example = super().format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip()) 24 | examples.append(example) 25 | return examples 26 | 27 | def format_prompt(self, nl="", sstl="", python=""): 28 | few_shots = self._get_few_shot_prompt() 29 | curr_prompt = super().format_prompt(self.format, nl, sstl, python) 30 | return f"{few_shots}{curr_prompt}" -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/DPO_one_d_combined/dataset_description.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "few_shot_examples_one_d_wave", 3 | "seed_used": 0, 4 | "test_split": 0.1, 5 | "training_set_size": 2, 6 | "test_set_size": 0, 7 | "training_set_keys": [ 8 | "nl", 9 | "sstl", 10 | "python" 11 | ], 12 | "test_set_keys": [], 13 | "combined_files": [ 14 | "examples.jsonl" 15 | ] 16 | } -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/DPO_one_d_combined/examples.jsonl: -------------------------------------------------------------------------------- 1 | {"nl": "Consider a metallic rod of 100 mm. The temperature at one end of the rod is fixed at 300k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.At one specific instance within the interval 1.243045282863269 and 3.403832000478401, the temperature profile of the rod should exceed that of the linear expression mu0(x) = 0.12296790254167801 * x + 300.3119346197452 across the section extending from 13 to 27. Furthermore, within another time frame defined by 4.246509016670876 and 5.31524017104925, the temperature profile must be less than the linear expression mu1(x) = 0.47269239908880306 * x + 311.0471706338625 in the section between 43.0 and 61.0. Additionally, throughout the entire interval from 5.88805508229499 to 9.973974528713061, the temperature profile needs to be lower than the linear expression mu2(x) = 0.390105515041803 * x + 306.62859951293456 as it pertains to the segment between 75.0 and 93.0. We assume the rod is made of two different materials: the section from 30 to 60 mm is made of a material with parameters E_a = 1500000.0, rho_a = 4.5e-06 and c_a = 380000000.0, while the rest of the rod is made of a material with parameters E_b = 800000.0, rho_b = 4e-06 and c_b = 466000000.0. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 10.445529233891722 seconds. Assume a 30-element mesh is used.", "sstl": "G_[[0.39209616323405205, 0.6510075872323481]] (\\forall x \\in [12, 12] (u(x) - (0.11378382357293601 \\cdot x + 284.18670158160194) > 0)) \\land G_[[0.10999846031809801, 1.107322231186995]] (\\forall x \\in [48.0, 63.0] (u(x) - (-0.06923842950142 \\cdot x + 307.9263759927145) > 0)) \\land G_[[0.12370042271526001, 0.43663007356696704]] (\\forall x \\in [73.0, 94.0] (u(x) - (0.12971027690449002 \\cdot x + 286.1162890994643) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 100\nrho = lambda x: 4e-06*466000000.0 if x < 30 or x > 60 else 4.5e-06*380000000.0\nE = lambda x: 800000.0 if x < 30 or x > 60 else 1500000.0\nxpart = np.linspace(0, L, N + 1)\ng = [300, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 1.107322231186995\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([12, 12], \">\", lambda x: 0.11378382357293601 * x + 284.18670158160194, lambda x: 0.11378382357293601)\napc1 = logic.APCont([48.0, 63.0], \">\", lambda x: -0.06923842950142 * x + 307.9263759927145, lambda x: -0.06923842950142)\napc2 = logic.APCont([73.0, 94.0], \">\", lambda x: 0.12971027690449002 * x + 286.1162890994643, lambda x: 0.12971027690449002)\ncregions = {\"A\" : apc0, \"B\" : apc1, \"C\" : apc2}\ncspec = \"((G_[0.39209616323405205, 0.6510075872323481] (A)) & (G_[0.10999846031809801, 1.107322231186995] (B)) & (G_[0.12370042271526001, 0.43663007356696704] (C)))\""} 2 | {"nl": "Consider a steel and brass rod of length L = 100000 mm, the section between 30000 mm and 60000 mm being brass, with densities rho_steel = 8e-06, rho_brass = 8.5e-06 and Young's modulus E_steel = 200000000.0, E_brass = 100000000.0, with one end fixed and a time-variant force applied to the other end. This is a 1D elastic wave propagation problem. Consider the displacement of the rod as u(x).For the full span of time spanning 0.17743529415762402 to 0.30799846894522903, the rod must have its displacement compressed following the linear profile mu0(x) = 4.103648149081095e-05 * x + 2.595659429574442 in the area between sections 4888 and 24593. Similarly, from 0.40700869237930004 to 0.42879980260728906, the rod's displacement ought to be compressed according to the linear profile mu1(x) = 2.03250331324289e-05 * x + 1.686454607140632 across sections 31956.0 and 61131.0. Alternatively, within the period from 0.734575276782226 to 0.7784237515030701, there should be an instance where the rod's displacement is stretched in accordance with the linear profile mu2(x) = 7.539036123328967e-06 * x + 2.67167604686356, between sections 81192.0 and 92008.0. Assume that the discretized time interval is 0.0025s and the max time is 0.9783559046532111 seconds. Assume a 20-element mesh is used.", "sstl": "G_[[0.066415899408669, 0.172261358338638]] (\\forall x \\in [29936, 85026] (u(x) - (3.137321609165203e-06 \\cdot x + -1.526785067872162) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 100000\nrho = lambda x: 8e-06 if x < 30000 or x > 60000 else 8.5e-06\nE = lambda x: 200000000.0 if x < 30000 or x > 60000 else 100000000.0\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 0.172261358338638\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([29936, 85026], \"<\", lambda x: 3.137321609165203e-06 * x + -1.526785067872162, lambda x: 3.137321609165203e-06)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[0.066415899408669, 0.172261358338638] (A)))\""} -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/DPO_one_d_heat/dataset_description.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "few_shot_examples_one_d_heat", 3 | "seed_used": 0, 4 | "test_split": 0.1, 5 | "training_set_size": 2, 6 | "test_set_size": 0, 7 | "training_set_keys": [ 8 | "nl", 9 | "sstl", 10 | "python" 11 | ], 12 | "test_set_keys": [], 13 | "combined_files": [ 14 | "examples.jsonl" 15 | ] 16 | } -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/DPO_one_d_heat/examples.jsonl: -------------------------------------------------------------------------------- 1 | {"nl": "Consider a metallic rod of 100 mm. The temperature at one end of the rod is fixed at 300k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.At a specific moment within the time frame 1.821027314615614 and 3.523724320287923, the temperature distribution of the rod must exceed the linear profile mu0(x) = -0.13868715203829302 * x + 309.2162061347626 in the section spanning from 3 to 20. Additionally, throughout the entirety of the time interval 10.222409852153994 to 10.228356455872138, the temperature distribution should surpass the linear profile mu1(x) = -0.018962241717187002 * x + 299.2091654709691 for the section between 45.0 and 51.0. We assume the rod is made of two different materials: the section from 30 to 60 mm is made of a material with parameters E_a = 1500000.0, rho_a = 4.5e-06 and c_a = 380000000.0, while the rest of the rod is made of a material with parameters E_b = 800000.0, rho_b = 4e-06 and c_b = 466000000.0. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 14.198467092759284 seconds. Assume a 30-element mesh is used.", "sstl": "G_[[0.8608569523120131, 1.308637095866543]] (\\forall x \\in [14, 25] (u(x) - (0.30619062826944204 \\cdot x + 307.88718554535666) < 0)) \\land G_[[0.8498494273514371, 1.7936246423777131]] (\\forall x \\in [62.0, 100.0] (u(x) - (0.32396276508224203 \\cdot x + 309.74287180858954) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 100\nrho = lambda x: 4e-06*466000000.0 if x < 30 or x > 60 else 4.5e-06*380000000.0\nE = lambda x: 800000.0 if x < 30 or x > 60 else 1500000.0\nxpart = np.linspace(0, L, N + 1)\ng = [300, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 1.7936246423777131\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([14, 25], \"<\", lambda x: 0.30619062826944204 * x + 307.88718554535666, lambda x: 0.30619062826944204)\napc1 = logic.APCont([62.0, 100.0], \">\", lambda x: 0.32396276508224203 * x + 309.74287180858954, lambda x: 0.32396276508224203)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((G_[0.8608569523120131, 1.308637095866543] (A)) & (G_[0.8498494273514371, 1.7936246423777131] (B)))\""} 2 | {"nl": "Consider a metallic rod of 100 mm. The temperature at one end of the rod is fixed at 300k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.At one specific instance within the interval 1.243045282863269 and 3.403832000478401, the temperature profile of the rod should exceed that of the linear expression mu0(x) = 0.12296790254167801 * x + 300.3119346197452 across the section extending from 13 to 27. Furthermore, within another time frame defined by 4.246509016670876 and 5.31524017104925, the temperature profile must be less than the linear expression mu1(x) = 0.47269239908880306 * x + 311.0471706338625 in the section between 43.0 and 61.0. Additionally, throughout the entire interval from 5.88805508229499 to 9.973974528713061, the temperature profile needs to be lower than the linear expression mu2(x) = 0.390105515041803 * x + 306.62859951293456 as it pertains to the segment between 75.0 and 93.0. We assume the rod is made of two different materials: the section from 30 to 60 mm is made of a material with parameters E_a = 1500000.0, rho_a = 4.5e-06 and c_a = 380000000.0, while the rest of the rod is made of a material with parameters E_b = 800000.0, rho_b = 4e-06 and c_b = 466000000.0. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 10.445529233891722 seconds. Assume a 30-element mesh is used.", "sstl": "G_[[0.39209616323405205, 0.6510075872323481]] (\\forall x \\in [12, 12] (u(x) - (0.11378382357293601 \\cdot x + 284.18670158160194) > 0)) \\land G_[[0.10999846031809801, 1.107322231186995]] (\\forall x \\in [48.0, 63.0] (u(x) - (-0.06923842950142 \\cdot x + 307.9263759927145) > 0)) \\land G_[[0.12370042271526001, 0.43663007356696704]] (\\forall x \\in [73.0, 94.0] (u(x) - (0.12971027690449002 \\cdot x + 286.1162890994643) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 100\nrho = lambda x: 4e-06*466000000.0 if x < 30 or x > 60 else 4.5e-06*380000000.0\nE = lambda x: 800000.0 if x < 30 or x > 60 else 1500000.0\nxpart = np.linspace(0, L, N + 1)\ng = [300, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 1.107322231186995\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([12, 12], \">\", lambda x: 0.11378382357293601 * x + 284.18670158160194, lambda x: 0.11378382357293601)\napc1 = logic.APCont([48.0, 63.0], \">\", lambda x: -0.06923842950142 * x + 307.9263759927145, lambda x: -0.06923842950142)\napc2 = logic.APCont([73.0, 94.0], \">\", lambda x: 0.12971027690449002 * x + 286.1162890994643, lambda x: 0.12971027690449002)\ncregions = {\"A\" : apc0, \"B\" : apc1, \"C\" : apc2}\ncspec = \"((G_[0.39209616323405205, 0.6510075872323481] (A)) & (G_[0.10999846031809801, 1.107322231186995] (B)) & (G_[0.12370042271526001, 0.43663007356696704] (C)))\""} -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/DPO_one_d_wave/dataset_description.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "few_shot_examples_one_d_wave", 3 | "seed_used": 0, 4 | "test_split": 0.1, 5 | "training_set_size": 2, 6 | "test_set_size": 0, 7 | "training_set_keys": [ 8 | "nl", 9 | "sstl", 10 | "python" 11 | ], 12 | "test_set_keys": [], 13 | "combined_files": [ 14 | "examples.jsonl" 15 | ] 16 | } -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/DPO_one_d_wave/examples.jsonl: -------------------------------------------------------------------------------- 1 | {"nl": "Consider a steel and brass rod of length L = 100000 mm, the section between 30000 mm and 60000 mm being brass, with densities rho_steel = 8e-06, rho_brass = 8.5e-06 and Young's modulus E_steel = 200000000.0, E_brass = 100000000.0, with one end fixed and a time-variant force applied to the other end. This is a 1D elastic wave propagation problem. Consider the displacement of the rod as u(x).For the complete duration between the intervals 0.289483287576353 and 0.36810738390869, the rod's displacement should comply with the linear profile mu0(x) = 4.585477841781323e-05 * x + 2.331946581525069 for the sections ranging from 11861 to 34211. In contrast, at a specific point in time within the interval 0.6018777959138361 to 1.251736812496592, the displacement should reflect compression according to the linear profile mu1(x) = 9.629101387124293e-06 * x + -2.660932944224053 from section 73701.0 to 91138.0.Assume that the discretized time interval is 0.0025s and the max time is 1.348443583660752 seconds. Assume a 20-element mesh is used.", "sstl": "G_[[0.022889123990597, 0.23971012731697103]] (\\forall x \\in [10916, 42249] (u(x) - (-3.5012259820233114e-05 \\cdot x + -1.180840333310321) > 0)) \\land G_[[0.134370803878005, 0.223822164312479]] (\\forall x \\in [57461.0, 64690.0] (u(x) - (-3.293125603528961e-05 \\cdot x + -1.51293238845689) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 100000\nrho = lambda x: 8e-06 if x < 30000 or x > 60000 else 8.5e-06\nE = lambda x: 200000000.0 if x < 30000 or x > 60000 else 100000000.0\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 0.23971012731697103\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([10916, 42249], \">\", lambda x: -3.5012259820233114e-05 * x + -1.180840333310321, lambda x: -3.5012259820233114e-05)\napc1 = logic.APCont([57461.0, 64690.0], \"<\", lambda x: -3.293125603528961e-05 * x + -1.51293238845689, lambda x: -3.293125603528961e-05)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((G_[0.022889123990597, 0.23971012731697103] (A)) & (G_[0.134370803878005, 0.223822164312479] (B)))\""} 2 | {"nl": "Consider a steel and brass rod of length L = 100000 mm, the section between 30000 mm and 60000 mm being brass, with densities rho_steel = 8e-06, rho_brass = 8.5e-06 and Young's modulus E_steel = 200000000.0, E_brass = 100000000.0, with one end fixed and a time-variant force applied to the other end. This is a 1D elastic wave propagation problem. Consider the displacement of the rod as u(x).For the full span of time spanning 0.17743529415762402 to 0.30799846894522903, the rod must have its displacement compressed following the linear profile mu0(x) = 4.103648149081095e-05 * x + 2.595659429574442 in the area between sections 4888 and 24593. Similarly, from 0.40700869237930004 to 0.42879980260728906, the rod's displacement ought to be compressed according to the linear profile mu1(x) = 2.03250331324289e-05 * x + 1.686454607140632 across sections 31956.0 and 61131.0. Alternatively, within the period from 0.734575276782226 to 0.7784237515030701, there should be an instance where the rod's displacement is stretched in accordance with the linear profile mu2(x) = 7.539036123328967e-06 * x + 2.67167604686356, between sections 81192.0 and 92008.0. Assume that the discretized time interval is 0.0025s and the max time is 0.9783559046532111 seconds. Assume a 20-element mesh is used.", "sstl": "G_[[0.066415899408669, 0.172261358338638]] (\\forall x \\in [29936, 85026] (u(x) - (3.137321609165203e-06 \\cdot x + -1.526785067872162) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 100000\nrho = lambda x: 8e-06 if x < 30000 or x > 60000 else 8.5e-06\nE = lambda x: 200000000.0 if x < 30000 or x > 60000 else 100000000.0\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 0.172261358338638\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([29936, 85026], \"<\", lambda x: 3.137321609165203e-06 * x + -1.526785067872162, lambda x: 3.137321609165203e-06)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[0.066415899408669, 0.172261358338638] (A)))\""} 3 | -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/one_d_combined/dataset_description.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "few_shot_examples_one_d_combined", 3 | "seed_used": 0, 4 | "test_split": 0.1, 5 | "training_set_size": 2, 6 | "test_set_size": 0, 7 | "training_set_keys": [ 8 | "nl", 9 | "sstl", 10 | "python" 11 | ], 12 | "test_set_keys": [], 13 | "combined_files": [ 14 | "examples.jsonl" 15 | ] 16 | } -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/one_d_combined/examples.jsonl: -------------------------------------------------------------------------------- 1 | {"nl": "Consider a metallic rod with a maximum length of 112 mm, where the temperature at one extremity is held at 321k, and the opposite extremity is exposed to a heat source. The temperature profile of the rod is described by the 1D linear heat equation.Throughout the duration from 1.8288 to 4.6769, there exists a point in time where the temperature distribution of the rod should be greater than the linear profile mu0(x) = 0.0771 * x + 326.154 applicable between the sections 5 and 97.The rod is presumed to be fabricated from two varieties of materials: from 3 to 49 mm, a material with parameters E_a = 1682393, \rrho_a = 5.952e-06, and c_a = 438533237 is utilized, while the remainder of the rod features a material with parameters E_b = 410042, \rrho_b = 3.977e-06, and c_b = 470729859. We define the temperature at position x as u(x). We will consider a discretized time interval of 0.05 seconds and a total time of 8 seconds, employing a 30-element mesh.", "sstl": "F_[[1.8288, 4.6769]] (\\forall x \\in [5, 97] (u(x) - (0.0771 \\cdot x + 326.154) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 112\nrho = lambda x: 3.977e-06*470729859 if x < 3 or x > 49 else 5.952e-06*438533237\nE = lambda x: 410042 if x < 3 or x > 49 else 1682393\nxpart = np.linspace(0, L, N + 1)\ng = [321, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 8\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([5, 97], \">\", lambda x: 0.0771 * x + 326.154, lambda x: 0.0771)\ncregions = {\"A\" : apc0}\ncspec = \"((F_[1.8288, 4.6769] (A)))\""} 2 | {"nl": "Consider a rod composed of steel and brass with a length of L = 79651 mm, where the brass section is located between 33634 mm and 43799 mm. The densities are defined as \rrho_steel = 7.927e-06 and \rrho_brass = 8.452e-06, and the Young's moduli are E_steel = 222786951 and E_brass = 102749268. One end is held in place, and a time-dependent force is applied to the other end. This setup is focused on a 1D elastic wave propagation challenge. Let u(x) denote the displacement of the rod.At a given time in the interval from 0.0541 to 0.2621, the rod's displacement must match the linear profile mu0(x) = -1.4897e-05 * x + -1.7281 throughout the range of sections 8330 and 30692. In addition, throughout the interval from 0.2845 to 0.8982, the rod's displacement is expected to be compressed to fit the linear profile mu1(x) = 1.029e-06 * x + -0.3131 across the sections 56782 and 69640.We will assume that the time interval is discretized at 0.0025s, with the maximum time of 1.9424 seconds, using a mesh that contains 20 elements.", "sstl": "F_[[0.0541, 0.2621]] (\\forall x \\in [8330, 30692] (u(x) - (-1.4897e-05 \\cdot x + -1.7281) = 0)) \\land G_[[0.2845, 0.8982]] (\\forall x \\in [56782, 69640] (u(x) - (1.029e-06 \\cdot x + -0.3131) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 79651\nrho = lambda x: 7.927e-06 if x < 33634 or x > 43799 else 8.452e-06\nE = lambda x: 222786951 if x < 33634 or x > 43799 else 102749268\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.9424\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([8330, 30692], \"=\", lambda x: -1.4897e-05 * x + -1.7281, lambda x: -1.4897e-05)\napc1 = logic.APCont([56782, 69640], \"<\", lambda x: 1.029e-06 * x + -0.3131, lambda x: 1.029e-06)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((F_[0.0541, 0.2621] (A)) & (G_[0.2845, 0.8982] (B)))\""} -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/one_d_heat/dataset_description.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "few_shot_examples_one_d_heat", 3 | "seed_used": 0, 4 | "test_split": 0.1, 5 | "training_set_size": 3, 6 | "test_set_size": 0, 7 | "training_set_keys": [ 8 | "nl", 9 | "sstl", 10 | "python" 11 | ], 12 | "test_set_keys": [], 13 | "combined_files": [ 14 | "examples.jsonl" 15 | ] 16 | } -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/one_d_heat/examples.jsonl: -------------------------------------------------------------------------------- 1 | {"nl": "Consider a metallic rod with a maximum length of 112 mm, where the temperature at one extremity is held at 321k, and the opposite extremity is exposed to a heat source. The temperature profile of the rod is described by the 1D linear heat equation.Throughout the duration from 1.8288 to 4.6769, there exists a point in time where the temperature distribution of the rod should be greater than the linear profile mu0(x) = 0.0771 * x + 326.154 applicable between the sections 5 and 97.The rod is presumed to be fabricated from two varieties of materials: from 3 to 49 mm, a material with parameters E_a = 1682393, \rrho_a = 5.952e-06, and c_a = 438533237 is utilized, while the remainder of the rod features a material with parameters E_b = 410042, \rrho_b = 3.977e-06, and c_b = 470729859. We define the temperature at position x as u(x). We will consider a discretized time interval of 0.05 seconds and a total time of 8 seconds, employing a 30-element mesh.", "sstl": "F_[[1.8288, 4.6769]] (\\forall x \\in [5, 97] (u(x) - (0.0771 \\cdot x + 326.154) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 112\nrho = lambda x: 3.977e-06*470729859 if x < 3 or x > 49 else 5.952e-06*438533237\nE = lambda x: 410042 if x < 3 or x > 49 else 1682393\nxpart = np.linspace(0, L, N + 1)\ng = [321, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 8\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([5, 97], \">\", lambda x: 0.0771 * x + 326.154, lambda x: 0.0771)\ncregions = {\"A\" : apc0}\ncspec = \"((F_[1.8288, 4.6769] (A)))\""} 2 | {"nl": "Consider a metallic rod of 180 mm. The temperature at one end of the rod is fixed at 281k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.For one point during the time interval 0.2591 and 2.7813, the temperature distribution of the rod should be the same as the linear profile mu0(x) = 0.3167 * x + 263.3785 between section 19 and 27. Moreover, for all time between the time interval 5.536 and 7.2884, the temperature distribution of the rod should be lower than the linear profile mu1(x) = -0.0214 * x + 265.8454 between section 132 and 145.We assume the rod is made of two different materials: the section from 137 to 159 mm is made of a material with parameters E_a = 1310187, rho_a = 5.456e-06 and c_a = 408307119, while the rest of the rod is made of a material with parameters E_b = 1194686, rho_b = 5.687e-06 and c_b = 476443964. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 11 seconds. Assume a 30-element mesh is used.", "sstl": "F_[[0.2591, 2.7813]] (\\forall x \\in [19, 27] (u(x) - (0.3167 \\cdot x + 263.3785) = 0)) \\land G_[[5.536, 7.2884]] (\\forall x \\in [132, 145] (u(x) - (-0.0214 \\cdot x + 265.8454) < 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 180\nrho = lambda x: 5.687e-06*476443964 if x < 137 or x > 159 else 5.456e-06*408307119\nE = lambda x: 1194686 if x < 137 or x > 159 else 1310187\nxpart = np.linspace(0, L, N + 1)\ng = [281, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 11\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([19, 27], \"=\", lambda x: 0.3167 * x + 263.3785, lambda x: 0.3167)\napc1 = logic.APCont([132, 145], \"<\", lambda x: -0.0214 * x + 265.8454, lambda x: -0.0214)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((F_[0.2591, 2.7813] (A)) & (G_[5.536, 7.2884] (B)))\""} 3 | {"nl": "Imagine a metallic rod with a maximum length of 112 mm, where one end is kept at a stable temperature of 321k, while the other end is subject to a heat source. The temperature within the rod obeys the 1D linear heat equation.During the entire duration from time 1.8288 to 4.6769, the temperature distribution along the rod must match the linear profile mu0(x) = 0.0771 * x + 326.154 in the segment between 5 and 97.We are assuming the rod comprises two types of materials: the segment from 3 to 49 mm is surrounding a material defined by parameters E_a = 1682393, \rrho_a = 5.952e-06, and c_a = 438533237, whereas the remaining length of the rod is fabricated from a material characterized by parameters E_b = 410042, \rrho_b = 3.977e-06, and c_b = 470729859. The temperature at position x will be denoted as u(x). The discretized time interval is set at 0.05 seconds, with a maximum duration of 8 seconds, while a mesh with 30 elements has been adopted.", "sstl": "G_[[1.8288, 4.6769]] (\\forall x \\in [5, 97] (u(x) - (0.0771 \\cdot x + 326.154) = 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 112\nrho = lambda x: 3.977e-06*470729859 if x < 3 or x > 49 else 5.952e-06*438533237\nE = lambda x: 410042 if x < 3 or x > 49 else 1682393\nxpart = np.linspace(0, L, N + 1)\ng = [321, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 8\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([5, 97], \"=\", lambda x: 0.0771 * x + 326.154, lambda x: 0.0771)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[1.8288, 4.6769] (A)))\""} -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/one_d_wave/dataset_description.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "few_shot_examples_one_d_wave", 3 | "seed_used": 0, 4 | "test_split": 0.1, 5 | "training_set_size": 3, 6 | "test_set_size": 0, 7 | "training_set_keys": [ 8 | "nl", 9 | "sstl", 10 | "python" 11 | ], 12 | "test_set_keys": [], 13 | "combined_files": [ 14 | "examples.jsonl" 15 | ] 16 | } -------------------------------------------------------------------------------- /utils/few_shot_prompts/examples/one_d_wave/examples.jsonl: -------------------------------------------------------------------------------- 1 | {"nl": "Let us examine a rod made of steel and brass, measuring L = 76182 mm in length, where the segment between 19212 mm and 48319 mm consists of brass. The densities are given as \rrho_steel = 7.628e-06 and \rrho_brass = 8.473e-06, with Young's moduli provided as E_steel = 225415054 and E_brass = 179787202. One end of the rod is fixed, while a force that varies with time is applied to the opposite end. This presents a 1D problem regarding the propagation of elastic waves. Denote the displacement of the rod as u(x).Throughout a particular moment in the duration from 0.09 to 0.192, the rod's displacement ought to correspond to the linear profile mu0(x) = -4.692e-05 * x + 1.3255 across the regions defined by 32712 and 42454.Assume the time discretization is 0.0025 seconds, and that the maximum time is 1.5266 seconds, with a 20-element mesh employed for this analysis.", "sstl": "F_[[0.09, 0.192]] (\\forall x \\in [32712, 42454] (u(x) - (-4.692e-05 \\cdot x + 1.3255) > 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 76182\nrho = lambda x: 7.628e-06 if x < 19212 or x > 48319 else 8.473e-06\nE = lambda x: 225415054 if x < 19212 or x > 48319 else 179787202\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.5266\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([32712, 42454], \">\", lambda x: -4.692e-05 * x + 1.3255, lambda x: -4.692e-05)\ncregions = {\"A\" : apc0}\ncspec = \"((F_[0.09, 0.192] (A)))\""} 2 | {"nl": "Consider a rod composed of steel and brass with a length of L = 79651 mm, where the brass section is located between 33634 mm and 43799 mm. The densities are defined as \rrho_steel = 7.927e-06 and \rrho_brass = 8.452e-06, and the Young's moduli are E_steel = 222786951 and E_brass = 102749268. One end is held in place, and a time-dependent force is applied to the other end. This setup is focused on a 1D elastic wave propagation challenge. Let u(x) denote the displacement of the rod.At a given time in the interval from 0.0541 to 0.2621, the rod's displacement must match the linear profile mu0(x) = -1.4897e-05 * x + -1.7281 throughout the range of sections 8330 and 30692. In addition, throughout the interval from 0.2845 to 0.8982, the rod's displacement is expected to be compressed to fit the linear profile mu1(x) = 1.029e-06 * x + -0.3131 across the sections 56782 and 69640.We will assume that the time interval is discretized at 0.0025s, with the maximum time of 1.9424 seconds, using a mesh that contains 20 elements.", "sstl": "F_[[0.0541, 0.2621]] (\\forall x \\in [8330, 30692] (u(x) - (-1.4897e-05 \\cdot x + -1.7281) = 0)) \\land G_[[0.2845, 0.8982]] (\\forall x \\in [56782, 69640] (u(x) - (1.029e-06 \\cdot x + -0.3131) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 79651\nrho = lambda x: 7.927e-06 if x < 33634 or x > 43799 else 8.452e-06\nE = lambda x: 222786951 if x < 33634 or x > 43799 else 102749268\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.9424\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([8330, 30692], \"=\", lambda x: -1.4897e-05 * x + -1.7281, lambda x: -1.4897e-05)\napc1 = logic.APCont([56782, 69640], \"<\", lambda x: 1.029e-06 * x + -0.3131, lambda x: 1.029e-06)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((F_[0.0541, 0.2621] (A)) & (G_[0.2845, 0.8982] (B)))\""} 3 | {"nl": "Consider a rod composed of steel and brass with a length of L = 76182 mm, where the brass section is located between 19212 mm and 48319 mm. The densities are defined as \rrho_steel = 7.628e-06 and \rrho_brass = 8.473e-06, and the Young's moduli are E_steel = 225415054 and E_brass = 179787202. One end is held in place, and a time-dependent force is applied to the other end. This setup is focused on a 1D elastic wave propagation challenge. Let u(x) denote the displacement of the rod.Across the time interval from 0.09 to 0.192, the displacement of the rod should conform to the linear profile mu0(x) = -4.692e-05 * x + 1.3255 over the range specified by 32712 and 42454.We will assume that the time interval is discretized at 0.0025s, with the maximum time of 1.5266 seconds, using a mesh that contains 20 elements.", "sstl": "G_[[0.09, 0.192]] (\\forall x \\in [32712, 42454] (u(x) - (-4.692e-05 \\cdot x + 1.3255) = 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 76182\nrho = lambda x: 7.628e-06 if x < 19212 or x > 48319 else 8.473e-06\nE = lambda x: 225415054 if x < 19212 or x > 48319 else 179787202\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.5266\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([32712, 42454], \"=\", lambda x: -4.692e-05 * x + 1.3255, lambda x: -4.692e-05)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[0.09, 0.192] (A)))\""} -------------------------------------------------------------------------------- /utils/few_shot_prompts/few_shot_prompting.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | 4 | 5 | class FewShotPrompting: 6 | def __init__(self, num_shots): 7 | self.num_shots = num_shots 8 | if num_shots > 3: 9 | raise ValueError("Only supports 0 up to 3 shots.") 10 | pass 11 | 12 | def get_alpaca_format(self, instruction, task_input, task_output="", wrap_in_code_block=None) -> str: 13 | if wrap_in_code_block == 'python': 14 | prompt = f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n```python\n{task_output}""" 15 | if task_output != "": 16 | prompt += "\n```\n\n" 17 | return prompt 18 | elif wrap_in_code_block == 'latex': 19 | prompt = f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n```latex\n{task_output}""" 20 | if task_output != "": 21 | prompt += "\n```\n\n" 22 | return prompt 23 | else: 24 | if task_output != "": 25 | return f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n{task_output}\n\n""" 26 | else: 27 | return f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n""" 28 | 29 | 30 | 31 | def _get_few_shot_prompt(self): 32 | if hasattr(self, 'examples'): 33 | if hasattr(self, 'shuffle'): 34 | if self.shuffle: 35 | temp_examples = copy.copy(self.examples) 36 | random.shuffle(temp_examples) 37 | return "".join(temp_examples[:self.num_shots]) 38 | prompt_examples = self.examples[:self.num_shots] 39 | else: 40 | raise AttributeError("Subclasses of FewShotPrompting must define 'self.examples'") 41 | return "".join(prompt_examples) 42 | 43 | def get_intruction(self, format): 44 | if format == "nl_to_python": 45 | return """Below is a natural language description of partial differential equation optimization problem. Translate the problem into Python code following spatial-signal temporal logic.""" 46 | elif format == "nl_to_sstl": 47 | return """Below is a natural language description of partial differential equation optimization problem. Translate the problem into Latex code following spatial-signal temporal logic.""" 48 | elif format == "train_nl_and_sstl_to_python": 49 | return """Below is a natural language description of partial differential equation optimization problem, paired with a spatial-signal temporal logic description of the same problem. Translate the problem into Python code following spatial-signal temporal logic.""" 50 | elif format == "test_nl_to_python_with_sstl_cot": 51 | return """Below is a natural language description of partial differential equation optimization problem. Translate the problem into Python code following spatial-signal temporal logic. Explain your reasoning by first providing spatial signal temporal logic statement in Latex. Let's think step by step.""" 52 | elif format == "test_nl_with_given_sstl_to_python" or format == "train_nl_with_given_sstl_to_python": 53 | return """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of the same problem provided earlier. Note that there may be mistakes in the spatial-signal temporal logic statement but the natural language description is accurate. Translate the problem into Python code following spatial-signal temporal logic.""" 54 | elif format == "dpo_train_nl_to_sstl": 55 | return """Below is a natural language description of partial differential equation optimization problem. Instead of optimizing the provided problem directly, we want to optimize an intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Generate a spatial-signal temporal logic description in Latex code for such an intermediate problem.""" 56 | elif format == "dpo_test_sstl_to_python": 57 | return """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Instead of optimizing the natural language problem directly, we want to optimize the intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Your spatial-signal temporal logic description in latex paired to the original problem describes this intermediate problem. Translate the intermediate problem into Python code following spatial-signal temporal logic.""" 58 | # return """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Translate the intermediate problem into Python code following spatial-signal temporal logic.""" 59 | 60 | # """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Instead of optimizing the provided problem directly, we want to optimize an intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Your spatial-signal temporal logic description paired to the original problem describes this intermediate problem. Translate the intermediate problem into Python code following spatial-signal temporal logic.""" 61 | 62 | # try the prompt below if the above one doesn't work. 63 | 64 | # """Below is a natural language description of a partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Instead of optimizing the provided problem directly, we want to optimize this intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Use both the natural language description and the LaTeX code of the intermediate problem to translate the intermediate problem into Python code following spatial-signal temporal logic.""" 65 | else: 66 | raise ValueError(f"Invalid format: {format}") -------------------------------------------------------------------------------- /utils/few_shot_prompts/few_shot_test.py: -------------------------------------------------------------------------------- 1 | from .few_shot_prompting import FewShotPrompting 2 | 3 | class FewShotTest(FewShotPrompting): 4 | def __init__(self, num_shots) -> None: 5 | super().__init__(num_shots=num_shots) 6 | 7 | def format_prompt(self, format, nl, sstl="", python=""): 8 | intruction = self.get_intruction(format) 9 | nl = nl.strip() 10 | sstl = sstl.strip() 11 | python = python.strip() 12 | if format == "nl_to_python": 13 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=python, wrap_in_code_block='python') 14 | return prompt 15 | elif format == "test_nl_to_python_with_sstl_cot": 16 | # Don't wrap sstl in code block because the output is custom formatted 17 | if sstl != "" and python != "": 18 | # few shot prompt 19 | task_output = f"""Spatial Signal Temporal Logic:\n```latex\n{sstl}\n```\n\nPython:\n```python\n{python}\n```""" 20 | else: 21 | # actual test case 22 | task_output = "" 23 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=task_output) 24 | return prompt 25 | elif format == "nl_to_sstl": 26 | 27 | task_output = f"""Spatial Signal Temporal Logic:\n```latex\n{sstl}""" 28 | if sstl != "": 29 | # few shot prompt 30 | task_output += """\n```""" 31 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=task_output, wrap_in_code_block=None) 32 | return prompt 33 | elif format == "test_nl_with_given_sstl_to_python": 34 | task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```""" 35 | prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python') 36 | return prompt 37 | 38 | 39 | def stop_words(self): 40 | return ["\n### Instruction:", "### Instruction:"] 41 | 42 | 43 | -------------------------------------------------------------------------------- /utils/few_shot_prompts/few_shot_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from .few_shot_prompting import FewShotPrompting 4 | 5 | 6 | class FewShotTrain(FewShotPrompting): 7 | def __init__(self, num_shots=0, format=None, dataset=None) -> None: 8 | super().__init__(num_shots=num_shots) 9 | self.examples = [] 10 | 11 | if num_shots > 0: 12 | # this is only for few-shot training evaluation to get the predicted sstl to construct the second training set of the 2 part training process. 13 | assert format is not None, "Please provide format for few-shot training evaluation" 14 | assert dataset is not None, "Please provide dataset for few-shot training evaluation" 15 | self.format=format 16 | self.shuffle = True 17 | current_dir = os.path.dirname(os.path.abspath(__file__)) 18 | jsonl_file_path = os.path.join(current_dir, "examples", f"one_d_{dataset}", "examples.jsonl") 19 | self.examples = self.load_examples(jsonl_file_path, format) 20 | 21 | def format_prompt(self, format, nl, sstl="", python=""): 22 | intruction = self.get_intruction(format) 23 | nl = nl.strip() 24 | sstl = sstl.strip() 25 | python = python.strip() 26 | if format == "nl_to_python": 27 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=python, wrap_in_code_block='python') 28 | return prompt 29 | elif format == "nl_to_sstl": 30 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=sstl, wrap_in_code_block='latex') 31 | return prompt 32 | elif format == "train_nl_and_sstl_to_python" or format == "train_nl_with_given_sstl_to_python": 33 | task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```""" 34 | prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python') 35 | return prompt 36 | 37 | def format_prompt_test(self, nl, sstl="", python=""): 38 | return self.format_prompt(self.format, nl.strip(), sstl.strip(), python.strip()) 39 | 40 | def load_examples(self, jsonl_file_path, format): 41 | examples = [] 42 | with open(jsonl_file_path, 'r') as file: 43 | for line in file: 44 | data = json.loads(line) 45 | nl = data['nl'] 46 | python = data['python'] 47 | sstl = data['sstl'] 48 | example = self.format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip()) 49 | examples.append(example) 50 | return examples 51 | 52 | def stop_words(self): 53 | return ["\n### Instruction:", "### Instruction:"] -------------------------------------------------------------------------------- /utils/few_shot_prompts/few_shot_train_dpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from .few_shot_prompting import FewShotPrompting 4 | 5 | 6 | class FewShotDPO(FewShotPrompting): 7 | def __init__(self, num_shots=0, format=None, dataset=None) -> None: 8 | super().__init__(num_shots=num_shots) 9 | self.examples = [] 10 | self.format=format 11 | 12 | if num_shots > 0: 13 | # this is only for few-shot training evaluation to get the predicted sstl to construct the second training set of the 2 part training process. 14 | assert format is not None, "Please provide format for few-shot training evaluation" 15 | assert dataset is not None, "Please provide dataset for few-shot training evaluation" 16 | self.format=format 17 | self.shuffle = True 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) 19 | jsonl_file_path = os.path.join(current_dir, "examples", f"DPO_one_d_{dataset}", "examples.jsonl") 20 | self.examples = self.load_examples(jsonl_file_path, format) 21 | 22 | def format_prompt(self, format, nl, sstl="", python=""): 23 | intruction = self.get_intruction(format) 24 | nl = nl.strip() 25 | sstl = sstl.strip() 26 | python = python.strip() 27 | # if format == "nl_to_python": 28 | # prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=python, wrap_in_code_block='python') 29 | # return prompt 30 | if format == "dpo_train_nl_to_sstl": 31 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=sstl, wrap_in_code_block='latex') 32 | return prompt 33 | elif format == "dpo_test_sstl_to_python": 34 | task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```""" 35 | prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python') 36 | return prompt 37 | # elif format == "train_nl_and_sstl_to_python" or format == "train_nl_with_given_sstl_to_python": 38 | # task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```""" 39 | # prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python') 40 | # return prompt 41 | 42 | def format_prompt_test(self, nl, sstl="", python=""): 43 | few_shots = self._get_few_shot_prompt() 44 | curr_prompt = self.format_prompt(self.format, nl.strip(), sstl.strip(), python.strip()) 45 | return f"{few_shots}{curr_prompt}" 46 | # return self.format_prompt(self.format, nl.strip(), sstl.strip(), python.strip()) 47 | 48 | def load_examples(self, jsonl_file_path, format): 49 | examples = [] 50 | with open(jsonl_file_path, 'r') as file: 51 | for line in file: 52 | data = json.loads(line) 53 | nl = data['nl'] 54 | python = data['python'] 55 | sstl = data['sstl'] 56 | example = self.format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip()) 57 | examples.append(example) 58 | return examples 59 | 60 | def stop_words(self): 61 | return ["\n### Instruction:", "### Instruction:"] --------------------------------------------------------------------------------