├── .gitignore
├── README.md
├── environment.sh
├── figures
├── case_study.png
├── logo.png
└── pipeline.png
├── requirements.txt
├── run_dpo_train.sh
├── run_testing.sh
├── run_train_testing.sh
├── run_training.sh
├── test
├── PDEcontrol
│ ├── .gitignore
│ ├── cog.yaml
│ └── evaluation
│ │ ├── data_processing
│ │ ├── __init__.py
│ │ ├── answer_extraction.py
│ │ └── process_utils.py
│ │ ├── eval
│ │ ├── __init__.py
│ │ ├── eval_robustness_reasoning_wrapper.py
│ │ ├── eval_robustness_wrapper.py
│ │ ├── eval_script.py
│ │ └── utils.py
│ │ ├── infer
│ │ └── simulate_gt.py
│ │ └── scripts
│ │ └── infer_pdecontrol.sh
├── README.md
├── requirements.txt
└── scripts
│ ├── read_result.py
│ ├── simulate_gt.sh
│ └── test_pdecontrol.sh
├── train
├── README.md
├── config
│ ├── deepspeed.json
│ └── deepspeed_dpo.json
├── scripts
│ ├── group_text.py
│ ├── group_text_dpo.py
│ ├── merge_model.py
│ ├── tokenize_data.py
│ ├── tokenize_data_dpo.py
│ ├── train.sh
│ ├── train_dpo.sh
│ └── utils
│ │ ├── loader.py
│ │ ├── trainer.py
│ │ └── util.py
├── train.py
├── train_dpo.py
├── train_finetune.py
├── utils
│ ├── loader.py
│ ├── trainer.py
│ └── util.py
└── validate.py
├── train_test_combined.yml
└── utils
└── few_shot_prompts
├── __init__.py
├── cot_one_d_combined_fewshot.py
├── cot_one_d_heat_fewshot.py
├── cot_one_d_wave_fewshot.py
├── examples
├── DPO_one_d_combined
│ ├── dataset_description.json
│ └── examples.jsonl
├── DPO_one_d_heat
│ ├── dataset_description.json
│ └── examples.jsonl
├── DPO_one_d_wave
│ ├── dataset_description.json
│ └── examples.jsonl
├── one_d_combined
│ ├── dataset_description.json
│ └── examples.jsonl
├── one_d_heat
│ ├── dataset_description.json
│ └── examples.jsonl
└── one_d_wave
│ ├── dataset_description.json
│ └── examples.jsonl
├── few_shot_prompting.py
├── few_shot_test.py
├── few_shot_train.py
└── few_shot_train_dpo.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | */__pycache__/
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111 | .pdm.toml
112 | .pdm-python
113 | .pdm-build/
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
165 | # WandB
166 | wandb/
167 |
168 | # extra outputs:
169 | train/scripts/file-names_to-be-tokenized.json
170 | train/scripts/file_names_to-be-trained.json
171 |
172 | # Femformal
173 |
174 | nusmv/
175 | doc/**/*.pdf
176 | .DS_Store
177 | *.pyc
178 | *.lp
179 | .srclist
180 | .tmuxp.yaml
181 | tags
182 | *.egg-info
183 | out_vars.txt
184 | temp_plots/
185 | doc/source/femformal/stubs
186 | doc/build
187 | out.ilp
188 | out.txt
189 |
190 |
191 | # robustness stuff from delta:
192 | *test/PDEcontrol/evaluation/eval/more_heat_robust_.npy
193 | *test/PDEcontrol/evaluation/eval/more_heat_time_.npy
194 |
195 | *more_heat_robust_.npy
196 | *more_heat_time_.npy
197 |
198 | *control/femformal/examples/heat_mix/*eval_llm*.py
199 | *control/femformal/examples/mech_mix/*eval_llm*.py
200 | *control/femformal/examples/heat_mix/*eval_llm*.pyc
201 | *control/femformal/examples/mech_mix/*eval_llm*.pyc
202 |
203 |
204 | # unprocessesed data because they may be too large
205 | data_processing/*
206 |
207 | datasets/
208 | models/
209 | outputs/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PDE-Controller
2 |
3 | This repository contains the implementation for [PDE-Controller: LLMs for Autoformalization and Reasoning of PDEs](http://arxiv.org/abs/2502.00963)
4 |
5 |
6 |
7 |

8 |
9 |
10 |
11 | 🤗 Hugging Face Datasets   |    📑 Paper    |   📖 Project Page
12 |
13 |
14 | ## Dataset and Models
15 | Our datasets, are released at [Hugging-face](https://huggingface.co/datasets/delta-lab-ai/pde-controller).
16 |
17 |
19 |
20 |
21 |
22 | The PDE-Controller models are as follows:
23 |
24 | |Model Name|Huggingface Link|
25 | |:--|:--|
26 | |Translator|🤗 [link](https://huggingface.co/delta-lab-ai/translator)|
27 | |Coder|🤗 [link](https://huggingface.co/delta-lab-ai/coder)|
28 | |Controller|🤗 [link](https://huggingface.co/delta-lab-ai/controller)|
29 | |Fine-tuned Coder for Controller|🤗 [link](https://huggingface.co/delta-lab-ai/finetuned_coder)|
30 |
31 |
32 | ## Introduction
33 |
34 | We present [PDE-Controller](http://arxiv.org/abs/2502.00963), a framework that enables large language models (LLMs) to control systems governed by partial differential equations (PDEs). Traditional LLMs have excelled in commonsense reasoning but fall short in rigorous logical reasoning. While recent AI-for-math has made strides in pure mathematics, areas of applied mathematics, particularly PDEs, remain underexplored despite their significant real-world applications. Our approach enables LLMs to transform informal natural language instructions into formal specifications, and then execute reasoning and planning steps to improve the utility of PDE control. We build a holistic solution comprising datasets (both human-written cases and 2 million synthetic samples), math-reasoning models, and novel evaluation metrics, all of which require significant effort. Our PDE-Controller significantly outperforms the latest open-source and GPT models in reasoning, autoformalization, and program synthesis, achieving up to a 62% improvement in utility gain for PDE control. By bridging the gap between language generation and PDE systems, we demonstrate the potential of LLMs in addressing complex scientific and engineering challenges.
35 |
36 |
37 |
38 |

39 |
40 |
41 | ## Case Study
42 |
43 | An example of LLM reasoning for PDE control on heat (top) and wave (bottom) problems.
44 |
45 |
46 |
47 |
48 |
49 |

50 |
51 |
52 | ## Installation
53 | For training only the `trainenv` environment is required. For testing the `trainenv` environment will work to obtain model outputs. In order to simulate utilities with the Gurobi optimizer and the Femformal repository in Python 2, an additional environment, `pdecontrol`, is required. Only `trainenv` need to be activated to run the code, activating `pdecontrol` will be done by the [main test script](./test/PDEcontrol/evaluation/infer/run_1d_pdecontrol_eval_full.py)
54 |
55 | ## Training and Testing in python 3
56 | Create a conda environment:
57 |
58 | ```shell
59 | conda create -n trainenv python=3.10
60 | ```
61 |
62 | Activate the environment:
63 |
64 | ```shell
65 | conda activate trainenv
66 | ```
67 |
68 | Run the commands in `environment.sh` to install the required packages.
69 |
70 |
71 |
77 |
78 |
83 |
84 | ## Training
85 |
86 | The documentation for training is at: [training](train/README.md)
87 |
88 | ## Testing
89 |
90 | The documentation for testing is at: [evaluation](test/README.md).
91 |
92 | ## Citation
93 |
94 | If you find this repository helpful, please consider citing our paper:
95 |
96 | ```
97 | @article{soroco2025pdecontroller,
98 | title={PDE-Controller: LLMs for Autoformalization and Reasoning of PDEs},
99 | author={Soroco, Mauricio and Song, Jialin and Xia, Mengzhou and Emond, Kye and Sun, Weiran and Chen, Wuyang},
100 | journal={arXiv preprint arXiv:2502.00963},
101 | year={2025}
102 | }
103 | ```
--------------------------------------------------------------------------------
/environment.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | pip install torch==2.4.0
6 | pip install torchaudio==2.4.0
7 |
8 | pip install xformers==0.0.27.post2
9 |
10 | # conda env update -f train_test_combined.yml
11 | pip install -r requirements.txt
12 |
13 | pip install gradio==5.9.1
14 |
15 | pip install flash-attn --no-build-isolation
16 |
--------------------------------------------------------------------------------
/figures/case_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/figures/case_study.png
--------------------------------------------------------------------------------
/figures/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/figures/logo.png
--------------------------------------------------------------------------------
/figures/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/figures/pipeline.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.0.0
2 | accelerate==1.2.0
3 | aiofiles==23.2.1
4 | aiohttp==3.9.1
5 | aiosignal==1.3.1
6 | airportsdata==20241001
7 | annotated-types==0.7.0
8 | anyio==4.7.0
9 | async-timeout==4.0.3
10 | attrs==23.1.0
11 | beautifulsoup4==4.12.3
12 | bitarray==3.0.0
13 | bs4==0.0.2
14 | cachetools==5.3.2
15 | certifi==2022.12.7
16 | charset-normalizer==2.1.1
17 | click==8.1.7
18 | cloudpickle==3.1.0
19 | compressed-tensors==0.6.0
20 | datasets==3.2.0
21 | deepspeed==0.15.4
22 | dill==0.3.7
23 | diskcache==5.6.3
24 | distro==1.9.0
25 | docker-pycreds==0.4.0
26 | docstring_parser==0.16
27 | editdistance==0.8.1
28 | einops==0.7.0
29 | exceptiongroup==1.2.2
30 | fastapi==0.115.6
31 | ffmpy==0.5.0
32 | filelock==3.16.1
33 | fire==0.5.0
34 | # flash-attn==2.7.2.post1
35 | frozenlist==1.4.0
36 | fsspec==2023.10.0
37 | gguf==0.10.0
38 | gitdb==4.0.11
39 | GitPython==3.1.43
40 | google-auth==2.25.2
41 | google-auth-oauthlib==1.0.0
42 | gradio==5.9.1
43 | gradio_client==1.5.2
44 | grpcio==1.60.0
45 | h11==0.14.0
46 | hjson==3.1.0
47 | httpcore==1.0.7
48 | httptools==0.6.4
49 | httpx==0.28.1
50 | huggingface-hub==0.26.5
51 | idna==3.4
52 | importlib_metadata==8.5.0
53 | interegular==0.3.3
54 | Jinja2==3.1.2
55 | jiter==0.8.2
56 | jsonlines==4.0.0
57 | jsonschema==4.23.0
58 | jsonschema-specifications==2024.10.1
59 | lark==1.2.2
60 | llvmlite==0.43.0
61 | lm-format-enforcer==0.10.6
62 | loader==2017.9.11
63 | loaderr==1.0.0
64 | Markdown==3.5.1
65 | markdown-it-py==3.0.0
66 | MarkupSafe==2.1.3
67 | mdurl==0.1.2
68 | mistral_common==1.4.4
69 | mpi4py==4.0.1
70 | mpmath==1.3.0
71 | msgpack==1.1.0
72 | msgspec==0.18.6
73 | multidict==6.0.4
74 | multiprocess==0.70.15
75 | nest-asyncio==1.6.0
76 | networkx==3.0
77 | ninja==1.11.1.1
78 | numba==0.60.0
79 | numpy==1.25.0
80 | nvidia-cublas-cu12==12.1.3.1
81 | nvidia-cuda-cupti-cu12==12.1.105
82 | nvidia-cuda-nvrtc-cu12==12.1.105
83 | nvidia-cuda-runtime-cu12==12.1.105
84 | nvidia-cudnn-cu12==9.1.0.70
85 | nvidia-cufft-cu12==11.0.2.54
86 | nvidia-curand-cu12==10.3.2.106
87 | nvidia-cusolver-cu12==11.4.5.107
88 | nvidia-cusparse-cu12==12.1.0.106
89 | nvidia-ml-py==12.560.30
90 | nvidia-nccl-cu12==2.20.5
91 | nvidia-nvjitlink-cu12==12.4.127
92 | nvidia-nvtx-cu12==12.1.105
93 | oauthlib==3.2.2
94 | openai==1.58.1
95 | opencv-python-headless==4.10.0.84
96 | orjson==3.10.12
97 | outlines==0.0.43
98 | outlines_core==0.1.26
99 | packaging==23.2
100 | pandas==2.1.4
101 | partial-json-parser==0.2.1.1.post4
102 | Pebble==5.1.0
103 | peft==0.14.0
104 | pillow==10.3.0
105 | platformdirs==4.2.2
106 | prometheus-fastapi-instrumentator==7.0.0
107 | prometheus_client==0.21.1
108 | protobuf==4.25.1
109 | psutil==5.9.6
110 | py-cpuinfo==9.0.0
111 | pyairports==2.1.1
112 | pyarrow==18.1.0
113 | pyarrow-hotfix==0.6
114 | pyasn1==0.5.1
115 | pyasn1-modules==0.3.0
116 | pycountry==24.6.1
117 | pydantic==2.10.3
118 | pydantic_core==2.27.1
119 | pydub==0.25.1
120 | Pygments==2.18.0
121 | python-dateutil==2.8.2
122 | python-dotenv==1.0.1
123 | python-multipart==0.0.20
124 | pytz==2023.3.post1
125 | PyYAML==6.0.1
126 | pyzmq==26.2.0
127 | ray==2.40.0
128 | referencing==0.35.1
129 | regex==2023.10.3
130 | requests==2.32.3
131 | requests-oauthlib==1.3.1
132 | rich==13.9.3
133 | rpds-py==0.22.3
134 | rsa==4.9
135 | ruff==0.8.4
136 | safehttpx==0.1.6
137 | safetensors==0.4.5
138 | semantic-version==2.10.0
139 | sentencepiece==0.2.0
140 | sentry-sdk==2.5.1
141 | setproctitle==1.3.3
142 | shellingham==1.5.4
143 | shtab==1.7.1
144 | six==1.16.0
145 | smmap==5.0.1
146 | sniffio==1.3.1
147 | soupsieve==2.5
148 | starlette==0.41.3
149 | sympy==1.13.1
150 | tensorboard==2.14.0
151 | tensorboard-data-server==0.7.2
152 | termcolor==2.4.0
153 | tiktoken==0.7.0
154 | tokenizers==0.20.3
155 | tomlkit==0.13.2
156 | torch==2.4.0
157 | torchaudio==2.4.0
158 | torchvision==0.19.0
159 | tqdm==4.67.1
160 | transformers==4.46.3
161 | triton==3.0.0
162 | trl==0.12.2
163 | typer==0.15.1
164 | typing_extensions==4.12.2
165 | tyro==0.8.14
166 | tzdata==2023.3
167 | urllib3==1.26.13
168 | uvicorn==0.32.1
169 | uvloop==0.21.0
170 | vllm==0.6.3.post1
171 | wandb==0.17.1
172 | watchfiles==1.0.0
173 | websockets==14.1
174 | Werkzeug==3.0.1
175 | # xformers==0.0.27.post2
176 | xxhash==3.4.1
177 | yarl==1.9.4
178 | zipp==3.21.0
179 |
--------------------------------------------------------------------------------
/run_dpo_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # DeepSeekMath-7B mothcoder2 version
5 | MODEL_PATH=models/base_model/MathCoder2-DeepSeekMath-7B
6 | OUTPUT_DIR=outputs/dpo
7 |
8 | ##### model parameters #####
9 | context_length=4096
10 |
11 | ##### Data parameters #####
12 | input_file_paths=(
13 | datasets/unprocessed/dpo/dpo_d0_train
14 | datasets/unprocessed/dpo/dpo_d1_train
15 | datasets/unprocessed/dpo/dpo_d2_train
16 | )
17 | out_file_path=datasets/dpo
18 |
19 |
20 |
21 | python train/scripts/tokenize_data_dpo.py --paths "${input_file_paths[@]}" --out_file_path ${out_file_path}/train --model_path $MODEL_PATH
22 | python train/scripts/group_text.py --max_len $context_length --out_file_path ${out_file_path}/train --model_path $MODEL_PATH --no_grouping --no_padding --balance 1 --dpo
23 |
24 |
25 | echo Starting training...
26 |
27 | bash train/scripts/train_dpo.sh
--------------------------------------------------------------------------------
/run_testing.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | eval "$(conda shell.bash hook)"
4 |
5 | PROJ_DIR=/localhome/mms43/scratch/mathcoder2/MathCoder2
6 | MODELS="/localhome/mms43/scratch/mathcoder2/outputs"
7 | BASE_MODEL_PATH=/localhome/mms43/scratch/mathcoder2/model_ckpts/MathCoder2-DeepSeekMath-7B
8 | OUTPUT_DIR=/localhome/mms43/scratch/mathcoder2/outputs_new
9 |
10 | ########
11 | DIR_TRANSLATOR=$MODELS/30233_ds_to_STL_s16800/checkpoint-step-3000/merged_model
12 | # DIR_CODER=$MODELS/30236_ds_to_python_given_STL_s16800/checkpoint-step-3000/merged_model ### coder (model to evaluate)
13 | # DIR_CODER=$BASE_MODEL_PATH ### MathCoder2-DeepSeekMath-7B
14 | # DIR_CODER=$MODELS/30234_ds_to_python_no_STL_s16800/checkpoint-step-6000/merged_model ### trained baseline
15 | # DIR_CODER==$MODELS/30235_ds_to_python_GT_STL_s16800/checkpoint-step-3000/merged_model ### oracle
16 | DIR_CODER=$MODELS/10136_ds_to_python_misaligned_s16800/checkpoint-step-1500/merged_model ### finetuned coder (model to evaluate)
17 | DIR_CONTROLLER=$MODELS/10138_ds_DPO__s16800/checkpoint-step-16000/policy/merged_model ### controller
18 |
19 |
20 |
21 |
22 |
23 | gpus="1"
24 |
25 |
26 | use_openai=
27 | # use_openai='gpt-4o'
28 | # use_openai='o1-mini'
29 |
30 |
31 | cd $PROJ_DIR
32 |
33 |
34 | if [ -n "$use_openai" ]; then
35 | EVAL_DIR=$OUTPUT_DIR/$use_openai/eval
36 | else
37 | EVAL_DIR=${OUTPUT_DIR}/eval
38 | fi
39 |
40 |
41 | bash test/scripts/test_pdecontrol.sh $EVAL_DIR $gpus $DIR_TRANSLATOR $DIR_CODER $DIR_CONTROLLER $use_openai
42 |
--------------------------------------------------------------------------------
/run_train_testing.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | dataset_list=$1
4 | prompt_format=$2
5 | few_shot_number=$3
6 | checkpoint_dir=$4
7 | out_file_path=$5
8 | max_samples=$6
9 |
10 | # Split the dataset_list into an array
11 | IFS=' ' read -r -a datasets <<< "$dataset_list"
12 |
13 |
14 | for dataset in "${datasets[@]}"
15 | do
16 | CMD="CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false python test/PDEcontrol/evaluation/infer/run_1d_pdecontrol_eval_train.py \
17 | --data_dir $dataset \
18 | --save_dir $out_file_path \
19 | --use_vllm \
20 | --model_name_or_path $checkpoint_dir \
21 | --tokenizer_name_or_path $checkpoint_dir \
22 | --eval_batch_size 1 \
23 | --temperature 0.0 \
24 | --prompt_format few_shot \
25 | --few_shot_number $few_shot_number \
26 | --few_shot_prompt $prompt_format"
27 |
28 | if [[ $dataset == *"heat"* ]]; then
29 | CMD="$CMD --prompt_dataset heat"
30 | elif [[ $dataset == *"wave"* ]]; then
31 | CMD="$CMD --prompt_dataset wave"
32 | fi
33 |
34 | if [ $max_samples -gt 0 ]; then
35 | CMD="$CMD --max_num_examples $max_samples"
36 | fi
37 |
38 | echo $CMD
39 | eval $CMD
40 | done
41 |
42 |
43 |
--------------------------------------------------------------------------------
/run_training.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python --version
3 |
4 | SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
5 |
6 |
7 | ### --------- Set Parameters --------- ###
8 |
9 |
10 | ##### model parameters #####
11 |
12 | tokenizer_path=models/base_model/MathCoder2-DeepSeekMath-7B
13 | context_length=4096
14 |
15 | ##### Data parameters #####
16 | input_file_paths=(
17 | datasets/unprocessed/sft/heat_nc1_train
18 | datasets/unprocessed/sft/heat_nc2_train
19 | datasets/unprocessed/sft/heat_nc3_train
20 | datasets/unprocessed/sft/wave_nc1_train
21 | datasets/unprocessed/sft/wave_nc2_train
22 | datasets/unprocessed/sft/wave_nc3_train
23 | )
24 |
25 | # ["to_python_no_STL", "to_STL", "to_python_GT_STL", "to_python_given_STL"]
26 | prompt_format=to_STL
27 |
28 | out_file_path=datasets/sft/${prompt_format}
29 |
30 | ### To set for "to_python_given_STL" ####
31 | max_samples=-1 # -1 to test all datapoints.
32 | TRANSLATOR_DIR=models/translator
33 |
34 |
35 | ######################################
36 | ### --------- Data processing --------- ###
37 | ######################################
38 |
39 |
40 | if [ "$prompt_format" = "to_python_given_STL" ]; then
41 | ### --------- prompt the model for its predictions and add this to the training data. --------- ###
42 |
43 |
44 | ### --------- First step: train the model on the original data. --------- ###
45 | prompt_format=to_STL
46 |
47 | few_shot_number=2
48 |
49 | # join the input file paths into a single string
50 | input_file_paths_str=$(IFS=" "; echo "${input_file_paths[*]}")
51 |
52 | predictions_output_dir=$out_file_path/train_eval
53 | # obtain the model's predictions for the second step
54 | "$SCRIPT_DIR/run_train_testing.sh" "$input_file_paths_str" "$prompt_format" "$few_shot_number" $TRANSLATOR_DIR $predictions_output_dir $max_samples
55 | # the second step continues below
56 | prompt_format=to_python_given_STL
57 |
58 | input_file_paths=(
59 | $predictions_output_dir/to_STL/
60 | )
61 | ### --------- End of first step --------- ###
62 | fi
63 |
64 |
65 | python train/scripts/tokenize_data.py --paths "${input_file_paths[@]}" --out_file_path ${out_file_path} --model_path $tokenizer_path --sft --prompt_format $prompt_format
66 | python train/scripts/group_text.py --max_len $context_length --out_file_path ${out_file_path} --model_path $tokenizer_path --sft --no_grouping --no_padding --balance 0.05 0.22 0.23 0.05 0.22 0.23 --total 128000
67 |
68 |
69 | ######################################
70 | ### --------- Training --------- ###
71 | ######################################
72 |
73 | echo Starting training...
74 |
75 | bash train/scripts/train.sh
76 |
--------------------------------------------------------------------------------
/test/PDEcontrol/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__
2 |
--------------------------------------------------------------------------------
/test/PDEcontrol/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 | python_version: "3.11"
7 | python_packages:
8 | - torch==2.0.1
9 | - torchvision==0.15.2
10 | - transformers==4.37.2
11 | - accelerate==0.27.0
12 | - hf_transfer
13 |
14 | # predict.py defines how predictions are run on your model
15 | predict: "replicate/predict.py:Predictor"
16 |
--------------------------------------------------------------------------------
/test/PDEcontrol/evaluation/data_processing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/delta-lab-ai/pde-controller/e680aad6635994fa6fd6af0e809f9064dfae2aea/test/PDEcontrol/evaluation/data_processing/__init__.py
--------------------------------------------------------------------------------
/test/PDEcontrol/evaluation/data_processing/process_utils.py:
--------------------------------------------------------------------------------
1 | import regex
2 |
3 | from data_processing.answer_extraction import extract_math_answer, strip_string
4 |
5 | def process_gsm8k_test(item):
6 | sample = {
7 | 'dataset': 'gsm8k-cot',
8 | 'id': item['id'],
9 | 'messages': [
10 | {'role': 'user', 'content': item['question']},
11 | {'role': 'assistant', 'content': regex.sub(r"<<[^<>]*>>", "", item['cot']) + "\nSo the answer is $\\boxed{" + item['answer'].strip() + "}$."}
12 | ],
13 | 'answer': item['answer'].replace(',', '')
14 | }
15 | yield sample
16 |
17 | def process_math_test(item):
18 | question = item["problem"]
19 | try:
20 | answer = extract_math_answer(question, item['solution'], task="cot")
21 | except:
22 | return
23 | sample = {
24 | "dataset": "math-cot",
25 | "id": item['id'],
26 | "level": item["level"],
27 | "type": item["type"],
28 | "category": item["category"],
29 | "messages": [
30 | {"role": "user", "content": question},
31 | {"role": "assistant", "content": "\n".join(regex.split(r"(?<=\.) (?=[A-Z])", item["solution"]))}
32 | ],
33 | "answer": answer
34 | }
35 | yield sample
36 |
37 | def process_math_sat(item):
38 | options = item['options'].strip()
39 | assert 'A' == options[0]
40 | options = '(' + options
41 | for ch in 'BCDEFG':
42 | if f' {ch}) ' in options:
43 | options = regex.sub(f' {ch}\) ', f" ({ch}) ", options)
44 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
45 | messages = [
46 | {'role': 'user', 'content': question},
47 | {'role': 'assistant', 'content': item['Answer']}
48 | ]
49 | item = {
50 | 'dataset': 'math_sat',
51 | 'id': item['id'],
52 | 'language': 'en',
53 | 'messages': messages,
54 | 'answer': item['Answer'],
55 | }
56 | yield item
57 |
58 | def process_ocwcourses(item):
59 | messages = [
60 | {'role': 'user', 'content': item['problem'].strip()},
61 | {'role': 'assistant', 'content': item['solution'].strip()}
62 | ]
63 | item = {
64 | "dataset": "OCWCourses",
65 | "id": item['id'],
66 | "language": "en",
67 | "messages": messages,
68 | "answer": item['answer']
69 | }
70 | yield item
71 |
72 | def process_mmlu_stem(item):
73 | options = item['options']
74 | for i, (label, option) in enumerate(zip('ABCD', options)):
75 | options[i] = f"({label}) {str(option).strip()}"
76 | options = ", ".join(options)
77 | question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
78 | messages = [
79 | {'role': 'user', 'content': question},
80 | {'role': 'assistant', 'content': item['answer']}
81 | ]
82 | item = {
83 | "dataset": "MMLU-STEM",
84 | "id": item['id'],
85 | "language": "en",
86 | "messages": messages,
87 | "answer": item['answer']
88 | }
89 | yield item
90 |
91 | def process_mgsm_zh(item):
92 | item['answer'] = item['answer'].replace(',', '')
93 | yield item
94 |
95 | def process_cmath(item):
96 | item = {
97 | 'dataset': 'cmath',
98 | 'id': item['id'],
99 | 'grade': item['grade'],
100 | 'reasoning_step': item['reasoning_step'],
101 | 'messages': [
102 | {'role': 'user', 'content': item['question'].strip()},
103 | {'role': 'assistant', 'content': ''}
104 | ],
105 | 'answer': item['golden'].strip().replace(",", "")
106 | }
107 | yield item
108 |
109 | def process_agieval_gaokao_math_cloze(item):
110 | item = {
111 | 'dataset': 'agieval-gaokao-math-cloze',
112 | 'id': item['id'],
113 | 'messages': [
114 | {'role': 'user', 'content': item['question'].strip()},
115 | {'role': 'assistant', 'content': ''}
116 | ],
117 | 'answer': [strip_string(ans) for ans in item['answer'].strip().split(";")]
118 | }
119 | yield item
120 |
121 | def process_agieval_gaokao_mathqa(item):
122 | question = item['question'].strip()
123 | options = []
124 | for option in item['options']:
125 | option = option.strip()
126 | assert option[0] == '('
127 | assert option[2] == ')'
128 | assert option[1] in 'ABCD'
129 | option = f"{option[1]}: {option[3:].strip()}"
130 | options.append(option.strip())
131 | question = f"{question}\n{options}"
132 | item = {
133 | 'dataset': 'agieval-gaokao-mathqa',
134 | 'id': item['id'],
135 | 'messages': [
136 | {'role': 'user', 'content': question},
137 | {'role': 'assistant', 'content': ''}
138 | ],
139 | "answer": item['label']
140 | }
141 | yield item
142 |
143 | def process_agieval_gaokao_mathqa_few_shot_cot_test(item):
144 | question = item['question'].strip().rstrip('\\')
145 | options = " ".join([opt.strip() for opt in item['options']])
146 | question = f"{question}\n从以下选项中选择: {options}"
147 | item = {
148 | 'dataset': 'agieval-gaokao-mathqa',
149 | 'id': item['id'],
150 | 'messages': [
151 | {'role': 'user', 'content': question},
152 | {'role': 'assistant', 'content': ''}
153 | ],
154 | "answer": item['label']
155 | }
156 | yield item
157 |
158 | def process_minif2f_isabelle(item):
159 | question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}"
160 | item = {
161 | 'dataset': 'minif2f-isabelle',
162 | 'id': item['id'],
163 | 'messages': [
164 | {'role': 'user', 'content': question},
165 | {'role': 'assistant', 'content': ''}
166 | ],
167 | "answer": "placeholder"
168 | }
169 | yield item
170 |
--------------------------------------------------------------------------------
/test/PDEcontrol/evaluation/eval/__init__.py:
--------------------------------------------------------------------------------
1 | from . import utils
2 | from . import eval_script
--------------------------------------------------------------------------------
/test/PDEcontrol/evaluation/eval/eval_robustness_reasoning_wrapper.py:
--------------------------------------------------------------------------------
1 | # eval_robustness_reasoning_wrapper.py
2 | import sys
3 | import json
4 | try:
5 | from control.femformal.eval_robustness_reasoning import eval_robustness as femformal_eval_robustness
6 | except ImportError:
7 | import sys
8 | import os
9 | sys.path.append(os.path.abspath('/localhome/mms43/scratch/mathcoder2/MathCoder2'))
10 | from control.femformal.eval_robustness_reasoning import eval_robustness as femformal_eval_robustness
11 |
12 |
13 | def main():
14 | llm_anchor_output = sys.argv[1]
15 | llm_interm_output = sys.argv[2]
16 | robustness, runtime = femformal_eval_robustness(llm_anchor_output, llm_interm_output)
17 | # Note that this code should not print anything else to stdout or it will screw with the evaluation
18 | print(json.dumps({
19 | "robustness": robustness,
20 | "runtime": runtime
21 | }))
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------
/test/PDEcontrol/evaluation/eval/eval_robustness_wrapper.py:
--------------------------------------------------------------------------------
1 | # eval_robustness_wrapper.py
2 | import sys
3 | import json
4 | try:
5 | from control.femformal.eval_robustness import eval_robustness as femformal_eval_robustness
6 | except ImportError:
7 | import sys
8 | import os
9 | sys.path.append(os.path.abspath('/localhome/mms43/scratch/mathcoder2/MathCoder2'))
10 | from control.femformal.eval_robustness import eval_robustness as femformal_eval_robustness
11 |
12 |
13 | def main():
14 | nl_in_prompt = sys.argv[1]
15 | llm_output = sys.argv[2]
16 | robustness, runtime = femformal_eval_robustness(nl_in_prompt, llm_output)
17 | # Note that this code should not print anything else to stdout or it will screw with the evaluation
18 | print(json.dumps({
19 | "robustness": robustness,
20 | "runtime": runtime
21 | }))
22 |
23 | if __name__ == "__main__":
24 | main()
--------------------------------------------------------------------------------
/test/PDEcontrol/evaluation/infer/simulate_gt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import traceback
5 | current_dir = os.path.dirname(os.path.realpath(__file__))
6 | from tqdm import tqdm
7 | import json
8 | from pebble import ProcessPool
9 | from concurrent.futures import TimeoutError
10 | import random
11 |
12 |
13 |
14 | utils_path = os.path.join(current_dir, '../../../../utils')
15 | sys.path.append(utils_path)
16 |
17 | from few_shot_prompts import CoTOneDHeat, CoTOneDWave
18 |
19 |
20 | evaluation_path = os.path.abspath(os.path.join(current_dir, '..'))
21 | sys.path.append(evaluation_path)
22 |
23 | eval_path = os.path.abspath(os.path.join(current_dir, '../eval'))
24 | sys.path.append(eval_path)
25 |
26 | data_processing_path = os.path.abspath(os.path.join(current_dir, '../data_processing'))
27 | sys.path.append(data_processing_path)
28 |
29 | from data_processing.answer_extraction import *
30 | from eval.eval_script import *
31 |
32 |
33 |
34 |
35 | # model_name_or_path cannot be both None or both not None.
36 | model = None
37 | tokenizer = None
38 | pool = None
39 | os.environ['TOKENIZERS_PARALLELISM'] = 'true'
40 |
41 | def evaluate(eval_fn, tasks, _timeout=15):
42 | with ProcessPool() as pool:
43 | timeout_cnt = 0
44 | failure_cnt = 0
45 | iterator = pool.map(eval_fn, tasks, timeout=_timeout).result()
46 | labels = []
47 | while True:
48 | try:
49 | labels.append((next(iterator)))
50 | except StopIteration:
51 | break
52 | except TimeoutError as error:
53 | labels.append(0)
54 | timeout_cnt += 1
55 | except Exception as error:
56 | print("An error occurred:", error, flush=True)
57 | traceback.print_exception(type(error), error, error.__traceback__, file=sys.stdout)
58 | failure_cnt += 1
59 | labels.append(-100)
60 | return labels, timeout_cnt, failure_cnt
61 |
62 | def evaluate_future(eval_fn, tasks, _timeout=300):
63 | # The runtime is doubled because the gt may have to be simulated also.
64 | num_cpus = os.cpu_count()
65 | max_workers = max(1, int(np.floor(num_cpus * 0.5)))
66 | print(f"Using {max_workers} out of {num_cpus} CPUs")
67 | with ProcessPool(max_workers=max_workers) as pool:
68 | timeout_cnt = 0
69 | futures = [(i, pool.schedule(eval_fn, args=(task,), timeout=_timeout)) for i, task in enumerate(tasks)]
70 | results = [None] * len(tasks)
71 |
72 | with tqdm(total=len(futures), desc="Evaluating robustness (completed)") as pbar:
73 | for i, future in futures:
74 | try:
75 | result = future.result()
76 | gt_robustness, gt_simtime = result
77 | results[i] = (gt_robustness, gt_simtime)
78 |
79 | except TimeoutError:
80 | results[i] = ("timeout", "timeout")
81 | timeout_cnt += 1
82 | except Exception as error:
83 | print("An error occurred:", error, flush=True)
84 | traceback.print_exception(type(error), error, error.__traceback__, file=sys.stdout)
85 | exit()
86 | finally:
87 | pbar.update(1)
88 | list_gt_robustness, list_gt_simtime = zip(*results)
89 | return list_gt_robustness, list_gt_simtime, timeout_cnt
90 |
91 |
92 | def main(args):
93 | if args.gpus is not None:
94 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
95 | random.seed(42)
96 |
97 | print("Loading data...")
98 | test_data = []
99 | with open(os.path.join(args.data_dir, f"validation.jsonl" if args.infer_on_train_set else f"test.jsonl")) as fin:
100 | for line in fin:
101 | example = json.loads(line)
102 | python_code = example[args.python_key]
103 | sstl = example[args.stl_key]
104 | natural_language = example[args.nl_key]
105 | robustness = example.get(args.robustness_key, None)
106 | example['python'] = python_code.strip()
107 | example['sstl'] = sstl
108 | example['nl'] = natural_language
109 | if robustness is not None:
110 | example['robustness'] = robustness
111 | test_data.append(example)
112 |
113 | if args.max_num_examples and len(test_data) > args.max_num_examples:
114 | test_data = random.sample(test_data, args.max_num_examples)
115 |
116 | if not test_data:
117 | print("Ending. There was no data to test.")
118 | return
119 |
120 | args.save_dir = os.path.join(args.data_dir + "_" + str(args.max_num_examples))
121 | os.makedirs(args.save_dir, exist_ok=True)
122 |
123 |
124 | if args.eval_robustness:
125 | eval_robustness_function = eval("eval_robustness_gt")
126 | sim_gt_robustnesses, sim_gt_simulation_times, eval_timeout_cnt_robustness = evaluate_future(eval_robustness_function, test_data)
127 | print("done evaluating robustness", flush=True)
128 | for item, gt_r, gt_time in zip(test_data, sim_gt_robustnesses, sim_gt_simulation_times):
129 | if 'robustness' not in item:
130 | item['robustness'] = gt_r
131 | item['time'] = gt_time
132 |
133 | print("Calculating accuracy...")
134 | ## track dataset stats:
135 | sum_gt_positive_robustness = 0
136 | sum_gt_negative_robustness = 0
137 | sum_gt_failed_robustness = 0
138 | for item in test_data:
139 | if args.eval_robustness:
140 | ## track dataset stats:
141 | gt_r = item['robustness']
142 | if gt_r > 0:
143 | sum_gt_positive_robustness += 1
144 | elif gt_r < 0 and gt_r != -100:
145 | sum_gt_negative_robustness += 1
146 | elif gt_r == -100:
147 | sum_gt_failed_robustness += 1
148 | else:
149 | raise ValueError(f"gt_r = {gt_r}")
150 | ## end track dataset stats
151 |
152 |
153 | if args.eval_robustness:
154 | ## track dataset stats:
155 | gt_positive_robustness_rate = sum_gt_positive_robustness / len(test_data)
156 | gt_negative_robustness_rate = sum_gt_negative_robustness / len(test_data)
157 | gt_failed_robustness_rate = sum_gt_failed_robustness / len(test_data)
158 | print(f"gt positive robustness rate = {gt_positive_robustness_rate * 100}", flush=True)
159 | print(f"gt negative robustness rate = {gt_negative_robustness_rate * 100}", flush=True)
160 | print(f"gt failed robustness rate = {gt_failed_robustness_rate * 100}", flush=True)
161 |
162 |
163 | if args.eval_robustness: print(f"Timeout count >>> output eval robustness = {eval_timeout_cnt_robustness}", flush=True)
164 |
165 | with open(os.path.join(args.save_dir, f"validation.jsonl" if args.infer_on_train_set else f"test.jsonl"), "w") as fout:
166 | for item in test_data:
167 | if 'sstl' in item:
168 | sstl = item.pop('sstl')
169 | item['sstl'] = sstl
170 | if 'python' in item:
171 | python = item.pop('python')
172 | item['python'] = python
173 | if 'robustness' in item:
174 | robustness = item.pop('robustness')
175 | item['robustness'] = robustness
176 | if 'time' in item:
177 | time = item.pop('time')
178 | item['time'] = time
179 | json.dump(item, fout, ensure_ascii=True)
180 | fout.write("\n")
181 |
182 | metric_fname = "metrics.json"
183 | with open(os.path.join(args.save_dir, metric_fname), "w") as fout:
184 | metrics = {
185 | "n_samples": len(test_data),
186 | }
187 | if args.eval_robustness:
188 | metrics["gt positive robustness rate"] = gt_positive_robustness_rate
189 | metrics["gt negative robustness rate"] = gt_negative_robustness_rate
190 | metrics["gt failed robustness rate"] = gt_failed_robustness_rate
191 | json.dump(metrics, fout, indent=4)
192 |
193 | if __name__ == "__main__":
194 | parser = argparse.ArgumentParser()
195 | parser.add_argument("--seed", type=int, default=None, help="random seed for the LLM generation only. This SHOULD NOT affect the random seed for data selection (I did not debug this so set it at your own peril).")
196 | parser.add_argument("--data_dir", type=str, default="data/mgsm")
197 | parser.add_argument("--python_key", type=str, default="python", help="Key for accessing the Python code in the example.")
198 | parser.add_argument("--stl_key", type=str, default="sstl", help="Key for accessing the STL in the example.")
199 | parser.add_argument("--nl_key", type=str, default="nl", help="Key for accessing the natural language in the example.")
200 | parser.add_argument("--robustness_key", type=str, default="robustness", help="Key for accessing the robustness in the example.")
201 |
202 | parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate.")
203 | parser.add_argument("--save_dir", type=str, default="results/mgsm")
204 |
205 | parser.add_argument("--infer_on_train_set", action="store_true")
206 |
207 | parser.add_argument("--eval_perplexity", action='store_true', help="Whether to evaluate the perplexity of the model outputs.")
208 | parser.add_argument("--eval_robustness", action='store_true', help="Whether to evaluate the robustness of the model outputs by comparing on the FemFormal repo.")
209 | parser.add_argument("--eval_edit_distance", action='store_true', help="Whether to evaluate the edit distance of the model python with the ground truth python.")
210 | parser.add_argument("--eval_iou", action='store_true', help="Whether to evaluate the precision and recall of the model sstl latex generation with the ground truth sstl.")
211 | parser.add_argument("--load_from_file", action='store_true', help="Whether to load the model predictions from a file instead of regenerating them. If this flag is provided but no file exists, the model predictions will be generated and saved to a file.")
212 | parser.add_argument("--skip_existing_scores", action='store_true', help="Whether to skip computing and overwriting the evaluation of a datapoint (ex. --eval_perplexity) that exists already. Only relevant if `--load_from_file` is provided.")
213 | parser.add_argument("--gpus", type=str, default=None, help="Use to set the CUDA_VISIBLE_DEVICES environment variable.")
214 | args, unparsed_args = parser.parse_known_args()
215 | if args.gpus is not None:
216 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
217 |
218 | print("unparsed args:", flush=True)
219 | print(unparsed_args, flush=True)
220 | print("args:", flush=True)
221 | print(args, flush=True)
222 |
223 | if args.seed is not None:
224 | if args.seed < 0:
225 | args.seed = None
226 |
227 | if 'math6' in args.data_dir:
228 | args.multi_turn = True
229 |
230 | # the basename of the datadir is the dataset name
231 | args.dataset_name = args.data_dir.split("/")[-1]
232 |
233 | main(args)
234 |
235 | if pool is not None:
236 | pool.close()
237 |
--------------------------------------------------------------------------------
/test/PDEcontrol/evaluation/scripts/infer_pdecontrol.sh:
--------------------------------------------------------------------------------
1 | eval "$(conda shell.bash hook)"
2 |
3 | set -e
4 |
5 | conda activate trainenv
6 | echo "activated conda env:" $CONDA_DEFAULT_ENV
7 | python --version
8 |
9 |
10 |
11 | dataset=$1
12 | out_dir=$2
13 | few_shot_number=$3
14 | prompt_format=$4
15 | max_samples=$5
16 | gpus=$6
17 | translator_path=${7}
18 | coder_path=${8}
19 | controller_path=${9}
20 | use_openai=${10}
21 |
22 | save_dir=$out_dir/${dataset}_shots=${few_shot_number}
23 |
24 |
25 | CMD="CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false python test/PDEcontrol/evaluation/infer/run_1d_pdecontrol_eval_full.py \
26 | --data_dir \"$(dirname $0)/../../test_data/${dataset}\" \
27 | --save_dir $save_dir \
28 | --use_vllm \
29 | --model_name_or_path_translator $translator_path \
30 | --tokenizer_name_or_path_translator $translator_path \
31 | --model_name_or_path_coder $coder_path \
32 | --tokenizer_name_or_path_coder $coder_path \
33 | --model_name_or_path_controller $controller_path \
34 | --tokenizer_name_or_path_controller $controller_path \
35 | --eval_batch_size 1 \
36 | --temperature 0.2 \
37 | --seed 0
38 | --n_repeat_sampling 3
39 | --prompt_format few_shot \
40 | --prompt_dataset CoTOneDCombined \
41 | --few_shot_number $few_shot_number \
42 | --few_shot_prompt $prompt_format \
43 | --eval_robustness
44 | --eval_iou
45 | --eval_edit_distance"
46 |
47 | if [ $max_samples -gt 0 ]; then
48 | CMD="$CMD --max_num_examples $max_samples"
49 | fi
50 |
51 | if [ "$gpus" != -1 ]; then
52 | CMD="$CMD --gpus $gpus"
53 | fi
54 |
55 | if [ -n "$use_openai" ]; then
56 | CMD="$CMD --use_openai $use_openai"
57 | else
58 | CMD="$CMD --eval_perplexity"
59 | fi
60 |
61 | echo $CMD
62 | eval $CMD
63 |
64 | exit 1
--------------------------------------------------------------------------------
/test/README.md:
--------------------------------------------------------------------------------
1 | # Testing
2 |
3 | ## Installation
4 |
5 | Install conda environments `trainenv` and `pdecontrol` as specified in the [main README](../README.md0)
6 |
7 | ## Evaluation
8 |
9 |
10 | We provide a test script for both zero-shot and few-shot evaluation on heat and wave PDEs used in our paper. Modify the settings in [`scripts/test_pdecontrol.sh`](./scripts/test_pdecontrol.sh).
11 |
12 | Then to evaluate the models, set the paths in the following script and run it:
13 |
14 | ```bash
15 | bash run_testing.sh`
16 | ```
17 |
18 | Set `$MODELS` with the path to the directory where the model weights are stored, and set `$OUTPUT_DIR` with the path to the directory where you wish the inference results would be saved. This script would also create a `result` directory under `EVAL_DIR` where markdown files containing a table of all the results would be saved.
19 |
20 |
21 | ## Labels
22 | The following code can be used to simulate labels for the ground-truth natural language & python inputs. To do so, run:
23 |
24 | ```bash
25 | bash simulate_gt.sh
26 | ```
--------------------------------------------------------------------------------
/test/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | tensorflow
3 | transformers
4 | wandb
5 | beautifulsoup4
6 | torchvision
7 | tqdm
8 | tensorboard
9 | flash-attn
10 | trl
11 | vllm=0.2.0
12 | deepspeed
13 | tf_keras
14 | torchaudio
15 | peft
16 | Pebble
17 |
--------------------------------------------------------------------------------
/test/scripts/read_result.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | import os
4 | from glob import glob
5 |
6 | import numpy as np
7 |
8 | def read_json(in_file):
9 | with open(in_file, "r", encoding="utf-8") as f:
10 | return json.load(f)
11 |
12 |
13 | def per_eval_method_read_result(metrics, datasets, in_dir, eval_method, subset_id=None, shots=3, seeds=[-1]):
14 | # Initialize the output text
15 | text = ""
16 | for metric in metrics:
17 | # Initialize a dictionary to store results
18 | max_shots = shots
19 | # results = {i: ["n/a"] * len(datasets) for i in range(max_shots + 1)}
20 | ## for each shot, there is a dataset. Each dataset contains a list to track the scores for each seed.
21 | results = {i: {dataset: [] for dataset in datasets} for i in range(max_shots + 1)}
22 |
23 | # Track dataset columns with valid entries
24 | valid_columns = {dataset: False for dataset in datasets}
25 |
26 | # Iterate over the directories again to populate results
27 | for dirname in os.listdir(in_dir):
28 | if "shots=" in dirname and "seed=" in dirname:
29 | # Extract shot number and seed number
30 | shot_num = int(dirname.split("shots=")[1].split("_")[0])
31 | seed_num = int(dirname.split("seed=")[1].split("_")[0])
32 |
33 | if seed_num not in seeds:
34 | continue
35 |
36 | # Determine the dataset and column based on the directory name
37 | dataset = dirname.split("_shots=")[0]
38 |
39 | # Read the metrics.json file
40 | in_file = os.path.join(in_dir, dirname, eval_method, "metrics.json")
41 | if subset_id is not None:
42 | in_file = os.path.join(in_dir, dirname, eval_method, f"metrics.{subset_id}.json")
43 | if os.path.exists(in_file):
44 | data = read_json(in_file)
45 | try:
46 | results[shot_num][dataset].append(data[metric])
47 | valid_columns[dataset] = True
48 | except KeyError:
49 | pass
50 |
51 | # Ensure all datasets have the same number of scores for each shot number (excluding the empty)
52 | for shot_num in range(max_shots + 1):
53 | lengths = [len(results[shot_num][dataset]) for dataset in datasets if valid_columns[dataset]]
54 | if lengths and not all(length == lengths[0] for length in lengths):
55 | raise ValueError(f"Inconsistent number of scores for shot number {shot_num}: {lengths}")
56 |
57 |
58 | # Filter out columns that are entirely empty
59 | filtered_datasets = []
60 | filtered_columns = []
61 | for dataset in datasets:
62 | if valid_columns[dataset]:
63 | filtered_datasets.append(dataset)
64 | filtered_columns.append(dataset)
65 |
66 | # Initialize the table header
67 | header = f"## Metric: {metric}\n\n"
68 | header += "| shots | " + " | ".join([f"{dataset}" for dataset in filtered_datasets]) + " |\n"
69 | separator = "|-------|" + "------------|" * (len(filtered_datasets)) + "\n"
70 | text += header + separator
71 |
72 |
73 | # Construct the table rows
74 | for shot_num in range(max_shots + 1):
75 | row = []
76 | for dataset in filtered_columns:
77 | scores = results[shot_num][dataset]
78 | if scores:
79 | mean_score = np.mean(scores)
80 | std_score = np.std(scores, ddof=1)
81 | row.append(f"{mean_score:.4f} ({std_score:.4f})")
82 | else:
83 | row.append("n/a")
84 | text += f"| {shot_num} | " + " | ".join(row) + " |\n"
85 |
86 | text += "\n\n"
87 |
88 | return text
89 |
90 |
91 |
92 | def read_result(in_dir, out_file, args): #, metrics=["perplexity"], subset_id=None):
93 | shots = args.shots
94 | metrics=args.metrics
95 | eval_methods=args.eval_methods
96 | seeds = args.seeds
97 | subset_id=args.subset_id
98 | if subset_id is not None:
99 | if subset_id < 0:
100 | subset_id = None
101 |
102 | # Initialize a set to store unique datasets
103 | datasets = set()
104 |
105 | # Iterate over the directories to identify datasets
106 | ## Directories take the form of `dataset_shots=3_seed=0`
107 | for dirname in os.listdir(in_dir):
108 | if "shots=" in dirname and "seed=" in dirname:
109 | dataset = dirname.split("_shots=")[0]
110 | datasets.add(dataset)
111 |
112 | # Sort datasets to maintain a consistent order
113 | datasets = sorted(datasets)
114 |
115 | for eval_method in eval_methods:
116 | text = per_eval_method_read_result(metrics, datasets, os.path.join(in_dir), eval_method, subset_id=subset_id, shots=shots, seeds=seeds)
117 |
118 |
119 | output_file = f"{out_file}-{eval_method}" + ".md"
120 | out_dir = os.path.dirname(output_file)
121 |
122 | if not os.path.exists(out_dir):
123 | os.makedirs(out_dir)
124 |
125 | print(eval_method)
126 | print(text)
127 | if text != "":
128 | with open(output_file, "w", encoding="utf-8") as f:
129 | f.write(text)
130 |
131 |
132 |
133 | def main():
134 | parser = ArgumentParser()
135 | parser.add_argument("--in_dir", type=str)
136 | parser.add_argument("--subset_id", type=int, default=None, help="The dataset to test can be split into `n_subsets`(see run_1d_pdecontrol_eval) and if `subset_id != None` and non-negative, the `subset_id` will be the only subset to find metrics.")
137 | parser.add_argument("--metrics", type=str, nargs="+", default=[
138 | "robustness accuracy",
139 | "robustness mre",
140 | "robustness failure rate",
141 | "robustness timeout rate",
142 | "simulation time mre",
143 | "edit distance",
144 | "iou",
145 | "iou failures",
146 | "iou timeout rate",
147 | "perplexity",
148 | "perplexity timeout rate",
149 | "gt positive robustness rate",
150 | "gt negative robustness rate",
151 | "gt failed robustness rate",
152 | "adjusted_failure_rate",
153 | ])
154 | parser.add_argument("--shots", type=int, default=3)
155 | parser.add_argument(
156 | "--eval_methods",
157 | type=str,
158 | choices=["to_python_direct_with_sstl_cot", "to_python_no_STL", "to_python_two_step", "to_STL"],
159 | default=["to_python_direct_with_sstl_cot", "to_python_no_STL", "to_python_two_step", "to_STL"],
160 | nargs='+'
161 | )
162 | parser.add_argument("--seeds", type=int, nargs="+", default=[-1], help="The seeds to to average results over. A directory may contain multiple seeds. This argument will specify which seeds to average over and the rest will be ignored.")
163 |
164 |
165 |
166 | args = parser.parse_args()
167 | in_dir = args.in_dir
168 | out_file = os.path.join(args.in_dir, "results", os.path.basename(in_dir))
169 | read_result(in_dir, out_file, args=args)
170 |
171 | if __name__ == "__main__":
172 | main()
173 |
--------------------------------------------------------------------------------
/test/scripts/simulate_gt.sh:
--------------------------------------------------------------------------------
1 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
2 |
3 | cd $DIR/../..
4 |
5 |
6 |
7 | declare -a datasets=(
8 | "heat_nc1"
9 | "heat_nc2"
10 | "heat_nc3"
11 | "wave_nc1"
12 | "wave_nc2"
13 | "wave_nc3"
14 | )
15 |
16 |
17 | ## set the max samples to be -1 to use all samples.
18 | max_samples=512
19 | skip_existing=False # True or False
20 | load_from_file=False # True or False
21 | gpus=1
22 |
23 |
24 | # Calculate the total number of iterations
25 | total_iterations=$(( ${#datasets[@]} ))
26 | current_iteration=0
27 | # Record the start time
28 | start_time=$(date +%s)
29 |
30 | for dataset in "${datasets[@]}"
31 | do
32 | echo " PROGRESS: $current_iteration/$total_iterations"
33 | echo " Evaluating $dataset" | tee -a $log_file
34 |
35 | if [[ $prompt_format == "to_STL" ]]; then
36 | eval_robustness=False
37 | fi
38 |
39 | CMD="CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false python test/PDEcontrol/evaluation/infer/simulate_gt.py \
40 | --data_dir \"$DIR/../../test/PDEcontrol/test_data/${dataset}\" \
41 | --eval_batch_size 1 \
42 | --eval_robustness"
43 |
44 | if [ $max_samples -gt 0 ]; then
45 | CMD="$CMD --max_num_examples $max_samples"
46 | fi
47 |
48 | if [ "$load_from_file" = "True" ]; then
49 | CMD="$CMD --load_from_file"
50 | fi
51 |
52 | if [ -n "$gpus" ]; then
53 | CMD="$CMD --gpus $gpus"
54 | fi
55 |
56 |
57 | echo $CMD
58 | eval $CMD
59 |
60 |
61 |
62 | echo " Done evaluating $dataset with $few_shot few shot examples and format=$prompt_format"
63 | current_iteration=$((current_iteration + 1))
64 | echo " PROGRESS: $current_iteration/$total_iterations"
65 | # Calculate elapsed time
66 | current_time=$(date +%s)
67 | elapsed_time=$((current_time - start_time))
68 | # Estimate remaining time
69 | average_time_per_iteration=$((elapsed_time / current_iteration))
70 | remaining_iterations=$((total_iterations - current_iteration))
71 | estimated_remaining_time=$((remaining_iterations * average_time_per_iteration))
72 | hours=$((estimated_remaining_time / 3600))
73 | minutes=$(( (estimated_remaining_time % 3600) / 60 ))
74 | seconds=$((estimated_remaining_time % 60))
75 | echo " Estimated remaining time: $hours hours, $minutes minutes, $seconds seconds"
76 | done
77 |
--------------------------------------------------------------------------------
/test/scripts/test_pdecontrol.sh:
--------------------------------------------------------------------------------
1 | out_dir=${1}
2 | gpus=${2}
3 | translator_path=${3}
4 | coder_path=${4}
5 | controller_path=${5}
6 | use_openai=${6}
7 |
8 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
9 |
10 | declare -a datasets=(
11 | "heat_nc1_512"
12 | "heat_nc2_512"
13 | "heat_nc3_512"
14 | "wave_nc1_512"
15 | "wave_nc2_512"
16 | "wave_nc3_512"
17 | # "dpo_manual_test"
18 | )
19 |
20 | declare -a prompt_formats=(
21 | # to_python_no_STL
22 | # to_python_direct_with_sstl_cot
23 | to_python_two_step
24 | # to_STL
25 | # full_pipeline
26 | # reasoning_only
27 | )
28 |
29 | declare -a few_shot_numbers=(
30 | 0
31 | 2
32 | )
33 |
34 |
35 | ## set the max samples to be -1 to use all samples.
36 | max_samples=4
37 | gpus=$gpus
38 | # gpus=0
39 |
40 | if [[ -n $use_openai ]]; then
41 | declare -a few_shot_numbers=(
42 | 2
43 | )
44 | fi
45 |
46 | # appends to the log file
47 | log_file=${out_dir}/eval_output.log
48 |
49 | # Calculate the total number of iterations
50 | total_iterations=$(( ${#datasets[@]} * ${#prompt_formats[@]} * ${#few_shot_numbers[@]}))
51 | current_iteration=0
52 | # Record the start time
53 | start_time=$(date +%s)
54 |
55 | for dataset in "${datasets[@]}"
56 | do
57 | for prompt_format in "${prompt_formats[@]}"
58 | do
59 | for few_shot in "${few_shot_numbers[@]}"
60 | do
61 | echo " PROGRESS: $current_iteration/$total_iterations" | tee -a $log_file
62 | echo " Evaluating $dataset with $few_shot few shot examples and format=$prompt_format" | tee -a $log_file
63 |
64 | echo dataset: $dataset
65 | echo out_dir: $out_dir
66 | echo shots: $few_shot
67 | echo translator_path: $translator_path
68 | echo coder_path: $coder_path
69 | echo controller_path: $controller_path
70 | echo prompt_format: $prompt_format
71 | echo max_samples: $max_samples
72 | echo gpus: $gpus
73 | echo use_openai: $use_openai
74 |
75 |
76 | bash test/PDEcontrol/evaluation/scripts/infer_pdecontrol.sh $dataset $out_dir $few_shot $prompt_format $max_samples $gpus $translator_path $coder_path $controller_path $use_openai | tee -a $log_file
77 |
78 | echo " Done evaluating $dataset with $few_shot few shot examples and format=$prompt_format" | tee -a $log_file
79 | current_iteration=$((current_iteration + 1))
80 | echo " PROGRESS: $current_iteration/$total_iterations" | tee -a $log_file
81 | # Calculate elapsed time
82 | current_time=$(date +%s)
83 | elapsed_time=$((current_time - start_time))
84 | # Estimate remaining time
85 | average_time_per_iteration=$((elapsed_time / current_iteration))
86 | remaining_iterations=$((total_iterations - current_iteration))
87 | estimated_remaining_time=$((remaining_iterations * average_time_per_iteration))
88 | # Convert estimated remaining time to human-readable format
89 | hours=$((estimated_remaining_time / 3600))
90 | minutes=$(( (estimated_remaining_time % 3600) / 60 ))
91 | seconds=$((estimated_remaining_time % 60))
92 | elapsed_hours=$((elapsed_time / 3600))
93 | elapsed_minutes=$(( (elapsed_time % 3600) / 60 ))
94 | elapsed_seconds=$((elapsed_time % 60))
95 | echo " Elapsed time: $elapsed_hours hours, $elapsed_minutes minutes, $elapsed_seconds seconds" | tee -a $log_file
96 | echo " Estimated remaining time: $hours hours, $minutes minutes, $seconds seconds" | tee -a $log_file
97 | done
98 | done
99 | done
100 |
101 | # # Initialize max_shots with the first element of the array
102 | # max_shots=${few_shot_numbers[0]}
103 |
104 | # # Iterate through the array to find the maximum value
105 | # for num in "${few_shot_numbers[@]}"; do
106 | # # Skip commented out values
107 | # if [[ $num =~ ^[0-9]+$ ]]; then
108 | # if (( num > max_shots )); then
109 | # max_shots=$num
110 | # fi
111 | # fi
112 | # done
113 |
114 | # echo "Maximum shots: $max_shots"
115 |
116 | # CMD="python /localhome/mms43/scratch/mathcoder2/MathCoder2/test/scripts/read_result.py --in_dir $out_dir --subset_id $max_samples --shots $max_shots --seeds ${seeds[@]}"
117 |
118 | # CMD="$CMD --eval_methods ${prompt_formats[@]}"
119 | # # if [[ $prompt_format == "to_STL" ]]; then
120 | # # CMD="$CMD --eval_methods to_STL"
121 | # # fi
122 |
123 | # eval $CMD
--------------------------------------------------------------------------------
/train/README.md:
--------------------------------------------------------------------------------
1 | # Supervised Finetuning and Direct Preference Optimization
2 |
3 | This directory contains the code for SFT and DPO.
4 |
5 | ## Installation
6 |
7 | Follow the instructions on [the main readme](../README.md)
8 |
9 | ## Training
10 |
11 | ### Step 1: save the paths of the jsonl files to be tokenized
12 |
13 | Modify `input_file_paths` in `run_training.sh`. Exch element of `input_file_paths` should be the path to a directory that contains `.jsonl` files with the texts to be trainined under the key `"text"`. For example:
14 | ```
15 | datasets/unprocessed/sft/one_d_heat_train
16 | datasets/unprocessed/sft/one_d_wave_train
17 | ```
18 |
19 | You can modify `out_file_path` in `run_training.sh` to change the output directory for the parquet files. Then run:
20 |
21 | ```shell
22 | bash run_training.sh
23 | ```
24 | #### Tokenization
25 | This script will automatically tokenize the file when calling `train/scripts/tokenize_data.py`. The paths for the output tokenized parquet files in automatically be written to the file `file_names_to-be-trained.json`. Each element should be a path to a directory that contains parquet files (created from tokenization).
26 |
27 | #### Grouping the texts into given context length
28 |
29 | You can modify `max_len` in `run_training.sh` to change the context length. This context length should equal the context length you wish to use in training. `train/scripts/group_text.py` will get called. This outputs a single parquet file containing token indexes grouped into the given context length.
30 |
31 | ### Step 2: train the model
32 |
33 | For example:
34 |
35 | You can modify the training configs in `scripts/train.py`. The example script runs on a single node with 4 A100 GPUs. You may need to modify the script so that it runs on your cluster. Run the following command (it is also included in the `run_training.sh` script)
36 |
37 | ```shell
38 | bash train/scripts/train.sh
39 | ```
40 |
41 | This script can train the model with validation, or validation can be turned off by simply running:
42 | ```sh
43 | train
44 | exit_status=$?
45 | ```
--------------------------------------------------------------------------------
/train/config/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": 0.1
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "warmup_min_lr": "auto",
18 | "warmup_max_lr": "auto",
19 | "warmup_num_steps": "auto",
20 | "total_num_steps": "auto"
21 | }
22 | },
23 | "flops_profiler": {
24 | "enabled": true,
25 | "profile_step": 25,
26 | "module_depth": -1,
27 | "top_modules": 1,
28 | "detailed": false,
29 | "output_file": null
30 | },
31 | "zero_optimization": {
32 | "stage": 3,
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": true
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "train_batch_size": "auto",
46 | "train_micro_batch_size_per_gpu": "auto",
47 | "wall_clock_breakdown": false
48 | }
--------------------------------------------------------------------------------
/train/config/deepspeed_dpo.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": 0.1
12 | }
13 | },
14 |
15 | "flops_profiler": {
16 | "enabled": true,
17 | "profile_step": 25,
18 | "module_depth": -1,
19 | "top_modules": 1,
20 | "detailed": false,
21 | "output_file": null
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "overlap_comm": true,
26 | "contiguous_gradients": true,
27 | "sub_group_size": 1e9,
28 | "reduce_bucket_size": "auto",
29 | "stage3_prefetch_bucket_size": "auto",
30 | "stage3_param_persistence_threshold": "auto",
31 | "stage3_max_live_parameters": 1e9,
32 | "stage3_max_reuse_distance": 1e9,
33 | "stage3_gather_16bit_weights_on_model_save": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "train_batch_size": "auto",
38 | "train_micro_batch_size_per_gpu": "auto",
39 | "wall_clock_breakdown": false
40 | }
--------------------------------------------------------------------------------
/train/scripts/group_text.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import json
5 | import random
6 | import numpy as np
7 | from utils.loader import Processor
8 | from datasets import load_dataset, concatenate_datasets
9 | import argparse
10 | from transformers import (
11 | AutoTokenizer,
12 | )
13 |
14 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
15 |
16 | def find_common_keywords_update_outfile(file_paths, outfile, ignore_keywords=None):
17 | if ignore_keywords is None:
18 | ignore_keywords = []
19 | # extract the basenames from the filenames
20 | basenames = [os.path.basename(file_name) for file_name in file_paths]
21 | # split the basenames into words by underscores, ignoring the file extension
22 | words_list = [set(os.path.splitext(basename)[0].split('_')) for basename in basenames]
23 | # filter the keywords if any
24 | words_list = [words - set(ignore_keywords) for words in words_list]
25 | # common words in all basenames
26 | common_words = set.intersection(*words_list)
27 | # modify the out_file basename to include these common words
28 | out_file_basename = os.path.splitext(outfile)[0] # remove the extension from outfile
29 | if common_words:
30 | out_file_basename += "_" + "_".join(sorted(common_words))
31 | out_file_basename += os.path.splitext(outfile)[1] # add the original extension back
32 | return out_file_basename
33 |
34 | def balance_datasets(datasets, balance, total=None, dataset_names=None):
35 | """Balance the datasets based on the balance values. Assumes each dataset was pre-shuffled.
36 | # Args:
37 | datasets: list of datasets to balance.
38 | balance: list(int). The percentage of data to keep from each dataset. The sum of the list must be 1.
39 | total: int. The total number of data to keep.
40 | dataset_names: list(str). The names of the datasets.
41 | ---
42 | All of these (except point 4) assume `total < sum([len(dataset) for dataset in datasets])`.
43 |
44 | 1. `Total = None, Balance = [1]` (**default**): Keep everything.
45 | 2. `Total = int, Balance = [1]`: Keep total number of data. Will first sample uniformly from all datasets, then sample from the remaining datasets based on the total.
46 | 3. `Total = None, len(Balance) > 1`: Keep all data for the smallest dataset and the rest will determined based of balance.
47 | 4. `Total = int, len(Balance) > 1`: Keep total number of data, and balance based on the balance values. This will may double-sample random datapoints from datasets that are too small.
48 | """
49 | assert sum(balance) == 1, "The balance values must sum to 1."
50 | if len(balance) > 1:
51 | assert len(datasets) == len(balance), "The number of datasets and `balance` values must be the same."
52 | assert dataset_names is None or len(datasets) == len(dataset_names), "The number of datasets and `dataset_names` must be the same."
53 |
54 | balanced_datasets = []
55 | if total is None:
56 | if balance == [1]:
57 | # 1. Keep everything
58 | return datasets
59 | else:
60 | # 3. Keep all data for the smallest dataset and balance the rest
61 | min_i, min_dataset = min(enumerate(datasets), key=lambda x: len(x[1]))
62 | total_size = len(min_dataset) / balance[min_i]
63 |
64 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)):
65 | num_to_keep = int(total_size * proportion)
66 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}"
67 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}")
68 | balanced_datasets.append(dataset.select(range(num_to_keep)))
69 | else:
70 | if balance == [1]:
71 | # 2. Keep total number of data. Since assumes data was pre-shuffled, we can just take the first `total` number of data.
72 | balanced_datasets = [dataset.select(range(total)) for dataset in datasets]
73 | else:
74 | # 4. Keep total number of data, balance based on balance values
75 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)):
76 | num_to_keep = int(total * proportion)
77 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}"
78 | temp_dataset = None
79 | if len(dataset) < num_to_keep:
80 | full_repeats = num_to_keep // len(dataset)
81 | remainder = num_to_keep % len(dataset)
82 | for _ in range(full_repeats):
83 | if temp_dataset is None:
84 | temp_dataset = dataset
85 | else:
86 | temp_dataset = concatenate_datasets([temp_dataset, dataset])
87 | indices_to_keep = random.sample(range(len(dataset)), remainder)
88 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)} (with {full_repeats} replications for all points and random sampling for the remainder)")
89 | temp_dataset = concatenate_datasets([temp_dataset, dataset.select(indices_to_keep)])
90 | balanced_datasets.append(temp_dataset)
91 | else:
92 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}")
93 | indices = random.sample(range(len(dataset)), num_to_keep)
94 | balanced_datasets.append(dataset.select(indices))
95 | return balanced_datasets
96 |
97 |
98 |
99 | def group_text(tokenizer, out_file, max_len, sft=False, no_grouping=False, no_padding=False, args=None):
100 | seed = 3407
101 | if tokenizer.pad_token is None:
102 | tokenizer.pad_token = tokenizer.eos_token
103 | tokenizer.pad_token_id = tokenizer.eos_token_id
104 |
105 | train_file_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "file_names_to-be-trained.json")
106 | with open(train_file_config, "r") as f:
107 | train_files = json.load(f)
108 |
109 | out_file = find_common_keywords_update_outfile(train_files, out_file)
110 | train_sets = []
111 | for file in train_files:
112 | print("loading file:", file)
113 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train')
114 |
115 | # shuffle the dataset for the balance_datasets function
116 | _dataset = _dataset.shuffle(seed=seed)
117 | train_sets.append(_dataset)
118 |
119 | train_sets = balance_datasets(train_sets, args.balance, args.total, train_files)
120 |
121 | lengths = np.array([_set.shape[0] for _set in train_sets])
122 | print(f"Data Lengths: {lengths}")
123 | train_sets = concatenate_datasets(train_sets)
124 | print(f"Total Length: {train_sets.shape[0]}")
125 | process_batch_size = min(1000, len(train_sets))
126 |
127 | processor = Processor()
128 | train_sets = train_sets.shuffle(seed=seed)
129 | column_names = list(train_sets.features)
130 |
131 | if no_grouping:
132 | if no_padding:
133 | # only truncate
134 | train_sets = train_sets.map(
135 | processor.truncate,
136 | fn_kwargs={
137 | "max_len": max_len,
138 | "dpo": args.dpo
139 | },
140 | batched=True,
141 | load_from_cache_file=False,
142 | batch_size=process_batch_size,
143 | num_proc=96,
144 | desc=f"Checking texts for chunks > {max_len}. Will truncate. No padding. No grouping.",
145 | )
146 | else:
147 | # truncate and add padding
148 | train_sets = train_sets.map(
149 | processor.truncate_and_add_padding,
150 | fn_kwargs={
151 | "tokenizer": tokenizer,
152 | "max_len": max_len
153 | },
154 | batched=True,
155 | load_from_cache_file=False,
156 | batch_size=process_batch_size,
157 | num_proc=96,
158 | desc=f"Checking texts for chunks > {max_len}. Will pad and truncate. No grouping.",
159 | )
160 | else:
161 | # pretraining
162 | train_sets = train_sets.map(
163 | processor.group_texts,
164 | fn_kwargs={
165 | "tokenizer": tokenizer,
166 | "max_len": max_len
167 | },
168 | batched=True,
169 | load_from_cache_file=False,
170 | remove_columns=column_names,
171 | batch_size=process_batch_size,
172 | num_proc=96,
173 | desc=f"Grouping texts in chunks of {max_len}",
174 | )
175 | train_sets.to_parquet(out_file)
176 | print(f"Saved to {out_file}")
177 |
178 | if __name__ == "__main__":
179 | parser = argparse.ArgumentParser()
180 | parser.add_argument("--sft", action="store_true", help="Run for supervised fine-tuning data")
181 | parser.add_argument("--max_len", type=int, default=4096, help="Max context length (or the length to use for truncation and padding if applicable)")
182 | parser.add_argument("--dpo", action="store_true", help="Run for DPO data")
183 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to output grouped parquet file")
184 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Tokenizer path")
185 | parser.add_argument("--no_grouping", action="store_true", help="Whether to not group smaller sequences and truncate longer ones when grouping.")
186 | parser.add_argument("--no_padding", action="store_true", help="Whether to not pad shorter sequences.")
187 | parser.add_argument("--balance", type=float, nargs="+", default=[1], help="The percentage of data to keep from each dataset. The sum of the list must be 1.")
188 | parser.add_argument("--total", type=int, default=None, help="The total number of data to keep.")
189 |
190 | args = parser.parse_args()
191 |
192 | sft = args.sft
193 | model_path = args.model_path
194 | max_len = args.max_len # context length
195 | out_file = f"{args.out_file_path}/{'not-' if args.no_grouping else ''}grouped_MaxContext{max_len}.parquet"
196 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
197 | group_text(tokenizer, out_file, max_len, sft, no_grouping=args.no_grouping, no_padding=args.no_padding, args=args)
--------------------------------------------------------------------------------
/train/scripts/group_text_dpo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import json
5 | import random
6 | import numpy as np
7 | from utils.loader import Processor
8 | from datasets import load_dataset, concatenate_datasets
9 | import argparse
10 | from transformers import (
11 | AutoTokenizer,
12 | )
13 |
14 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
15 |
16 | def find_common_keywords_update_outfile(file_paths, outfile, ignore_keywords=None):
17 | if ignore_keywords is None:
18 | ignore_keywords = []
19 | # extract the basenames from the filenames
20 | basenames = [os.path.basename(file_name) for file_name in file_paths]
21 | # split the basenames into words by underscores, ignoring the file extension
22 | words_list = [set(os.path.splitext(basename)[0].split('_')) for basename in basenames]
23 | # filter the keywords if any
24 | words_list = [words - set(ignore_keywords) for words in words_list]
25 | # common words in all basenames
26 | common_words = set.intersection(*words_list)
27 | # modify the out_file basename to include these common words
28 | out_file_basename = os.path.splitext(outfile)[0] # remove the extension from outfile
29 | if common_words:
30 | out_file_basename += "_" + "_".join(sorted(common_words))
31 | out_file_basename += os.path.splitext(outfile)[1] # add the original extension back
32 | return out_file_basename
33 |
34 | def balance_datasets(datasets, balance, total=None, dataset_names=None):
35 | """Balance the datasets based on the balance values. Assumes each dataset was pre-shuffled.
36 | # Args:
37 | datasets: list of datasets to balance.
38 | balance: list(int). The percentage of data to keep from each dataset. The sum of the list must be 1.
39 | total: int. The total number of data to keep.
40 | dataset_names: list(str). The names of the datasets.
41 | ---
42 | All of these (except point 4) assume `total < sum([len(dataset) for dataset in datasets])`.
43 |
44 | 1. `Total = None, Balance = [1]` (**default**): Keep everything.
45 | 2. `Total = int, Balance = [1]`: Keep total number of data. Will first sample uniformly from all datasets, then sample from the remaining datasets based on the total.
46 | 3. `Total = None, len(Balance) > 1`: Keep all data for the smallest dataset and the rest will determined based of balance.
47 | 4. `Total = int, len(Balance) > 1`: Keep total number of data, and balance based on the balance values. This will may double-sample random datapoints from datasets that are too small.
48 | """
49 | assert sum(balance) == 1, "The balance values must sum to 1."
50 | if len(balance) > 1:
51 | assert len(datasets) == len(balance), "The number of datasets and `balance` values must be the same."
52 | assert dataset_names is None or len(datasets) == len(dataset_names), "The number of datasets and `dataset_names` must be the same."
53 |
54 | balanced_datasets = []
55 | if total is None:
56 | if balance == [1]:
57 | # 1. Keep everything
58 | return datasets
59 | else:
60 | # 3. Keep all data for the smallest dataset and balance the rest
61 | min_i, min_dataset = min(enumerate(datasets), key=lambda x: len(x[1]))
62 | total_size = len(min_dataset) / balance[min_i]
63 |
64 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)):
65 | num_to_keep = int(total_size * proportion)
66 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}"
67 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}")
68 | balanced_datasets.append(dataset.select(range(num_to_keep)))
69 | else:
70 | if balance == [1]:
71 | # 2. Keep total number of data. Since assumes data was pre-shuffled, we can just take the first `total` number of data.
72 | balanced_datasets = [dataset.select(range(total)) for dataset in datasets]
73 | else:
74 | # 4. Keep total number of data, balance based on balance values
75 | for i, (dataset, proportion) in enumerate(zip(datasets, balance)):
76 | num_to_keep = int(total * proportion)
77 | dataset_name = dataset_names[i] if dataset_names else f"Dataset {i+1}"
78 | temp_dataset = None
79 | if len(dataset) < num_to_keep:
80 | full_repeats = num_to_keep // len(dataset)
81 | remainder = num_to_keep % len(dataset)
82 | for _ in range(full_repeats):
83 | if temp_dataset is None:
84 | temp_dataset = dataset
85 | else:
86 | temp_dataset = concatenate_datasets([temp_dataset, dataset])
87 | indices_to_keep = random.sample(range(len(dataset)), remainder)
88 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)} (with {full_repeats} replications for all points and random sampling for the remainder)")
89 | temp_dataset = concatenate_datasets([temp_dataset, dataset.select(indices_to_keep)])
90 | balanced_datasets.append(temp_dataset)
91 | else:
92 | print(f"Sampling {num_to_keep} datapoints from {dataset_name} of size {len(dataset)}")
93 | indices = random.sample(range(len(dataset)), num_to_keep)
94 | balanced_datasets.append(dataset.select(indices))
95 | return balanced_datasets
96 |
97 |
98 |
99 | def group_text(tokenizer, out_file, max_len, sft=False, no_grouping=False, no_padding=False, args=None):
100 | seed = 3407
101 | if tokenizer.pad_token is None:
102 | tokenizer.pad_token = tokenizer.eos_token
103 | tokenizer.pad_token_id = tokenizer.eos_token_id
104 |
105 | train_file_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "file_names_to-be-trained.json")
106 | with open(train_file_config, "r") as f:
107 | train_files = json.load(f)
108 |
109 | out_file = find_common_keywords_update_outfile(train_files, out_file)
110 | train_sets = []
111 | for file in train_files:
112 | print("loading file:", file)
113 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train')
114 |
115 | # shuffle the dataset. this is important for the balance_datasets function
116 | _dataset = _dataset.shuffle(seed=seed)
117 | train_sets.append(_dataset)
118 |
119 | train_sets = balance_datasets(train_sets, args.balance, args.total, train_files)
120 |
121 | lengths = np.array([_set.shape[0] for _set in train_sets])
122 | print(f"Data Lengths: {lengths}")
123 | train_sets = concatenate_datasets(train_sets)
124 | print(f"Total Length: {train_sets.shape[0]}")
125 | process_batch_size = min(1000, len(train_sets))
126 |
127 | processor = Processor()
128 | train_sets = train_sets.shuffle(seed=seed)
129 | column_names = list(train_sets.features)
130 |
131 | if no_grouping:
132 | if no_padding:
133 | # only truncate
134 | train_sets = train_sets.map(
135 | processor.truncate,
136 | fn_kwargs={
137 | "max_len": max_len
138 | },
139 | batched=True,
140 | load_from_cache_file=False,
141 | batch_size=process_batch_size,
142 | num_proc=96,
143 | desc=f"Checking texts for chunks > {max_len}. Will truncate. No padding. No grouping.",
144 | )
145 | else:
146 | # truncate and add padding
147 | train_sets = train_sets.map(
148 | processor.truncate_and_add_padding,
149 | fn_kwargs={
150 | "tokenizer": tokenizer,
151 | "max_len": max_len
152 | },
153 | batched=True,
154 | load_from_cache_file=False,
155 | batch_size=process_batch_size,
156 | num_proc=96,
157 | desc=f"Checking texts for chunks > {max_len}. Will pad and truncate. No grouping.",
158 | )
159 | else:
160 | if sft:
161 | # group texts for seq-to-seq
162 | raise NotImplementedError("Seq-to-seq grouping used to be implemented but this code needs to be checked before using again.")
163 | else:
164 | # pretraining: the original grouping function
165 | train_sets = train_sets.map(
166 | processor.group_texts,
167 | fn_kwargs={
168 | "tokenizer": tokenizer,
169 | "max_len": max_len
170 | },
171 | batched=True,
172 | load_from_cache_file=False,
173 | remove_columns=column_names,
174 | batch_size=process_batch_size,
175 | num_proc=96,
176 | desc=f"Grouping texts in chunks of {max_len}",
177 | )
178 | train_sets.to_parquet(out_file)
179 |
180 | if __name__ == "__main__":
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument("--sft", action="store_true", help="Run for supervised fine-tuning data")
183 | parser.add_argument("--max_len", type=int, default=4096, help="Max context length (or the length to use for truncation and padding if applicable)")
184 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to output grouped parquet file")
185 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Model path")
186 | parser.add_argument("--no_grouping", action="store_true", help="Whether to not group smaller sequences and truncate longer ones when grouping.")
187 | parser.add_argument("--no_padding", action="store_true", help="Whether to not pad shorter sequences.")
188 | parser.add_argument("--balance", type=float, nargs="+", default=[1], help="The percentage of data to keep from each dataset. The sum of the list must be 1.")
189 | parser.add_argument("--total", type=int, default=None, help="The total number of data to keep.")
190 |
191 | args = parser.parse_args()
192 |
193 | sft = args.sft
194 | model_path = args.model_path
195 | max_len = args.max_len # context length
196 | out_file = f"{args.out_file_path}/{'not-' if args.no_grouping else ''}grouped_MaxContext{max_len}.parquet"
197 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
198 | group_text(tokenizer, out_file, max_len, sft, no_grouping=args.no_grouping, no_padding=args.no_padding, args=args)
--------------------------------------------------------------------------------
/train/scripts/merge_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import logging
5 | from transformers import AutoModelForCausalLM
6 | from peft import PeftModel
7 | import argparse
8 | from transformers import (
9 | AutoTokenizer,
10 | AutoModelForCausalLM,
11 | )
12 |
13 | def main(args):
14 | model_path = args.model_path
15 | output_dir = args.output_dir
16 | adapter_path = args.adapter_path
17 |
18 | if args.gpus is not None:
19 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
20 |
21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22 | print(f"Using device: {device}")
23 |
24 | print(f'Loading base model: {model_path}')
25 | tokenizer = AutoTokenizer.from_pretrained(model_path)
26 | base_model = AutoModelForCausalLM.from_pretrained(
27 | model_path,
28 | use_safetensors=True,
29 | trust_remote_code=True,
30 | device_map="auto",)
31 |
32 | if hasattr(base_model.config, 'torch_dtype'):
33 | dtype = base_model.config.torch_dtype
34 | print("Will save merged model following the base model dtype: ", dtype)
35 | else:
36 | dtype = torch.float32
37 | print("Couldn't find the base model dtype. Will save merged model in:", dtype)
38 |
39 | print(f"loading adapter model: {adapter_path}")
40 | model = PeftModel.from_pretrained(base_model, adapter_path)
41 | model.to(device)
42 | model = model.to(dtype)
43 | print(model)
44 | model.eval()
45 |
46 | print('Saving merged model')
47 | merged_model = model.merge_and_unload()
48 | merged_model_dir = os.path.join(output_dir, "merged_model")
49 | merged_model.save_pretrained(merged_model_dir, safe_serialization=True)
50 |
51 | print('Saving tokenizer')
52 | tokenizer.save_pretrained(merged_model_dir)
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser(description='Merge Model')
57 | parser.add_argument('--model_path', type=str, help='Path to the model')
58 | parser.add_argument('--adapter_path', type=str, help='Path to the adapter')
59 | parser.add_argument('--output_dir', type=str, help='Output directory')
60 | parser.add_argument("--gpus", type=str, default=None, help="Use to set the CUDA_VISIBLE_DEVICES environment variable.")
61 | args = parser.parse_args()
62 | try:
63 | main(args)
64 | except Exception as e:
65 | logging.exception(e)
66 | exit(-1)
67 |
--------------------------------------------------------------------------------
/train/scripts/tokenize_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import sys
5 | from transformers import AutoTokenizer
6 | from datasets import load_dataset
7 | from tqdm import tqdm
8 | from utils.loader import Processor
9 | import random
10 |
11 |
12 |
13 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
14 |
15 |
16 | def create_file_names(file_paths, out_file):
17 | with open(out_file, "w") as f:
18 | json.dump(file_paths, f)
19 |
20 |
21 | def load_jsonl(in_file):
22 | with open(in_file, "r", encoding="utf-8") as f:
23 | datas = [json.loads(line) for line in f]
24 | return datas
25 |
26 |
27 | def save_jsonl(datas, out_file):
28 | with open(out_file, "w", encoding="utf-8") as f:
29 | for data in datas:
30 | f.write(json.dumps(data, ensure_ascii=False) + "\n")
31 |
32 | def get_dataset_class(keywords):
33 | """Get the dataset class from the keywords in the provided dataset name. One of "heat", "wave"."""
34 | if "heat" in keywords:
35 | return "heat"
36 | elif "wave" in keywords:
37 | return "wave"
38 | else:
39 | raise ValueError(f"Dataset {keywords} does not contain one of 'heat' or 'wave'.")
40 |
41 |
42 | def name_output_file(in_file, out_dir):
43 | """Create a name for the output file using keywords in the base-file name and base-directory name.
44 | Does not include the .parquet file extension."""
45 | parent_folder = os.path.basename(os.path.dirname(in_file))
46 | basename_without_extension = os.path.splitext(os.path.basename(in_file))[0]
47 | # split into words
48 | parent_words = parent_folder.split('_')
49 | basename_words = basename_without_extension.split('_')
50 | # combine and remove duplicates
51 | unique_keywords = '_'.join(list(dict.fromkeys(parent_words + basename_words)))
52 | dataset_class = get_dataset_class(unique_keywords)
53 | out_file = os.path.join(out_dir, "tokenized_" + "".join(unique_keywords) + ".parquet")
54 | return str(out_file), dataset_class
55 |
56 |
57 |
58 | def prosess_tokenize(in_file, out_dir, tokenizer, prompt_format, sft=False, truncate=False, padding=False):
59 | """Tokenize the data in the input file, and save the tokenized data in the output directory as a parquet file.
60 | Will returned the name of the saved file."""
61 | out_file, dataset_class = name_output_file(in_file, out_dir)
62 | if os.path.isfile(out_file):
63 | print(f"Warning: {out_file} already exists. Will replace")
64 | os.remove(out_file)
65 |
66 | processor = Processor()
67 | _dataset = load_dataset(in_file.split(".")[-1] if in_file.split(".")[-1] != "jsonl" else "json", data_files=in_file, split='train')
68 | process_batch_size = min(2, len(_dataset))
69 |
70 | print(_dataset)
71 | _dataset = _dataset.map(
72 | processor.create_prompt,
73 | fn_kwargs={
74 | "prompt_format": prompt_format,
75 | "dataset_class": dataset_class
76 | },
77 | batched=True,
78 | load_from_cache_file=False,
79 | batch_size=process_batch_size,
80 | num_proc=64,
81 | desc=f"Creating prompt in format: {prompt_format}, in the standard data'{{'text', 'labels'}}'.",
82 | )
83 |
84 | columns_to_remove = [feature for feature in _dataset.features if feature not in ["text", "labels"]]
85 | _dataset = _dataset.remove_columns(columns_to_remove)
86 |
87 | print("Example from the processed dataset:")
88 | random_indices = random.sample(range(len(_dataset)), 2)
89 | for i in random_indices:
90 | print("\nexample", i, ":")
91 | print(_dataset[i])
92 |
93 |
94 | column_names = list(_dataset.features)
95 | if sft:
96 | _dataset = _dataset.map(
97 | processor.process_tokenize_sft,
98 | fn_kwargs={
99 | "tokenizer": tokenizer,
100 | "truncate": truncate,
101 | "padding": padding
102 | },
103 | batched=True,
104 | load_from_cache_file=False,
105 | remove_columns=column_names,
106 | batch_size=process_batch_size,
107 | num_proc=64,
108 | desc=f"Running tokenizer on SFT dataset. padding: {padding}, truncation: {truncate}.",
109 | )
110 | else:
111 | _dataset = _dataset.map(
112 | processor.process_tokenize,
113 | fn_kwargs={
114 | "tokenizer": tokenizer,
115 | "truncate": truncate,
116 | "padding": padding
117 | },
118 | batched=True,
119 | load_from_cache_file=False,
120 | remove_columns=column_names,
121 | batch_size=process_batch_size,
122 | num_proc=64,
123 | desc="Running tokenizer on pretraining dataset",
124 | )
125 |
126 | print(_dataset)
127 | _dataset.to_parquet(out_file)
128 | return str(out_file)
129 |
130 | def get_all_files_in_directories(directories):
131 | file_paths = []
132 | for in_dir in directories:
133 | dir_list = os.listdir(in_dir)
134 | for file_name in dir_list:
135 | full_path = os.path.join(in_dir, file_name)
136 | # ignore markdown/jsonl description files of datasets
137 | if os.path.isfile(full_path) and full_path.split(".")[-1] != "md" and "description" not in file_name.split(".")[0]:
138 | file_paths.append(full_path)
139 |
140 | # send file_paths out for debugging
141 | current_dir_path = os.path.dirname(os.path.realpath(__file__))
142 | out_file = os.path.join(current_dir_path, f"file-names_to-be-tokenized.json")
143 | with open(out_file, "w") as f:
144 | json.dump(file_paths, f)
145 |
146 | return file_paths
147 |
148 | def main():
149 | parser = argparse.ArgumentParser()
150 | parser.add_argument("--paths", nargs="+", help="List of paths directories containing json files of training data that will be preprocessed into individual parquet files for training.")
151 | parser.add_argument("--sft", action="store_true", help="""Tokenize as supervised fine-tuning for text-label data. For SFT, output data will contain:
152 | data["input_ids"], data["labels"], where data["labels"] = len(data["input_ids"]) * [-100] + `true labels`. [-100] is the default ignore token for cross-entropy loss.
153 | If not set, then the output data will use the pretraining objective: data["input_ids"], data["labels"], where data["labels"] = data["input_ids"].""")
154 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to directory where to store the output parquet file")
155 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Path to tokenizer.")
156 | parser.add_argument("--prompt_format", choices=["to_python_no_STL", "to_STL", "to_python_GT_STL", "to_python_given_STL", "to_python_misaligned"], default="to_python_no_STL", help="""Choose the prompt format, the Dataset requires the corresponding keys.
157 | `to_python_no_STL`: No STL is given in the prompt. The model is given natural language and should directly convert to code. Required keys: 'nl', 'python'.
158 | `to_STL`: No STL is given in the prompt. The model is given natural language and should directly convert to STL. Required keys: 'nl', 'sstl'.
159 | `to_python_GT_STL`: Ground truth STL is given in the prompt. The model is given natural language and STL and should directly convert to code. Required keys: 'nl', 'sstl', 'python'.
160 | `to_python_given_STL`: Must provide STL to be given in the prompt. The model is given natural language and STL and should directly convert to code. Required keys: 'nl', 'predicted_sstl', 'python'.
161 | `to_python_misaligned`: Must provide STL to be given in the prompt. The model is given natural language and STL which do not necessarily describe the same problem and should directly convert to code. Required keys: 'nl', 'predicted_sstl', 'python'.
162 | """)
163 |
164 | args = parser.parse_args()
165 |
166 | model_path = args.model_path
167 | tokenizer = AutoTokenizer.from_pretrained(model_path)
168 |
169 | file_paths = get_all_files_in_directories(args.paths)
170 |
171 | out_dir = args.out_file_path
172 | if not os.path.exists(out_dir):
173 | os.makedirs(out_dir)
174 |
175 | paths_to_be_trained = []
176 | for in_file in file_paths:
177 | path = prosess_tokenize(in_file, out_dir, tokenizer, args.prompt_format, args.sft)
178 | if path in paths_to_be_trained:
179 | raise ValueError(f"{path} was made twice in this script. The fist has already been overwritten in the function `process_tokenize`. Check the dataset names and try again.")
180 | else:
181 | paths_to_be_trained.append(path)
182 |
183 |
184 | print(os.getcwd())
185 | dir_path = os.path.dirname(os.path.realpath(__file__))
186 | out_file = os.path.join(dir_path, "file_names_to-be-trained.json")
187 | create_file_names(paths_to_be_trained, out_file)
188 |
189 |
190 | if __name__ == "__main__":
191 | main()
--------------------------------------------------------------------------------
/train/scripts/tokenize_data_dpo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import sys
5 | from transformers import AutoTokenizer
6 | from datasets import load_dataset
7 | from tqdm import tqdm
8 | from utils.loader import Processor
9 | import random
10 |
11 |
12 |
13 | os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
14 |
15 |
16 | def create_file_names(file_paths, out_file):
17 | # Put into file_paths
18 | with open(out_file, "w") as f:
19 | json.dump(file_paths, f)
20 |
21 |
22 | def load_jsonl(in_file):
23 | with open(in_file, "r", encoding="utf-8") as f:
24 | datas = [json.loads(line) for line in f]
25 | return datas
26 |
27 |
28 | def save_jsonl(datas, out_file):
29 | with open(out_file, "w", encoding="utf-8") as f:
30 | for data in datas:
31 | f.write(json.dumps(data, ensure_ascii=False) + "\n")
32 |
33 | def get_dataset_class(keywords):
34 | """Get the dataset class from the keywords in the provided dataset name. One of "heat", "wave"."""
35 | if "heat" in keywords:
36 | return "heat"
37 | elif "wave" in keywords:
38 | return "wave"
39 | else:
40 | return None
41 | # raise ValueError(f"Dataset {keywords} does not contain one of 'heat' or 'wave'.")
42 |
43 |
44 | def name_output_file(in_file, out_dir):
45 | """Create a name for the output file using keywords in the base-file name and base-directory name.
46 | Does not include the .parquet file extension."""
47 | parent_folder = os.path.basename(os.path.dirname(in_file))
48 | basename_without_extension = os.path.splitext(os.path.basename(in_file))[0]
49 | # Split into words
50 | parent_words = parent_folder.split('_')
51 | basename_words = basename_without_extension.split('_')
52 | # Combine and remove duplicates
53 | unique_keywords = '_'.join(list(dict.fromkeys(parent_words + basename_words)))
54 | dataset_class = get_dataset_class(unique_keywords)
55 | out_file = os.path.join(out_dir, "tokenized_" + "".join(unique_keywords) + ".parquet")
56 | return str(out_file), dataset_class
57 |
58 |
59 |
60 | def prosess_tokenize(in_file, out_dir, tokenizer, prompt_format, truncate=False, padding=False):
61 | """Tokenize the data in the input file, and save the tokenized data in the output directory as a parquet file.
62 | Will returned the name of the saved file."""
63 | out_file, dataset_class = name_output_file(in_file, out_dir)
64 | if os.path.isfile(out_file):
65 | print(f"Warning: {out_file} already exists. Will replace")
66 | os.remove(out_file)
67 |
68 | processor = Processor()
69 | _dataset = load_dataset(in_file.split(".")[-1] if in_file.split(".")[-1] != "jsonl" else "json", data_files=in_file, split='train')
70 | process_batch_size = min(2, len(_dataset))
71 |
72 | print(_dataset)
73 | _dataset = _dataset.map(
74 | processor.create_prompt_dpo,
75 | fn_kwargs={
76 | "prompt_format": prompt_format
77 | },
78 | batched=True,
79 | load_from_cache_file=False,
80 | batch_size=process_batch_size,
81 | num_proc=64,
82 | desc=f"Creating prompt in format: {prompt_format}, in the standard data'{{'prompt', 'chosen', 'rejected'}}'.",
83 | )
84 |
85 | columns_to_remove = [feature for feature in _dataset.features if feature not in ["prompt", "chosen", "rejected"]]
86 | _dataset = _dataset.remove_columns(columns_to_remove)
87 |
88 | print("Example from the processed dataset:")
89 | random_indices = random.sample(range(len(_dataset)), 2)
90 | for i in random_indices:
91 | print("\nexample", i, ":")
92 | print(_dataset[i])
93 |
94 |
95 | column_names = list(_dataset.features)
96 | ## dpo
97 | # DPOTrainer will do the tokenization of the prompt, chosen, and rejected in the training loop.
98 | _dataset = _dataset.map(
99 | processor.process_tokenize_dpo,
100 | fn_kwargs={
101 | "tokenizer": tokenizer,
102 | },
103 | batched=True,
104 | load_from_cache_file=False,
105 | remove_columns=column_names,
106 | batch_size=process_batch_size,
107 | num_proc=64,
108 | desc="Running tokenizer on DPO dataset",
109 | )
110 |
111 | print(_dataset)
112 | _dataset.to_parquet(out_file)
113 | return str(out_file)
114 |
115 | def get_all_files_in_directories(directories):
116 | file_paths = []
117 | for in_dir in directories:
118 | dir_list = os.listdir(in_dir)
119 | for file_name in dir_list:
120 | full_path = os.path.join(in_dir, file_name)
121 | # ignore markdown/jsonl description files of datasets
122 | if os.path.isfile(full_path) and full_path.split(".")[-1] != "md" and "description" not in file_name.split(".")[0]:
123 | file_paths.append(full_path)
124 |
125 | # Put file_paths out for debugging
126 | current_dir_path = os.path.dirname(os.path.realpath(__file__))
127 | out_file = os.path.join(current_dir_path, f"file-names_to-be-tokenized.json")
128 | with open(out_file, "w") as f:
129 | json.dump(file_paths, f)
130 |
131 | return file_paths
132 |
133 | def main():
134 | parser = argparse.ArgumentParser()
135 | parser.add_argument("--paths", nargs="+", help="List of paths directories containing json files of training data that will be preprocessed into individual parquet files for training.")
136 | parser.add_argument("--out_file_path", type=str, default=None, help="Path to directory where to store the output parquet file")
137 | parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="Path to the model that will be trained and whose tokenizer will be use here to preprocess the data.")
138 | parser.add_argument("--prompt_format", choices=["DPO"], default="DPO", help="""Choose the prompt format, the Dataset requires corresponding keys.
139 | `DPO`: No STL is given in the prompt. The model is given natural language and should directly convert to STL. Required keys: 'nl', 'sstl'.
140 | """)
141 | ## Padding and truncating is removed from this code because it should be done in the grouping stage (which can also combine other datasets so they all have the same context length).
142 | # parser.add_argument("--truncate", action="store_true", help="Whether to truncate the input and labels in tokenizer if either exceeds the context length of the model. If you plan to group the training data together then do not truncate here. Truncation will happen during grouping.")
143 | # parser.add_argument("--padding", action="store_true", help="Whether to pad the input and labels in tokenizer if either is shorter than the context length of the model. If you plan to group the training data together then do not pad here. Padding will happen during grouping. If you don't plan to group the data, padding here is equivalent to padding during grouping stage (just skipping the grouping part).")
144 |
145 | args = parser.parse_args()
146 |
147 | model_path = args.model_path
148 | tokenizer = AutoTokenizer.from_pretrained(model_path)
149 |
150 | file_paths = get_all_files_in_directories(args.paths)
151 |
152 | out_dir = args.out_file_path
153 | if not os.path.exists(out_dir):
154 | os.makedirs(out_dir)
155 |
156 | paths_to_be_trained = []
157 | for in_file in file_paths:
158 | path = prosess_tokenize(in_file, out_dir, tokenizer, args.prompt_format)
159 | if path in paths_to_be_trained:
160 | raise ValueError(f"{path} was made twice in this script. The fist has already been overwritten in the function `process_tokenize`. Check the dataset names and try again.")
161 | else:
162 | paths_to_be_trained.append(path)
163 |
164 |
165 | print(os.getcwd())
166 | dir_path = os.path.dirname(os.path.realpath(__file__))
167 | out_file = os.path.join(dir_path, "file_names_to-be-trained.json")
168 | create_file_names(paths_to_be_trained, out_file)
169 |
170 |
171 | if __name__ == "__main__":
172 | main()
--------------------------------------------------------------------------------
/train/scripts/train.sh:
--------------------------------------------------------------------------------
1 | python --version
2 |
3 | export NCCL_DEBUG=WARN
4 |
5 | export NCCL_IB_TIMEOUT=22
6 | export NCCL_IB_RETRY_CNT=13
7 | export NCCL_IB_AR_THRESHOLD=0
8 |
9 | wandb login $WANDB_TOKEN
10 |
11 | # Constants for paths
12 | OUTS="outputs"
13 | MODELS="models"
14 | DATA="datasets/sft"
15 |
16 | # variables
17 | base_run_name=test
18 | max_steps=4
19 | per_gpu_batch_size=8
20 | accum_grad=8
21 | epoch_size=2
22 |
23 | save_meged_model=True
24 |
25 | # ["to_python_no_STL", "to_STL", "to_python_GT_STL", "to_python_given_STL", "to_python_misaligned""]
26 | prompt_format=to_STL
27 |
28 |
29 | ### specify path to base models for which an adapter will be attached.
30 | base_model="$MODELS/MathCoder2-DeepSeekMath-7B"
31 |
32 |
33 | ### Distributed settings:
34 | # export MASTER_ADDR=$(hostname -i)
35 | # export MASTER_PORT=29500
36 | export WORLD_SIZE=1
37 | export RANK=0
38 | export GPUPerNode=2
39 | export CUDA_VISIBLE_DEVICES=2,3
40 |
41 |
42 |
43 | run_name=${base_run_name}_${prompt_format}_s${max_steps}
44 |
45 | datafile_name=${prompt_format}/not-grouped_MaxContext4096_tokenized_train.parquet
46 |
47 | if [ "$prompt_format" = "to_python_given_STL" ]; then
48 | datafile_name=${prompt_format}/not-grouped_MaxContext4096_STL_predictions_to_tokenized_train.parquet
49 | elif [ "$prompt_format" = "to_python_misaligned" ]; then
50 | DATA="/localhome/mms43/scratch/mathcoder2/datasets/pdecontrol/dpo"
51 | datafile_name=${prompt_format}/not-grouped_MaxContext4096_misaligned_aligned_tokenized_train.parquet
52 | fi
53 |
54 |
55 |
56 | find_latest_checkpoint() {
57 | local checkpoint_dir=$1
58 | latest_checkpoint=$(ls -d ${checkpoint_dir}/checkpoint-* 2>/dev/null | sort -V | tail -n 1)
59 | echo $latest_checkpoint
60 | }
61 |
62 |
63 | train() {
64 | # directory where checkpoints are stored
65 | CHECKPOINT_DIR=$OUTS/$run_name
66 |
67 | export WANDB_DIR=$OUTS/$run_name
68 | export WANDB_RUN_ID=$run_name
69 |
70 | cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES OMP_NUM_THREADS=1 torchrun --nnodes $WORLD_SIZE --node_rank $RANK --nproc_per_node $GPUPerNode --rdzv-backend=c10d --rdzv-endpoint=localhost:0 train/train_finetune.py \
71 | --ddp_timeout 360000 \
72 | --train_parquet_file ${DATA}/${datafile_name} \
73 | --run_name $run_name \
74 | --output_dir $OUTS/${run_name} \
75 | --no_timestamps \
76 | --dataloader_num_workers 2 \
77 | --max_len 4096 \
78 | --max_steps $max_steps \
79 | --num_train_epochs -1 \
80 | --save_steps 5 \
81 | --save_total_limit 2 \
82 | --step_save_interval $epoch_size \
83 | --warmup_steps 50 \
84 | --logging_steps 10 \
85 | --learning_rate 4e-5 \
86 | --weight_decay 0.1 \
87 | --lr_scheduler_type cosine \
88 | --per_device_train_batch_size $per_gpu_batch_size \
89 | --gradient_accumulation_steps $accum_grad \
90 | --seed 3407 \
91 | --deepspeed train/config/deepspeed.json \
92 | --bf16 \
93 | --stream \
94 | --do_train \
95 | --gradient_checkpointing \
96 | --report_to wandb \
97 | --lora_r 64 \
98 | --lora_alpha 256 \
99 | --lora_dropout 0.1 \
100 | --model_cfg $base_model"
101 |
102 | if [ -n "$LATEST_CHECKPOINT" ]; then
103 | echo "Will try to load checkpoint from $LATEST_CHECKPOINT"
104 | export WANDB_RESUME="auto"
105 | cmd="$cmd --resume_from $LATEST_CHECKPOINT"
106 | fi
107 |
108 | if [ "$external_validation" == "True" ]; then
109 | cmd="$cmd --external_validation"
110 | fi
111 |
112 | eval $cmd
113 | }
114 |
115 |
116 |
117 | ############## Training loop with validation ###################
118 | ################################################################
119 |
120 | validate() {
121 | local checkpoint_dir=$1
122 | python train/scripts/merge_model.py --adapter_path $checkpoint_dir --output_dir $checkpoint_dir --model_path $base_model --gpus $CUDA_VISIBLE_DEVICES
123 | echo "validating"
124 | python train/validate.py --checkpoint_dir $checkpoint_dir --base_model $base_model --cuda_visible_devices $CUDA_VISIBLE_DEVICES --validation_data_dir /localhome/mms43/scratch/mathcoder2/MathCoder2/test/PDEcontrol/validation_data --wandb_dir $OUTS/${run_name}
125 | }
126 |
127 |
128 |
129 | validate_interval=$epoch_size
130 | iterations=$((max_steps / validate_interval))
131 | epoch_size=0
132 | external_validation=True
133 | for ((r=0; r/dev/null | sort -V | tail -n 1)
49 | echo $latest_checkpoint
50 | }
51 |
52 |
53 | train() {
54 | # directory where checkpoints are stored
55 | CHECKPOINT_DIR=$OUTS/$run_name
56 |
57 | export WANDB_DIR=$OUTS/$run_name
58 | export WANDB_RUN_ID=$run_name
59 |
60 | cmd="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES OMP_NUM_THREADS=1 torchrun --nnodes $WORLD_SIZE --node_rank $RANK --nproc_per_node $GPUPerNode --rdzv-backend=c10d --rdzv-endpoint=localhost:0 train/train_dpo.py \
61 | --ddp_timeout 360000 \
62 | --train_parquet_file ${DATA}/${datafile_name} \
63 | --run_name $run_name \
64 | --output_dir $OUTS/${run_name} \
65 | --no_timestamps \
66 | --dataloader_num_workers 2 \
67 | --max_len 4096 \
68 | --max_steps $max_steps \
69 | --num_train_epochs -1 \
70 | --save_steps 20 \
71 | --save_total_limit 2 \
72 | --step_save_interval $epoch_size \
73 | --warmup_steps 50 \
74 | --logging_steps 10 \
75 | --learning_rate 4e-5 \
76 | --weight_decay 0.1 \
77 | --lr_scheduler_type cosine \
78 | --per_device_train_batch_size $per_gpu_batch_size \
79 | --gradient_accumulation_steps $accum_grad \
80 | --seed 3407 \
81 | --bf16 \
82 | --stream \
83 | --do_train \
84 | --report_to wandb \
85 | --lora_r 64 \
86 | --lora_alpha 256 \
87 | --lora_dropout 0.1 \
88 | --adapter_cfg $translator_adapter \
89 | --model_cfg $base_model"
90 |
91 | if [ -n "$LATEST_CHECKPOINT" ]; then
92 | echo "Will try to load checkpoint from $LATEST_CHECKPOINT"
93 | export WANDB_RESUME="auto"
94 | cmd="$cmd --resume_from $LATEST_CHECKPOINT"
95 | fi
96 |
97 | if [ "$external_validation" == "True" ]; then
98 | cmd="$cmd --external_validation"
99 | fi
100 |
101 | eval $cmd
102 | }
103 |
104 |
105 | ############## Training loop with validation ###################
106 | ################################################################
107 | validate() {
108 | local checkpoint_dir=$1
109 | python train/scripts/merge_model.py --adapter_path $checkpoint_dir --output_dir $checkpoint_dir --model_path $base_model --gpus $CUDA_VISIBLE_DEVICES
110 | echo "validating"
111 | python train/validate.py --checkpoint_dir $checkpoint_dir --base_model $base_model --cuda_visible_devices $CUDA_VISIBLE_DEVICES --validation_data_dir /localhome/mms43/scratch/mathcoder2/MathCoder2/test/PDEcontrol/validation_data --wandb_dir $OUTS/${run_name}
112 | }
113 |
114 | validate_interval=$epoch_size
115 | iterations=$((max_steps / validate_interval))
116 | epoch_size=0
117 | external_validation=True
118 | for ((r=0; r %s", keystr, val)
33 | logger.info(f"******************* {name} *******************")
34 |
--------------------------------------------------------------------------------
/train/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import random
5 | import shutil
6 | import logging
7 | import transformers
8 | import torch
9 | import json
10 |
11 | import numpy as np
12 | import torch.distributed as dist
13 |
14 | from datetime import datetime
15 | from dataclasses import field, dataclass
16 | from utils.util import set_logger, print_args
17 |
18 | from utils.loader import Processor
19 | from utils.trainer import LoggerCallback, RemoveStateCallback
20 |
21 | from transformers.tokenization_utils import AddedToken
22 | from datasets import load_dataset, concatenate_datasets
23 | from transformers import (
24 | Trainer,
25 | set_seed,
26 | AutoConfig,
27 | AutoTokenizer,
28 | HfArgumentParser,
29 | TrainingArguments,
30 | LlamaForCausalLM,
31 | AutoModelForCausalLM,
32 | default_data_collator
33 | )
34 |
35 | logger = logging.getLogger()
36 |
37 | @dataclass
38 | class DataArguments:
39 |
40 | no_timestamps: bool = field(default=False)
41 | no_load_model_pararmeters: bool = field(default=False)
42 |
43 | resume_step: int = field(default=None)
44 | resume_batch_size: int = field(default=None)
45 |
46 | # data
47 | train_parquet_file: str = field(default=None)
48 | train_file_config: str = field(default=None)
49 | train_dataset: str = field(default=None)
50 | train_coef: str = field(default=None)
51 | delete_long_sample: bool = field(default=False)
52 |
53 | # process
54 | max_len: int = field(default=2048)
55 | preprocessing_num_workers: int = field(default=64)
56 |
57 | # model
58 | model_cfg: str = field(default="data/models/starcoder")
59 | flash_attention: bool = field(default=False)
60 |
61 | resume_from: str = field(default=None)
62 |
63 | # output
64 | stream: bool = field(default=False)
65 |
66 |
67 | def train():
68 | parser = HfArgumentParser((DataArguments, TrainingArguments))
69 |
70 | data_args, training_args = parser.parse_args_into_dataclasses()
71 |
72 | training_args._frozen = False
73 |
74 | if not data_args.no_timestamps:
75 | timestr = datetime.now().strftime("-%m%d%H%M")
76 | training_args.output_dir = training_args.output_dir + timestr
77 |
78 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logging')
79 |
80 | if os.path.exists(training_args.output_dir):
81 | if training_args.overwrite_output_dir:
82 | if training_args.process_index == 0:
83 | shutil.rmtree(training_args.output_dir)
84 | else:
85 | raise ValueError(f"Output directory ({training_args.output_dir}) already exists. Use --overwrite_output_dir to overcome.")
86 |
87 | if training_args.world_size > 1:
88 | dist.barrier()
89 |
90 | if training_args.process_index == 0:
91 | os.makedirs(training_args.output_dir)
92 |
93 | if training_args.world_size > 1:
94 | dist.barrier()
95 |
96 | set_seed(training_args.seed)
97 |
98 | node_rank = int(os.getenv('GROUP_RANK', '0'))
99 |
100 | for _logger in [logger, transformers.utils.logging.get_logger(), logging.getLogger('DeepSpeed')]:
101 | set_logger(_logger, training_args.local_rank, data_args.stream, os.path.join(training_args.output_dir, f'log-node-{node_rank}.log'))
102 |
103 | logger.warning("Device: %s, rank: %s, world size: %s", training_args.device, training_args.process_index, training_args.world_size)
104 |
105 | if training_args.world_size > 1:
106 | dist.barrier()
107 |
108 | print_args(data_args, 'Data Arguments')
109 | print_args(training_args, 'Training Arguments')
110 |
111 | processor = Processor()
112 |
113 | config = AutoConfig.from_pretrained(data_args.model_cfg, trust_remote_code=True)
114 | config._attn_implementation = "flash_attention_2"
115 | config.use_cache = False
116 |
117 | if data_args.no_load_model_pararmeters:
118 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
119 | else:
120 | model = AutoModelForCausalLM.from_pretrained(data_args.model_cfg, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True)
121 | # tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, legacy=False, use_fast=True)
122 | tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, trust_remote_code=True)
123 |
124 | if tokenizer.pad_token is None:
125 | tokenizer.pad_token = tokenizer.eos_token
126 | tokenizer.pad_token_id = tokenizer.eos_token_id
127 |
128 | if data_args.train_parquet_file is not None:
129 | train_sets = load_dataset("parquet", data_files=data_args.train_parquet_file, split='train')
130 | elif data_args.train_file_config is not None:
131 | with open(data_args.train_file_config, "r") as f:
132 | train_files = json.load(f)
133 |
134 | train_sets = []
135 | for file in train_files:
136 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train')
137 | train_sets.append(_dataset)
138 |
139 | lengths = np.array([_set.shape[0] for _set in train_sets])
140 | logger.info("Data Lengths: %s", lengths)
141 |
142 | for i in range(1, len(train_sets)):
143 | train_sets[i] = train_sets[i].cast(train_sets[0].features)
144 |
145 | train_sets = concatenate_datasets(train_sets)
146 | else:
147 | raise ValueError("Should provide either 'train_dataset' or 'train_file_config'")
148 |
149 | logger.info('Total %d case', len(train_sets))
150 |
151 | process_batch_size = min(1000, len(train_sets))
152 |
153 | with training_args.main_process_first(desc="Log a few random samples from the training set"):
154 | for index in random.sample(range(len(train_sets)), 3):
155 | logger.info(
156 | "Sample %d of the raw training set:\n\ninput_tokens: %s\n\n%s",
157 | index,
158 | train_sets[index]['input_ids'],
159 | tokenizer.convert_ids_to_tokens(train_sets[index]['input_ids']),
160 | )
161 |
162 | train_sets = train_sets.shuffle(seed=training_args.seed)
163 | column_names = list(train_sets.features)
164 | if data_args.train_parquet_file is None:
165 | with training_args.main_process_first(desc="dataset map grouping"):
166 | train_sets = train_sets.map(
167 | processor.group_texts,
168 | fn_kwargs={
169 | "tokenizer": tokenizer,
170 | "max_len": data_args.max_len
171 | },
172 | batched=True,
173 | load_from_cache_file=False,
174 | remove_columns=column_names,
175 | batch_size=process_batch_size,
176 | num_proc=data_args.preprocessing_num_workers,
177 | desc=f"Grouping texts in chunks of {data_args.max_len}",
178 | )
179 |
180 | with training_args.main_process_first(desc="Log a few random samples from the grouped training set"):
181 | for index in random.sample(range(len(train_sets)), 3):
182 | logger.info(
183 | "Sample %d of the merged training set:\n\n%s",
184 | index, tokenizer.decode(train_sets[index]['input_ids'])
185 | )
186 |
187 | if data_args.resume_step is not None and data_args.resume_batch_size is not None:
188 | train_sets = train_sets[data_args.resume_step * data_args.resume_batch_size:]
189 | training_args.max_steps -= data_args.resume_step
190 | new_warmup_steps = max(0, training_args.warmup_steps - data_args.resume_step)
191 | new_learning_rate -= max(0, data_args.resume_step - training_args.warmup_steps) * (training_args.learning_rate / training_args.max_steps - training_args.warmup_steps)
192 | training_args.warmup_steps = new_warmup_steps
193 | training_args.learning_rate = new_learning_rate
194 |
195 | trainer = Trainer(
196 | args=training_args,
197 | model=model,
198 | tokenizer=tokenizer,
199 | train_dataset=train_sets,
200 | callbacks=[LoggerCallback, RemoveStateCallback],
201 | data_collator=default_data_collator,
202 | )
203 |
204 | trainer.train(resume_from_checkpoint=data_args.resume_from)
205 |
206 | trainer.save_model(training_args.output_dir)
207 |
208 | if __name__ == "__main__":
209 |
210 | try:
211 | train()
212 | except Exception as e:
213 | logging.exception(e)
214 | exit(-1)
215 |
--------------------------------------------------------------------------------
/train/train_dpo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import inspect
4 | import os
5 | import random
6 | import shutil
7 | import logging
8 | import transformers
9 | import torch
10 |
11 | import torch.distributed as dist
12 |
13 | from datetime import datetime
14 | from dataclasses import field, dataclass
15 | from utils.util import set_logger, print_args
16 |
17 | from utils.loader import Processor
18 | from utils.trainer import LoggerCallback, RemoveStateCallback, StepCheckpointCallback, DataCollatorForDPODataset
19 |
20 | from datasets import load_dataset
21 | from transformers import (
22 | set_seed,
23 | AutoConfig,
24 | AutoTokenizer,
25 | HfArgumentParser,
26 | TrainingArguments,
27 | AutoModelForCausalLM,
28 | )
29 |
30 | from peft import LoraConfig, PeftModel
31 | from trl import DPOTrainer, DPOConfig
32 |
33 | logger = logging.getLogger()
34 |
35 | @dataclass
36 | class DataArguments:
37 |
38 | no_timestamps: bool = field(default=False)
39 |
40 |
41 | # data
42 | train_parquet_file: str = field(default=None)
43 | train_file_config: str = field(default=None)
44 | train_dataset: str = field(default=None)
45 | train_coef: str = field(default=None)
46 | delete_long_sample: bool = field(default=False)
47 |
48 | # process
49 | max_len: int = field(default=4096)
50 | preprocessing_num_workers: int = field(default=64)
51 |
52 | # model
53 | model_cfg: str = field(default="data/models/starcoder")
54 | adapter_cfg: str = field(default="data/models/starcoder")
55 | flash_attention: bool = field(default=False)
56 |
57 | # LoRA model
58 | save_merged_model: bool = field(default=False)
59 | no_load_model_pararmeters: bool = field(default=False)
60 | resume_from: str = field(default=None)
61 |
62 | resume_step: int = field(default=None)
63 | resume_batch_size: int = field(default=None)
64 |
65 | # output
66 | stream: bool = field(default=False)
67 |
68 | step_save_interval: int = field(default=1000) # save model every n steps. These will persist.
69 |
70 | external_validation: bool = field(default=False)
71 |
72 | @dataclass
73 | class PeftArguments:
74 | ## See https://github.com/huggingface/peft/blob/f0b066eae888d5dea598e756b7e6d3401d0708e7/src/peft/tuners/lora/config.py#L72
75 | ## for the default values of the fields (or define more or new defaults here).
76 | # some_field: str = field(default="default_value")
77 | target_modules: str = field(default="k_proj,down_proj,q_proj,v_proj,gate_proj,o_proj,up_proj")
78 | task_type: str = field(default="CAUSAL_LM")
79 |
80 | lora_r: int = field(default=16)
81 | lora_alpha: int = field(default=64)
82 | lora_dropout: float = field(default=0.1)
83 | bias: str = field(default="none")
84 |
85 | def train():
86 | dist.init_process_group(backend='nccl', init_method='env://')
87 | rank = dist.get_rank()
88 | torch.cuda.set_device(rank)
89 | print("Rank", rank, "Current device", torch.cuda.current_device())
90 |
91 | parser = HfArgumentParser((DataArguments, TrainingArguments, PeftArguments))
92 |
93 | data_args, training_args, peft_args = parser.parse_args_into_dataclasses()
94 |
95 | training_args._frozen = False
96 |
97 | if not data_args.no_timestamps:
98 | timestr = datetime.now().strftime("-%m%d%H%M")
99 | training_args.output_dir = training_args.output_dir + timestr
100 |
101 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logging')
102 |
103 | if os.path.exists(training_args.output_dir):
104 | if training_args.overwrite_output_dir:
105 | print(f"Output directory ({training_args.output_dir}) already exists. Overwriting output dir.")
106 | if training_args.process_index == 0:
107 | shutil.rmtree(training_args.output_dir)
108 | else:
109 | print(f"Output directory ({training_args.output_dir}) already exists. Use --overwrite_output_dir to overcome.")
110 |
111 | if training_args.world_size > 1:
112 | dist.barrier(device_ids=[rank])
113 |
114 | if training_args.process_index == 0:
115 | if not os.path.exists(training_args.output_dir):
116 | os.makedirs(training_args.output_dir)
117 |
118 | if training_args.world_size > 1:
119 | dist.barrier(device_ids=[rank])
120 |
121 | set_seed(training_args.seed)
122 |
123 | node_rank = int(os.getenv('GROUP_RANK', '0'))
124 |
125 | for _logger in [logger, transformers.utils.logging.get_logger(), logging.getLogger('DeepSpeed')]:
126 | set_logger(_logger, training_args.local_rank, data_args.stream, os.path.join(training_args.output_dir, f'log-node-{node_rank}.log'))
127 |
128 | logger.warning("Device: %s, rank: %s, world size: %s", training_args.device, training_args.process_index, training_args.world_size)
129 |
130 | if training_args.world_size > 1:
131 | dist.barrier(device_ids=[rank])
132 |
133 | print_args(data_args, 'Data Arguments')
134 | print_args(training_args, 'Training Arguments')
135 | print_args(peft_args, 'LoRA Arguments')
136 |
137 | config = AutoConfig.from_pretrained(data_args.model_cfg, trust_remote_code=True)
138 | config._attn_implementation = "flash_attention_2"
139 | config.use_cache = False
140 |
141 | # load base model
142 | base_model = AutoModelForCausalLM.from_pretrained(data_args.model_cfg, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True,)
143 | logger.info(base_model)
144 | base_model.config.use_cache = False
145 | ## load the same adapter twice. once in training mode and once in eval mode.
146 | model = PeftModel.from_pretrained(
147 | base_model,
148 | data_args.adapter_cfg,
149 | is_trainable=True,
150 | adapter_name='policy',
151 | )
152 | # load the adapter a second time, with a different name, which will be our reference model.
153 | model.load_adapter(
154 | data_args.adapter_cfg,
155 | adapter_name="reference",
156 | )
157 |
158 | tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, trust_remote_code=True)
159 | tokenizer.padding_side = 'right'
160 | logger.info(f"padding side: {tokenizer.padding_side}")
161 |
162 | if tokenizer.pad_token is None:
163 | tokenizer.pad_token = tokenizer.eos_token
164 | tokenizer.pad_token_id = tokenizer.eos_token_id
165 |
166 |
167 | print("Loading training data", data_args.train_parquet_file)
168 | if data_args.train_parquet_file is not None:
169 | train_sets = load_dataset("parquet", data_files=data_args.train_parquet_file, split='train')
170 | else:
171 | raise ValueError("Should provide either 'train_dataset' or 'train_file_config'")
172 |
173 | logger.info('Total %d case', len(train_sets))
174 |
175 | with training_args.main_process_first(desc="Log a few random samples from the training set"):
176 | for index in random.sample(range(len(train_sets)), 3):
177 | ### DPOTrainer implements the tokenizer itself. no way to turn it off short of editing the source code.
178 | logger.info(
179 | "Sample %d of the raw training set:\n\nprompt_input_ids: %s\n\nprompt_tokens: %s\n\nchosen_input_ids: %s\n\nchosen_tokens: %s\n\nrejected_input_ids: %s\n\nrejected_tokens: %s\n\n",
180 | index,
181 | train_sets[index]['prompt_input_ids'],
182 | tokenizer.convert_ids_to_tokens(train_sets[index]['prompt_input_ids']),
183 | train_sets[index]['chosen_input_ids'],
184 | tokenizer.convert_ids_to_tokens(train_sets[index]['chosen_input_ids']),
185 | train_sets[index]['rejected_input_ids'],
186 | tokenizer.convert_ids_to_tokens(train_sets[index]['rejected_input_ids']),
187 | )
188 |
189 | train_sets = train_sets.shuffle(seed=training_args.seed)
190 |
191 | epoch_checkpoint_callback = StepCheckpointCallback(save_interval=data_args.step_save_interval, output_dir=training_args.output_dir, external_validation=data_args.external_validation)
192 |
193 | data_collator = DataCollatorForDPODataset(tokenizer=tokenizer)
194 |
195 | target_modules_str = peft_args.target_modules
196 | # convert the comma-separated string to a list of strings
197 | if target_modules_str:
198 | target_modules_list = [module.strip() for module in target_modules_str.split(',')]
199 | else:
200 | target_modules_list = None
201 |
202 | lora_config = LoraConfig(
203 | target_modules=target_modules_list,
204 | task_type=peft_args.task_type,
205 | r=peft_args.lora_r,
206 | lora_alpha=peft_args.lora_alpha,
207 | lora_dropout=peft_args.lora_dropout,
208 | bias=peft_args.bias,
209 | )
210 |
211 | print(lora_config)
212 |
213 | training_args_main_output_dir = training_args.output_dir
214 | training_args.output_dir = os.path.join(training_args.output_dir, f"backups")
215 |
216 |
217 | ### DPO Config initialized this way since even the defaults will override the transformer.TrainingArguments
218 | # Get the list of parameters that DPOConfig accepts
219 | dpo_config_params = inspect.signature(DPOConfig).parameters
220 | # Filter training_args to include only those parameters
221 | relevant_args = {k: v for k, v in vars(training_args).items() if k in dpo_config_params}
222 | # Add any additional parameters that are not in training_args
223 | relevant_args.update({
224 | 'ref_adapter_name': "reference",
225 | 'model_adapter_name': "policy",
226 | 'beta': 0.1,
227 | 'loss_type': "sigmoid",
228 | 'is_encoder_decoder': False,
229 | 'rpo_alpha': 1.0,
230 | 'sync_ref_model': False,
231 | 'ref_model_sync_steps': 4, # only used if the sync_ref_module is set to True
232 | 'force_use_ref_model': False,
233 | })
234 |
235 |
236 | dpo_config = DPOConfig(**relevant_args)
237 | logger.info(dpo_config)
238 |
239 | dpo_trainer = DPOTrainer(
240 | model=model,
241 | args=dpo_config,
242 | train_dataset=train_sets,
243 | tokenizer=tokenizer,
244 | data_collator=data_collator,
245 | callbacks=[LoggerCallback, epoch_checkpoint_callback],
246 | )
247 |
248 | epoch_checkpoint_callback.set_trainer(dpo_trainer)
249 |
250 | dpo_trainer.train(resume_from_checkpoint=data_args.resume_from)
251 |
252 | dpo_trainer.save_model(os.path.join(training_args_main_output_dir, "final"))
253 |
254 | if __name__ == "__main__":
255 |
256 | try:
257 | train()
258 | except Exception as e:
259 | logging.exception(e)
260 | exit(-1)
261 |
--------------------------------------------------------------------------------
/train/train_finetune.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import random
5 | import shutil
6 | import logging
7 | import transformers
8 | import torch
9 | import json
10 |
11 | import numpy as np
12 | import torch.distributed as dist
13 |
14 | from datetime import datetime
15 | from dataclasses import field, dataclass
16 | from utils.util import set_logger, print_args
17 |
18 | from utils.loader import Processor
19 | from utils.trainer import LoggerCallback, RemoveStateCallback, StepCheckpointCallback, DataCollatorForSupervisedDataset
20 |
21 | from datasets import load_dataset, concatenate_datasets
22 | from transformers import (
23 | Trainer,
24 | set_seed,
25 | AutoConfig,
26 | AutoTokenizer,
27 | HfArgumentParser,
28 | TrainingArguments,
29 | AutoModelForCausalLM,
30 | )
31 |
32 | from peft import LoraConfig
33 | from trl import SFTTrainer
34 |
35 | logger = logging.getLogger()
36 |
37 | @dataclass
38 | class DataArguments:
39 |
40 | no_timestamps: bool = field(default=False)
41 |
42 |
43 | # data
44 | train_parquet_file: str = field(default=None)
45 | train_file_config: str = field(default=None)
46 | train_dataset: str = field(default=None)
47 | train_coef: str = field(default=None)
48 | delete_long_sample: bool = field(default=False)
49 |
50 | # process
51 | max_len: int = field(default=4096)
52 | preprocessing_num_workers: int = field(default=64)
53 |
54 | # model
55 | model_cfg: str = field(default="data/models/starcoder")
56 | flash_attention: bool = field(default=False)
57 |
58 | # LoRA model
59 | save_merged_model: bool = field(default=False)
60 | no_load_model_pararmeters: bool = field(default=False)
61 | resume_from: str = field(default=None)
62 |
63 | resume_step: int = field(default=None)
64 | resume_batch_size: int = field(default=None)
65 |
66 | # output
67 | stream: bool = field(default=False)
68 |
69 | step_save_interval: int = field(default=1000) # save model every n steps. These will persist.
70 |
71 | external_validation: bool = field(default=False)
72 |
73 | @dataclass
74 | class PeftArguments:
75 | ## See https://github.com/huggingface/peft/blob/f0b066eae888d5dea598e756b7e6d3401d0708e7/src/peft/tuners/lora/config.py#L72
76 | ## for the default values of the fields (or define more or new defaults here).
77 | # some_field: str = field(default="default_value")
78 | target_modules: str = field(default="k_proj,down_proj,q_proj,v_proj,gate_proj,o_proj,up_proj")
79 | task_type: str = field(default="CAUSAL_LM")
80 |
81 | lora_r: int = field(default=16)
82 | lora_alpha: int = field(default=64)
83 | lora_dropout: float = field(default=0.1)
84 | bias: str = field(default="none")
85 |
86 | def train():
87 | dist.init_process_group(backend='nccl', init_method='env://')
88 | rank = dist.get_rank()
89 | torch.cuda.set_device(rank)
90 | print("Rank", rank, "Current device", torch.cuda.current_device())
91 |
92 | parser = HfArgumentParser((DataArguments, TrainingArguments, PeftArguments))
93 |
94 | data_args, training_args, peft_args = parser.parse_args_into_dataclasses()
95 |
96 | training_args._frozen = False
97 |
98 | if not data_args.no_timestamps:
99 | timestr = datetime.now().strftime("-%m%d%H%M")
100 | training_args.output_dir = training_args.output_dir + timestr
101 |
102 | training_args.logging_dir = os.path.join(training_args.output_dir, 'logging')
103 |
104 | if os.path.exists(training_args.output_dir):
105 | if training_args.overwrite_output_dir:
106 | print(f"Output directory ({training_args.output_dir}) already exists. Overwriting output dir.")
107 | if training_args.process_index == 0:
108 | shutil.rmtree(training_args.output_dir)
109 | else:
110 | print(f"Output directory ({training_args.output_dir}) already exists. Use --overwrite_output_dir to overcome.")
111 |
112 | if training_args.world_size > 1:
113 | dist.barrier(device_ids=[rank])
114 |
115 | if training_args.process_index == 0:
116 | if not os.path.exists(training_args.output_dir):
117 | os.makedirs(training_args.output_dir)
118 |
119 | if training_args.world_size > 1:
120 | dist.barrier(device_ids=[rank])
121 |
122 | set_seed(training_args.seed)
123 |
124 | node_rank = int(os.getenv('GROUP_RANK', '0'))
125 |
126 | for _logger in [logger, transformers.utils.logging.get_logger(), logging.getLogger('DeepSpeed')]:
127 | set_logger(_logger, training_args.local_rank, data_args.stream, os.path.join(training_args.output_dir, f'log-node-{node_rank}.log'))
128 |
129 | logger.warning("Device: %s, rank: %s, world size: %s", training_args.device, training_args.process_index, training_args.world_size)
130 |
131 | if training_args.world_size > 1:
132 | dist.barrier(device_ids=[rank])
133 |
134 | print_args(data_args, 'Data Arguments')
135 | print_args(training_args, 'Training Arguments')
136 | print_args(peft_args, 'LoRA Arguments')
137 |
138 | processor = Processor()
139 |
140 | config = AutoConfig.from_pretrained(data_args.model_cfg, trust_remote_code=True)
141 | config._attn_implementation = "flash_attention_2"
142 | config.use_cache = False
143 |
144 | base_model = AutoModelForCausalLM.from_pretrained(data_args.model_cfg, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True)
145 | logger.info(base_model)
146 |
147 | tokenizer = AutoTokenizer.from_pretrained(data_args.model_cfg, trust_remote_code=True)
148 | tokenizer.padding_side = 'right'
149 | logger.info(f"padding side: {tokenizer.padding_side}")
150 |
151 | if tokenizer.pad_token is None:
152 | tokenizer.pad_token = tokenizer.eos_token
153 | tokenizer.pad_token_id = tokenizer.eos_token_id
154 |
155 | if data_args.train_parquet_file is not None:
156 | train_sets = load_dataset("parquet", data_files=data_args.train_parquet_file, split='train')
157 | elif data_args.train_file_config is not None:
158 | with open(data_args.train_file_config, "r") as f:
159 | train_files = json.load(f)
160 |
161 | train_sets = []
162 | for file in train_files:
163 | _dataset = load_dataset(file.split(".")[-1] if file.split(".")[-1] != "jsonl" else "json", data_files=file, split='train')
164 | train_sets.append(_dataset)
165 |
166 | lengths = np.array([_set.shape[0] for _set in train_sets])
167 | logger.info("Data Lengths: %s", lengths)
168 |
169 | for i in range(1, len(train_sets)):
170 | train_sets[i] = train_sets[i].cast(train_sets[0].features)
171 |
172 | train_sets = concatenate_datasets(train_sets)
173 | else:
174 | raise ValueError("Should provide either 'train_dataset' or 'train_file_config'")
175 |
176 | logger.info('Total %d case', len(train_sets))
177 |
178 | process_batch_size = min(1000, len(train_sets))
179 |
180 | with training_args.main_process_first(desc="Log a few random samples from the training set"):
181 | for index in random.sample(range(len(train_sets)), 3):
182 | logger.info(
183 | "Sample %d of the raw training set:\n\ninput_tokens: %s\n\n%s\n\n",
184 | index,
185 | train_sets[index]['input_ids'],
186 | tokenizer.convert_ids_to_tokens(train_sets[index]['input_ids']),
187 | )
188 |
189 | train_sets = train_sets.shuffle(seed=training_args.seed)
190 | column_names = list(train_sets.features)
191 | if data_args.train_parquet_file is None:
192 | with training_args.main_process_first(desc="dataset map grouping"):
193 | train_sets = train_sets.map(
194 | processor.group_texts,
195 | fn_kwargs={
196 | "tokenizer": tokenizer,
197 | "max_len": data_args.max_len
198 | },
199 | batched=True,
200 | load_from_cache_file=False,
201 | remove_columns=column_names,
202 | batch_size=process_batch_size,
203 | num_proc=data_args.preprocessing_num_workers,
204 | desc=f"Grouping texts in chunks of {data_args.max_len}",
205 | )
206 |
207 | with training_args.main_process_first(desc="Log a few random samples from the grouped training set"):
208 | for index in random.sample(range(len(train_sets)), 3):
209 | logger.info(
210 | "Sample %d of the merged training set:\n\n%s",
211 | index, tokenizer.decode(train_sets[index]['input_ids'])
212 | )
213 |
214 | if data_args.resume_step is not None and data_args.resume_batch_size is not None:
215 | train_sets = train_sets[data_args.resume_step * data_args.resume_batch_size:]
216 | training_args.max_steps -= data_args.resume_step
217 | new_warmup_steps = max(0, training_args.warmup_steps - data_args.resume_step)
218 | new_learning_rate = training_args.learning_rate
219 | new_learning_rate -= max(0, data_args.resume_step - training_args.warmup_steps) * (training_args.learning_rate / training_args.max_steps - training_args.warmup_steps)
220 | training_args.warmup_steps = new_warmup_steps
221 | training_args.learning_rate = new_learning_rate
222 |
223 | epoch_checkpoint_callback = StepCheckpointCallback(save_interval=data_args.step_save_interval, output_dir=training_args.output_dir, external_validation=data_args.external_validation)
224 |
225 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
226 |
227 | target_modules_str = peft_args.target_modules
228 | # convert the comma-separated string to a list of strings
229 | if target_modules_str:
230 | target_modules_list = [module.strip() for module in target_modules_str.split(',')]
231 | else:
232 | target_modules_list = None
233 |
234 | lora_config = LoraConfig(
235 | target_modules=target_modules_list,
236 | task_type=peft_args.task_type,
237 | r=peft_args.lora_r,
238 | lora_alpha=peft_args.lora_alpha,
239 | lora_dropout=peft_args.lora_dropout,
240 | bias=peft_args.bias,
241 | )
242 |
243 | training_args_main_output_dir = training_args.output_dir
244 | training_args.output_dir = os.path.join(training_args.output_dir, f"backups")
245 |
246 |
247 | trainer = SFTTrainer(
248 | model=base_model,
249 | train_dataset=train_sets,
250 | max_seq_length=data_args.max_len,
251 | peft_config=lora_config,
252 | tokenizer=tokenizer,
253 | args=training_args,
254 | data_collator=data_collator,
255 | callbacks=[LoggerCallback, epoch_checkpoint_callback],
256 | # RemoveStateCallback can be used to save disk space but you may not be able to resume runs from ckpt
257 | )
258 |
259 | epoch_checkpoint_callback.set_trainer(trainer)
260 | print(lora_config)
261 |
262 | trainer.train(resume_from_checkpoint=data_args.resume_from)
263 |
264 | trainer.save_model(os.path.join(training_args_main_output_dir, "final"))
265 |
266 | if __name__ == "__main__":
267 |
268 | try:
269 | train()
270 | except Exception as e:
271 | logging.exception(e)
272 | exit(-1)
273 |
--------------------------------------------------------------------------------
/train/utils/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import re
4 | import torch
5 | import logging
6 |
7 | logger = logging.getLogger()
8 |
9 | IGNORE_INDEX = -100
10 |
11 | class Processor:
12 |
13 | def group_texts(self, examples, tokenizer, max_len):
14 | input_ids, labels = [], []
15 | final_input_ids, final_labels = [], []
16 |
17 | for idx in range(len(examples['input_ids'])):
18 | _input_ids = examples['input_ids'][idx]
19 | _labels = examples['input_ids'][idx]
20 | examples['input_ids'][idx] = None
21 | if len(_input_ids) > max_len:
22 | # if single sample longer than max_len, break into several
23 | devided_input_ids, devided_labels = [], []
24 | for i in range(0, len(_input_ids), max_len):
25 | devided_input_ids = _input_ids[i: i + max_len]
26 | devided_labels = _labels[i: i + max_len]
27 | if len(devided_input_ids) < max_len:
28 | devided_pad_num = max_len - len(devided_input_ids)
29 | devided_input_ids += [tokenizer.pad_token_id] * devided_pad_num
30 | devided_labels += [IGNORE_INDEX] * devided_pad_num
31 | final_input_ids.append(devided_input_ids)
32 | final_labels.append(devided_labels)
33 | continue
34 |
35 | # if single sample shorter than max_len, combine together
36 | if len(input_ids) + len(_input_ids) > max_len:
37 | pad_num = max_len - len(input_ids)
38 | final_input_ids.append(input_ids + [tokenizer.pad_token_id] * pad_num)
39 | final_labels.append(labels + [IGNORE_INDEX] * pad_num)
40 |
41 | input_ids, labels = [], []
42 |
43 | input_ids.extend(_input_ids)
44 | labels.extend(_labels)
45 |
46 | if len(input_ids) > 0:
47 | pad_num = max_len - len(input_ids)
48 | final_input_ids.append(input_ids + [tokenizer.pad_token_id] * pad_num)
49 | final_labels.append(labels + [IGNORE_INDEX] * pad_num)
50 |
51 | return {
52 | "input_ids": torch.tensor(final_input_ids).long(),
53 | "labels": torch.tensor(final_labels).long()
54 | }
55 |
56 | def process_tokenize(self, exmaples, tokenizer):
57 | """
58 | tokenize samples and add bos and eos tokens
59 | """
60 | inputs = tokenizer(exmaples['text'], truncation=False, padding=False)
61 |
62 | input_ids, labels = [], []
63 | for input_id in inputs['input_ids']:
64 | if tokenizer.bos_token_id is not None:
65 | input_ids.append([tokenizer.bos_token_id] + input_id + [tokenizer.eos_token_id])
66 | else:
67 | input_ids.append(input_id + [tokenizer.eos_token_id])
68 |
69 | return {
70 | "input_ids": input_ids,
71 | }
72 |
--------------------------------------------------------------------------------
/train/utils/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import glob
4 | import logging
5 | import datetime
6 |
7 | from transformers import TrainerCallback
8 | from transformers import TrainingArguments, TrainerState, TrainerControl
9 | import os
10 |
11 | import torch
12 |
13 | from typing import Dict, Sequence
14 | from dataclasses import dataclass
15 | import transformers
16 |
17 |
18 | IGNORE_INDEX = -100
19 |
20 | logger = logging.getLogger()
21 |
22 | class LoggerCallback(TrainerCallback):
23 |
24 | def on_train_begin(self, args, state, control, **kwargs):
25 |
26 | self.start_time = datetime.datetime.now()
27 |
28 | def on_log(self, args, state, control, logs=None, **kwargs):
29 | if not state.is_local_process_zero:
30 | return
31 |
32 | if 'loss' not in logs:
33 | return
34 |
35 | loss_msg = ' '.join(["%s: %s" % (k, v) for k, v in logs.items() if 'loss' in k])
36 | now = datetime.datetime.now()
37 | pass_time = now - self.start_time
38 | rest_time = pass_time * (state.max_steps - state.global_step) / state.global_step
39 | eta = now + rest_time
40 |
41 | pt_min = pass_time.seconds // 60
42 | pass_time = '%.2d:%.2d' % (pt_min // 60 + pass_time.days * 24, pt_min % 60)
43 |
44 | rt_min = rest_time.seconds // 60
45 | rest_time = '%.2d:%.2d' % (rt_min // 60 + rest_time.days * 24, rt_min % 60)
46 |
47 | logger.info(
48 | 'step: %d epoch: %.2f %s lr: %.4g passed time: %s rest time: %s eta: %s',
49 | state.global_step, state.epoch, loss_msg, logs.get('learning_rate', 0),
50 | pass_time, rest_time, eta.strftime('%m/%d %H:%M')
51 | )
52 |
53 | class RemoveStateCallback(TrainerCallback):
54 |
55 | def remove_state(self, args, step):
56 | step = int(step)
57 |
58 | if step <= 0:
59 | return
60 |
61 | step_dir = os.path.join(args.output_dir, f'checkpoint-{step}')
62 | logger.info('Remove state in %s', step_dir)
63 |
64 | remove_paths = [
65 | os.path.join(step_dir, 'latest'), # deepspeed state
66 | os.path.join(step_dir, f'global_step{step}'), # deepspeed state
67 | os.path.join(step_dir, 'optimizer.pt'), # optimizer state
68 | os.path.join(step_dir, 'scheduler.pt'), # scheduler state
69 | os.path.join(step_dir, 'generation_config.json'), # generation config
70 | os.path.join(step_dir, 'trainer_state.json'), # trainer state
71 | os.path.join(step_dir, 'training_args.bin'), # training args
72 | os.path.join(step_dir, 'zero_to_fp32.py')
73 | ]
74 |
75 | remove_paths.extend(glob.glob(os.path.join(step_dir, 'rng_state_*.pth'))) # numpy random state
76 |
77 | for path in remove_paths:
78 | if os.path.exists(path):
79 | os.system('rm -rf %s' % path)
80 |
81 | def on_save(self, args, state, control, **kwargs):
82 |
83 | if not state.is_world_process_zero:
84 | return
85 |
86 | self.remove_state(args, state.global_step - state.save_steps)
87 |
88 | def on_train_end(self, args, state, control, **kwargs):
89 |
90 | if not state.is_world_process_zero:
91 | return
92 |
93 | self.remove_state(args, state.global_step)
94 |
95 |
96 | class StepCheckpointCallback(TrainerCallback):
97 | def __init__(self, save_interval, output_dir, external_validation):
98 | self.save_interval = save_interval
99 | self.output_dir = output_dir
100 | self.trainer = None
101 | self.external_validation = external_validation
102 |
103 | def set_trainer(self, trainer):
104 | self.trainer = trainer
105 |
106 | def save_model(self, checkpoint_dir, state: TrainerState):
107 | self.trainer.save_model(checkpoint_dir)
108 | print(f"Global step: {state.global_step} (Epoch {int(state.epoch)}). Saved checkpoint at {checkpoint_dir}")
109 |
110 |
111 | def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
112 | if state.global_step % self.save_interval == 0:
113 | checkpoint_dir = os.path.join(self.output_dir, f"checkpoint-step-{state.global_step}")
114 | self.save_model(checkpoint_dir, state)
115 | control.should_save = True # save to resume later
116 | if self.external_validation:
117 | control.should_training_stop = True # signal to stop training
118 | return control
119 |
120 |
121 |
122 | @dataclass
123 | class DataCollatorForSupervisedDataset(object):
124 | """Collate examples for supervised fine-tuning.
125 |
126 | Will do right padding for input_ids and labels up to the longest sequence in the batch."""
127 |
128 | tokenizer: transformers.PreTrainedTokenizer
129 |
130 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
131 | input_ids, labels = tuple([torch.tensor(instance[key]) for instance in instances] for key in ("input_ids", "labels"))
132 | input_ids = torch.nn.utils.rnn.pad_sequence(
133 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
134 | )
135 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
136 | return dict(
137 | input_ids=input_ids,
138 | labels=labels,
139 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
140 | )
141 |
142 |
143 |
144 |
145 | @dataclass
146 | class DataCollatorForDPODataset(object):
147 | """Collate examples for DPO fine-tuning.
148 |
149 | Will do right padding for input_ids and labels up to the longest sequence in the batch."""
150 |
151 | tokenizer: transformers.PreTrainedTokenizer
152 |
153 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
154 | input_ids, chosen_ids, rejected_ids = tuple([torch.tensor(instance[key]) for instance in instances] for key in ("prompt_input_ids", "chosen_input_ids", "rejected_input_ids"))
155 | input_ids = torch.nn.utils.rnn.pad_sequence(
156 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
157 | )
158 | chosen_ids = torch.nn.utils.rnn.pad_sequence(
159 | chosen_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
160 | )
161 | rejected_ids = torch.nn.utils.rnn.pad_sequence(
162 | rejected_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
163 | )
164 | return dict(
165 | prompt_input_ids=input_ids,
166 | prompt_attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
167 | chosen_input_ids=chosen_ids,
168 | chosen_attention_mask=chosen_ids.ne(self.tokenizer.pad_token_id),
169 | rejected_input_ids=rejected_ids,
170 | rejected_attention_mask=rejected_ids.ne(self.tokenizer.pad_token_id),
171 | )
172 |
--------------------------------------------------------------------------------
/train/utils/util.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger = logging.getLogger()
4 |
5 | def set_logger(_logger, local_rank, stream=True, log_file=None):
6 | _logger.handlers.clear()
7 |
8 | if local_rank in [-1, 0]:
9 | _logger.setLevel(logging.INFO)
10 | else:
11 | _logger.setLevel(logging.WARN)
12 |
13 | log_format = '[%(asctime)s] [Rank {} - %(levelname)s] [%(filename)s - %(lineno)d] %(message)s'.format(local_rank)
14 | log_format = logging.Formatter(log_format, '%Y-%m-%d %H:%M:%S')
15 |
16 | if stream:
17 | console = logging.StreamHandler()
18 | console.setFormatter(log_format)
19 | _logger.addHandler(console)
20 |
21 | if log_file is not None:
22 |
23 | file = logging.FileHandler(log_file, mode='a')
24 | file.setFormatter(log_format)
25 | _logger.addHandler(file)
26 |
27 | def print_args(args, name):
28 | max_len = max([len(k) for k in vars(args).keys()]) + 4
29 | logger.info(f"******************* {name} *******************")
30 | for key, val in sorted(vars(args).items()):
31 | keystr = "{}".format(key) + (" " * (max_len - len(key)))
32 | logger.info("%s --> %s", keystr, val)
33 | logger.info(f"******************* {name} *******************")
34 |
--------------------------------------------------------------------------------
/train/validate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import sys
5 | import wandb
6 | import concurrent.futures
7 |
8 | current_dir = os.path.dirname(os.path.realpath(__file__))
9 | merge_path = os.path.join(current_dir, "scripts")
10 | sys.path.append(merge_path)
11 | import merge_model
12 |
13 | eval_path = os.path.join(current_dir, "../test/PDEcontrol/evaluation/infer")
14 | sys.path.append(eval_path)
15 | import run_1d_pdecontrol_eval_full as run_1d_pdecontrol_eval
16 |
17 |
18 |
19 | def create_merge_args(args):
20 | class merge_Args:
21 | model_path = args.base_model
22 | output_dir = f"{args.checkpoint_dir}"
23 | adapter_path = args.checkpoint_dir
24 | gpus = args.cuda_visible_devices
25 | return merge_Args()
26 |
27 |
28 | def create_eval_args(args, data_type, p_format, validation_dir, save_directory, shots=2):
29 | class eval_Args:
30 | data_dir = validation_dir
31 | python_key = "python"
32 | stl_key = "sstl"
33 | nl_key = "nl"
34 | robustness_key = "robustness"
35 | max_num_examples = args.valid_num_examples
36 | save_dir = f"{save_directory}/{shots}_shots"
37 | model_name_or_path_coder = f"{args.checkpoint_dir}/merged_model"
38 | tokenizer_name_or_path_coder = args.base_model
39 | eval_batch_size = len(args.cuda_visible_devices)
40 | load_in_8bits = False
41 | gptq= False
42 | use_vllm = True
43 | load_in_half = False
44 | infer_on_train_set=True
45 | n_subsets=1
46 | subset_id=0
47 | temperature = 0
48 | repeat_id_start = 0
49 | n_repeat_sampling = 1
50 | prompt_format = "few_shot"
51 | few_shot_prompt = p_format
52 | prompt_dataset = data_type
53 | few_shot_number = shots
54 | answer_extraction_fn=None
55 | eval_perplexity=True
56 | eval_robustness = True
57 | eval_edit_distance = True
58 | eval_iou = True
59 | load_from_file=False
60 | skip_existing_scores=False
61 | gpus = args.cuda_visible_devices
62 | seed = None
63 | use_openai = False
64 | return eval_Args()
65 |
66 |
67 | def delete_merged_model(checkpoint_dir):
68 | os.system(f"rm -rf {checkpoint_dir}/merged_model")
69 |
70 | def log_to_wandb(results_dir, args):
71 | for directory in os.listdir(results_dir):
72 | eval_shots_dir = os.path.join(results_dir, directory)
73 | if os.path.isdir(eval_shots_dir):
74 | for eval_method_subdir in os.listdir(eval_shots_dir):
75 | eval_method_subdir_path = os.path.join(eval_shots_dir, eval_method_subdir)
76 | if os.path.isdir(eval_method_subdir_path):
77 | metrics_file = os.path.join(eval_method_subdir_path, f"metrics.{args.valid_num_examples}.json")
78 | if os.path.isfile(metrics_file):
79 | with open(metrics_file, 'r') as f:
80 | metrics = json.load(f)
81 |
82 | few_shot_prompt = metrics.get("few_shot_prompt", "")
83 | few_shot_number = metrics.get("few_shot_number", None)
84 |
85 | for metric, value in metrics.items():
86 | if metric not in ["few_shot_prompt", "few_shot_number"]:
87 | wandb.log({
88 | f"validation_{few_shot_prompt}_{few_shot_number}shots/{metric}": value
89 | })
90 |
91 |
92 | def validate_model(args):
93 | wandb.init(project=args.project_name, resume=True, dir=args.wandb_dir)
94 |
95 | merge_args = create_merge_args(args)
96 | merge_model.main(merge_args)
97 |
98 | for dataset in os.listdir(args.validation_data_dir):
99 | dataset_path = os.path.join(args.validation_data_dir, dataset)
100 | if os.path.isdir(dataset_path):
101 | for shots in [0, 2]:
102 | for p_format in ["to_python_no_STL"]:
103 | if 'heat' in dataset:
104 | data_type = "CoTOneDHeat"
105 | elif 'wave' in dataset:
106 | data_type = "CoTOneDWave"
107 |
108 | save_dir = f"{args.checkpoint_dir}/validation"
109 | eval_args = create_eval_args(args, data_type, p_format, dataset_path, save_dir, shots)
110 | run_1d_pdecontrol_eval.main(eval_args)
111 |
112 | with concurrent.futures.ThreadPoolExecutor() as executor:
113 | future = executor.submit(run_1d_pdecontrol_eval.run_main, eval_args)
114 | try:
115 | future.result(timeout=900) # Set timeout in seconds
116 | except concurrent.futures.TimeoutError:
117 | print(f"Timeout occurred for dataset: {dataset}, shots: {shots}, p_format: {p_format}")
118 |
119 |
120 |
121 | delete_merged_model(args.checkpoint_dir)
122 | log_to_wandb(save_dir, args)
123 |
124 |
125 | if __name__ == "__main__":
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument('--checkpoint_dir', type=str, required=True)
128 | parser.add_argument('--base_model', type=str, required=True)
129 | parser.add_argument('--cuda_visible_devices', type=str, required=True)
130 | parser.add_argument('--valid_num_examples', type=int, default=8)
131 | parser.add_argument('--validation_data_dir', type=str, required=True)
132 | parser.add_argument('--project_name', type=str, default="huggingface")
133 | parser.add_argument('--wandb_dir', type=str, required=True)
134 |
135 |
136 | args = parser.parse_args()
137 |
138 | # vllm may freeze on multi-gpu
139 | args.cuda_visible_devices = args.cuda_visible_devices.split(',')[0]
140 |
141 | validate_model(args)
--------------------------------------------------------------------------------
/train_test_combined.yml:
--------------------------------------------------------------------------------
1 | name: trainenv
2 | channels:
3 | - conda-forge
4 | - https://repo.anaconda.com/pkgs/main
5 | - https://repo.anaconda.com/pkgs/r
6 | dependencies:
7 | - _libgcc_mutex=0.1=conda_forge
8 | - _openmp_mutex=4.5=2_gnu
9 | - bzip2=1.0.8=h4bc722e_7
10 | - ca-certificates=2024.8.30=hbcca054_0
11 | - ld_impl_linux-64=2.43=h712a8e2_1
12 | - libffi=3.4.2=h7f98852_5
13 | - libgcc=14.1.0=h77fa898_1
14 | - libgcc-ng=14.1.0=h69a702a_1
15 | - libgomp=14.1.0=h77fa898_1
16 | - libnsl=2.0.1=hd590300_0
17 | - libsqlite=3.46.1=hadc24fc_0
18 | - libuuid=2.38.1=h0b41bf4_0
19 | - libxcrypt=4.4.36=hd590300_1
20 | - libzlib=1.3.1=hb9d3cd8_2
21 | - ncurses=6.5=he02047a_1
22 | - openssl=3.3.2=hb9d3cd8_0
23 | - pip=24.2=pyh8b19718_1
24 | - python=3.10.15=h4a871b0_1_cpython
25 | - readline=8.2=h8228510_1
26 | - setuptools=75.1.0=pyhd8ed1ab_0
27 | - tk=8.6.13=noxft_h4845f30_101
28 | - wheel=0.44.0=pyhd8ed1ab_0
29 | - xz=5.2.6=h166bdaf_0
30 | - pip:
31 | - absl-py==2.0.0
32 | - accelerate==1.2.0
33 | - aiofiles==23.2.1
34 | - aiohttp==3.9.1
35 | - aiosignal==1.3.1
36 | - airportsdata==20241001
37 | - annotated-types==0.7.0
38 | - anyio==4.7.0
39 | - async-timeout==4.0.3
40 | - attrs==23.1.0
41 | - beautifulsoup4==4.12.3
42 | - bitarray==3.0.0
43 | - bs4==0.0.2
44 | - cachetools==5.3.2
45 | - certifi==2022.12.7
46 | - charset-normalizer==2.1.1
47 | - click==8.1.7
48 | - cloudpickle==3.1.0
49 | - compressed-tensors==0.6.0
50 | - datasets==3.2.0
51 | - deepspeed==0.15.4
52 | - dill==0.3.7
53 | - diskcache==5.6.3
54 | - distro==1.9.0
55 | - docker-pycreds==0.4.0
56 | - docstring-parser==0.16
57 | - editdistance==0.8.1
58 | - einops==0.7.0
59 | - exceptiongroup==1.2.2
60 | - fastapi==0.115.6
61 | - ffmpy==0.5.0
62 | - filelock==3.16.1
63 | - fire==0.5.0
64 | # - flash-attn==2.7.2.post1
65 | - frozenlist==1.4.0
66 | - fsspec==2023.10.0
67 | - gguf==0.10.0
68 | - gitdb==4.0.11
69 | - gitpython==3.1.43
70 | - google-auth==2.25.2
71 | - google-auth-oauthlib==1.0.0
72 | - gradio==5.9.1
73 | - gradio-client==1.5.2
74 | - grpcio==1.60.0
75 | - h11==0.14.0
76 | - hjson==3.1.0
77 | - httpcore==1.0.7
78 | - httptools==0.6.4
79 | - httpx==0.28.1
80 | - huggingface-hub==0.26.5
81 | - idna==3.4
82 | - importlib-metadata==8.5.0
83 | - interegular==0.3.3
84 | - jinja2==3.1.2
85 | - jiter==0.8.2
86 | - jsonlines==4.0.0
87 | - jsonschema==4.23.0
88 | - jsonschema-specifications==2024.10.1
89 | - lark==1.2.2
90 | - llvmlite==0.43.0
91 | - lm-format-enforcer==0.10.6
92 | - loader==2017.9.11
93 | - loaderr==1.0.0
94 | - markdown==3.5.1
95 | - markdown-it-py==3.0.0
96 | - markupsafe==2.1.3
97 | - mdurl==0.1.2
98 | - mistral-common==1.4.4
99 | - mpi4py==4.0.1
100 | - mpmath==1.3.0
101 | - msgpack==1.1.0
102 | - msgspec==0.18.6
103 | - multidict==6.0.4
104 | - multiprocess==0.70.15
105 | - nest-asyncio==1.6.0
106 | - networkx==3.0
107 | - ninja==1.11.1.1
108 | - numba==0.60.0
109 | - numpy==1.25.0
110 | - nvidia-cublas-cu12==12.1.3.1
111 | - nvidia-cuda-cupti-cu12==12.1.105
112 | - nvidia-cuda-nvrtc-cu12==12.1.105
113 | - nvidia-cuda-runtime-cu12==12.1.105
114 | - nvidia-cudnn-cu12==9.1.0.70
115 | - nvidia-cufft-cu12==11.0.2.54
116 | - nvidia-curand-cu12==10.3.2.106
117 | - nvidia-cusolver-cu12==11.4.5.107
118 | - nvidia-cusparse-cu12==12.1.0.106
119 | - nvidia-ml-py==12.560.30
120 | - nvidia-nccl-cu12==2.20.5
121 | - nvidia-nvjitlink-cu12==12.4.127
122 | - nvidia-nvtx-cu12==12.1.105
123 | - oauthlib==3.2.2
124 | - openai==1.58.1
125 | - orjson==3.10.12
126 | - outlines==0.0.43
127 | - outlines-core==0.1.26
128 | - packaging==23.2
129 | - pandas==2.1.4
130 | - partial-json-parser==0.2.1.1.post4
131 | - pebble==5.1.0
132 | - peft==0.14.0
133 | - pillow==10.3.0
134 | - platformdirs==4.2.2
135 | - prometheus-client==0.21.1
136 | - prometheus-fastapi-instrumentator==7.0.0
137 | - protobuf==4.25.1
138 | - psutil==5.9.6
139 | - py-cpuinfo==9.0.0
140 | - pyairports==2.1.1
141 | - pyarrow==18.1.0
142 | - pyarrow-hotfix==0.6
143 | - pyasn1==0.5.1
144 | - pyasn1-modules==0.3.0
145 | - pycountry==24.6.1
146 | - pydantic==2.10.3
147 | - pydantic-core==2.27.1
148 | - pydub==0.25.1
149 | - pygments==2.18.0
150 | - python-dateutil==2.8.2
151 | - python-dotenv==1.0.1
152 | - python-multipart==0.0.20
153 | - pytz==2023.3.post1
154 | - pyyaml==6.0.1
155 | - pyzmq==26.2.0
156 | - ray==2.40.0
157 | - referencing==0.35.1
158 | - regex==2023.10.3
159 | - requests==2.32.3
160 | - requests-oauthlib==1.3.1
161 | - rich==13.9.3
162 | - rpds-py==0.22.3
163 | - rsa==4.9
164 | - ruff==0.8.4
165 | - safehttpx==0.1.6
166 | - safetensors==0.4.5
167 | - semantic-version==2.10.0
168 | - sentencepiece==0.2.0
169 | - sentry-sdk==2.5.1
170 | - setproctitle==1.3.3
171 | - shellingham==1.5.4
172 | - shtab==1.7.1
173 | - six==1.16.0
174 | - smmap==5.0.1
175 | - sniffio==1.3.1
176 | - soupsieve==2.5
177 | - starlette==0.41.3
178 | - sympy==1.13.1
179 | - tensorboard==2.14.0
180 | - tensorboard-data-server==0.7.2
181 | - termcolor==2.4.0
182 | - tiktoken==0.7.0
183 | - tokenizers==0.20.3
184 | - tomlkit==0.13.2
185 | - torch==2.4.0
186 | - torchaudio==2.4.0
187 | - torchvision==0.19.0
188 | - tqdm==4.67.1
189 | - transformers==4.46.3
190 | - triton==3.0.0
191 | - trl==0.12.2
192 | - typer==0.15.1
193 | - typing-extensions==4.12.2
194 | - tyro==0.8.14
195 | - tzdata==2023.3
196 | - urllib3==1.26.13
197 | - uvicorn==0.32.1
198 | - uvloop==0.21.0
199 | - vllm==0.6.3.post1
200 | - wandb==0.17.1
201 | - watchfiles==1.0.0
202 | - websockets==14.1
203 | - werkzeug==3.0.1
204 | # - xformers==0.0.27.post2
205 | - xxhash==3.4.1
206 | - yarl==1.9.4
207 | - zipp==3.21.0
208 | prefix: miniconda3/envs/trainenv
209 |
--------------------------------------------------------------------------------
/utils/few_shot_prompts/__init__.py:
--------------------------------------------------------------------------------
1 | from .cot_one_d_heat_fewshot import CoTOneDHeat
2 | from .cot_one_d_wave_fewshot import CoTOneDWave
3 | from .cot_one_d_combined_fewshot import CoTOneDCombined
4 | from .few_shot_train import FewShotTrain
5 | from .few_shot_train_dpo import FewShotDPO
--------------------------------------------------------------------------------
/utils/few_shot_prompts/cot_one_d_combined_fewshot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_test import FewShotTest
2 | import json
3 | import os
4 |
5 | class CoTOneDCombined(FewShotTest):
6 | def __init__(self, num_shots, format, dataset="combined"):
7 | assert dataset in ["combined", "heat", "wave"]
8 | if dataset== "combined" and num_shots not in [0, 2]:
9 | raise ValueError(f"Number of shots must be 0 or 2 for dataset {dataset}")
10 | super().__init__(num_shots)
11 | self.format=format
12 | current_dir = os.path.dirname(os.path.abspath(__file__))
13 | jsonl_file_path = os.path.join(current_dir, "examples", f"one_d_{dataset}", "examples.jsonl")
14 | self.examples = self.load_examples(jsonl_file_path, format)
15 |
16 | def load_examples(self, jsonl_file_path, format):
17 | examples = []
18 | with open(jsonl_file_path, 'r') as file:
19 | for line in file:
20 | data = json.loads(line)
21 | nl = data['nl']
22 | python = data['python']
23 | sstl = data['sstl']
24 | example = super().format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip())
25 | examples.append(example)
26 | return examples
27 |
28 | def format_prompt(self, nl="", sstl="", python=""):
29 | few_shots = self._get_few_shot_prompt()
30 | curr_prompt = super().format_prompt(self.format, nl, sstl, python)
31 | return f"{few_shots}{curr_prompt}"
--------------------------------------------------------------------------------
/utils/few_shot_prompts/cot_one_d_heat_fewshot.py:
--------------------------------------------------------------------------------
1 | from .few_shot_test import FewShotTest
2 | import json
3 | import os
4 |
5 | class CoTOneDHeat(FewShotTest):
6 | def __init__(self, num_shots, format, dataset="heat"):
7 | if dataset != "heat" and num_shots != 0:
8 | raise ValueError(f"Dataset {dataset} not supported. Only 'heat' dataset is supported")
9 | super().__init__(num_shots)
10 | self.format=format
11 | current_dir = os.path.dirname(os.path.abspath(__file__))
12 | jsonl_file_path = os.path.join(current_dir, "examples", "one_d_heat", "examples.jsonl")
13 | self.examples = self.load_examples(jsonl_file_path, format)
14 |
15 | def load_examples(self, jsonl_file_path, format):
16 | examples = []
17 | with open(jsonl_file_path, 'r') as file:
18 | for line in file:
19 | data = json.loads(line)
20 | nl = data['nl']
21 | python = data['python']
22 | sstl = data['sstl']
23 | example = super().format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip())
24 | examples.append(example)
25 | return examples
26 |
27 | def format_prompt(self, nl="", sstl="", python=""):
28 | few_shots = self._get_few_shot_prompt()
29 | curr_prompt = super().format_prompt(self.format, nl, sstl, python)
30 | return f"{few_shots}{curr_prompt}"
--------------------------------------------------------------------------------
/utils/few_shot_prompts/cot_one_d_wave_fewshot.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from .few_shot_test import FewShotTest
4 |
5 | class CoTOneDWave(FewShotTest):
6 | def __init__(self, num_shots, format, dataset="wave"):
7 | if dataset != "wave" and num_shots != 0:
8 | raise ValueError(f"Dataset {dataset} not supported. Only 'wave' dataset is supported")
9 | super().__init__(num_shots)
10 | self.format = format
11 | current_dir = os.path.dirname(os.path.abspath(__file__))
12 | jsonl_file_path = os.path.join(current_dir, "examples", "one_d_wave", "examples.jsonl")
13 | self.examples = self.load_examples(jsonl_file_path, format)
14 |
15 | def load_examples(self, jsonl_file_path, format):
16 | examples = []
17 | with open(jsonl_file_path, 'r') as file:
18 | for line in file:
19 | data = json.loads(line)
20 | nl = data['nl']
21 | python = data['python']
22 | sstl = data['sstl']
23 | example = super().format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip())
24 | examples.append(example)
25 | return examples
26 |
27 | def format_prompt(self, nl="", sstl="", python=""):
28 | few_shots = self._get_few_shot_prompt()
29 | curr_prompt = super().format_prompt(self.format, nl, sstl, python)
30 | return f"{few_shots}{curr_prompt}"
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/DPO_one_d_combined/dataset_description.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "few_shot_examples_one_d_wave",
3 | "seed_used": 0,
4 | "test_split": 0.1,
5 | "training_set_size": 2,
6 | "test_set_size": 0,
7 | "training_set_keys": [
8 | "nl",
9 | "sstl",
10 | "python"
11 | ],
12 | "test_set_keys": [],
13 | "combined_files": [
14 | "examples.jsonl"
15 | ]
16 | }
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/DPO_one_d_combined/examples.jsonl:
--------------------------------------------------------------------------------
1 | {"nl": "Consider a metallic rod of 100 mm. The temperature at one end of the rod is fixed at 300k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.At one specific instance within the interval 1.243045282863269 and 3.403832000478401, the temperature profile of the rod should exceed that of the linear expression mu0(x) = 0.12296790254167801 * x + 300.3119346197452 across the section extending from 13 to 27. Furthermore, within another time frame defined by 4.246509016670876 and 5.31524017104925, the temperature profile must be less than the linear expression mu1(x) = 0.47269239908880306 * x + 311.0471706338625 in the section between 43.0 and 61.0. Additionally, throughout the entire interval from 5.88805508229499 to 9.973974528713061, the temperature profile needs to be lower than the linear expression mu2(x) = 0.390105515041803 * x + 306.62859951293456 as it pertains to the segment between 75.0 and 93.0. We assume the rod is made of two different materials: the section from 30 to 60 mm is made of a material with parameters E_a = 1500000.0, rho_a = 4.5e-06 and c_a = 380000000.0, while the rest of the rod is made of a material with parameters E_b = 800000.0, rho_b = 4e-06 and c_b = 466000000.0. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 10.445529233891722 seconds. Assume a 30-element mesh is used.", "sstl": "G_[[0.39209616323405205, 0.6510075872323481]] (\\forall x \\in [12, 12] (u(x) - (0.11378382357293601 \\cdot x + 284.18670158160194) > 0)) \\land G_[[0.10999846031809801, 1.107322231186995]] (\\forall x \\in [48.0, 63.0] (u(x) - (-0.06923842950142 \\cdot x + 307.9263759927145) > 0)) \\land G_[[0.12370042271526001, 0.43663007356696704]] (\\forall x \\in [73.0, 94.0] (u(x) - (0.12971027690449002 \\cdot x + 286.1162890994643) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 100\nrho = lambda x: 4e-06*466000000.0 if x < 30 or x > 60 else 4.5e-06*380000000.0\nE = lambda x: 800000.0 if x < 30 or x > 60 else 1500000.0\nxpart = np.linspace(0, L, N + 1)\ng = [300, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 1.107322231186995\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([12, 12], \">\", lambda x: 0.11378382357293601 * x + 284.18670158160194, lambda x: 0.11378382357293601)\napc1 = logic.APCont([48.0, 63.0], \">\", lambda x: -0.06923842950142 * x + 307.9263759927145, lambda x: -0.06923842950142)\napc2 = logic.APCont([73.0, 94.0], \">\", lambda x: 0.12971027690449002 * x + 286.1162890994643, lambda x: 0.12971027690449002)\ncregions = {\"A\" : apc0, \"B\" : apc1, \"C\" : apc2}\ncspec = \"((G_[0.39209616323405205, 0.6510075872323481] (A)) & (G_[0.10999846031809801, 1.107322231186995] (B)) & (G_[0.12370042271526001, 0.43663007356696704] (C)))\""}
2 | {"nl": "Consider a steel and brass rod of length L = 100000 mm, the section between 30000 mm and 60000 mm being brass, with densities rho_steel = 8e-06, rho_brass = 8.5e-06 and Young's modulus E_steel = 200000000.0, E_brass = 100000000.0, with one end fixed and a time-variant force applied to the other end. This is a 1D elastic wave propagation problem. Consider the displacement of the rod as u(x).For the full span of time spanning 0.17743529415762402 to 0.30799846894522903, the rod must have its displacement compressed following the linear profile mu0(x) = 4.103648149081095e-05 * x + 2.595659429574442 in the area between sections 4888 and 24593. Similarly, from 0.40700869237930004 to 0.42879980260728906, the rod's displacement ought to be compressed according to the linear profile mu1(x) = 2.03250331324289e-05 * x + 1.686454607140632 across sections 31956.0 and 61131.0. Alternatively, within the period from 0.734575276782226 to 0.7784237515030701, there should be an instance where the rod's displacement is stretched in accordance with the linear profile mu2(x) = 7.539036123328967e-06 * x + 2.67167604686356, between sections 81192.0 and 92008.0. Assume that the discretized time interval is 0.0025s and the max time is 0.9783559046532111 seconds. Assume a 20-element mesh is used.", "sstl": "G_[[0.066415899408669, 0.172261358338638]] (\\forall x \\in [29936, 85026] (u(x) - (3.137321609165203e-06 \\cdot x + -1.526785067872162) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 100000\nrho = lambda x: 8e-06 if x < 30000 or x > 60000 else 8.5e-06\nE = lambda x: 200000000.0 if x < 30000 or x > 60000 else 100000000.0\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 0.172261358338638\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([29936, 85026], \"<\", lambda x: 3.137321609165203e-06 * x + -1.526785067872162, lambda x: 3.137321609165203e-06)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[0.066415899408669, 0.172261358338638] (A)))\""}
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/DPO_one_d_heat/dataset_description.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "few_shot_examples_one_d_heat",
3 | "seed_used": 0,
4 | "test_split": 0.1,
5 | "training_set_size": 2,
6 | "test_set_size": 0,
7 | "training_set_keys": [
8 | "nl",
9 | "sstl",
10 | "python"
11 | ],
12 | "test_set_keys": [],
13 | "combined_files": [
14 | "examples.jsonl"
15 | ]
16 | }
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/DPO_one_d_heat/examples.jsonl:
--------------------------------------------------------------------------------
1 | {"nl": "Consider a metallic rod of 100 mm. The temperature at one end of the rod is fixed at 300k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.At a specific moment within the time frame 1.821027314615614 and 3.523724320287923, the temperature distribution of the rod must exceed the linear profile mu0(x) = -0.13868715203829302 * x + 309.2162061347626 in the section spanning from 3 to 20. Additionally, throughout the entirety of the time interval 10.222409852153994 to 10.228356455872138, the temperature distribution should surpass the linear profile mu1(x) = -0.018962241717187002 * x + 299.2091654709691 for the section between 45.0 and 51.0. We assume the rod is made of two different materials: the section from 30 to 60 mm is made of a material with parameters E_a = 1500000.0, rho_a = 4.5e-06 and c_a = 380000000.0, while the rest of the rod is made of a material with parameters E_b = 800000.0, rho_b = 4e-06 and c_b = 466000000.0. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 14.198467092759284 seconds. Assume a 30-element mesh is used.", "sstl": "G_[[0.8608569523120131, 1.308637095866543]] (\\forall x \\in [14, 25] (u(x) - (0.30619062826944204 \\cdot x + 307.88718554535666) < 0)) \\land G_[[0.8498494273514371, 1.7936246423777131]] (\\forall x \\in [62.0, 100.0] (u(x) - (0.32396276508224203 \\cdot x + 309.74287180858954) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 100\nrho = lambda x: 4e-06*466000000.0 if x < 30 or x > 60 else 4.5e-06*380000000.0\nE = lambda x: 800000.0 if x < 30 or x > 60 else 1500000.0\nxpart = np.linspace(0, L, N + 1)\ng = [300, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 1.7936246423777131\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([14, 25], \"<\", lambda x: 0.30619062826944204 * x + 307.88718554535666, lambda x: 0.30619062826944204)\napc1 = logic.APCont([62.0, 100.0], \">\", lambda x: 0.32396276508224203 * x + 309.74287180858954, lambda x: 0.32396276508224203)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((G_[0.8608569523120131, 1.308637095866543] (A)) & (G_[0.8498494273514371, 1.7936246423777131] (B)))\""}
2 | {"nl": "Consider a metallic rod of 100 mm. The temperature at one end of the rod is fixed at 300k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.At one specific instance within the interval 1.243045282863269 and 3.403832000478401, the temperature profile of the rod should exceed that of the linear expression mu0(x) = 0.12296790254167801 * x + 300.3119346197452 across the section extending from 13 to 27. Furthermore, within another time frame defined by 4.246509016670876 and 5.31524017104925, the temperature profile must be less than the linear expression mu1(x) = 0.47269239908880306 * x + 311.0471706338625 in the section between 43.0 and 61.0. Additionally, throughout the entire interval from 5.88805508229499 to 9.973974528713061, the temperature profile needs to be lower than the linear expression mu2(x) = 0.390105515041803 * x + 306.62859951293456 as it pertains to the segment between 75.0 and 93.0. We assume the rod is made of two different materials: the section from 30 to 60 mm is made of a material with parameters E_a = 1500000.0, rho_a = 4.5e-06 and c_a = 380000000.0, while the rest of the rod is made of a material with parameters E_b = 800000.0, rho_b = 4e-06 and c_b = 466000000.0. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 10.445529233891722 seconds. Assume a 30-element mesh is used.", "sstl": "G_[[0.39209616323405205, 0.6510075872323481]] (\\forall x \\in [12, 12] (u(x) - (0.11378382357293601 \\cdot x + 284.18670158160194) > 0)) \\land G_[[0.10999846031809801, 1.107322231186995]] (\\forall x \\in [48.0, 63.0] (u(x) - (-0.06923842950142 \\cdot x + 307.9263759927145) > 0)) \\land G_[[0.12370042271526001, 0.43663007356696704]] (\\forall x \\in [73.0, 94.0] (u(x) - (0.12971027690449002 \\cdot x + 286.1162890994643) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 100\nrho = lambda x: 4e-06*466000000.0 if x < 30 or x > 60 else 4.5e-06*380000000.0\nE = lambda x: 800000.0 if x < 30 or x > 60 else 1500000.0\nxpart = np.linspace(0, L, N + 1)\ng = [300, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 1.107322231186995\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([12, 12], \">\", lambda x: 0.11378382357293601 * x + 284.18670158160194, lambda x: 0.11378382357293601)\napc1 = logic.APCont([48.0, 63.0], \">\", lambda x: -0.06923842950142 * x + 307.9263759927145, lambda x: -0.06923842950142)\napc2 = logic.APCont([73.0, 94.0], \">\", lambda x: 0.12971027690449002 * x + 286.1162890994643, lambda x: 0.12971027690449002)\ncregions = {\"A\" : apc0, \"B\" : apc1, \"C\" : apc2}\ncspec = \"((G_[0.39209616323405205, 0.6510075872323481] (A)) & (G_[0.10999846031809801, 1.107322231186995] (B)) & (G_[0.12370042271526001, 0.43663007356696704] (C)))\""}
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/DPO_one_d_wave/dataset_description.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "few_shot_examples_one_d_wave",
3 | "seed_used": 0,
4 | "test_split": 0.1,
5 | "training_set_size": 2,
6 | "test_set_size": 0,
7 | "training_set_keys": [
8 | "nl",
9 | "sstl",
10 | "python"
11 | ],
12 | "test_set_keys": [],
13 | "combined_files": [
14 | "examples.jsonl"
15 | ]
16 | }
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/DPO_one_d_wave/examples.jsonl:
--------------------------------------------------------------------------------
1 | {"nl": "Consider a steel and brass rod of length L = 100000 mm, the section between 30000 mm and 60000 mm being brass, with densities rho_steel = 8e-06, rho_brass = 8.5e-06 and Young's modulus E_steel = 200000000.0, E_brass = 100000000.0, with one end fixed and a time-variant force applied to the other end. This is a 1D elastic wave propagation problem. Consider the displacement of the rod as u(x).For the complete duration between the intervals 0.289483287576353 and 0.36810738390869, the rod's displacement should comply with the linear profile mu0(x) = 4.585477841781323e-05 * x + 2.331946581525069 for the sections ranging from 11861 to 34211. In contrast, at a specific point in time within the interval 0.6018777959138361 to 1.251736812496592, the displacement should reflect compression according to the linear profile mu1(x) = 9.629101387124293e-06 * x + -2.660932944224053 from section 73701.0 to 91138.0.Assume that the discretized time interval is 0.0025s and the max time is 1.348443583660752 seconds. Assume a 20-element mesh is used.", "sstl": "G_[[0.022889123990597, 0.23971012731697103]] (\\forall x \\in [10916, 42249] (u(x) - (-3.5012259820233114e-05 \\cdot x + -1.180840333310321) > 0)) \\land G_[[0.134370803878005, 0.223822164312479]] (\\forall x \\in [57461.0, 64690.0] (u(x) - (-3.293125603528961e-05 \\cdot x + -1.51293238845689) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 100000\nrho = lambda x: 8e-06 if x < 30000 or x > 60000 else 8.5e-06\nE = lambda x: 200000000.0 if x < 30000 or x > 60000 else 100000000.0\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 0.23971012731697103\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([10916, 42249], \">\", lambda x: -3.5012259820233114e-05 * x + -1.180840333310321, lambda x: -3.5012259820233114e-05)\napc1 = logic.APCont([57461.0, 64690.0], \"<\", lambda x: -3.293125603528961e-05 * x + -1.51293238845689, lambda x: -3.293125603528961e-05)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((G_[0.022889123990597, 0.23971012731697103] (A)) & (G_[0.134370803878005, 0.223822164312479] (B)))\""}
2 | {"nl": "Consider a steel and brass rod of length L = 100000 mm, the section between 30000 mm and 60000 mm being brass, with densities rho_steel = 8e-06, rho_brass = 8.5e-06 and Young's modulus E_steel = 200000000.0, E_brass = 100000000.0, with one end fixed and a time-variant force applied to the other end. This is a 1D elastic wave propagation problem. Consider the displacement of the rod as u(x).For the full span of time spanning 0.17743529415762402 to 0.30799846894522903, the rod must have its displacement compressed following the linear profile mu0(x) = 4.103648149081095e-05 * x + 2.595659429574442 in the area between sections 4888 and 24593. Similarly, from 0.40700869237930004 to 0.42879980260728906, the rod's displacement ought to be compressed according to the linear profile mu1(x) = 2.03250331324289e-05 * x + 1.686454607140632 across sections 31956.0 and 61131.0. Alternatively, within the period from 0.734575276782226 to 0.7784237515030701, there should be an instance where the rod's displacement is stretched in accordance with the linear profile mu2(x) = 7.539036123328967e-06 * x + 2.67167604686356, between sections 81192.0 and 92008.0. Assume that the discretized time interval is 0.0025s and the max time is 0.9783559046532111 seconds. Assume a 20-element mesh is used.", "sstl": "G_[[0.066415899408669, 0.172261358338638]] (\\forall x \\in [29936, 85026] (u(x) - (3.137321609165203e-06 \\cdot x + -1.526785067872162) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 100000\nrho = lambda x: 8e-06 if x < 30000 or x > 60000 else 8.5e-06\nE = lambda x: 200000000.0 if x < 30000 or x > 60000 else 100000000.0\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 0.172261358338638\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([29936, 85026], \"<\", lambda x: 3.137321609165203e-06 * x + -1.526785067872162, lambda x: 3.137321609165203e-06)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[0.066415899408669, 0.172261358338638] (A)))\""}
3 |
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/one_d_combined/dataset_description.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "few_shot_examples_one_d_combined",
3 | "seed_used": 0,
4 | "test_split": 0.1,
5 | "training_set_size": 2,
6 | "test_set_size": 0,
7 | "training_set_keys": [
8 | "nl",
9 | "sstl",
10 | "python"
11 | ],
12 | "test_set_keys": [],
13 | "combined_files": [
14 | "examples.jsonl"
15 | ]
16 | }
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/one_d_combined/examples.jsonl:
--------------------------------------------------------------------------------
1 | {"nl": "Consider a metallic rod with a maximum length of 112 mm, where the temperature at one extremity is held at 321k, and the opposite extremity is exposed to a heat source. The temperature profile of the rod is described by the 1D linear heat equation.Throughout the duration from 1.8288 to 4.6769, there exists a point in time where the temperature distribution of the rod should be greater than the linear profile mu0(x) = 0.0771 * x + 326.154 applicable between the sections 5 and 97.The rod is presumed to be fabricated from two varieties of materials: from 3 to 49 mm, a material with parameters E_a = 1682393, \rrho_a = 5.952e-06, and c_a = 438533237 is utilized, while the remainder of the rod features a material with parameters E_b = 410042, \rrho_b = 3.977e-06, and c_b = 470729859. We define the temperature at position x as u(x). We will consider a discretized time interval of 0.05 seconds and a total time of 8 seconds, employing a 30-element mesh.", "sstl": "F_[[1.8288, 4.6769]] (\\forall x \\in [5, 97] (u(x) - (0.0771 \\cdot x + 326.154) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 112\nrho = lambda x: 3.977e-06*470729859 if x < 3 or x > 49 else 5.952e-06*438533237\nE = lambda x: 410042 if x < 3 or x > 49 else 1682393\nxpart = np.linspace(0, L, N + 1)\ng = [321, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 8\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([5, 97], \">\", lambda x: 0.0771 * x + 326.154, lambda x: 0.0771)\ncregions = {\"A\" : apc0}\ncspec = \"((F_[1.8288, 4.6769] (A)))\""}
2 | {"nl": "Consider a rod composed of steel and brass with a length of L = 79651 mm, where the brass section is located between 33634 mm and 43799 mm. The densities are defined as \rrho_steel = 7.927e-06 and \rrho_brass = 8.452e-06, and the Young's moduli are E_steel = 222786951 and E_brass = 102749268. One end is held in place, and a time-dependent force is applied to the other end. This setup is focused on a 1D elastic wave propagation challenge. Let u(x) denote the displacement of the rod.At a given time in the interval from 0.0541 to 0.2621, the rod's displacement must match the linear profile mu0(x) = -1.4897e-05 * x + -1.7281 throughout the range of sections 8330 and 30692. In addition, throughout the interval from 0.2845 to 0.8982, the rod's displacement is expected to be compressed to fit the linear profile mu1(x) = 1.029e-06 * x + -0.3131 across the sections 56782 and 69640.We will assume that the time interval is discretized at 0.0025s, with the maximum time of 1.9424 seconds, using a mesh that contains 20 elements.", "sstl": "F_[[0.0541, 0.2621]] (\\forall x \\in [8330, 30692] (u(x) - (-1.4897e-05 \\cdot x + -1.7281) = 0)) \\land G_[[0.2845, 0.8982]] (\\forall x \\in [56782, 69640] (u(x) - (1.029e-06 \\cdot x + -0.3131) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 79651\nrho = lambda x: 7.927e-06 if x < 33634 or x > 43799 else 8.452e-06\nE = lambda x: 222786951 if x < 33634 or x > 43799 else 102749268\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.9424\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([8330, 30692], \"=\", lambda x: -1.4897e-05 * x + -1.7281, lambda x: -1.4897e-05)\napc1 = logic.APCont([56782, 69640], \"<\", lambda x: 1.029e-06 * x + -0.3131, lambda x: 1.029e-06)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((F_[0.0541, 0.2621] (A)) & (G_[0.2845, 0.8982] (B)))\""}
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/one_d_heat/dataset_description.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "few_shot_examples_one_d_heat",
3 | "seed_used": 0,
4 | "test_split": 0.1,
5 | "training_set_size": 3,
6 | "test_set_size": 0,
7 | "training_set_keys": [
8 | "nl",
9 | "sstl",
10 | "python"
11 | ],
12 | "test_set_keys": [],
13 | "combined_files": [
14 | "examples.jsonl"
15 | ]
16 | }
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/one_d_heat/examples.jsonl:
--------------------------------------------------------------------------------
1 | {"nl": "Consider a metallic rod with a maximum length of 112 mm, where the temperature at one extremity is held at 321k, and the opposite extremity is exposed to a heat source. The temperature profile of the rod is described by the 1D linear heat equation.Throughout the duration from 1.8288 to 4.6769, there exists a point in time where the temperature distribution of the rod should be greater than the linear profile mu0(x) = 0.0771 * x + 326.154 applicable between the sections 5 and 97.The rod is presumed to be fabricated from two varieties of materials: from 3 to 49 mm, a material with parameters E_a = 1682393, \rrho_a = 5.952e-06, and c_a = 438533237 is utilized, while the remainder of the rod features a material with parameters E_b = 410042, \rrho_b = 3.977e-06, and c_b = 470729859. We define the temperature at position x as u(x). We will consider a discretized time interval of 0.05 seconds and a total time of 8 seconds, employing a 30-element mesh.", "sstl": "F_[[1.8288, 4.6769]] (\\forall x \\in [5, 97] (u(x) - (0.0771 \\cdot x + 326.154) > 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 112\nrho = lambda x: 3.977e-06*470729859 if x < 3 or x > 49 else 5.952e-06*438533237\nE = lambda x: 410042 if x < 3 or x > 49 else 1682393\nxpart = np.linspace(0, L, N + 1)\ng = [321, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 8\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([5, 97], \">\", lambda x: 0.0771 * x + 326.154, lambda x: 0.0771)\ncregions = {\"A\" : apc0}\ncspec = \"((F_[1.8288, 4.6769] (A)))\""}
2 | {"nl": "Consider a metallic rod of 180 mm. The temperature at one end of the rod is fixed at 281k, while a heat source is applied to the other end. The temperature of the rod follows the 1D linear heat equation.For one point during the time interval 0.2591 and 2.7813, the temperature distribution of the rod should be the same as the linear profile mu0(x) = 0.3167 * x + 263.3785 between section 19 and 27. Moreover, for all time between the time interval 5.536 and 7.2884, the temperature distribution of the rod should be lower than the linear profile mu1(x) = -0.0214 * x + 265.8454 between section 132 and 145.We assume the rod is made of two different materials: the section from 137 to 159 mm is made of a material with parameters E_a = 1310187, rho_a = 5.456e-06 and c_a = 408307119, while the rest of the rod is made of a material with parameters E_b = 1194686, rho_b = 5.687e-06 and c_b = 476443964. Denote the temperature at location x as u(x). Assume that the discretized time interval is 0.05s and the max time is 11 seconds. Assume a 30-element mesh is used.", "sstl": "F_[[0.2591, 2.7813]] (\\forall x \\in [19, 27] (u(x) - (0.3167 \\cdot x + 263.3785) = 0)) \\land G_[[5.536, 7.2884]] (\\forall x \\in [132, 145] (u(x) - (-0.0214 \\cdot x + 265.8454) < 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 180\nrho = lambda x: 5.687e-06*476443964 if x < 137 or x > 159 else 5.456e-06*408307119\nE = lambda x: 1194686 if x < 137 or x > 159 else 1310187\nxpart = np.linspace(0, L, N + 1)\ng = [281, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 11\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([19, 27], \"=\", lambda x: 0.3167 * x + 263.3785, lambda x: 0.3167)\napc1 = logic.APCont([132, 145], \"<\", lambda x: -0.0214 * x + 265.8454, lambda x: -0.0214)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((F_[0.2591, 2.7813] (A)) & (G_[5.536, 7.2884] (B)))\""}
3 | {"nl": "Imagine a metallic rod with a maximum length of 112 mm, where one end is kept at a stable temperature of 321k, while the other end is subject to a heat source. The temperature within the rod obeys the 1D linear heat equation.During the entire duration from time 1.8288 to 4.6769, the temperature distribution along the rod must match the linear profile mu0(x) = 0.0771 * x + 326.154 in the segment between 5 and 97.We are assuming the rod comprises two types of materials: the segment from 3 to 49 mm is surrounding a material defined by parameters E_a = 1682393, \rrho_a = 5.952e-06, and c_a = 438533237, whereas the remaining length of the rod is fabricated from a material characterized by parameters E_b = 410042, \rrho_b = 3.977e-06, and c_b = 470729859. The temperature at position x will be denoted as u(x). The discretized time interval is set at 0.05 seconds, with a maximum duration of 8 seconds, while a mesh with 30 elements has been adopted.", "sstl": "G_[[1.8288, 4.6769]] (\\forall x \\in [5, 97] (u(x) - (0.0771 \\cdot x + 326.154) = 0))", "python": "\nfrom femformal.core.fem import heatlinfem as heatlinfem\n\nN = 30\nL = 112\nrho = lambda x: 3.977e-06*470729859 if x < 3 or x > 49 else 5.952e-06*438533237\nE = lambda x: 410042 if x < 3 or x > 49 else 1682393\nxpart = np.linspace(0, L, N + 1)\ng = [321, None]\nf_nodal = np.zeros(N + 1)\ndt = .05\n\nT = 8\nfosys = heatlinfem.heatlinfem_mix(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([5, 97], \"=\", lambda x: 0.0771 * x + 326.154, lambda x: 0.0771)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[1.8288, 4.6769] (A)))\""}
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/one_d_wave/dataset_description.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "few_shot_examples_one_d_wave",
3 | "seed_used": 0,
4 | "test_split": 0.1,
5 | "training_set_size": 3,
6 | "test_set_size": 0,
7 | "training_set_keys": [
8 | "nl",
9 | "sstl",
10 | "python"
11 | ],
12 | "test_set_keys": [],
13 | "combined_files": [
14 | "examples.jsonl"
15 | ]
16 | }
--------------------------------------------------------------------------------
/utils/few_shot_prompts/examples/one_d_wave/examples.jsonl:
--------------------------------------------------------------------------------
1 | {"nl": "Let us examine a rod made of steel and brass, measuring L = 76182 mm in length, where the segment between 19212 mm and 48319 mm consists of brass. The densities are given as \rrho_steel = 7.628e-06 and \rrho_brass = 8.473e-06, with Young's moduli provided as E_steel = 225415054 and E_brass = 179787202. One end of the rod is fixed, while a force that varies with time is applied to the opposite end. This presents a 1D problem regarding the propagation of elastic waves. Denote the displacement of the rod as u(x).Throughout a particular moment in the duration from 0.09 to 0.192, the rod's displacement ought to correspond to the linear profile mu0(x) = -4.692e-05 * x + 1.3255 across the regions defined by 32712 and 42454.Assume the time discretization is 0.0025 seconds, and that the maximum time is 1.5266 seconds, with a 20-element mesh employed for this analysis.", "sstl": "F_[[0.09, 0.192]] (\\forall x \\in [32712, 42454] (u(x) - (-4.692e-05 \\cdot x + 1.3255) > 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 76182\nrho = lambda x: 7.628e-06 if x < 19212 or x > 48319 else 8.473e-06\nE = lambda x: 225415054 if x < 19212 or x > 48319 else 179787202\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.5266\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([32712, 42454], \">\", lambda x: -4.692e-05 * x + 1.3255, lambda x: -4.692e-05)\ncregions = {\"A\" : apc0}\ncspec = \"((F_[0.09, 0.192] (A)))\""}
2 | {"nl": "Consider a rod composed of steel and brass with a length of L = 79651 mm, where the brass section is located between 33634 mm and 43799 mm. The densities are defined as \rrho_steel = 7.927e-06 and \rrho_brass = 8.452e-06, and the Young's moduli are E_steel = 222786951 and E_brass = 102749268. One end is held in place, and a time-dependent force is applied to the other end. This setup is focused on a 1D elastic wave propagation challenge. Let u(x) denote the displacement of the rod.At a given time in the interval from 0.0541 to 0.2621, the rod's displacement must match the linear profile mu0(x) = -1.4897e-05 * x + -1.7281 throughout the range of sections 8330 and 30692. In addition, throughout the interval from 0.2845 to 0.8982, the rod's displacement is expected to be compressed to fit the linear profile mu1(x) = 1.029e-06 * x + -0.3131 across the sections 56782 and 69640.We will assume that the time interval is discretized at 0.0025s, with the maximum time of 1.9424 seconds, using a mesh that contains 20 elements.", "sstl": "F_[[0.0541, 0.2621]] (\\forall x \\in [8330, 30692] (u(x) - (-1.4897e-05 \\cdot x + -1.7281) = 0)) \\land G_[[0.2845, 0.8982]] (\\forall x \\in [56782, 69640] (u(x) - (1.029e-06 \\cdot x + -0.3131) < 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 79651\nrho = lambda x: 7.927e-06 if x < 33634 or x > 43799 else 8.452e-06\nE = lambda x: 222786951 if x < 33634 or x > 43799 else 102749268\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.9424\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([8330, 30692], \"=\", lambda x: -1.4897e-05 * x + -1.7281, lambda x: -1.4897e-05)\napc1 = logic.APCont([56782, 69640], \"<\", lambda x: 1.029e-06 * x + -0.3131, lambda x: 1.029e-06)\ncregions = {\"A\" : apc0, \"B\" : apc1}\ncspec = \"((F_[0.0541, 0.2621] (A)) & (G_[0.2845, 0.8982] (B)))\""}
3 | {"nl": "Consider a rod composed of steel and brass with a length of L = 76182 mm, where the brass section is located between 19212 mm and 48319 mm. The densities are defined as \rrho_steel = 7.628e-06 and \rrho_brass = 8.473e-06, and the Young's moduli are E_steel = 225415054 and E_brass = 179787202. One end is held in place, and a time-dependent force is applied to the other end. This setup is focused on a 1D elastic wave propagation challenge. Let u(x) denote the displacement of the rod.Across the time interval from 0.09 to 0.192, the displacement of the rod should conform to the linear profile mu0(x) = -4.692e-05 * x + 1.3255 over the range specified by 32712 and 42454.We will assume that the time interval is discretized at 0.0025s, with the maximum time of 1.5266 seconds, using a mesh that contains 20 elements.", "sstl": "G_[[0.09, 0.192]] (\\forall x \\in [32712, 42454] (u(x) - (-4.692e-05 \\cdot x + 1.3255) = 0))", "python": "\nfrom femformal.core.fem import mechlinfem as mechlinfem\n\nN = 20\nL = 76182\nrho = lambda x: 7.628e-06 if x < 19212 or x > 48319 else 8.473e-06\nE = lambda x: 225415054 if x < 19212 or x > 48319 else 179787202\nxpart = np.linspace(0, L, N + 1)\ng = [0.0, None]\nf_nodal = np.zeros(N + 1)\ndt = .0025\n\nT = 1.5266\nsosys = mechlinfem.mechlinfem(xpart, rho, E, g, f_nodal, dt)\n\napc0 = logic.APCont([32712, 42454], \"=\", lambda x: -4.692e-05 * x + 1.3255, lambda x: -4.692e-05)\ncregions = {\"A\" : apc0}\ncspec = \"((G_[0.09, 0.192] (A)))\""}
--------------------------------------------------------------------------------
/utils/few_shot_prompts/few_shot_prompting.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 |
4 |
5 | class FewShotPrompting:
6 | def __init__(self, num_shots):
7 | self.num_shots = num_shots
8 | if num_shots > 3:
9 | raise ValueError("Only supports 0 up to 3 shots.")
10 | pass
11 |
12 | def get_alpaca_format(self, instruction, task_input, task_output="", wrap_in_code_block=None) -> str:
13 | if wrap_in_code_block == 'python':
14 | prompt = f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n```python\n{task_output}"""
15 | if task_output != "":
16 | prompt += "\n```\n\n"
17 | return prompt
18 | elif wrap_in_code_block == 'latex':
19 | prompt = f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n```latex\n{task_output}"""
20 | if task_output != "":
21 | prompt += "\n```\n\n"
22 | return prompt
23 | else:
24 | if task_output != "":
25 | return f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n{task_output}\n\n"""
26 | else:
27 | return f"""### Instruction:\n{instruction}\n\n### Input:\n{task_input}\n\n### Response:\n"""
28 |
29 |
30 |
31 | def _get_few_shot_prompt(self):
32 | if hasattr(self, 'examples'):
33 | if hasattr(self, 'shuffle'):
34 | if self.shuffle:
35 | temp_examples = copy.copy(self.examples)
36 | random.shuffle(temp_examples)
37 | return "".join(temp_examples[:self.num_shots])
38 | prompt_examples = self.examples[:self.num_shots]
39 | else:
40 | raise AttributeError("Subclasses of FewShotPrompting must define 'self.examples'")
41 | return "".join(prompt_examples)
42 |
43 | def get_intruction(self, format):
44 | if format == "nl_to_python":
45 | return """Below is a natural language description of partial differential equation optimization problem. Translate the problem into Python code following spatial-signal temporal logic."""
46 | elif format == "nl_to_sstl":
47 | return """Below is a natural language description of partial differential equation optimization problem. Translate the problem into Latex code following spatial-signal temporal logic."""
48 | elif format == "train_nl_and_sstl_to_python":
49 | return """Below is a natural language description of partial differential equation optimization problem, paired with a spatial-signal temporal logic description of the same problem. Translate the problem into Python code following spatial-signal temporal logic."""
50 | elif format == "test_nl_to_python_with_sstl_cot":
51 | return """Below is a natural language description of partial differential equation optimization problem. Translate the problem into Python code following spatial-signal temporal logic. Explain your reasoning by first providing spatial signal temporal logic statement in Latex. Let's think step by step."""
52 | elif format == "test_nl_with_given_sstl_to_python" or format == "train_nl_with_given_sstl_to_python":
53 | return """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of the same problem provided earlier. Note that there may be mistakes in the spatial-signal temporal logic statement but the natural language description is accurate. Translate the problem into Python code following spatial-signal temporal logic."""
54 | elif format == "dpo_train_nl_to_sstl":
55 | return """Below is a natural language description of partial differential equation optimization problem. Instead of optimizing the provided problem directly, we want to optimize an intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Generate a spatial-signal temporal logic description in Latex code for such an intermediate problem."""
56 | elif format == "dpo_test_sstl_to_python":
57 | return """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Instead of optimizing the natural language problem directly, we want to optimize the intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Your spatial-signal temporal logic description in latex paired to the original problem describes this intermediate problem. Translate the intermediate problem into Python code following spatial-signal temporal logic."""
58 | # return """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Translate the intermediate problem into Python code following spatial-signal temporal logic."""
59 |
60 | # """Below is a natural language description of partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Instead of optimizing the provided problem directly, we want to optimize an intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Your spatial-signal temporal logic description paired to the original problem describes this intermediate problem. Translate the intermediate problem into Python code following spatial-signal temporal logic."""
61 |
62 | # try the prompt below if the above one doesn't work.
63 |
64 | # """Below is a natural language description of a partial differential equation optimization problem, paired with your spatial-signal temporal logic description of an intermediate problem provided earlier. Instead of optimizing the provided problem directly, we want to optimize this intermediate problem to produce a state that will better serve to achieve the final conditions outlined in the natural language problem. Use both the natural language description and the LaTeX code of the intermediate problem to translate the intermediate problem into Python code following spatial-signal temporal logic."""
65 | else:
66 | raise ValueError(f"Invalid format: {format}")
--------------------------------------------------------------------------------
/utils/few_shot_prompts/few_shot_test.py:
--------------------------------------------------------------------------------
1 | from .few_shot_prompting import FewShotPrompting
2 |
3 | class FewShotTest(FewShotPrompting):
4 | def __init__(self, num_shots) -> None:
5 | super().__init__(num_shots=num_shots)
6 |
7 | def format_prompt(self, format, nl, sstl="", python=""):
8 | intruction = self.get_intruction(format)
9 | nl = nl.strip()
10 | sstl = sstl.strip()
11 | python = python.strip()
12 | if format == "nl_to_python":
13 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=python, wrap_in_code_block='python')
14 | return prompt
15 | elif format == "test_nl_to_python_with_sstl_cot":
16 | # Don't wrap sstl in code block because the output is custom formatted
17 | if sstl != "" and python != "":
18 | # few shot prompt
19 | task_output = f"""Spatial Signal Temporal Logic:\n```latex\n{sstl}\n```\n\nPython:\n```python\n{python}\n```"""
20 | else:
21 | # actual test case
22 | task_output = ""
23 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=task_output)
24 | return prompt
25 | elif format == "nl_to_sstl":
26 |
27 | task_output = f"""Spatial Signal Temporal Logic:\n```latex\n{sstl}"""
28 | if sstl != "":
29 | # few shot prompt
30 | task_output += """\n```"""
31 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=task_output, wrap_in_code_block=None)
32 | return prompt
33 | elif format == "test_nl_with_given_sstl_to_python":
34 | task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```"""
35 | prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python')
36 | return prompt
37 |
38 |
39 | def stop_words(self):
40 | return ["\n### Instruction:", "### Instruction:"]
41 |
42 |
43 |
--------------------------------------------------------------------------------
/utils/few_shot_prompts/few_shot_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from .few_shot_prompting import FewShotPrompting
4 |
5 |
6 | class FewShotTrain(FewShotPrompting):
7 | def __init__(self, num_shots=0, format=None, dataset=None) -> None:
8 | super().__init__(num_shots=num_shots)
9 | self.examples = []
10 |
11 | if num_shots > 0:
12 | # this is only for few-shot training evaluation to get the predicted sstl to construct the second training set of the 2 part training process.
13 | assert format is not None, "Please provide format for few-shot training evaluation"
14 | assert dataset is not None, "Please provide dataset for few-shot training evaluation"
15 | self.format=format
16 | self.shuffle = True
17 | current_dir = os.path.dirname(os.path.abspath(__file__))
18 | jsonl_file_path = os.path.join(current_dir, "examples", f"one_d_{dataset}", "examples.jsonl")
19 | self.examples = self.load_examples(jsonl_file_path, format)
20 |
21 | def format_prompt(self, format, nl, sstl="", python=""):
22 | intruction = self.get_intruction(format)
23 | nl = nl.strip()
24 | sstl = sstl.strip()
25 | python = python.strip()
26 | if format == "nl_to_python":
27 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=python, wrap_in_code_block='python')
28 | return prompt
29 | elif format == "nl_to_sstl":
30 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=sstl, wrap_in_code_block='latex')
31 | return prompt
32 | elif format == "train_nl_and_sstl_to_python" or format == "train_nl_with_given_sstl_to_python":
33 | task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```"""
34 | prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python')
35 | return prompt
36 |
37 | def format_prompt_test(self, nl, sstl="", python=""):
38 | return self.format_prompt(self.format, nl.strip(), sstl.strip(), python.strip())
39 |
40 | def load_examples(self, jsonl_file_path, format):
41 | examples = []
42 | with open(jsonl_file_path, 'r') as file:
43 | for line in file:
44 | data = json.loads(line)
45 | nl = data['nl']
46 | python = data['python']
47 | sstl = data['sstl']
48 | example = self.format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip())
49 | examples.append(example)
50 | return examples
51 |
52 | def stop_words(self):
53 | return ["\n### Instruction:", "### Instruction:"]
--------------------------------------------------------------------------------
/utils/few_shot_prompts/few_shot_train_dpo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from .few_shot_prompting import FewShotPrompting
4 |
5 |
6 | class FewShotDPO(FewShotPrompting):
7 | def __init__(self, num_shots=0, format=None, dataset=None) -> None:
8 | super().__init__(num_shots=num_shots)
9 | self.examples = []
10 | self.format=format
11 |
12 | if num_shots > 0:
13 | # this is only for few-shot training evaluation to get the predicted sstl to construct the second training set of the 2 part training process.
14 | assert format is not None, "Please provide format for few-shot training evaluation"
15 | assert dataset is not None, "Please provide dataset for few-shot training evaluation"
16 | self.format=format
17 | self.shuffle = True
18 | current_dir = os.path.dirname(os.path.abspath(__file__))
19 | jsonl_file_path = os.path.join(current_dir, "examples", f"DPO_one_d_{dataset}", "examples.jsonl")
20 | self.examples = self.load_examples(jsonl_file_path, format)
21 |
22 | def format_prompt(self, format, nl, sstl="", python=""):
23 | intruction = self.get_intruction(format)
24 | nl = nl.strip()
25 | sstl = sstl.strip()
26 | python = python.strip()
27 | # if format == "nl_to_python":
28 | # prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=python, wrap_in_code_block='python')
29 | # return prompt
30 | if format == "dpo_train_nl_to_sstl":
31 | prompt = self.get_alpaca_format(intruction, task_input=nl, task_output=sstl, wrap_in_code_block='latex')
32 | return prompt
33 | elif format == "dpo_test_sstl_to_python":
34 | task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```"""
35 | prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python')
36 | return prompt
37 | # elif format == "train_nl_and_sstl_to_python" or format == "train_nl_with_given_sstl_to_python":
38 | # task_input = f"""{nl}\n\nSpatial Signal Temporal Logic:\n```latex\n{sstl}\n```"""
39 | # prompt = self.get_alpaca_format(intruction, task_input=task_input, task_output=python, wrap_in_code_block='python')
40 | # return prompt
41 |
42 | def format_prompt_test(self, nl, sstl="", python=""):
43 | few_shots = self._get_few_shot_prompt()
44 | curr_prompt = self.format_prompt(self.format, nl.strip(), sstl.strip(), python.strip())
45 | return f"{few_shots}{curr_prompt}"
46 | # return self.format_prompt(self.format, nl.strip(), sstl.strip(), python.strip())
47 |
48 | def load_examples(self, jsonl_file_path, format):
49 | examples = []
50 | with open(jsonl_file_path, 'r') as file:
51 | for line in file:
52 | data = json.loads(line)
53 | nl = data['nl']
54 | python = data['python']
55 | sstl = data['sstl']
56 | example = self.format_prompt(format, nl=nl.strip(), sstl=sstl.strip(), python=python.strip())
57 | examples.append(example)
58 | return examples
59 |
60 | def stop_words(self):
61 | return ["\n### Instruction:", "### Instruction:"]
--------------------------------------------------------------------------------