├── .gitignore ├── LICENSE ├── README.md ├── datasets └── converted_aime_dataset │ ├── dataset_dict.json │ ├── test │ ├── cache-80251b5741d7ee37.arrow │ ├── cache-86b4d5a4725288d2.arrow │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json │ └── train │ ├── cache-d363ce53f400e210.arrow │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json ├── dockerfile ├── Dockerfile └── docker-entrypoint.sh ├── docs ├── logo.png ├── ppo_examples.md └── ray_architecture.png ├── evaluate_model.py ├── examples └── scripts │ ├── docker_run.sh │ ├── nvidia_docker_install.sh │ ├── serve_remote_rm.sh │ ├── train_conditional_llama.sh │ ├── train_continue_pretrain_llama.sh │ ├── train_dpo_llama.sh │ ├── train_dpo_llama_34b.sh │ ├── train_dpo_ring_llama.sh │ ├── train_iterative_dpo_llama.sh │ ├── train_knowledge_distillation.sh │ ├── train_kto_llama.sh │ ├── train_llama_slurm.sh │ ├── train_ppo_llama.sh │ ├── train_ppo_llama_ray.sh │ ├── train_ppo_llama_ray_70b.sh │ ├── train_ppo_llama_ray_slurm.sh │ ├── train_ppo_llama_with_remote_rm.sh │ ├── train_prm_mistral.sh │ ├── train_reinforce_llama_ray.sh │ ├── train_rejection_sampling_llama.sh │ ├── train_rm_llama.sh │ ├── train_sft_llama.sh │ └── train_sft_mixtral_lora.sh ├── index.html ├── openrlhf ├── __init__.py ├── cli │ ├── __init__.py │ ├── batch_inference.py │ ├── interactive_chat.py │ ├── serve_rm.py │ ├── train_dpo.py │ ├── train_gsm.py │ ├── train_kd.py │ ├── train_kto.py │ ├── train_ppo.py │ ├── train_ppo_ray.py │ ├── train_prm.py │ ├── train_rm.py │ └── train_sft.py ├── datasets │ ├── __init__.py │ ├── process_reward_dataset.py │ ├── prompts_dataset.py │ ├── reward_dataset.py │ ├── sft_dataset.py │ ├── unpaired_preference_dataset.py │ └── utils.py ├── models │ ├── __init__.py │ ├── actor.py │ ├── loss.py │ ├── model.py │ ├── ring_attn_utils.py │ └── utils.py ├── trainer │ ├── __init__.py │ ├── dpo_trainer.py │ ├── kd_trainer.py │ ├── kto_trainer.py │ ├── ppo_trainer.py │ ├── ppo_utils │ │ ├── __init__.py │ │ ├── experience_maker.py │ │ ├── kl_controller.py │ │ └── replay_buffer.py │ ├── prm_trainer.py │ ├── ray │ │ ├── __init__.py │ │ ├── launcher.py │ │ ├── ppo_actor.py │ │ ├── ppo_critic.py │ │ ├── utils.py │ │ ├── vllm_engine.py │ │ └── vllm_worker_wrap.py │ ├── rm_trainer.py │ └── sft_trainer.py └── utils │ ├── __init__.py │ ├── deepspeed │ ├── __init__.py │ ├── deepspeed.py │ └── deepspeed_utils.py │ ├── distributed_sampler.py │ ├── distributed_util.py │ ├── logging_utils.py │ ├── processor.py │ ├── remote_rm_utils.py │ └── utils.py ├── pyproject.toml ├── requirements.txt ├── reward_server ├── math_server.py └── math_server_test.py ├── run_rloo_1.5B.sh ├── run_rloo_7B.sh ├── setup.py ├── static ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css ├── images │ ├── Fast.png │ ├── Hierarchy.png │ ├── Idea.png │ ├── Observation.png │ ├── Telescope.png │ ├── combined_dual_axis_plots_horizontal-1.png │ ├── combined_dual_axis_plots_horizontal.pdf │ ├── eqn-2.png │ ├── eqn.drawio-3.pdf │ ├── eqn.drawio-4-1.png │ ├── eqn.drawio-4.pdf │ ├── eqn.png │ ├── gpt.png │ ├── method-1.png │ ├── method.pdf │ ├── normalized_averages copy-1.png │ ├── normalized_averages copy.pdf │ ├── normalized_averages-1.png │ ├── normalized_averages.pdf │ ├── output.png │ ├── perf_rm.png │ ├── pipeline.drawio-5.png │ ├── rej.png │ ├── result.png │ ├── spr.png │ └── train.png ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js └── pdfs │ └── sample.pdf ├── utils ├── __init__.py ├── grader.py ├── latex2sympy │ ├── .coveragerc │ ├── .gitignore │ ├── LICENSE.txt │ ├── PS.g4 │ ├── README.md │ ├── __init__.py │ ├── antlr-4.11.1-complete.jar │ ├── asciimath_printer.py │ ├── description.txt │ ├── dev-requirements.in │ ├── dev-requirements.txt │ ├── gen │ │ ├── PS.interp │ │ ├── PS.tokens │ │ ├── PSLexer.interp │ │ ├── PSLexer.py │ │ ├── PSLexer.tokens │ │ ├── PSListener.py │ │ ├── PSParser.py │ │ └── __init__.py │ ├── icon.png │ ├── latex2sympy2.py │ ├── requirements.in │ ├── requirements.txt │ ├── sandbox │ │ ├── linalg_equations.py │ │ ├── linalg_span.py │ │ ├── matrix.py │ │ ├── matrix_placeholders.py │ │ ├── sandbox.py │ │ ├── sandbox_equality.py │ │ ├── sectan.py │ │ └── vector.py │ ├── scripts │ │ ├── compile.sh │ │ ├── coverage-ci.sh │ │ ├── coverage.sh │ │ ├── pre-commit │ │ ├── pre-push │ │ ├── publish.sh │ │ ├── setup-hooks.sh │ │ ├── setup.sh │ │ └── test.sh │ ├── setup.cfg │ ├── setup.py │ └── tests │ │ ├── __init__.py │ │ ├── abs_test.py │ │ ├── all_bad_test.py │ │ ├── all_good_test.py │ │ ├── atom_expr_test.py │ │ ├── binomial_test.py │ │ ├── ceil_test.py │ │ ├── complex_test.py │ │ ├── context.py │ │ ├── exp_test.py │ │ ├── floor_test.py │ │ ├── gcd_test.py │ │ ├── greek_test.py │ │ ├── grouping_test.py │ │ ├── lcm_test.py │ │ ├── left_right_cdot_test.py │ │ ├── linalg_test.py │ │ ├── max_test.py │ │ ├── min_test.py │ │ ├── mod_test.py │ │ ├── overline_test.py │ │ ├── pi_test.py │ │ ├── trig_test.py │ │ └── variable_test.py ├── parser.py └── utils.py └── version.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | outputs/ 176 | results/ 177 | logs/ 178 | wandb/ 179 | datasets/compression_dataset/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training Large Language Models to Reason Efficiently 2 | 3 | This is the codebase for our paper ["Training Large Language Models to Reason Efficiently"](https://arxiv.org/abs/2502.04463). 4 | The codebase has been tested on GH200 GPUs with Python 3.10.15 and CUDA 12.6. 5 | Other environments might require different tweaks in installation of Flash Attention or vLLM. 6 | 7 | ## Trained Checkpoints 8 | 9 | Trained checkpoints can be found at this [link](https://huggingface.co/collections/Zanette-Labs/efficient-reasoning-67f75497d50c1316c04f9332) 10 | 11 | ## Installation 12 | 13 | ``` 14 | conda create -n efficient_reasoning python=3.10.15 15 | conda activate efficient_reasoning 16 | cd utils/latex2sympy 17 | pip install -e . 18 | cd ../../ 19 | pip install -e . 20 | ``` 21 | 22 | ## Dataset 23 | 24 | Download the dataset used in the paper using: 25 | 26 | ``` 27 | huggingface-cli download daman1209arora/compression_dataset --repo-type dataset --local-dir datasets/compression_dataset 28 | ``` 29 | 30 | This dataset is a random split created using easily parsed problems from the MATH, cn k12, AIME, AoPS and Olympiad subsets Numina Math dataset. 31 | 32 | ## Design 33 | Our codebase is adapted using the [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) library. 34 | For minimal changes to the codebase, we launch a remote reward server defined in `reward_server/math_server.py` 35 | which is then passed to the trainer defined in OpenRLHF. 36 | 37 | 38 | ## Usage 39 | 40 | For an illustrative example, we provide scripts to run on a slurm cluster: 41 | 42 | 1. 1.5B with 4 GH200 GPUs on 1 node. `run_rloo_1.5B.sh` 43 | 2. 7B with 8 GH200 GPUs on 2 nodes. `run_rloo_7B.sh` 44 | 45 | To train the `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` model, set `WANDB_KEY` and `ALPHA` in run_rloo_1.5B.sh and run the following command: 46 | 47 | ``` 48 | sbatch run_rloo_1.5B.sh 49 | ``` 50 | 51 | ## Evaluation 52 | 53 | To evaluate a model, use the script provided in the main directory: 54 | 55 | ``` 56 | python evaluate_model.py \ 57 | --model_path='scale:1.5B_alpha:0.1/' \ 58 | --dataset=openai/gsm8k \ 59 | --scale=1.5B 60 | ``` 61 | 62 | ## Citation 63 | 64 | If you find this code repository useful, please cite us! 65 | 66 | ``` 67 | @misc{arora2025traininglanguagemodelsreason, 68 | title={Training Language Models to Reason Efficiently}, 69 | author={Daman Arora and Andrea Zanette}, 70 | year={2025}, 71 | eprint={2502.04463}, 72 | archivePrefix={arXiv}, 73 | primaryClass={cs.LG}, 74 | url={https://arxiv.org/abs/2502.04463}, 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/dataset_dict.json: -------------------------------------------------------------------------------- 1 | {"splits": ["train", "test"]} -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/test/cache-80251b5741d7ee37.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/datasets/converted_aime_dataset/test/cache-80251b5741d7ee37.arrow -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/test/cache-86b4d5a4725288d2.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/datasets/converted_aime_dataset/test/cache-86b4d5a4725288d2.arrow -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/test/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/datasets/converted_aime_dataset/test/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/test/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "parquet", 3 | "citation": "", 4 | "config_name": "default", 5 | "dataset_name": "aimo-validation-aime", 6 | "dataset_size": 520443, 7 | "description": "", 8 | "download_checksums": { 9 | "hf://datasets/AI-MO/aimo-validation-aime@8c6be017548b0fe2eb9fdf2239fca99700b9bccf/data/train-00000-of-00001.parquet": { 10 | "num_bytes": 261038, 11 | "checksum": null 12 | } 13 | }, 14 | "download_size": 261038, 15 | "features": { 16 | "problem": { 17 | "dtype": "string", 18 | "_type": "Value" 19 | }, 20 | "solution": { 21 | "dtype": "string", 22 | "_type": "Value" 23 | }, 24 | "year": { 25 | "dtype": "string", 26 | "_type": "Value" 27 | } 28 | }, 29 | "homepage": "", 30 | "license": "", 31 | "size_in_bytes": 781481, 32 | "splits": { 33 | "train": { 34 | "name": "train", 35 | "num_bytes": 520443, 36 | "num_examples": 90, 37 | "dataset_name": "aimo-validation-aime" 38 | } 39 | }, 40 | "version": { 41 | "version_str": "0.0.0", 42 | "major": 0, 43 | "minor": 0, 44 | "patch": 0 45 | } 46 | } -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/test/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "b1373567ca1973ef", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "train" 13 | } -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/train/cache-d363ce53f400e210.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/datasets/converted_aime_dataset/train/cache-d363ce53f400e210.arrow -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/train/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/datasets/converted_aime_dataset/train/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/train/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "csv", 3 | "citation": "", 4 | "config_name": "default", 5 | "dataset_name": "csv", 6 | "dataset_size": 358605, 7 | "description": "", 8 | "download_checksums": { 9 | "/u/darora1/AIME_Dataset_1983_2024.csv": { 10 | "num_bytes": 342215, 11 | "checksum": null 12 | } 13 | }, 14 | "download_size": 342215, 15 | "features": { 16 | "problem": { 17 | "dtype": "string", 18 | "_type": "Value" 19 | }, 20 | "solution": { 21 | "dtype": "string", 22 | "_type": "Value" 23 | }, 24 | "year": { 25 | "dtype": "int64", 26 | "_type": "Value" 27 | } 28 | }, 29 | "homepage": "", 30 | "license": "", 31 | "size_in_bytes": 700820, 32 | "splits": { 33 | "train": { 34 | "name": "train", 35 | "num_bytes": 358605, 36 | "num_examples": 933, 37 | "dataset_name": "csv" 38 | } 39 | }, 40 | "version": { 41 | "version_str": "0.0.0", 42 | "major": 0, 43 | "minor": 0, 44 | "patch": 0 45 | } 46 | } -------------------------------------------------------------------------------- /datasets/converted_aime_dataset/train/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "b78409907fd81052", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "train" 13 | } -------------------------------------------------------------------------------- /dockerfile/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.07-py3 2 | 3 | WORKDIR /app 4 | 5 | RUN set -eux && \ 6 | apt-get update && \ 7 | apt-get install -y gosu && \ 8 | rm -rf /var/lib/apt/lists/* && \ 9 | gosu nobody true 10 | 11 | RUN apt-get update && apt-get -y install sudo 12 | RUN sudo su - 13 | 14 | RUN DEBIAN_FRONTEND=noninteractive apt install -y tzdata 15 | 16 | RUN apt-get -y install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev gdb 17 | RUN pip uninstall xgboost transformer_engine flash_attn -y 18 | RUN pip install vllm==0.6.4.post1 19 | 20 | COPY docker-entrypoint.sh . 21 | RUN chmod a+x docker-entrypoint.sh 22 | 23 | ENTRYPOINT ["/app/docker-entrypoint.sh"] 24 | -------------------------------------------------------------------------------- /dockerfile/docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | USER=${LOCAL_USER:-"root"} 4 | 5 | if [[ "${USER}" != "root" ]]; then 6 | USER_ID=${LOCAL_USER_ID:-9001} 7 | echo ${USER} 8 | echo ${USER_ID} 9 | 10 | chown ${USER_ID} /home/${USER} 11 | useradd --shell /bin/bash -u ${USER_ID} -o -c "" -m ${USER} 12 | usermod -a -G root ${USER} 13 | adduser ${USER} sudo 14 | 15 | # user:password 16 | echo "${USER}:123" | chpasswd 17 | 18 | export HOME=/home/${USER} 19 | export PATH=/home/${USER}/.local/bin/:$PATH 20 | else 21 | export PATH=/root/.local/bin/:$PATH 22 | fi 23 | 24 | cd $HOME 25 | exec gosu ${USER} "$@" -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/docs/logo.png -------------------------------------------------------------------------------- /docs/ppo_examples.md: -------------------------------------------------------------------------------- 1 | ### Training settings 2 | 3 | * Base model: LLaMA2 7B 4 | * SFT dataset: Ocra 500k 5 | * RM datasets: Anthropic HH, LMSys, OASST, WebGPT ~ 240K+ 6 | * PPO prompts: alpaca, Anthropic HH, OASST ~ 100K 7 | 8 | ### Question: Please introduce the game GTA5 9 | 10 | **SFT** 11 | 12 | Human: Please introduce the game GTA5 13 | 14 | Assistant: GTA5, or Grand Theft Auto V, is an action-adventure video game developed by Rockstar North and published by Rockstar Games. It was released in 2013 for the PlayStation 3 and Xbox 360, and later in 2014 for the PlayStation 4 and Xbox One. The game is set in the fictional city of Los Santos, based in Los Angeles, and follows the story of three criminals as they plan and execute heists. 15 | 16 | GTA5 is an open-world game, meaning that players can explore the game's vast and detailed world at their leisure, completing missions, engaging in side activities, and interacting with other characters. The game features an extensive array of vehicles, from cars to boats to planes, and players can use them to travel around the city and engage in various forms of criminal activity. 17 | 18 | GTA5 is known for its immersive gameplay, stunning graphics, and rich storytelling. It has received critical acclaim and has become one of the best-selling video games of all time, with over 150 million copies sold worldwide. 19 | 20 | **PPO** 21 | 22 | Human: Please introduce the game GTA5 23 | 24 | Assistant: 1. **Game Name:** Grand Theft Auto V (GTA V) 25 | 26 | 2. **Platforms:** GTA V is available on PlayStation 3 and Xbox 360, and later released on PlayStation 4, Xbox One, and Microsoft Windows. 27 | 28 | 3. **Genre:** Action-adventure game 29 | 30 | 4. **Gameplay:** GTA V is an open-world game where players can explore the fictional city of Los Santos and its surrounding areas, completing missions, engaging in side activities, and interacting with non-player characters. The game features a third-person perspective and allows players to freely roam the game world, engage in combat, and complete objectives. 31 | 32 | 5. **Storyline:** The game follows the story of three protagonists, Michael De Santa, Franklin Clinton, and Trevor Philips, as they become embroiled in a series of heists and other criminal activities. The game's storyline is considered one of the strongest aspects of GTA 33 | 34 | 6. **Graphics and Sound:** GTA V features stunning graphics and realistic character animations, making the game world feel alive and immersive. The soundtrack and voice acting are also highly praised, with a diverse and engaging musical score that complements the game's action-packed gameplay. 35 | 36 | 7. **Reception:** GTA V received critical acclaim upon its release, with many praising its storyline, gameplay, graphics, and sound. It is considered one of the best video games of all time and has sold over 150 million copies worldwide, making it one of the best-selling video games of all time. 37 | 38 | In summary, GTA V is an action-adventure game set in a massive open world, with a complex and engaging storyline, stunning graphics, and highly praised soundtrack and voice acting. It is considered one of the greatest video games of all time and continues to be a popular choice among gamers worldwide. 39 | 40 | 41 | ### Question: Which GPU is better, AMD or NVIDIA? 42 | 43 | **SFT** 44 | 45 | Human: Which GPU is better, AMD or NVIDIA? 46 | 47 | Assistant: 48 | 49 | The performance and suitability of a GPU (Graphics Processing Unit) depend on various factors, including the specific task or application being used, the model and generation of the GPU, and the system it is being used in. Both AMD and NVIDIA are well-known and reputable manufacturers of GPUs, and they both have a wide range of models and generations. 50 | 51 | In general, the performance of a GPU can vary depending on the specific task or application being used. For example, AMD's GPUs are known for their performance in gaming and video editing, while NVIDIA's GPUs are known for their performance in machine learning and deep learning. 52 | 53 | It is important to note that the performance of a GPU can also vary depending on the generation of the GPU. Newer generations of GPUs tend to have better performance and features than older generations. 54 | 55 | Ultimately, the best GPU for you will depend on your specific needs and the tasks you plan to use it for. It is recommended to research and compare the specific models and generations of AMD and NVIDIA GPUs to determine which one is the best fit for your needs. 56 | 57 | **PPO** 58 | 59 | Human: Which GPU is better, AMD or NVIDIA? 60 | 61 | Assistant: 1. **Performance:** In terms of raw performance, NVIDIA typically outperforms AMD in most cases, especially in gaming and video editing. 62 | 63 | 2. **Price:** NVIDIA GPUs tend to be more expensive than AMD GPUs, although this can vary depending on the specific model and its features. 64 | 65 | 3. **Architecture:** NVIDIA's GPU architecture is based on the CUDA (Compute Unified Device Architecture) platform, which allows for more efficient processing of parallel computing tasks. AMD's GPU architecture, on the other hand, is based on the Stream Compute architecture, which is designed for general-purpose computing. 66 | 67 | 4. **Application Support:** NVIDIA's GPUs have better support for certain applications, particularly in the fields of deep learning and artificial intelligence. AMD's GPUs, however, have better support for certain games and video editing software. 68 | 69 | In summary, NVIDIA's GPUs tend to be more powerful and efficient for certain tasks, while AMD's GPUs are better suited for other tasks and have better support for certain applications. Ultimately, the choice between the two depends on the specific needs and preferences of the user. 70 | -------------------------------------------------------------------------------- /docs/ray_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/docs/ray_architecture.png -------------------------------------------------------------------------------- /examples/scripts/docker_run.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | PROJECT_PATH=$(cd $(dirname $0)/../../; pwd) 4 | IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" 5 | 6 | docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN \ 7 | -v $PROJECT_PATH:/openrlhf -v $HOME/.cache:/root/.cache -v $HOME/.bash_history2:/root/.bash_history \ 8 | $IMAGE_NAME bash -------------------------------------------------------------------------------- /examples/scripts/nvidia_docker_install.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | # remove old docker 4 | sudo apt-get autoremove docker docker-ce docker-engine docker.io containerd runc 5 | dpkg -l |grep ^rc|awk '{print $2}' |sudo xargs dpkg -P 6 | sudo apt-get autoremove docker-ce-* 7 | sudo rm -rf /etc/systemd/system/docker.service.d 8 | sudo rm -rf /var/lib/docker 9 | 10 | # install docker 11 | curl https://get.docker.com | sh \ 12 | && sudo systemctl --now enable docker 13 | 14 | # install nvidia-docker 15 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ 16 | && curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ 17 | && curl -s -L https://nvidia.github.io/libnvidia-container/$distribution/libnvidia-container.list | \ 18 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ 19 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list 20 | 21 | sudo apt-get update 22 | sudo apt-get install -y nvidia-container-toolkit 23 | sudo nvidia-ctk runtime configure --runtime=docker 24 | 25 | sudo groupadd docker 26 | sudo usermod -aG docker $USER 27 | newgrp docker 28 | docker ps -------------------------------------------------------------------------------- /examples/scripts/serve_remote_rm.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | python -m openrlhf.cli.serve_rm \ 4 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ 5 | --port 5000 \ 6 | --bf16 \ 7 | --flash_attn \ 8 | --normalize_reward \ 9 | --max_len 8192 \ 10 | --batch_size 16 -------------------------------------------------------------------------------- /examples/scripts/train_conditional_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | checkSuccess() { 4 | if [[ $? != 0 ]]; then 5 | echo "FAILED $1" 6 | exit 1 7 | fi 8 | } 9 | 10 | mkdir -p ./checkpoint/llama-2-8b-csft 11 | RM_OUTPUT=./checkpoint/llama-2-8b-csft/rm.jsonl 12 | 13 | read -r -d '' get_rewards_commands < 0)); then 32 | POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH 33 | fi 34 | 35 | read -r -d '' generate_commands <$ITER_LOG_PATH 98 | fi 99 | done -------------------------------------------------------------------------------- /examples/scripts/train_knowledge_distillation.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands <> ${JOBLOG} 26 | 27 | # load training commands 28 | source ./${training_script} slurm 29 | echo training_commands &>> ${JOBLOG} 30 | echo $training_commands &>> ${JOBLOG} 31 | 32 | # master addr and port 33 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 34 | export MASTER_PORT=9901 35 | 36 | srun --container-image="$IMAGE_NAME" \ 37 | --container-mounts="$PROJECT_PATH:/openrlhf,$HOME/.cache:/root/.cache" \ 38 | bash -c " cd /openrlhf; pip install . ; torchrun \ 39 | --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \ 40 | --master_addr $MASTER_ADDR --master_port $MASTER_PORT -m ${training_commands}" &>> ${JOBLOG} 41 | 42 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands < 16 | MOUNT="$OPENRLHF_PATH:/openrlhf,$HOME/.cache:/root/.cache" 17 | IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" 18 | RAY_VERSION=2.12.0 19 | 20 | JOBLOG="$(realpath .)/train_ppo_llama_ray-$SLURM_JOB_ID.log" 21 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} started ..." &>> ${JOBLOG} 22 | 23 | # launch ray daemon 24 | nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") # Getting the node names 25 | nodes_array=( $nodes ) 26 | node_1=${nodes_array[0]} 27 | ip=$node_1 28 | 29 | port=6379 30 | ip_head=$ip:$port 31 | export ip_head 32 | echo "IP Head: $ip_head" &>> ${JOBLOG} 33 | 34 | echo "STARTING HEAD at $node_1" &>> ${JOBLOG} 35 | srun --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 36 | && pip install ray[default]==$RAY_VERSION \ 37 | && /root/.local/bin/ray start --head --node-ip-address=$ip --port=$port --block" &>> ${JOBLOG} & 38 | sleep 10s 39 | 40 | worker_num=$((SLURM_JOB_NUM_NODES)) #number of nodes other than the head node 41 | for ((i = 1; i < worker_num; i++)); do 42 | node_i=${nodes_array[$i]} 43 | echo "STARTING WORKER $i at $node_i" &>> ${JOBLOG} 44 | srun --nodes=1 --ntasks=1 -w "$node_i" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 45 | && pip install ray[default]==$RAY_VERSION \ 46 | && /root/.local/bin/ray start --address "$ip_head" --block" &>> ${JOBLOG} & 47 | sleep 1s; 48 | done 49 | 50 | sleep 30s 51 | 52 | # ===== submit ray job ===== 53 | # Job start 54 | srun --overlap --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 55 | "pip install ray[default]==$RAY_VERSION \ 56 | && /root/.local/bin/ray job submit --address=http://localhost:8265 \ 57 | --runtime-env-json='{\"working_dir\": \"/openrlhf\", \"pip\": \"/openrlhf/requirements.txt\"}' \ 58 | -- python3 -m openrlhf.cli.train_ppo_ray \ 59 | --ref_num_nodes 1 \ 60 | --ref_num_gpus_per_node 4 \ 61 | --reward_num_nodes 1 \ 62 | --reward_num_gpus_per_node 4 \ 63 | --critic_num_nodes 1 \ 64 | --critic_num_gpus_per_node 4 \ 65 | --actor_num_nodes 1 \ 66 | --actor_num_gpus_per_node 4 \ 67 | --vllm_num_engines 4 \ 68 | --vllm_tensor_parallel_size 2 \ 69 | --colocate_critic_reward \ 70 | --colocate_actor_ref \ 71 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 72 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ 73 | --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ 74 | --micro_train_batch_size 8 \ 75 | --train_batch_size 128 \ 76 | --micro_rollout_batch_size 16 \ 77 | --rollout_batch_size 1024 \ 78 | --max_samples 100000 \ 79 | --max_epochs 1 \ 80 | --prompt_max_len 1024 \ 81 | --generate_max_len 1024 \ 82 | --zero_stage 3 \ 83 | --bf16 \ 84 | --actor_learning_rate 5e-7 \ 85 | --critic_learning_rate 9e-6 \ 86 | --init_kl_coef 0.01 \ 87 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 88 | --input_key context_messages \ 89 | --apply_chat_template \ 90 | --normalize_reward \ 91 | --adam_offload \ 92 | --flash_attn \ 93 | --packing_samples \ 94 | --vllm_sync_backend nccl \ 95 | --gradient_checkpointing \ 96 | --use_wandb {wandb_token}" &>> ${JOBLOG} 97 | 98 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_with_remote_rm.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | # python -m openrlhf.cli.serve_rm \ 4 | # --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ 5 | # --port 5000 \ 6 | # --bf16 \ 7 | # --flash_attn \ 8 | # --normalize_reward \ 9 | # --max_len 8192 \ 10 | # --batch_size 16 11 | 12 | ray job submit --address="http://127.0.0.1:8265" \ 13 | --runtime-env-json='{"working_dir": "/openrlhf"}' \ 14 | -- python3 -m openrlhf.cli.train_ppo_ray \ 15 | --ref_num_nodes 1 \ 16 | --ref_num_gpus_per_node 2 \ 17 | --critic_num_nodes 1 \ 18 | --critic_num_gpus_per_node 2 \ 19 | --actor_num_nodes 1 \ 20 | --actor_num_gpus_per_node 2 \ 21 | --vllm_num_engines 2 \ 22 | --vllm_tensor_parallel_size 2 \ 23 | --colocate_actor_ref \ 24 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 25 | --remote_rm_url http://localhost:5000/get_reward \ 26 | --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ 27 | --micro_train_batch_size 8 \ 28 | --train_batch_size 128 \ 29 | --micro_rollout_batch_size 16 \ 30 | --rollout_batch_size 1024 \ 31 | --max_samples 100000 \ 32 | --max_epochs 1 \ 33 | --prompt_max_len 1024 \ 34 | --generate_max_len 1024 \ 35 | --zero_stage 3 \ 36 | --bf16 \ 37 | --actor_learning_rate 5e-7 \ 38 | --critic_learning_rate 9e-6 \ 39 | --init_kl_coef 0.01 \ 40 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 41 | --input_key context_messages \ 42 | --apply_chat_template \ 43 | --normalize_reward \ 44 | --packing_samples \ 45 | --adam_offload \ 46 | --flash_attn \ 47 | --gradient_checkpointing \ 48 | --use_wandb {wandb_token} 49 | 50 | # --vllm_sync_backend nccl (Only for multi-nodes with vLLM 0.6.4+ or vLLM 0.4.2) -------------------------------------------------------------------------------- /examples/scripts/train_prm_mistral.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands < 0)); then 31 | POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH 32 | fi 33 | 34 | read -r -d '' generate_commands <$ITER_LOG_PATH 98 | fi 99 | done -------------------------------------------------------------------------------- /examples/scripts/train_rm_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands < None: 30 | super().__init__() 31 | self.tokenizer = tokenizer 32 | self.strategy = strategy 33 | self.max_length = max_length 34 | self.multiple_of = multiple_of 35 | 36 | # chat_template 37 | self.input_key = getattr(self.strategy.args, "input_key", None) 38 | self.label_key = getattr(self.strategy.args, "label_key", None) 39 | self.placeholder_token = getattr(self.strategy.args, "placeholder_token", None) 40 | self.reward_tokens = getattr(self.strategy.args, "reward_tokens", None) 41 | 42 | self.placeholder_token_id = convert_token_to_id(self.placeholder_token, self.tokenizer) 43 | 44 | # Store the processed data in class attributes 45 | self.inputs = dataset[self.input_key] 46 | self.labels = dataset[self.label_key] 47 | 48 | def __len__(self): 49 | length = len(self.inputs) 50 | return length 51 | 52 | def __getitem__(self, idx): 53 | input_token = self.tokenizer( 54 | self.inputs[idx], 55 | max_length=self.max_length, 56 | padding=False, 57 | truncation=True, 58 | return_tensors="pt", 59 | add_special_tokens=False, 60 | ) 61 | 62 | input_ids = input_token["input_ids"] 63 | label_values = self.labels[idx] 64 | assert isinstance(label_values, list), "labels should be a list of strings or numbers" 65 | if isinstance(label_values[0], str): 66 | label_tokens = [] 67 | for label in label_values: 68 | assert ( 69 | self.reward_tokens is None or label in self.reward_tokens 70 | ), f"label should be in reward tokens {self.reward_tokens}, got {label}" 71 | label_tokens.append(convert_token_to_id(label, self.tokenizer)) 72 | 73 | labels = torch.full_like(input_ids, -100) 74 | labels[input_ids == self.placeholder_token_id] = torch.tensor(label_tokens, dtype=input_ids.dtype) 75 | else: 76 | assert isinstance(label_values[0], numbers.Number), "labels should be a list of strings or numbers" 77 | labels = torch.full_like(input_ids, -100, dtype=torch.float) 78 | labels[input_ids == self.placeholder_token_id] = torch.tensor(label_values, dtype=torch.float) 79 | 80 | return ( 81 | input_ids, 82 | input_token["attention_mask"], 83 | labels, 84 | ) 85 | 86 | def collate_fn(self, item_list): 87 | input_ids = [] 88 | input_masks = [] 89 | label_ids = [] 90 | for input_id, input_mask, label_id in item_list: 91 | input_ids.append(input_id) 92 | input_masks.append(input_mask) 93 | label_ids.append(label_id) 94 | 95 | padding_side = "right" 96 | input_ids = zero_pad_sequences(input_ids, side=padding_side, value=self.tokenizer.pad_token_id) 97 | input_masks = zero_pad_sequences(input_masks, side=padding_side) 98 | label_ids = zero_pad_sequences(label_ids, side=padding_side, value=self.tokenizer.pad_token_id) 99 | return input_ids, input_masks, label_ids 100 | 101 | def packing_collate_fn(self, item_list): 102 | input_ids = [] 103 | input_att_masks = [] 104 | input_seq_lens = [] 105 | label_ids = [] 106 | index = 1 107 | for input_id, input_mask, label_id in item_list: 108 | input_ids.append(input_id.flatten()) 109 | input_att_masks.append(torch.full_like(input_id.flatten(), index)) 110 | input_seq_lens.append(len(input_id.flatten())) 111 | 112 | label_ids.append(label_id.flatten()) 113 | index += 1 114 | 115 | packed_input_ids = torch.cat(input_ids, dim=0).unsqueeze(0) 116 | packed_attention_masks = torch.cat(input_att_masks, dim=0).unsqueeze(0) 117 | packed_seq_lens = input_seq_lens 118 | packed_label_ids = torch.cat(label_ids, dim=0).unsqueeze(0) 119 | 120 | if self.multiple_of > 1 and packed_input_ids.numel() % self.multiple_of != 0: 121 | padding_len = self.multiple_of - (packed_input_ids.numel() % self.multiple_of) 122 | packed_input_ids = F.pad(packed_input_ids, (0, padding_len), value=self.tokenizer.pad_token_id) 123 | packed_attention_masks = F.pad(packed_attention_masks, (0, padding_len), value=0) 124 | 125 | return packed_input_ids, packed_attention_masks, packed_seq_lens, packed_label_ids 126 | -------------------------------------------------------------------------------- /openrlhf/datasets/prompts_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tqdm import tqdm 3 | from utils import DATASET_KEYS, RESPONSE_COMPARATOR, RESPONSE_EXTRACTOR 4 | 5 | def preprocess_data(data, input_template=None, apply_chat_template=None, system_prompt=None) -> str: 6 | dataset_name = data.get("dataset_name", None) 7 | input_key = DATASET_KEYS[dataset_name]["question"] 8 | if apply_chat_template: 9 | chat = data[input_key] 10 | if isinstance(chat, str): 11 | chat = [ 12 | {"role": "system", "content": system_prompt}, 13 | {"role": "user", "content": chat} 14 | ] 15 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 16 | else: 17 | prompt = data[input_key] 18 | if input_template: 19 | prompt = input_template.format(prompt) 20 | return prompt, data 21 | 22 | 23 | class PromptDataset(Dataset): 24 | """ 25 | Dataset for PPO model 26 | 27 | Args: 28 | dataset: dataset for PPO model 29 | tokenizer: tokenizer for PPO model 30 | max_length: max length of input 31 | """ 32 | 33 | def __init__( 34 | self, 35 | dataset, 36 | tokenizer, 37 | strategy, 38 | input_template=None, 39 | system_prompt=None, 40 | ) -> None: 41 | super().__init__() 42 | self.strategy = strategy 43 | self.tokenizer = tokenizer 44 | 45 | # chat_template 46 | self.input_template = input_template 47 | # input_key = getattr(self.strategy.args, "input_key", None) no need for this since we are passing it in the dataset 48 | apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 49 | 50 | if apply_chat_template: 51 | apply_chat_template = self.tokenizer.apply_chat_template 52 | 53 | self.prompts = [] 54 | for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()): 55 | prompt, aux_info = preprocess_data(data, input_template, apply_chat_template, system_prompt) 56 | self.prompts.append((prompt, aux_info)) 57 | 58 | def __len__(self): 59 | length = len(self.prompts) 60 | return length 61 | 62 | def __getitem__(self, idx): 63 | return self.prompts[idx] 64 | -------------------------------------------------------------------------------- /openrlhf/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def zero_pad_sequences(sequences, side: str = "left", value=0): 6 | assert side in ("left", "right") 7 | max_len = max(seq.size(-1) for seq in sequences) 8 | padded_sequences = [] 9 | for seq in sequences: 10 | pad_len = max_len - seq.size(-1) 11 | padding = (pad_len, 0) if side == "left" else (0, pad_len) 12 | padded_sequences.append(F.pad(seq, padding, value=value)) 13 | return torch.stack(padded_sequences, dim=0) 14 | 15 | 16 | def exist_and_not_none(d, key): 17 | return key in d and not d[key] is None 18 | -------------------------------------------------------------------------------- /openrlhf/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import Actor 2 | from .loss import ( 3 | DPOLoss, 4 | GPTLMLoss, 5 | KDLoss, 6 | KTOLoss, 7 | LogExpLoss, 8 | PairWiseLoss, 9 | PolicyLoss, 10 | PRMLoss, 11 | ValueLoss, 12 | VanillaKTOLoss, 13 | ) 14 | from .model import get_llm_for_sequence_regression 15 | 16 | __all__ = [ 17 | "Actor", 18 | "DPOLoss", 19 | "GPTLMLoss", 20 | "KDLoss", 21 | "KTOLoss", 22 | "LogExpLoss", 23 | "PairWiseLoss", 24 | "PolicyLoss", 25 | "PRMLoss", 26 | "ValueLoss", 27 | "VanillaKTOLoss", 28 | "get_llm_for_sequence_regression", 29 | ] 30 | -------------------------------------------------------------------------------- /openrlhf/models/ring_attn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch.nn.functional as F 4 | 5 | 6 | RING_ATTN_GROUP = None 7 | 8 | 9 | def set_ring_attn_group(group): 10 | global RING_ATTN_GROUP 11 | RING_ATTN_GROUP = group 12 | 13 | 14 | def get_ring_attn_group(): 15 | return RING_ATTN_GROUP 16 | 17 | 18 | def reset_ring_attn_position_ids(start, end, packed_seq_lens): 19 | """ 20 | Calculate position ids for packed_seq_ids[start:end]. 21 | For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8, 22 | the position ids will be [2, 0, 1, 0, 1, 2]. 23 | 24 | Args: 25 | start: the start position 26 | end: the end position 27 | packed_seq_lens: the sequence lengths of packed sequences 28 | """ 29 | position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device()) 30 | offset = 0 31 | for seqlen in packed_seq_lens: 32 | seq_start = max(offset, start) 33 | seq_end = min(offset + seqlen, end) 34 | if seq_start < seq_end: 35 | position_ids[0, seq_start - start : seq_end - start] = torch.arange(seq_start - offset, seq_end - offset) 36 | 37 | offset += seqlen 38 | if offset >= end: 39 | break 40 | return position_ids 41 | 42 | 43 | def update_ring_attn_params(packed_seq_lens, total_seq_len): 44 | """ 45 | Calculate the cu_seqlens for the current forward pass and pass the value to 46 | the substituted ring_flash_attn. 47 | 48 | Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding. 49 | """ 50 | assert RING_ATTN_GROUP is not None 51 | cu_seqlens = torch.cumsum( 52 | torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), 53 | dim=-1, 54 | dtype=torch.int32, 55 | ) 56 | cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) 57 | 58 | from ring_flash_attn import update_ring_flash_attn_params 59 | 60 | update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP) 61 | 62 | 63 | def convert_ring_attn_params(sequences, attention_mask, packed_seq_lens, ring_attn_group): 64 | # each rank within the ring group will process sequences[start:end] 65 | ring_attn_rank = dist.get_rank(group=ring_attn_group) 66 | ring_attn_size = dist.get_world_size(group=ring_attn_group) 67 | total_seq_len = sequences.numel() 68 | local_seq_len = total_seq_len // ring_attn_size 69 | start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len 70 | sequences = sequences[:, start:end] 71 | attention_mask = attention_mask[:, start:end] 72 | position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens) 73 | update_ring_attn_params(packed_seq_lens, total_seq_len) 74 | return sequences, attention_mask, position_ids 75 | -------------------------------------------------------------------------------- /openrlhf/models/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def compute_approx_kl( 8 | log_probs: torch.Tensor, 9 | log_probs_base: torch.Tensor, 10 | action_mask: Optional[torch.Tensor] = None, 11 | use_kl_estimator_k3: bool = False, 12 | ) -> torch.Tensor: 13 | """ 14 | Compute the approximate KL divergence between two distributions. 15 | Schulman blog: http://joschu.net/blog/kl-approx.html 16 | 17 | Args: 18 | log_probs: Log probabilities of the new distribution. 19 | log_probs_base: Log probabilities of the base distribution. 20 | action_mask: Mask for actions. 21 | """ 22 | 23 | log_ratio = log_probs - log_probs_base 24 | if action_mask is not None: 25 | log_ratio = log_ratio * action_mask 26 | 27 | # The k3 estimator is the non negative kl approximation in 28 | # http://joschu.net/blog/kl-approx.html 29 | # Besides non negative, it is also unbiased and have lower variance. 30 | if use_kl_estimator_k3: 31 | log_ratio = -log_ratio 32 | log_ratio = log_ratio.exp() - 1 - log_ratio 33 | 34 | return log_ratio 35 | 36 | 37 | def compute_reward( 38 | r: Union[torch.Tensor, float], 39 | kl_coef: float, 40 | kl: Union[torch.Tensor, list[torch.Tensor]], 41 | action_mask: Optional[torch.Tensor] = None, 42 | num_actions: Optional[Union[int, list[int]]] = None, 43 | reward_clip_range: Tuple[float, float] = None, 44 | ) -> Union[torch.Tensor, list[torch.Tensor]]: 45 | if kl_coef <= 0.0: 46 | kl_coef = 0.0 47 | 48 | if reward_clip_range: 49 | r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1]) 50 | 51 | if action_mask is not None: 52 | kl_reward = -kl_coef * kl 53 | # The following code is equivalent to: 54 | # 55 | # last_reward = torch.zeros_like(kl) 56 | # for i in range(last_reward.size(0)): 57 | # for t in reversed(range(last_reward.size(1))): 58 | # if action_mask[i][t] > 0.5: 59 | # last_reward[i][t] = r[i] 60 | # break 61 | # 62 | eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True) 63 | last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype)) 64 | 65 | reward = last_reward + kl_reward 66 | else: 67 | # TODO: write a more efficient version 68 | reward = [] 69 | for i, (kl_seg, action_len) in enumerate(zip(kl, num_actions)): 70 | kl_reward = -kl_coef * kl_seg 71 | kl_reward[action_len - 1] += r[i] 72 | reward.append(kl_reward) 73 | 74 | return reward 75 | 76 | 77 | def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 78 | log_probs = F.log_softmax(logits, dim=-1) 79 | log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) 80 | return log_probs_labels.squeeze(-1) 81 | 82 | 83 | def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor: 84 | if mask is None: 85 | return tensor.mean(axis=dim) 86 | return (tensor * mask).sum(axis=dim) / mask.sum(axis=dim) 87 | 88 | 89 | def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: 90 | tensor = tensor * mask 91 | mean = masked_mean(tensor, mask, dim=dim) 92 | mean_centered = tensor - mean 93 | var = masked_mean(mean_centered**2, mask, dim=dim) 94 | return mean_centered * var.clamp(min=eps).rsqrt() 95 | 96 | 97 | # Reset positions for packed samples 98 | # For example 99 | # Input: attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2, 3, 3, 0]]) 100 | # Output: position_ids = torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 0]]) 101 | def reset_position_ids(attention_mask): 102 | position_ids = torch.zeros_like(attention_mask, dtype=torch.long) 103 | for i in range(attention_mask.size(0)): 104 | mask = attention_mask[i] 105 | seq_num = mask.max().item() 106 | for index in range(1, seq_num + 1): 107 | sample_mask = mask == index 108 | sample_length = sample_mask.sum().item() 109 | position_ids[i, sample_mask] = torch.arange(sample_length, device=mask.device) 110 | return position_ids 111 | 112 | 113 | def unpacking_samples(values: torch.Tensor, packed_seqlens: list[int]): 114 | values = values.squeeze(0) 115 | unpacked_values = [] 116 | offset = 0 117 | for seqlen in packed_seqlens: 118 | unpacked_values.append(values[offset : offset + seqlen]) 119 | offset += seqlen 120 | return unpacked_values 121 | -------------------------------------------------------------------------------- /openrlhf/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .dpo_trainer import DPOTrainer 2 | from .kd_trainer import KDTrainer 3 | from .kto_trainer import KTOTrainer 4 | from .ppo_trainer import PPOTrainer 5 | from .prm_trainer import ProcessRewardModelTrainer 6 | from .rm_trainer import RewardModelTrainer 7 | from .sft_trainer import SFTTrainer 8 | 9 | __all__ = [ 10 | "DPOTrainer", 11 | "KDTrainer", 12 | "KTOTrainer", 13 | "PPOTrainer", 14 | "ProcessRewardModelTrainer", 15 | "RewardModelTrainer", 16 | "SFTTrainer", 17 | ] 18 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .experience_maker import Experience, NaiveExperienceMaker, RemoteExperienceMaker 2 | from .kl_controller import AdaptiveKLController, FixedKLController 3 | from .replay_buffer import NaiveReplayBuffer 4 | 5 | __all__ = [ 6 | "Experience", 7 | "NaiveExperienceMaker", 8 | "RemoteExperienceMaker", 9 | "AdaptiveKLController", 10 | "FixedKLController", 11 | "NaiveReplayBuffer", 12 | ] 13 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/kl_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AdaptiveKLController: 5 | """ 6 | Adaptive KL controller described in the paper: 7 | https://arxiv.org/pdf/1909.08593.pdf 8 | """ 9 | 10 | def __init__(self, init_kl_coef, target, horizon): 11 | self.value = init_kl_coef 12 | self.target = target 13 | self.horizon = horizon 14 | 15 | def update(self, current, n_steps): 16 | target = self.target 17 | proportional_error = np.clip(current / target - 1, -0.2, 0.2) 18 | mult = 1 + proportional_error * n_steps / self.horizon 19 | self.value *= mult 20 | 21 | 22 | class FixedKLController: 23 | """Fixed KL controller.""" 24 | 25 | def __init__(self, kl_coef): 26 | self.value = kl_coef 27 | 28 | def update(self, current, n_steps): 29 | pass 30 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/__init__.py: -------------------------------------------------------------------------------- 1 | from .launcher import DistributedTorchRayActor, PPORayActorGroup, ReferenceModelRayActor, RewardModelRayActor 2 | from .ppo_actor import ActorModelRayActor 3 | from .ppo_critic import CriticModelRayActor 4 | from .vllm_engine import create_vllm_engines 5 | 6 | __all__ = [ 7 | "DistributedTorchRayActor", 8 | "PPORayActorGroup", 9 | "ReferenceModelRayActor", 10 | "RewardModelRayActor", 11 | "ActorModelRayActor", 12 | "CriticModelRayActor", 13 | "create_vllm_engines", 14 | ] 15 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def ray_noset_visible_devices(env_vars=os.environ): 5 | # Refer to 6 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 7 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 8 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 9 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 10 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 11 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 12 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 13 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ 14 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", 15 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", 16 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", 17 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", 18 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", 19 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", 20 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", 21 | ] 22 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) 23 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/vllm_worker_wrap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vllm.worker.worker import Worker 3 | 4 | from openrlhf.utils.distributed_util import init_process_group 5 | from openrlhf.utils.logging_utils import init_logger 6 | 7 | logger = init_logger(__name__) 8 | 9 | 10 | class WorkerWrap(Worker): 11 | def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"): 12 | """Init torch process group for model weights update""" 13 | assert torch.distributed.is_initialized(), f"default torch process group must be initialized" 14 | assert group_name != "", f"group name must not be empty" 15 | 16 | rank = torch.distributed.get_rank() + rank_offset 17 | self._model_update_group = init_process_group( 18 | backend=backend, 19 | init_method=f"tcp://{master_address}:{master_port}", 20 | world_size=world_size, 21 | rank=rank, 22 | group_name=group_name, 23 | ) 24 | print( 25 | f"init_process_group: master_address={master_address}, master_port={master_port}, ", 26 | f"rank={rank}, world_size={world_size}, group_name={group_name}", 27 | ) 28 | 29 | def update_weight(self, name, dtype, shape, empty_cache=False): 30 | """Broadcast weight to all vllm workers from source rank 0 (actor model)""" 31 | if torch.distributed.get_rank() == 0: 32 | print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") 33 | 34 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" 35 | weight = torch.empty(shape, dtype=dtype, device="cuda") 36 | torch.distributed.broadcast(weight, 0, group=self._model_update_group) 37 | 38 | self.model_runner.model.load_weights(weights=[(name, weight)]) 39 | 40 | del weight 41 | # TODO: should we empty cache if all weights have updated? 42 | # if empty_cache: 43 | # torch.cuda.empty_cache() 44 | -------------------------------------------------------------------------------- /openrlhf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import get_processor, reward_normalization 2 | from .utils import blending_datasets, get_strategy, get_tokenizer 3 | 4 | __all__ = [ 5 | "get_processor", 6 | "reward_normalization", 7 | "blending_datasets", 8 | "get_strategy", 9 | "get_tokenizer", 10 | ] 11 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed import DeepspeedStrategy 2 | 3 | __all__ = [ 4 | "DeepspeedStrategy", 5 | ] 6 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/deepspeed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 7 | 8 | 9 | def get_train_ds_config( 10 | offload, 11 | adam_offload=True, 12 | stage=2, 13 | bf16=True, 14 | max_norm=1.0, 15 | zpg=8, 16 | grad_accum_dtype=None, 17 | disable_trace_cache=False, 18 | ): 19 | device = "cpu" if offload else "none" 20 | zero_opt_dict = { 21 | "stage": stage, 22 | "offload_param": {"device": device}, 23 | "offload_optimizer": { 24 | "device": "cpu" if adam_offload else "none", 25 | "pin_memory": True, 26 | }, 27 | "sub_group_size": "auto", 28 | "stage3_max_live_parameters": "auto", 29 | "stage3_max_reuse_distance": "auto", 30 | "stage3_param_persistence_threshold": "auto", 31 | "stage3_prefetch_bucket_size": "auto", 32 | "reduce_bucket_size": "auto", 33 | # ZeRO++ 34 | "zero_hpz_partition_size": zpg, 35 | "zero_quantized_weights": False, 36 | "zero_quantized_gradients": False, 37 | } 38 | if disable_trace_cache: 39 | zero_opt_dict["stage3_prefetch_bucket_size"] = 0 40 | zero_opt_dict["stage3_max_live_parameters"] = 0 41 | zero_opt_dict["stage3_max_reuse_distance"] = 0 42 | 43 | return { 44 | "steps_per_print": 100, 45 | "zero_optimization": zero_opt_dict, 46 | "bf16": { 47 | "enabled": bf16, 48 | }, 49 | "gradient_clipping": max_norm, 50 | "prescale_gradients": False, 51 | "wall_clock_breakdown": False, 52 | "data_types": {"grad_accum_dtype": grad_accum_dtype}, 53 | } 54 | 55 | 56 | def get_eval_ds_config( 57 | offload, 58 | stage=0, 59 | bf16=True, 60 | ): 61 | zero_opt_dict = { 62 | "stage": stage, 63 | "stage3_param_persistence_threshold": "auto", 64 | "offload_param": { 65 | "device": "cpu" if offload else "none", 66 | "pin_memory": True, 67 | }, 68 | } 69 | return { 70 | "steps_per_print": 100, 71 | "zero_optimization": zero_opt_dict, 72 | "bf16": { 73 | "enabled": bf16, 74 | }, 75 | "gradient_clipping": 1.0, 76 | "prescale_gradients": False, 77 | "wall_clock_breakdown": False, 78 | } 79 | 80 | 81 | def get_optimizer_grouped_parameters( 82 | model, 83 | weight_decay, 84 | no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], 85 | ): 86 | optimizer_grouped_parameters = [ 87 | { 88 | "params": [ 89 | p 90 | for n, p in model.named_parameters() 91 | if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) 92 | ], 93 | "weight_decay": weight_decay, 94 | }, 95 | { 96 | "params": [ 97 | p 98 | for n, p in model.named_parameters() 99 | if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) 100 | ], 101 | "weight_decay": 0.0, 102 | }, 103 | ] 104 | return optimizer_grouped_parameters 105 | 106 | 107 | def _z3_params_to_fetch(param_list): 108 | return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] 109 | -------------------------------------------------------------------------------- /openrlhf/utils/distributed_util.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Optional, Union 3 | 4 | import torch 5 | import torch.distributed 6 | from torch.distributed.distributed_c10d import ( 7 | Backend, 8 | PrefixStore, 9 | Store, 10 | _new_process_group_helper, 11 | _world, 12 | default_pg_timeout, 13 | rendezvous, 14 | ) 15 | 16 | 17 | # Copy from pytorch to allow creating multiple main groups. 18 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py 19 | def init_process_group( 20 | backend: Union[str, Backend] = None, 21 | init_method: Optional[str] = None, 22 | timeout: Optional[timedelta] = None, 23 | world_size: int = -1, 24 | rank: int = -1, 25 | store: Optional[Store] = None, 26 | group_name: str = None, 27 | pg_options: Optional[Any] = None, 28 | ): 29 | assert (store is None) or (init_method is None), "Cannot specify both init_method and store." 30 | 31 | if store is not None: 32 | assert world_size > 0, "world_size must be positive if using store" 33 | assert rank >= 0, "rank must be non-negative if using store" 34 | elif init_method is None: 35 | init_method = "env://" 36 | 37 | if backend: 38 | backend = Backend(backend) 39 | else: 40 | backend = Backend("undefined") 41 | 42 | if timeout is None: 43 | timeout = default_pg_timeout 44 | 45 | # backward compatible API 46 | if store is None: 47 | rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) 48 | store, rank, world_size = next(rendezvous_iterator) 49 | store.set_timeout(timeout) 50 | 51 | # Use a PrefixStore to avoid accidental overrides of keys used by 52 | # different systems (e.g. RPC) in case the store is multi-tenant. 53 | store = PrefixStore(group_name, store) 54 | 55 | # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 56 | # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 57 | # We need to determine the appropriate parameter name based on PyTorch version 58 | pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" 59 | pg, _ = _new_process_group_helper( 60 | world_size, 61 | rank, 62 | [], 63 | backend, 64 | store, 65 | group_name=group_name, 66 | **{pg_options_param_name: pg_options}, 67 | timeout=timeout, 68 | ) 69 | 70 | _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} 71 | 72 | return pg 73 | -------------------------------------------------------------------------------- /openrlhf/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py 3 | """Logging configuration for vLLM.""" 4 | import logging 5 | import sys 6 | 7 | _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" 8 | _DATE_FORMAT = "%m-%d %H:%M:%S" 9 | 10 | 11 | class NewLineFormatter(logging.Formatter): 12 | """Adds logging prefix to newlines to align multi-line messages.""" 13 | 14 | def __init__(self, fmt, datefmt=None): 15 | logging.Formatter.__init__(self, fmt, datefmt) 16 | 17 | def format(self, record): 18 | msg = logging.Formatter.format(self, record) 19 | if record.message != "": 20 | parts = msg.split(record.message) 21 | msg = msg.replace("\n", "\r\n" + parts[0]) 22 | return msg 23 | 24 | 25 | _root_logger = logging.getLogger("openrlhf") 26 | _default_handler = None 27 | 28 | 29 | def _setup_logger(): 30 | _root_logger.setLevel(logging.DEBUG) 31 | global _default_handler 32 | if _default_handler is None: 33 | _default_handler = logging.StreamHandler(sys.stdout) 34 | _default_handler.flush = sys.stdout.flush # type: ignore 35 | _default_handler.setLevel(logging.INFO) 36 | _root_logger.addHandler(_default_handler) 37 | fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) 38 | _default_handler.setFormatter(fmt) 39 | # Setting this will avoid the message 40 | # being propagated to the parent logger. 41 | _root_logger.propagate = False 42 | 43 | 44 | # The logger is initialized when the module is imported. 45 | # This is thread-safe as the module is only imported once, 46 | # guaranteed by the Python GIL. 47 | _setup_logger() 48 | 49 | 50 | def init_logger(name: str): 51 | # Use the same settings as above for root logger 52 | logger = logging.getLogger(name) 53 | logger.setLevel(logging.DEBUG) 54 | logger.addHandler(_default_handler) 55 | logger.propagate = False 56 | return logger 57 | -------------------------------------------------------------------------------- /openrlhf/utils/processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | def reward_normalization(objs): 6 | rewards = [float(obj["reward"]) for obj in objs] 7 | rewards = torch.tensor(rewards, dtype=torch.float64) 8 | rewards = (rewards - rewards.mean()) / rewards.std() 9 | for i, obj in enumerate(objs): 10 | obj["reward"] = rewards[i].item() 11 | 12 | 13 | # Conditional SFT 14 | # See https://arxiv.org/abs/2308.12050 15 | DEFAULT_REWARD_PROMPT = "{input} : {reward} " 16 | 17 | 18 | def conditional_sft_processor(args, objs): 19 | if "reward_template" not in args or args.reward_template is None: 20 | reward_template = DEFAULT_REWARD_PROMPT 21 | else: 22 | reward_template = args.reward_template 23 | assert "{input}" in reward_template 24 | assert "{reward}" in reward_template 25 | 26 | if args.normalize_reward: 27 | reward_normalization(objs) 28 | 29 | for obj in tqdm(objs, desc="Conditional SFT process..."): 30 | input = obj["input"] 31 | reward = "{:.2f}".format(float(obj["reward"])) 32 | input = reward_template.replace("{reward}", reward).replace("{input}", input) 33 | obj["input"] = input 34 | 35 | return objs 36 | 37 | 38 | # Rejection Sampling 39 | # See https://arxiv.org/abs/2307.09288 40 | def rejection_sampling_processor(args, objs): 41 | out = {} 42 | for obj in tqdm(objs, desc="Rejection Sampling process...."): 43 | input = obj["input"] 44 | output = obj["output"] 45 | reward = float(obj["reward"]) 46 | 47 | if input not in out: 48 | out[input] = {"output": output, "reward": reward} 49 | elif reward > out[input]["reward"]: 50 | out[input]["reward"] = reward 51 | out[input]["output"] = output 52 | 53 | return [{"input": k, "output": v["output"], "reward": v["reward"]} for k, v in out.items()] 54 | 55 | 56 | # Iterative DPO 57 | # See https://github.com/RLHFlow/Online-RLHF/blob/main/run_loop.sh 58 | def iterative_dpo_processor(args, objs): 59 | out = {} 60 | for obj in tqdm(objs, desc="Iterative DPO process...."): 61 | input = obj["input"] 62 | output = obj["output"] 63 | reward = float(obj["reward"]) 64 | 65 | if input not in out: 66 | out[input] = { 67 | "output": output, 68 | "chosen": output, 69 | "chosen_reward": reward, 70 | "rejected": output, 71 | "rejected_reward": reward, 72 | } 73 | elif reward > out[input]["chosen_reward"]: 74 | out[input]["chosen_reward"] = reward 75 | out[input]["chosen"] = output 76 | elif reward < out[input]["rejected_reward"]: 77 | out[input]["rejected_reward"] = reward 78 | out[input]["rejected"] = output 79 | 80 | return [ 81 | { 82 | "prompt": k, 83 | "chosen": v["chosen"], 84 | "chosen_reward": v["chosen_reward"], 85 | "rejected": v["rejected"], 86 | "rejected_reward": v["rejected_reward"], 87 | } 88 | for k, v in out.items() 89 | ] 90 | 91 | 92 | PROCESSORS = { 93 | "rs": rejection_sampling_processor, 94 | "csft": conditional_sft_processor, 95 | "iter_dpo": iterative_dpo_processor, 96 | } 97 | 98 | 99 | def get_processor(name): 100 | if name in PROCESSORS: 101 | return PROCESSORS[name] 102 | else: 103 | raise ValueError(f"Processor {name} does not exist.") 104 | -------------------------------------------------------------------------------- /openrlhf/utils/remote_rm_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import ray 3 | import requests 4 | import torch 5 | 6 | from openrlhf.utils.logging_utils import init_logger 7 | 8 | logger = init_logger(__name__) 9 | 10 | 11 | def request_api_wrapper(url, data, score_key="rewards", try_max_times=5): 12 | """Synchronous request API wrapper""" 13 | headers = { 14 | "Content-Type": "application/json", 15 | } 16 | for _ in range(try_max_times): 17 | try: 18 | response = requests.post(url=url, json=data, headers=headers, timeout=180) 19 | response.raise_for_status() # Raise an HTTPError for bad responses 20 | response = response.json() 21 | assert score_key in response, f"{score_key} not in {response}" 22 | # Return auxiliary information as well not just the score_key 23 | return response 24 | except requests.RequestException as e: 25 | logger.info(f"Request error, please check: {e}") 26 | except Exception as e: 27 | logger.info(f"Unexpected error, please check: {e}") 28 | time.sleep(1) 29 | 30 | raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.") 31 | 32 | 33 | def remote_rm_fn(api_url, queries, steps, total_steps, score_key="rewards"): 34 | """remote reward model API 35 | api_url: RM API, We assume that the API supports two modes: merging query + response and not merging 36 | queries: query+response with the template 37 | design is made optional. 38 | score_key: RM score key 39 | """ 40 | api_outputs = request_api_wrapper(api_url, {"query": queries, "steps": steps, "total_steps": total_steps}, score_key) 41 | # TODO: Return dict with scores and other information 42 | return {key: torch.tensor(value) for key, value in api_outputs.items()} 43 | 44 | 45 | @ray.remote 46 | def remote_rm_fn_ray(api_url, queries, steps, total_steps, score_key="rewards"): 47 | return remote_rm_fn(api_url, queries, steps, total_steps, score_key) 48 | 49 | 50 | if __name__ == "__main__": 51 | # test utils 52 | url = "http:xxx/get_rm_score" 53 | score = remote_rm_fn(url, ["example query"], ["example response"]) 54 | print(score) 55 | -------------------------------------------------------------------------------- /openrlhf/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import interleave_datasets, load_dataset, load_from_disk, concatenate_datasets 4 | from transformers import AutoTokenizer 5 | 6 | 7 | DEFAULT_PAD_TOKEN = "[PAD]" 8 | DEFAULT_EOS_TOKEN = "" 9 | DEFAULT_BOS_TOKEN = "" 10 | DEFAULT_UNK_TOKEN = "" 11 | 12 | 13 | def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True): 14 | tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast) 15 | tokenizer.padding_side = padding_side 16 | # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM. 17 | # https://github.com/facebookresearch/llama-recipes/pull/196 18 | if tokenizer.pad_token is None: 19 | tokenizer.pad_token = tokenizer.eos_token 20 | tokenizer.pad_token_id = tokenizer.eos_token_id 21 | model.config.pad_token_id = tokenizer.pad_token_id 22 | 23 | return tokenizer 24 | 25 | 26 | def get_strategy(args): 27 | from openrlhf.utils.deepspeed import DeepspeedStrategy 28 | 29 | strategy = DeepspeedStrategy( 30 | seed=getattr(args, "seed", 42), 31 | max_norm=getattr(args, "max_norm", 1.0), 32 | micro_train_batch_size=getattr(args, "micro_train_batch_size", 1), 33 | train_batch_size=getattr(args, "train_batch_size", 128), 34 | zero_stage=args.zero_stage, 35 | bf16=getattr(args, "bf16", True), 36 | args=args, 37 | ) 38 | return strategy 39 | 40 | 41 | def blending_datasets( 42 | datasets, 43 | probabilities, 44 | strategy=None, 45 | seed=42, 46 | max_count=5000000, 47 | return_eval=True, 48 | stopping_strategy="first_exhausted", 49 | train_split="train", 50 | eval_split="test", 51 | ): 52 | datasets = datasets.split(",") 53 | probabilities = list(map(float, probabilities.split(","))) 54 | assert len(probabilities) == len(datasets) 55 | 56 | train_data_list = [] 57 | eval_data_list = [] 58 | for i, dataset in enumerate(datasets): 59 | dataset = dataset.strip() 60 | strategy.print(f"dataset: {dataset}") 61 | 62 | data_dir = dataset.split("@")[1].strip() if "@" in dataset else None 63 | dataset = dataset.split("@")[0].strip() 64 | dataset_basename = os.path.basename(dataset) 65 | 66 | ext = os.path.splitext(dataset)[-1] 67 | # local python script 68 | if ext == ".py" or ( 69 | os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py")) 70 | ): 71 | data = load_dataset(dataset, trust_remote_code=True) 72 | strategy.print(f"loaded {dataset} with python script") 73 | # local text file 74 | elif ext in [".json", ".jsonl", ".csv"]: 75 | ext = ext.lower().strip(".") 76 | if ext == "jsonl": 77 | ext = "json" 78 | data = load_dataset(ext, data_files=dataset) 79 | strategy.print(f"loaded {dataset} with data_files={dataset}") 80 | # local dataset saved with `datasets.Dataset.save_to_disk` 81 | elif os.path.isdir(dataset): 82 | data = load_from_disk(dataset) 83 | strategy.print(f"loaded {dataset} from disk") 84 | # remote/local folder or common file 85 | else: 86 | if dataset == 'openai/gsm8k': 87 | data = load_dataset(dataset, 'main') 88 | elif dataset == 'hendrycks/competition_math': 89 | data = load_dataset(dataset) 90 | else: 91 | data = load_dataset(dataset, data_dir=data_dir) 92 | strategy.print(f"loaded {dataset} from files") 93 | 94 | if train_split and train_split in data: 95 | train_data = data[train_split].select(range(min(max_count, len(data[train_split])))) 96 | else: 97 | train_data = data.select(range(min(max_count, len(data)))) 98 | # Add dataset_name column to choose appropriate keys 99 | train_data = train_data.add_column('dataset_name', [dataset] * len(train_data)) 100 | train_data_list.append(train_data) 101 | 102 | if return_eval: 103 | if eval_split and eval_split in data: 104 | eval_data = data[eval_split].select(range(min(max_count, len(data[eval_split])))) 105 | # train will contains eval? TODO 106 | else: 107 | eval_data = train_data.select(range(min(max_count, int(len(train_data) * 0.03)))) 108 | eval_data_list.append(eval_data) 109 | 110 | # merge datasets 111 | if strategy.is_rank_0(): 112 | print(train_data_list) 113 | 114 | # train_dataset = interleave_datasets( 115 | # train_data_list, 116 | # probabilities=probabilities, 117 | # seed=seed, 118 | # stopping_strategy=stopping_strategy, 119 | # ) 120 | train_dataset = concatenate_datasets( 121 | train_data_list 122 | ) 123 | train_dataset = train_dataset.shuffle(seed=42) 124 | if return_eval: 125 | eval_dataset = interleave_datasets( 126 | eval_data_list, 127 | probabilities=probabilities, 128 | seed=seed, 129 | stopping_strategy=stopping_strategy, 130 | ) 131 | return train_dataset, eval_dataset 132 | else: 133 | return train_dataset 134 | 135 | 136 | def convert_token_to_id(token, tokenizer): 137 | if isinstance(token, str): 138 | token = tokenizer.encode(token, add_special_tokens=False) 139 | assert len(token) == 1 140 | return token[0] 141 | else: 142 | raise ValueError("token should be int or str") 143 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "packaging", 4 | "setuptools >= 49.4.0", 5 | "wheel", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.isort] 10 | profile = "black" # black-compatible 11 | line_length = 119 # should match black parameters 12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style 13 | py_version = 310 # python 3.10 as a target version 14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] 15 | default_section = "THIRDPARTY" 16 | extend_skip = ["setup.py", "docs/source/conf.py"] 17 | 18 | 19 | [tool.black] 20 | line_length = 119 21 | 22 | [tool.ruff] 23 | line-length = 119 24 | 25 | [tool.pytest.ini_options] 26 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. 27 | # -vv will also display tests with durration = 0.00s 28 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest 29 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module 30 | # directories to ignore when discovering tests 31 | norecursedirs = [ 32 | "external", 33 | "examples", 34 | "docs", 35 | "scripts", 36 | "tools", 37 | "tutorials", 38 | "*.egg", 39 | ".*", 40 | "_darcs", 41 | "build", 42 | "CVS", 43 | "dist", 44 | "venv", 45 | "{arch}", 46 | ] 47 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests 48 | markers = [ 49 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')", 50 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')", 51 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')", 52 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')", 53 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')", 54 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups", 55 | "pleasefixme: marks tests that are broken and need fixing", 56 | ] 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets 4 | deepspeed==0.15.0 5 | einops 6 | flash-attn==2.7.0.post2 7 | Flask==3.1.0 8 | isort 9 | jsonlines 10 | loralib 11 | optimum 12 | packaging 13 | peft 14 | ray[default]==2.12.0 15 | tensorboard 16 | torch 17 | torchmetrics 18 | tqdm 19 | transformers==4.46.3 20 | transformers_stream_generator 21 | wandb 22 | wheel -------------------------------------------------------------------------------- /reward_server/math_server_test.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import json 3 | from datasets import load_dataset 4 | import requests 5 | import time 6 | 7 | def load_math_dataset(): 8 | """Loads the GSM8K dataset.""" 9 | dataset = load_dataset("hendrycks/competition_math", 'main', split="train") 10 | return {entry['problem']: entry['solution'] for entry in dataset} 11 | 12 | def load_gsm8k_dataset(): 13 | """Loads the GSM8K dataset.""" 14 | dataset = load_dataset("gsm8k", 'main', split="train") 15 | return {entry['question']: entry['answer'] for entry in dataset} 16 | 17 | def send_query_to_server(query, host='localhost', port=12345): 18 | """Sends a query to the server and returns the response.""" 19 | try: 20 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 21 | client_socket.connect((host, port)) 22 | 23 | client_socket.sendall(json.dumps(query).encode('utf-8')) 24 | response = client_socket.recv(1024).decode('utf-8') 25 | 26 | return json.loads(response) 27 | except Exception as e: 28 | return {"error": str(e)} 29 | finally: 30 | client_socket.close() 31 | 32 | def request_api_wrapper(url, data, try_max_times=5): 33 | """Synchronous request API wrapper.""" 34 | headers = { 35 | "Content-Type": "application/json", 36 | } 37 | for _ in range(try_max_times): 38 | try: 39 | response = requests.post(url=url, json=data, headers=headers, timeout=180) 40 | response.raise_for_status() 41 | response = response.json() 42 | 43 | if "error" in response: 44 | raise Exception(response["error"]) 45 | 46 | return response 47 | except requests.RequestException as e: 48 | print(f"Request error, please check: {e}") 49 | except Exception as e: 50 | print(f"Unexpected error, please check: {e}") 51 | time.sleep(1) 52 | 53 | raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.") 54 | 55 | if __name__ == "__main__": 56 | # Example queries 57 | # gsm8k_entries = load_gsm8k_dataset() 58 | # math_entries = load_math_dataset() 59 | # q1 = list(gsm8k_entries.keys())[0] 60 | # q2 = list(math_entries.keys())[0] 61 | # print(q1, gsm8k_entries[q1]) 62 | # print(q2, math_entries[q2]) 63 | query = json.load(open('error.json')) 64 | # query['steps'] = 120 65 | queries = [ 66 | # {'query': [ 67 | # {'response': 'the answer is most probably \\boxed{72.0}.<|eot_id|>', 'aux_info': {'question': q1, 'answer': gsm8k_entries[q1], 'dataset_name': 'openai/gsm8k'}}, 68 | # {'response': 'the answer is most probably \\boxed{72.0} but i don\'t know.', 'aux_info': {'question': q1, 'answer': gsm8k_entries[q1], 'dataset_name': 'openai/gsm8k'}}, 69 | # {'response': 'the answer is most probably \\boxed{72.0} but i don\'t know what do you think.', 'aux_info': {'question': q1, 'answer': gsm8k_entries[q1], 'dataset_name': 'openai/gsm8k'}}, 70 | # # {'response': 'the answer is \\bo', 'aux_info': {'problem': q2, 'solution': math_entries[q2], 'dataset_name': 'hendrycks/competition_math'}}, 71 | # # {'response': 'the answer is \\bo', 'aux_info': {'problem': q2, 'solution': math_entries[q2], 'dataset_name': 'hendrycks/competition_math'}} 72 | # ]} 73 | # json.load(open('error3.log')) 74 | query 75 | ] 76 | url = "http://127.0.0.1:12310/query" 77 | 78 | for i, query in enumerate(queries): 79 | print(f"Sending Query {i + 1}: {query}") 80 | # response = send_query_to_server(query) 81 | response = request_api_wrapper(url, query) 82 | print(f"Response {i + 1}: {response}\n") 83 | -------------------------------------------------------------------------------- /run_rloo_1.5B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --mem=100g 4 | #SBATCH --nodes=1 5 | #SBATCH --gpus-per-node=4 6 | #SBATCH --cpus-per-task=16 7 | #SBATCH --job-name=run 8 | #SBATCH --partition=ghx4 9 | #SBATCH --time=24:00:00 # hh:mm:ss for the job 10 | #SBATCH -e logs/slurm-%j.err 11 | #SBATCH -o logs/slurm-%j.out 12 | 13 | echo "job is starting on `hostname`" 14 | 15 | NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 16 | 17 | MODEL_SCALE='1.5B' 18 | REWARD_TYPE='sigmoid' 19 | ALPHA=0.1 # This controls the penalty for longer correct respones. Increase to penalize longer responses. 20 | WANDB_KEY="" # Provide your wandb key here before running 21 | CHECK_EOS='--check_eos' 22 | SCHEDULER_TYPE='warmup_with_constant_lr' # can be cosine otherwise 23 | 24 | PRETRAIN='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B' 25 | ACTOR_NUM_GPUS=2 26 | REF_NUM_GPUS=2 27 | VLLM_NUM_ENGINES=2 28 | ACTOR_LEARNING_RATE=5e-6 29 | INIT_KL_COEF=0.001 30 | MIN_P=0 31 | MAX_EPOCHS=1 32 | TOKENIZER='deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B' 33 | NUM_EPISODES=1 34 | GENERATE_MAX_LEN=32000 35 | SAVE_STEPS=10 36 | SEED=42 37 | 38 | RUN_NAME="scale:"$MODEL_SCALE"_alpha:"$ALPHA 39 | INPUT_KEY="problem" 40 | DATASET='datasets/compression_dataset' 41 | BASE_PROJECT_DIR=$PWD # Change this to the path of the project directory 42 | RM_ADDRESS="0.0.0.0:24372" 43 | SAVE_PATH="$BASE_PROJECT_DIR/$RUN_NAME" 44 | CKPT_PATH="$SAVE_PATH" 45 | 46 | echo "Using: ($DATASET) logging run to ($RUN_NAME)" 47 | 48 | # stop if any previous instances are running 49 | ray stop 50 | # launch the master node of ray in container 51 | ray start --head --node-ip-address 0.0.0.0 --num-gpus $NUM_GPUS --ray-debugger-external 52 | 53 | # launch reward server 54 | python -m reward_server.math_server \ 55 | --address $RM_ADDRESS \ 56 | --dataset $DATASET \ 57 | --tokenizer $TOKENIZER \ 58 | --reward_type $REWARD_TYPE \ 59 | --alpha $ALPHA \ 60 | $CHECK_EOS \ 61 | 1> logs/server$SLURM_JOB_ID.out 2> logs/server$SLURM_JOB_ID.err& 62 | 63 | python -m openrlhf.cli.train_ppo_ray \ 64 | --advantage_estimator rloo \ 65 | --n_samples_per_prompt 8 \ 66 | --max_epochs $MAX_EPOCHS \ 67 | --remote_rm_url http://$RM_ADDRESS/query \ 68 | --ref_num_nodes 1 \ 69 | --ref_num_gpus_per_node $REF_NUM_GPUS \ 70 | --actor_num_nodes 1 \ 71 | --actor_num_gpus_per_node $ACTOR_NUM_GPUS \ 72 | --vllm_num_engines $VLLM_NUM_ENGINES \ 73 | --vllm_tensor_parallel_size 1 \ 74 | --max_ckpt_num 10 \ 75 | --num_episodes $NUM_EPISODES \ 76 | --colocate_critic_reward \ 77 | --colocate_actor_ref \ 78 | --pretrain $PRETRAIN \ 79 | --wandb_run_name $RUN_NAME \ 80 | --save_path $SAVE_PATH \ 81 | --ckpt_path $CKPT_PATH \ 82 | --save_steps $SAVE_STEPS \ 83 | --prompt_data_probs 1.0 \ 84 | --scheduler_type $SCHEDULER_TYPE \ 85 | --min_p $MIN_P \ 86 | --micro_train_batch_size 1 \ 87 | --train_batch_size 128 \ 88 | --micro_rollout_batch_size 1 \ 89 | --rollout_batch_size 32 \ 90 | --max_samples 3200 \ 91 | --prompt_max_len 512 \ 92 | --generate_max_len $GENERATE_MAX_LEN \ 93 | --zero_stage 2 \ 94 | --bf16 \ 95 | --seed $SEED \ 96 | --actor_learning_rate $ACTOR_LEARNING_RATE \ 97 | --init_kl_coef $INIT_KL_COEF \ 98 | --prompt_data $DATASET \ 99 | --input_key $INPUT_KEY \ 100 | --input_template $'<|begin▁of▁sentence|><|User|>Please reason step by step, and put your final answer within \\boxed{{}}. Question: {}<|Assistant|>' \ 101 | --flash_attn \ 102 | --gradient_checkpointing \ 103 | --use_wandb $WANDB_KEY 104 | -------------------------------------------------------------------------------- /run_rloo_7B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --mem=300g 4 | #SBATCH --nodes=2 5 | #SBATCH --gpus-per-node=4 6 | #SBATCH --cpus-per-task=16 7 | #SBATCH --job-name=run 8 | #SBATCH --partition=ghx4 9 | #SBATCH --time=24:00:00 # hh:mm:ss for the job 10 | #SBATCH -e logs/slurm-%j.err 11 | #SBATCH -o logs/slurm-%j.out 12 | 13 | echo "job is starting on `hostname`" 14 | 15 | NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 16 | 17 | nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") 18 | nodes_array=($nodes) 19 | 20 | head_node=${nodes_array[0]} 21 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 22 | 23 | # if we detect a space character in the head node IP, we'll 24 | # convert it to an ipv4 address. This step is optional. 25 | if [[ "$head_node_ip" == *" "* ]]; then 26 | IFS=' ' read -ra ADDR <<<"$head_node_ip" 27 | if [[ ${#ADDR[0]} -gt 16 ]]; then 28 | head_node_ip=${ADDR[1]} 29 | else 30 | head_node_ip=${ADDR[0]} 31 | fi 32 | echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" 33 | fi 34 | # __doc_head_address_end__ 35 | 36 | # __doc_head_ray_start__ 37 | port=6379 38 | ip_head=$head_node_ip:$port 39 | export ip_head 40 | echo "IP Head: $ip_head" 41 | 42 | echo "Starting HEAD at $head_node" 43 | srun --nodes=1 --ntasks=1 -w "$head_node" \ 44 | ray start --head --node-ip-address="$head_node_ip" --port=$port \ 45 | --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${NUM_GPUS}" --block & 46 | # __doc_head_ray_end__ 47 | 48 | # __doc_worker_ray_start__ 49 | # optional, though may be useful in certain versions of Ray < 1.0. 50 | sleep 10 51 | 52 | # number of nodes other than the head node 53 | worker_num=$((SLURM_JOB_NUM_NODES - 1)) 54 | 55 | for ((i = 1; i <= worker_num; i++)); do 56 | node_i=${nodes_array[$i]} 57 | echo "Starting WORKER $i at $node_i" 58 | srun --nodes=1 --ntasks=1 -w "$node_i" \ 59 | ray start --address "$ip_head" \ 60 | --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${NUM_GPUS}" --block & 61 | sleep 5 62 | done 63 | 64 | 65 | MODEL_SCALE='7B' 66 | REWARD_TYPE='sigmoid' 67 | ALPHA=0.1 68 | WANDB_KEY="" 69 | CHECK_EOS='--check_eos' 70 | SCHEDULER_TYPE='warmup_with_constant_lr' # can be cosine otherwise 71 | 72 | 73 | PRETRAIN='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B' 74 | ACTOR_NUM_GPUS=4 75 | REF_NUM_GPUS=4 76 | VLLM_NUM_ENGINES=4 77 | ACTOR_LEARNING_RATE=2e-6 78 | INIT_KL_COEF=0.001 79 | MIN_P=0 80 | MAX_EPOCHS=1 81 | TOKENIZER='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B' 82 | NUM_EPISODES=1 83 | GENERATE_MAX_LEN=32000 84 | SAVE_STEPS=10 85 | SEED=42 86 | 87 | RUN_NAME="scale:"$MODEL_SCALE"_alpha:"$ALPHA 88 | INPUT_KEY="problem" 89 | DATASET='datasets/compression_dataset' 90 | BASE_PROJECT_DIR=$PWD 91 | ADDRESS=$head_node_ip:12345 92 | SAVE_PATH="$BASE_PROJECT_DIR/$RUN_NAME" 93 | CKPT_PATH="$SAVE_PATH" 94 | 95 | echo "Using: ($DATASET) logging run to ($RUN_NAME)" 96 | 97 | # launch gsm_server 98 | python -m reward_server.math_server \ 99 | --address $ADDRESS \ 100 | --dataset $DATASET \ 101 | --tokenizer $TOKENIZER \ 102 | --reward_type $REWARD_TYPE \ 103 | --alpha $ALPHA \ 104 | $CHECK_EOS \ 105 | 1> logs/server$SLURM_JOB_ID.out 2> logs/server$SLURM_JOB_ID.err& 106 | 107 | python -m openrlhf.cli.train_ppo_ray \ 108 | --advantage_estimator rloo \ 109 | --n_samples_per_prompt 8 \ 110 | --max_epochs $MAX_EPOCHS \ 111 | --remote_rm_url http://$ADDRESS/query \ 112 | --ref_num_nodes 1 \ 113 | --ref_num_gpus_per_node $REF_NUM_GPUS \ 114 | --actor_num_nodes 1 \ 115 | --actor_num_gpus_per_node $ACTOR_NUM_GPUS \ 116 | --vllm_num_engines $VLLM_NUM_ENGINES \ 117 | --vllm_tensor_parallel_size 1 \ 118 | --max_ckpt_num 10 \ 119 | --num_episodes $NUM_EPISODES \ 120 | --colocate_critic_reward \ 121 | --colocate_actor_ref \ 122 | --pretrain $PRETRAIN \ 123 | --wandb_run_name $RUN_NAME \ 124 | --save_path $SAVE_PATH \ 125 | --ckpt_path $CKPT_PATH \ 126 | --save_steps $SAVE_STEPS \ 127 | --prompt_data_probs 1.0 \ 128 | --scheduler_type $SCHEDULER_TYPE \ 129 | --min_p $MIN_P \ 130 | --micro_train_batch_size 1 \ 131 | --train_batch_size 128 \ 132 | --micro_rollout_batch_size 1 \ 133 | --rollout_batch_size 32 \ 134 | --max_samples 3200 \ 135 | --prompt_max_len 512 \ 136 | --generate_max_len $GENERATE_MAX_LEN \ 137 | --zero_stage 3 \ 138 | --bf16 \ 139 | --seed $SEED \ 140 | --actor_learning_rate $ACTOR_LEARNING_RATE \ 141 | --init_kl_coef $INIT_KL_COEF \ 142 | --prompt_data $DATASET \ 143 | --input_key $INPUT_KEY \ 144 | --input_template $'<|begin▁of▁sentence|><|User|>Please reason step by step, and put your final answer within \\boxed{{}}. Question: {}<|Assistant|>' \ 145 | --flash_attn \ 146 | --gradient_checkpointing \ 147 | --adam_offload \ 148 | --use_wandb $WANDB_KEY 149 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import platform 4 | 5 | from datetime import datetime 6 | from setuptools import find_packages, setup 7 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 8 | 9 | _build_mode = os.getenv("OPENRLHF_BUILD_MODE", "") 10 | 11 | 12 | def _is_nightly(): 13 | return _build_mode.lower() == "nightly" 14 | 15 | 16 | def _fetch_requirements(path): 17 | with open(path, "r") as fd: 18 | return [r.strip() for r in fd.readlines()] 19 | 20 | 21 | def _fetch_readme(): 22 | with open("README.md", encoding="utf-8") as f: 23 | return f.read() 24 | 25 | 26 | def _fetch_version(): 27 | with open("version.txt", "r") as f: 28 | version = f.read().strip() 29 | 30 | if _is_nightly(): 31 | now = datetime.now() 32 | date_str = now.strftime("%Y%m%d") 33 | version += f".dev{date_str}" 34 | 35 | return version 36 | 37 | 38 | def _fetch_package_name(): 39 | return "openrlhf-nightly" if _is_nightly() else "openrlhf" 40 | 41 | 42 | # Custom wheel class to modify the wheel name 43 | class bdist_wheel(_bdist_wheel): 44 | def finalize_options(self): 45 | _bdist_wheel.finalize_options(self) 46 | self.root_is_pure = False 47 | 48 | def get_tag(self): 49 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 50 | abi_tag = f"{python_version}" 51 | 52 | if platform.system() == "Linux": 53 | platform_tag = "manylinux1_x86_64" 54 | else: 55 | platform_tag = platform.system().lower() 56 | 57 | return python_version, abi_tag, platform_tag 58 | 59 | 60 | # Setup configuration 61 | setup( 62 | name=_fetch_package_name(), 63 | version=_fetch_version(), 64 | packages=find_packages( 65 | exclude=( 66 | "data", 67 | "docs", 68 | "examples", 69 | ) 70 | ), 71 | description="A Ray-based High-performance RLHF framework.", 72 | long_description=_fetch_readme(), 73 | long_description_content_type="text/markdown", 74 | install_requires=_fetch_requirements("requirements.txt"), 75 | extras_require={ 76 | "vllm": ["vllm==0.6.4.post1"], 77 | "vllm_latest": ["vllm>0.6.4.post1"], 78 | }, 79 | python_requires=">=3.10", 80 | classifiers=[ 81 | "Programming Language :: Python :: 3.10", 82 | "Programming Language :: Python :: 3.11", 83 | "Environment :: GPU :: NVIDIA CUDA", 84 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 85 | "Topic :: System :: Distributed Computing", 86 | ], 87 | cmdclass={"bdist_wheel": bdist_wheel}, 88 | ) 89 | -------------------------------------------------------------------------------- /static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 1rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | padding: 20px; 121 | font-size: 0; 122 | } 123 | 124 | .results-carousel video { 125 | margin: 0; 126 | } 127 | 128 | .slider-pagination .slider-page { 129 | background: #000000; 130 | } 131 | 132 | .eql-cntrb { 133 | font-size: smaller; 134 | } 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /static/images/Fast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/Fast.png -------------------------------------------------------------------------------- /static/images/Hierarchy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/Hierarchy.png -------------------------------------------------------------------------------- /static/images/Idea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/Idea.png -------------------------------------------------------------------------------- /static/images/Observation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/Observation.png -------------------------------------------------------------------------------- /static/images/Telescope.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/Telescope.png -------------------------------------------------------------------------------- /static/images/combined_dual_axis_plots_horizontal-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/combined_dual_axis_plots_horizontal-1.png -------------------------------------------------------------------------------- /static/images/combined_dual_axis_plots_horizontal.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/combined_dual_axis_plots_horizontal.pdf -------------------------------------------------------------------------------- /static/images/eqn-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/eqn-2.png -------------------------------------------------------------------------------- /static/images/eqn.drawio-3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/eqn.drawio-3.pdf -------------------------------------------------------------------------------- /static/images/eqn.drawio-4-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/eqn.drawio-4-1.png -------------------------------------------------------------------------------- /static/images/eqn.drawio-4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/eqn.drawio-4.pdf -------------------------------------------------------------------------------- /static/images/eqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/eqn.png -------------------------------------------------------------------------------- /static/images/gpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/gpt.png -------------------------------------------------------------------------------- /static/images/method-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/method-1.png -------------------------------------------------------------------------------- /static/images/method.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/method.pdf -------------------------------------------------------------------------------- /static/images/normalized_averages copy-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/normalized_averages copy-1.png -------------------------------------------------------------------------------- /static/images/normalized_averages copy.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/normalized_averages copy.pdf -------------------------------------------------------------------------------- /static/images/normalized_averages-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/normalized_averages-1.png -------------------------------------------------------------------------------- /static/images/normalized_averages.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/normalized_averages.pdf -------------------------------------------------------------------------------- /static/images/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/output.png -------------------------------------------------------------------------------- /static/images/perf_rm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/perf_rm.png -------------------------------------------------------------------------------- /static/images/pipeline.drawio-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/pipeline.drawio-5.png -------------------------------------------------------------------------------- /static/images/rej.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/rej.png -------------------------------------------------------------------------------- /static/images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/result.png -------------------------------------------------------------------------------- /static/images/spr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/spr.png -------------------------------------------------------------------------------- /static/images/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/images/train.png -------------------------------------------------------------------------------- /static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | 7 | var options = { 8 | slidesToScroll: 1, 9 | slidesToShow: 1, 10 | loop: true, 11 | infinite: true, 12 | autoplay: true, 13 | autoplaySpeed: 5000, 14 | } 15 | 16 | // Initialize all div with carousel class 17 | var carousels = bulmaCarousel.attach('.carousel', options); 18 | 19 | bulmaSlider.attach(); 20 | 21 | }) 22 | -------------------------------------------------------------------------------- /static/pdfs/sample.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/static/pdfs/sample.pdf -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import DATASET_KEYS, RESPONSE_COMPARATOR, RESPONSE_EXTRACTOR -------------------------------------------------------------------------------- /utils/latex2sympy/.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | include = 5 | latex2sympy.py 6 | omit = 7 | sandbox/* 8 | gen/* 9 | asciimath_printer.py 10 | setup.py 11 | __init__.py 12 | 13 | [report] 14 | # Regexes for lines to exclude from consideration 15 | exclude_lines = 16 | # Have to re-enable the standard pragma 17 | pragma: no cover 18 | 19 | # Don't complain about missing debug-only code: 20 | def __repr__ 21 | if self\.debug 22 | 23 | # Don't complain if tests don't hit defensive assertion code: 24 | raise AssertionError 25 | raise NotImplementedError 26 | 27 | # Don't complain if non-runnable code isn't run: 28 | if 0: 29 | if __name__ == .__main__.: 30 | 31 | ignore_errors = True 32 | -------------------------------------------------------------------------------- /utils/latex2sympy/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .antlr 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # Azure Functions artifacts 127 | bin 128 | obj 129 | appsettings.json 130 | local.settings.json 131 | .python_packages 132 | stemgen-solution-engine.zip -------------------------------------------------------------------------------- /utils/latex2sympy/LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright 2016, latex2sympy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/latex2sympy/README.md: -------------------------------------------------------------------------------- 1 | ![Logo](https://picgo-1258602555.cos.ap-nanjing.myqcloud.com/icon.png) 2 | 3 | # [latex2sympy2](https://github.com/OrangeX4/latex2sympy) 4 | 5 | ## About 6 | 7 | `latex2sympy2` parses **LaTeX math expressions** and converts it into the equivalent **SymPy form**. The latex2sympy2 is adapted from [augustt198/latex2sympy](https://github.com/augustt198/latex2sympy) and [purdue-tlt / latex2sympy](https://github.com/purdue-tlt/latex2sympy). 8 | 9 | This project is a part of a VS Code extension called [Latex Sympy Calculator](https://marketplace.visualstudio.com/items?itemName=OrangeX4.latex-sympy-calculator). It is designed for providing people writing in latex or markdown a ability to calculate something when writing math expression. 10 | 11 | [ANTLR](http://www.antlr.org/) is used to generate the parser. 12 | 13 | ## Features 14 | 15 | * **Arithmetic:** Add (+), Sub (-), Dot Mul (·), Cross Mul (×), Frac (/), Power (^), Abs (|x|), Sqrt (√), etc... 16 | * **Alphabet:** a - z, A - Z, α - ω, Subscript (x_1), Accent Bar(ā), etc... 17 | * **Common Functions:** gcd, lcm, floor, ceil, max, min, log, ln, exp, sin, cos, tan, csc, sec, cot, arcsin, sinh, arsinh, etc... 18 | * **Funcion Symbol:** f(x), f(x-1,), g(x,y), etc... 19 | * **Calculous:** Limit ($lim_{n\to\infty}$), Derivation ($\frac{d}{dx}(x^2+x)$), Integration ($\int xdx$), etc... 20 | * **Linear Algebra:** Matrix, Determinant, Transpose, Inverse, Elementary Transformation, etc... 21 | * **Other:** Binomial... 22 | 23 | **NOTICE:** It will do some irreversible calculations when converting determinants, transposed matrixes and elementary transformations... 24 | 25 | ## Installation 26 | 27 | ``` 28 | pip install latex2sympy2 29 | ``` 30 | 31 | **Requirements:** `sympy` and `antlr4-python3-runtime` packages. 32 | 33 | ## Usage 34 | 35 | ### Basic 36 | 37 | In Python: 38 | 39 | ```python 40 | from latex2sympy2 import latex2sympy, latex2latex 41 | 42 | tex = r"\frac{d}{dx}(x^{2}+x)" 43 | # Or you can use '\mathrm{d}' to replace 'd' 44 | latex2sympy(tex) 45 | # => "Derivative(x**2 + x, x)" 46 | latex2latex(tex) 47 | # => "2 x + 1" 48 | ``` 49 | 50 | ### Examples 51 | 52 | |LaTeX|Converted SymPy|Calculated Latex| 53 | |-----|-----|---------------| 54 | |`x^{3}` $x^{3}$| `x**3`|`x^{3}` $x^{3}$| 55 | |`\frac{d}{dx} tx` $\frac{d}{dx}tx$|`Derivative(x*t, x)`|`t` $t$| 56 | |`\sum_{i = 1}^{n} i` $\sum_{i = 1}^{n} i$|`Sum(i, (i, 1, n))`|`\frac{n \left(n + 1\right)}{2}` $\frac{n \left(n + 1\right)}{2}$| 57 | |`\int_{a}^{b} \frac{dt}{t}`|`Integral(1/t, (t, a, b))`|`-\log{(a)} + \log{(b)}` $-\log{(a)} + \log{(b)}$| 58 | |`(2x^3 - x + z)|_{x=3}` $(2x^3 - x + z)\|_{x=3}$|`z + 51`| `z + 51` $z + 51$ | 59 | 60 | If you want to read the math formula, you can click [GitNotes](https://notes.orangex4.cool/?git=github&github=OrangeX4/latex2sympy). 61 | 62 | ### Solve Equation 63 | 64 | ``` latex 65 | # Before 66 | x + y = 1 67 | 68 | # After 69 | [ y = 1 - x, \ x = 1 - y] 70 | ``` 71 | 72 | ### Eval At 73 | 74 | ``` latex 75 | # Before 76 | (x+2)|_{x=y+1} 77 | 78 | # After 79 | y + 3 80 | ``` 81 | 82 | ### Matrix 83 | 84 | #### Identity matrix 85 | 86 | ``` 87 | tex = r"\bm{I}_3" 88 | latex2sympy(tex) 89 | # => "Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]])" 90 | ``` 91 | 92 | #### Determinant 93 | 94 | ``` python 95 | from latex2sympy2 import latex2sympy 96 | 97 | tex = r"\begin{vmatrix} x & 0 & 0 \\ 0 & x & 0 \\ 0 & 0 & x \end{vmatrix}" 98 | latex2sympy(tex) 99 | # => "x^{3}" 100 | ``` 101 | 102 | #### Transpose 103 | 104 | ``` python 105 | from latex2sympy2 import latex2sympy 106 | 107 | tex = r"\begin{pmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{pmatrix}^T" 108 | # Or you can use "\begin{pmatrix}1&2&3\\4&5&6\\7&8&9\end{pmatrix}'" 109 | latex2sympy(tex) 110 | # => "Matrix([[1, 4, 7], [2, 5, 8], [3, 6, 9]])" 111 | ``` 112 | 113 | #### Elementary Transformation 114 | 115 | ``` python 116 | from latex2sympy2 import latex2sympy 117 | 118 | matrix = r''' 119 | \begin{pmatrix} 120 | 1 & 2 & 3 \\ 121 | 4 & 5 & 6 \\ 122 | 7 & 8 & 9 \\ 123 | \end{pmatrix} 124 | ''' 125 | 126 | # Scale the row with grammar "\xrightarrow{kr_n}" 127 | tex = matrix + r'\xrightarrow{3r_1}' 128 | latex2sympy(tex) 129 | # => "Matrix([[3, 6, 9], [4, 5, 6], [7, 8, 9]])" 130 | 131 | # Swap the cols with grammar "\xrightarrow{c_1<=>c_2}" 132 | # Of course, you can use "\leftrightarrow" to replace "<=>" 133 | tex = matrix + r'\xrightarrow{c_1<=>c_2}' 134 | latex2sympy(tex) 135 | # => "Matrix([[2, 1, 3], [5, 4, 6], [8, 7, 9]])" 136 | 137 | # Scale the second row and add it to the first row 138 | # with grammar "\xrightarrow{r_1+kr_2}" 139 | tex = matrix + r'\xrightarrow{r_1+kr_2}' 140 | latex2sympy(tex) 141 | # => "Matrix([[4*k + 1, 5*k + 2, 6*k + 3], [4, 5, 6], [7, 8, 9]])" 142 | 143 | # You can compose the transform with comma "," 144 | # and grammar "\xrightarrow[4r_3]{2r_1, 3r_2}" 145 | # Remember the priority of "{}" is higher than "[]" 146 | tex = matrix + r'\xrightarrow[4r_3]{2r_1, 3r_2}' 147 | latex2sympy(tex) 148 | # => "Matrix([[2, 4, 6], [12, 15, 18], [28, 32, 36]])" 149 | ``` 150 | 151 | ### Variances 152 | 153 | ``` python 154 | from latex2sympy2 import latex2sympy, variances, var, set_variances 155 | 156 | # Assign x a value of 1 157 | latex2sympy(r"x = 1") 158 | 159 | # Assign x a matrix symbol with dimension of n x m 160 | latex2sympy(r"x \in \mathbb{R}^{n \times m}") 161 | 162 | # Calculate x + y 163 | latex2sympy(r"x + y") 164 | # => "y + 1" 165 | 166 | # Get all variances 167 | print(variances) 168 | # => "{x: 1}" 169 | 170 | # Get variance of "x" 171 | print(var["x"]) 172 | # => "1" 173 | 174 | # Reset all variances 175 | set_variances({}) 176 | latex2sympy(r"x + y") 177 | # => "x + y" 178 | ``` 179 | 180 | ### Complex Number Support 181 | 182 | ``` python 183 | from latex2sympy2 import set_real 184 | 185 | set_real(False) 186 | ``` 187 | 188 | 189 | ## Contributing 190 | 191 | If you want to add a new grammar, you can fork the code from [OrangeX4/latex2sympy](https://github.com/OrangeX4/latex2sympy). 192 | 193 | * To modify parser grammar, view the existing structure in `PS.g4`. 194 | * To modify the action associated with each grammar, look into `latex2sympy.py`. 195 | 196 | Contributors are welcome! Feel free to open a pull request or an issue. 197 | -------------------------------------------------------------------------------- /utils/latex2sympy/__init__.py: -------------------------------------------------------------------------------- 1 | import latex2sympy -------------------------------------------------------------------------------- /utils/latex2sympy/antlr-4.11.1-complete.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/utils/latex2sympy/antlr-4.11.1-complete.jar -------------------------------------------------------------------------------- /utils/latex2sympy/asciimath_printer.py: -------------------------------------------------------------------------------- 1 | from sympy.printing.str import StrPrinter 2 | from sympy.core import S 3 | 4 | class AsciiMathPrinter(StrPrinter): 5 | 6 | def _print_Limit(self, expr): 7 | e, z = expr.args 8 | 9 | return "lim_(%s -> %s) %s" % (self._print(z), self._print(z), self._print(e)) 10 | 11 | def _print_Integral(self, expr): 12 | e, lims = expr.args 13 | if len(lims) > 1: 14 | return "int_(%s)^(%s) %s d%s" % (self._print(lims[1]), self._print(lims[2]), self._print(e), self._print(lims[0])) 15 | else: 16 | return "int %s d%s" % (self._print(e), self._print(lims)) 17 | 18 | def _print_Sum(self, expr): 19 | e, lims = expr.args 20 | return "sum_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e)) 21 | 22 | def _print_Product(self, expr): 23 | e, lims = expr.args 24 | return "prod_(%s = %s)^(%s) %s" % (self._print(lims[0]), self._print(lims[1]), self._print(lims[2]), self._print(e)) 25 | 26 | def _print_factorial(self, expr): 27 | return "%s!" % self._print(expr.args[0]) 28 | 29 | def _print_Derivative(self, expr): 30 | e = expr.args[0] 31 | wrt = expr.args[1] 32 | return "d/d%s %s" % (self._print(wrt), self._print(e)) 33 | 34 | def _print_Abs(self, expr): 35 | return "|%s|" % self._print(expr.args[0]) 36 | 37 | def _print_Equality(self, expr): 38 | return "%s = %s" % (self._print(expr.args[0]), self._print(expr.args[1])) 39 | 40 | def _print_Pow(self, expr): 41 | b = self._print(expr.base) 42 | if expr.exp is S.Half: 43 | return "sqrt(%s)" % b 44 | 45 | if -expr.exp is S.Half: 46 | return "1/sqrt(%s)" % b 47 | if expr.exp is -S.One: 48 | return "1/%s" % b 49 | 50 | return "%s^(%s)" % (b, self._print(expr.exp)) 51 | -------------------------------------------------------------------------------- /utils/latex2sympy/description.txt: -------------------------------------------------------------------------------- 1 | latex2sympy2: https://github.com/OrangeX4/latex2sympy 2 | 3 | About 4 | 5 | `latex2sympy2` parses **LaTeX math expressions** and converts it into the equivalent **SymPy form**. The latex2sympy2 is adapted from [augustt198/latex2sympy](https://github.com/augustt198/latex2sympy) and [purdue-tlt / latex2sympy](https://github.com/purdue-tlt/latex2sympy). 6 | 7 | [ANTLR](http://www.antlr.org/) is used to generate the parser. 8 | 9 | Features 10 | 11 | * **Arithmetic:** Add (+), Sub (-), Dot Mul (·), Cross Mul (×), Frac (/), Power (^), Abs (|x|), Sqrt (√), etc... 12 | * **Alphabet:** a - z, A - Z, α - ω, Subscript (x_1), Accent Bar(ā), etc... 13 | * **Common Functions:** gcd, lcm, floor, ceil, max, min, log, ln, exp, sin, cos, tan, csc, sec, cot, arcsin, sinh, arsinh, etc... 14 | * **Calculous:** Limit ($lim_{n\to\infty}$), Derivation ($\frac{d}{dx}(x^2+x)$), Integration ($\int xdx$), etc... 15 | * **Linear Algebra:** Matrix, Determinant, Transpose, Inverse, Elementary Transformation, etc... 16 | * **Other:** Binomial... 17 | 18 | **NOTICE:** It will do some irreversible calculations when converting determinants, transposed matrixes and elementary transformations... 19 | 20 | Installation 21 | 22 | ``` 23 | pip install latex2sympy2 24 | ``` 25 | 26 | **Requirements:** `sympy` and `antlr4-python3-runtime` packages. 27 | 28 | Usage 29 | 30 | Basic 31 | 32 | In Python: 33 | 34 | ```python 35 | from latex2sympy2 import latex2sympy, latex2latex 36 | 37 | tex = r"\frac{d}{dx}(x^{2}+x)" 38 | # Or you can use '\mathrm{d}' to replace 'd' 39 | latex2sympy(tex) 40 | # => "Derivative(x**2 + x, x)" 41 | latex2latex(tex) 42 | # => "2 x + 1" 43 | ``` 44 | 45 | Examples 46 | 47 | |LaTeX|Converted SymPy|Calculated Latex| 48 | |-----|-----|---------------| 49 | |`x^{3}` $x^{3}$| `x**3`|`x^{3}` $x^{3}$| 50 | |`\frac{d}{dx} tx` $\frac{d}{dx}tx$|`Derivative(x*t, x)`|`t` $t$| 51 | |`\sum_{i = 1}^{n} i` $\sum_{i = 1}^{n} i$|`Sum(i, (i, 1, n))`|`\frac{n \left(n + 1\right)}{2}` $\frac{n \left(n + 1\right)}{2}$| 52 | |`\int_{a}^{b} \frac{dt}{t}`|`Integral(1/t, (t, a, b))`|`-\log{(a)} + \log{(b)}` $-\log{(a)} + \log{(b)}$| 53 | |`(2x^3 - x + z)|_{x=3}` $(2x^3 - x + z)\|_{x=3}$|`z + 51`| `z + 51` $z + 51$ | 54 | 55 | If you want to read the math formula, you can click [GitNotes](https://notes.orangex4.cool/?git=github&github=OrangeX4/latex2sympy). 56 | 57 | Matrix 58 | 59 | Determinant 60 | 61 | ``` python 62 | from latex2sympy2 import latex2sympy 63 | 64 | tex = r"\begin{vmatrix} x & 0 & 0 \\ 0 & x & 0 \\ 0 & 0 & x \end{vmatrix}" 65 | latex2sympy(tex) 66 | # => "x^{3}" 67 | ``` 68 | 69 | Transpose 70 | 71 | ``` python 72 | from latex2sympy2 import latex2sympy 73 | 74 | tex = r"\begin{pmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{pmatrix}^T" 75 | # Or you can use "\begin{pmatrix}1&2&3\\4&5&6\\7&8&9\end{pmatrix}'" 76 | latex2sympy(tex) 77 | # => "Matrix([[1, 4, 7], [2, 5, 8], [3, 6, 9]])" 78 | ``` 79 | 80 | Elementary Transformation 81 | 82 | ``` python 83 | from latex2sympy2 import latex2sympy 84 | 85 | matrix = r''' 86 | \begin{pmatrix} 87 | 1 & 2 & 3 \\ 88 | 4 & 5 & 6 \\ 89 | 7 & 8 & 9 \\ 90 | \end{pmatrix} 91 | ''' 92 | 93 | # Scale the row with grammar "\xrightarrow{kr_n}" 94 | tex = matrix + r'\xrightarrow{3r_1}' 95 | latex2sympy(tex) 96 | # => "Matrix([[3, 6, 9], [4, 5, 6], [7, 8, 9]])" 97 | 98 | # Swap the cols with grammar "\xrightarrow{c_1<=>c_2}" 99 | # Of course, you can use "\leftrightarrow" to replace "<=>" 100 | tex = matrix + r'\xrightarrow{c_1<=>c_2}' 101 | latex2sympy(tex) 102 | # => "Matrix([[2, 1, 3], [5, 4, 6], [8, 7, 9]])" 103 | 104 | # Scale the second row and add it to the first row 105 | # with grammar "\xrightarrow{r_1+kr_2}" 106 | tex = matrix + r'\xrightarrow{r_1+kr_2}' 107 | latex2sympy(tex) 108 | # => "Matrix([[4*k + 1, 5*k + 2, 6*k + 3], [4, 5, 6], [7, 8, 9]])" 109 | 110 | # You can compose the transform with comma "," 111 | # and grammar "\xrightarrow[4r_3]{2r_1, 3r_2}" 112 | # Remember the priority of "{}" is higher than "[]" 113 | tex = matrix + r'\xrightarrow[4r_3]{2r_1, 3r_2}' 114 | latex2sympy(tex) 115 | # => "Matrix([[2, 4, 6], [12, 15, 18], [28, 32, 36]])" 116 | ``` 117 | 118 | Variances 119 | 120 | ``` python 121 | from latex2sympy2 import latex2sympy, variances, var, set_variances 122 | 123 | # Assign x a value of 1 124 | latex2sympy(r"x = 1") 125 | 126 | # Calculate x + y 127 | latex2sympy(r"x + y") 128 | # => "y + 1" 129 | 130 | # Get all variances 131 | print(variances) 132 | # => "{x: 1}" 133 | 134 | # Get variance of "x" 135 | print(var["x"]) 136 | # => "1" 137 | 138 | # Reset all variances 139 | set_variances({}) 140 | latex2sympy(r"x + y") 141 | # => "x + y" 142 | ``` 143 | 144 | 145 | Contributing 146 | 147 | If you want to add a new grammar, you can fork the code from [OrangeX4/latex2sympy](https://github.com/OrangeX4/latex2sympy). 148 | 149 | * To modify parser grammar, view the existing structure in `PS.g4`. 150 | * To modify the action associated with each grammar, look into `latex2sympy.py`. 151 | 152 | Contributors are welcome! Feel free to open a pull request or an issue. 153 | -------------------------------------------------------------------------------- /utils/latex2sympy/dev-requirements.in: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | # Development 3 | pip-tools 4 | pytest 5 | pytest-cov 6 | pycodestyle 7 | autopep8 8 | -e . 9 | -------------------------------------------------------------------------------- /utils/latex2sympy/dev-requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.10 3 | # by the following command: 4 | # 5 | # pip-compile dev-requirements.in 6 | # 7 | # via -r dev-requirements.in 8 | antlr4-python3-runtime==4.11.1 9 | # via 10 | # -r requirements.txt 11 | # latex2sympy2 12 | atomicwrites==1.3.0 13 | # via pytest 14 | attrs==19.3.0 15 | # via pytest 16 | autopep8==1.4.4 17 | # via -r dev-requirements.in 18 | click==7.0 19 | # via pip-tools 20 | coverage==4.5.4 21 | # via pytest-cov 22 | more-itertools==7.2.0 23 | # via pytest 24 | mpmath==1.3.0 25 | # via 26 | # -r requirements.txt 27 | # sympy 28 | packaging==19.2 29 | # via pytest 30 | pip-tools==4.2.0 31 | # via -r dev-requirements.in 32 | pluggy==0.13.0 33 | # via pytest 34 | py==1.8.0 35 | # via pytest 36 | pycodestyle==2.5.0 37 | # via 38 | # -r dev-requirements.in 39 | # autopep8 40 | pyparsing==2.4.4 41 | # via packaging 42 | pytest==5.2.2 43 | # via 44 | # -r dev-requirements.in 45 | # pytest-cov 46 | pytest-cov==2.8.1 47 | # via -r dev-requirements.in 48 | six==1.13.0 49 | # via 50 | # packaging 51 | # pip-tools 52 | sympy==1.12 53 | # via 54 | # -r requirements.txt 55 | # latex2sympy2 56 | wcwidth==0.1.7 57 | # via pytest 58 | 59 | # THIS MUST BE MAINTAINED AS-IS 60 | -e . -------------------------------------------------------------------------------- /utils/latex2sympy/gen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/utils/latex2sympy/gen/__init__.py -------------------------------------------------------------------------------- /utils/latex2sympy/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/utils/latex2sympy/icon.png -------------------------------------------------------------------------------- /utils/latex2sympy/requirements.in: -------------------------------------------------------------------------------- 1 | sympy 2 | antlr4-python3-runtime 3 | -------------------------------------------------------------------------------- /utils/latex2sympy/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.10 3 | # by the following command: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | antlr4-python3-runtime==4.11.1 8 | # via -r requirements.in 9 | mpmath==1.3.0 10 | # via sympy 11 | sympy==1.12 12 | # via -r requirements.in 13 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/linalg_equations.py: -------------------------------------------------------------------------------- 1 | from latex2sympy import process_sympy 2 | import sys 3 | sys.path.append("..") 4 | 5 | # latex = "2\\begin{pmatrix}1&1&1\\\\0&1&1\\\\0&0&1\\end{pmatrix}\\begin{pmatrix}1&1&1\\\\0&1&1\\\\0&0&1\\end{pmatrix}" 6 | latex = "\\frac{a^{2} \\left(3 \\pi - 4 \\sin{\\left(\\pi \\right)} + \\frac{\\sin{\\left(2 \\pi \\right)}}{2}\\right)}{2}" 7 | math = process_sympy(latex) 8 | 9 | print(type(math)) 10 | print("latex: %s to math: %s" % (latex, math)) 11 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/linalg_span.py: -------------------------------------------------------------------------------- 1 | from latex2sympy import process_sympy 2 | import sys 3 | sys.path.append("..") 4 | 5 | latex = "\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix}" 6 | math = process_sympy(latex) 7 | print("latex: %s to math: %s" % (latex, math)) 8 | 9 | latex = "\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}" 10 | math = process_sympy(latex) 11 | print("latex: %s to math: %s" % (latex, math)) 12 | 13 | latex = "[\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}]" 14 | math = process_sympy(latex) 15 | print("latex: %s to math: %s" % (latex, math)) 16 | 17 | latex = "\\left\\{\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix}\\right\\}" 18 | math = process_sympy(latex) 19 | print("latex: %s to math: %s" % (latex, math)) 20 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/matrix.py: -------------------------------------------------------------------------------- 1 | from latex2sympy import process_sympy 2 | from sympy import * 3 | import sys 4 | sys.path.append("..") 5 | 6 | theta = Symbol('theta', real=True) 7 | 8 | latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}" 9 | math = process_sympy(latex) 10 | print("latex: %s to math: %s" % (latex, math)) 11 | 12 | latex = "\\begin{matrix}1&2\\\\3&4\\\\5&6\\end{matrix}" 13 | math = process_sympy(latex) 14 | print("latex: %s to math: %s" % (latex, math)) 15 | 16 | latex = "\\begin{matrix}1&2&3\\\\4&5&6\\\\7&8&9\\end{matrix}" 17 | math = process_sympy(latex) 18 | print("latex: %s to math: %s" % (latex, math)) 19 | 20 | latex = "\\begin{matrix}x^1&x^2&x^3\\\\y^1&y^2&y^3\\\\z^1&z^2&z^3\\end{matrix}" 21 | math = process_sympy(latex) 22 | print("latex: %s to math: %s" % (latex, math)) 23 | 24 | latex = "\\begin{matrix}x\\\\y\\end{matrix}" 25 | math = process_sympy(latex) 26 | print("latex: %s to math: %s" % (latex, math)) 27 | 28 | latex = "2\\cdot\\begin{matrix}x\\\\y\\end{matrix}" 29 | math = process_sympy(latex) 30 | print("latex: %s to math: %s" % (latex, math)) 31 | 32 | latex = "2\\cdot\\begin{matrix}x\\\\y\\end{matrix} + \\begin{matrix}2\\\\3\\end{matrix}" 33 | math = process_sympy(latex) 34 | print("latex: %s to math: %s" % (latex, math)) 35 | 36 | latex = "-2\\begin{matrix}1&2\\\\3&4\\end{matrix}" 37 | math = process_sympy(latex) 38 | print("latex: %s to math: %s" % (latex, math)) 39 | 40 | latex = "2\\cdot\\theta\\begin{matrix}x\\\\y\\end{matrix} + \\begin{matrix}2\\\\3\\end{matrix}" 41 | math = process_sympy(latex) 42 | print("latex: %s to math: %s" % (latex, math)) 43 | 44 | latex = "\\theta\\begin{matrix}1\\\\3\\end{matrix} - \\begin{matrix}-1\\\\2\\end{matrix}" 45 | math = process_sympy(latex) 46 | print("latex: %s to math: %s" % (latex, math)) 47 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/matrix_placeholders.py: -------------------------------------------------------------------------------- 1 | from latex2sympy import process_sympy 2 | from sympy import * 3 | import sys 4 | import hashlib 5 | import time 6 | 7 | sys.path.append("..") 8 | 9 | 10 | M = Matrix([[1, 2], [3, 4]]) 11 | v = Matrix([1, 2]) 12 | 13 | # sub settings 14 | sub_settings_symbols = {} 15 | sub_settings_symbols[Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), real=True)] = M 16 | sub_settings_symbols[Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), real=True)] = v 17 | 18 | 19 | # one parameters 20 | latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}\\cdot[!v!]" 21 | equation_sympy_check = MatMul(M, Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), real=True)) 22 | equation_sympy_subs_check = MatMul(M, v) 23 | # placeholders 24 | equation_sympy = process_sympy(latex) 25 | print('latex = %s' % latex) 26 | print('equation_sympy = %s' % equation_sympy) 27 | print('equation_sympy_check = %s' % equation_sympy_check) 28 | print('equation_sympy = %s' % (srepr(equation_sympy))) 29 | 30 | equation_sympy_subs = equation_sympy.subs(sub_settings_symbols, evaluate=False) 31 | print('equation_sympy_subs = %s' % equation_sympy_subs) 32 | print('equation_sympy_subs_check = %s' % equation_sympy_subs_check) 33 | 34 | 35 | # two parameters 36 | 37 | # sub settings 38 | print('') 39 | print('============== Two Parameters -> M*v = Matrix*Vector =============') 40 | sub_settings_symbols = {} 41 | sub_settings_symbols[Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), commutative=False)] = M 42 | sub_settings_symbols[Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), commutative=False)] = v 43 | 44 | latex = "[!M!]\\cdot[!v!]" 45 | math_check = Mul(Symbol('M' + hashlib.md5('M'.encode()).hexdigest(), commutative=False), Symbol('v' + hashlib.md5('v'.encode()).hexdigest(), commutative=False)) 46 | # placeholders 47 | equation_sympy = process_sympy(latex) 48 | print(latex) 49 | print(math_check) 50 | print(equation_sympy) 51 | print(srepr(equation_sympy)) 52 | 53 | # performance 54 | t0 = time.time() 55 | 56 | # process_sympy and substitute at the same time 57 | # Only needed for linalg input 58 | placeholder_values = {'M': M, 'v': v} 59 | equation_sympy_subs = process_sympy(latex, variable_values=placeholder_values) 60 | 61 | t1 = time.time() 62 | print('equation with substituted placeholders = %s' % (str(equation_sympy_subs))) 63 | print('time to process to sympy with placeholders = %s s' % (t1 - t0)) 64 | print('') 65 | print('============== Two Parameters -> M*v = Matrix*Vector =============') 66 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/sandbox.py: -------------------------------------------------------------------------------- 1 | from sympy import * 2 | from latex2sympy import process_sympy 3 | 4 | 5 | # latex = '\\variable{a}^{\\variable{b}}' 6 | # variables = {'a': process_sympy('658.95998'), 'b': process_sympy('185083.8060')} 7 | # c_ans_expr = process_sympy(latex, variables) 8 | # print(c_ans_expr) 9 | # print(srepr(c_ans_expr)) 10 | # c_ans = c_ans_expr.doit(deep=False).evalf(chop=True) 11 | # print(c_ans) 12 | # print(srepr(c_ans)) 13 | 14 | 15 | # numeric_responses = ['1', '1.0', '-1', '-1.0', '.5', '-.5', '3x10^3', '3E3', '3,000x10^{-3}', '0.5E-1', '\\frac{1}{3}', '(5\\times 3)^3', '\\sin(1)'] 16 | # for latex in numeric_responses: 17 | # parsed = process_sympy(latex) 18 | # print('latex: ', latex) 19 | # print('sympy: ', parsed) 20 | # print('is_number: ', parsed.is_number) 21 | # print('is_Number: ', parsed.is_Number) 22 | # print('srepr: ', srepr(parsed)) 23 | # print('-----------------------------------------------------') 24 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/sandbox_equality.py: -------------------------------------------------------------------------------- 1 | from sympy import * 2 | from latex2sympy import process_sympy 3 | 4 | 5 | # 6 | # Equality Testing 7 | # 8 | 9 | answer_sets = [ 10 | { 11 | 'correct_answer': '(x-y)(x+2y)', 12 | 'student_answers': [ 13 | 'x^2+xy-2y^2', 14 | '(x-y)(x+2y)', 15 | '(x+2y)(x-y)', 16 | '(2\\times y+x)(-y+x)', 17 | '(y\\cdot 2+x)(-y+x)' 18 | ] 19 | }, 20 | { 21 | 'correct_answer': '2\\pi \\variable{r}^2', 22 | 'student_answers': [ 23 | '2\\pi \\variable{r}^2', 24 | '\\pi 2\\variable{r}^2', 25 | '2\\times \\pi \\times \\variable{r}^2', 26 | '2\\pi \\variable{r} \\times \\variable{r}' 27 | ] 28 | }, 29 | { 30 | 'correct_answer': '2x - 3y', 31 | 'student_answers': [ 32 | '-3y + 2x' 33 | ] 34 | }, 35 | { 36 | 'correct_answer': 'x\\times x', 37 | 'student_answers': [ 38 | 'x\\times x', 39 | 'x\\cdot x', 40 | 'x^2', 41 | '(\\sqrt{x})^{4}' 42 | ] 43 | }, 44 | { 45 | 'correct_answer': '23e^{-1\\times \\sqrt{t^2}}', 46 | 'student_answers': [ 47 | '23e^{-t}' 48 | ] 49 | }, 50 | { 51 | 'correct_answer': 'a=x^2+1', 52 | 'student_answers': [ 53 | 'x^2+1=a' 54 | ] 55 | } 56 | ] 57 | 58 | for answer_set in answer_sets: 59 | correct_answer = answer_set['correct_answer'] 60 | correct_answer_parsed = process_sympy(answer_set['correct_answer']) 61 | for student_answer in answer_set['student_answers']: 62 | student_answer_parsed = process_sympy(student_answer) 63 | print('correct_answer (c): ', correct_answer, correct_answer_parsed) 64 | print('student_answer (a): ', student_answer, student_answer_parsed) 65 | print('') 66 | print('Expression Tree (srepr(c) == srepr(a)) =>', srepr(correct_answer_parsed) == srepr(student_answer_parsed)) 67 | print('srepr(c) =>', srepr(correct_answer_parsed)) 68 | print('srepr(a) =>', srepr(student_answer_parsed)) 69 | print('') 70 | # print('Structural (c == a) =>', correct_answer_parsed == student_answer_parsed) 71 | print('Symbolic (simplify(c - s) == 0) =>', simplify(correct_answer_parsed - student_answer_parsed) == 0) 72 | print('simplified =>', simplify(correct_answer_parsed - student_answer_parsed)) 73 | print('') 74 | print('Numeric Substitution (c.equals(s)) =>', correct_answer_parsed.equals(student_answer_parsed)) 75 | print('-----------------------------------------------------') 76 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/sectan.py: -------------------------------------------------------------------------------- 1 | from sympy import * 2 | import sys 3 | sys.path.append("..") 4 | 5 | # # x^2\cdot \left(3\cdot \tan \left([!a!]\cdot x+[!c!]\right)+[!a!]\cdot x\left(\sec \left([!a!]\cdot x+[!c!]\right)\right)^2\right) 6 | # latex1 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(\\sec \\left(2\\cdot x+5\\right)\\right)^2\\right)" 7 | # math1 = process_sympy(latex1) 8 | # print("latex: %s to math: %s" %(latex1,math1)) 9 | # 10 | # latex2 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(\\sec \\left(2\\cdot x+5\\right)^2\\right)\\right)" 11 | # math2 = process_sympy(latex2) 12 | # print("latex: %s to math: %s" %(latex2,math2)) 13 | # 14 | # latex3 = "x^2\\cdot \\left(3\\cdot \\tan \\left(2\\cdot x+5\\right)+2\\cdot x\\left(1+\\tan \\left(2\\cdot x+5\\right)^2\\right)\\right)" 15 | # math3 = process_sympy(latex3) 16 | # print("latex: %s to math: %s" %(latex3,math3)) 17 | # 18 | # print(simplify(math1 - math2)) 19 | # print(simplify(math1 - math3)) 20 | 21 | # 22 | # latex1 = "\\sec^2(2\\cdot x+5)" 23 | # math1 = process_sympy(latex1) 24 | # print("latex: %s to math: %s" %(latex1,math1)) 25 | # 26 | # latex2 = "1+\\tan^2(2\\cdot x+5)" 27 | # math2 = process_sympy(latex2) 28 | # print("latex: %s to math: %s" %(latex2,math2)) 29 | # print(simplify(math1 - math2)) 30 | 31 | 32 | x = Symbol('x', real=True) 33 | y = Symbol('y', real=True) 34 | 35 | # BUG: 1 + tan^2(x+1) should be == sec^2(x+1) but isnt 36 | lhs = (1 + (tan(x + 1))**2) 37 | rhs = (sec(x + 1))**2 38 | eq = lhs - rhs 39 | print(simplify(lhs)) 40 | print(simplify(rhs)) 41 | print(simplify(eq)) 42 | print(simplify(lhs) == simplify(rhs)) 43 | 44 | # 1 + tan^2(x) == sec^2(x) but isnt 45 | lhs = (1 + (tan(x))**2) 46 | rhs = (sec(x))**2 47 | eq = lhs - rhs 48 | print(simplify(lhs)) 49 | print(simplify(rhs)) 50 | print(simplify(eq)) 51 | print(simplify(lhs) == simplify(rhs)) 52 | -------------------------------------------------------------------------------- /utils/latex2sympy/sandbox/vector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sympy import * 3 | import sys 4 | sys.path.append("..") 5 | 6 | # row column matrix = vector 7 | v = [1, 2, 3] 8 | 9 | # single column matrix = vector 10 | m = Matrix([1, 2, 3]) 11 | print(m[:, 0]) 12 | 13 | # a three row and 2 column matrix 14 | m = Matrix([[1, 2], [3, 4], [5, 6]]) 15 | print(m[:, 0]) 16 | 17 | # determinant of lin indp system != 0 18 | m = Matrix([[1, 1], [1, 2]]) 19 | print(m.det()) 20 | 21 | # determinant of lin dep system = 0 22 | m = Matrix([[1, 1], [2, 2]]) 23 | print(m.det()) 24 | 25 | # determinant of lin dep system = 0 26 | x = Symbol('x') 27 | y = Symbol('y') 28 | m = Matrix([[x, y], [x, y]]) 29 | print(m.det()) 30 | # Reduced Row-Echelon Form 31 | _, ind = m.rref() 32 | print(len(ind)) 33 | 34 | # determinant of lin dep system != 0 35 | m = Matrix([[x, y], [y, x]]) 36 | print(m.det()) 37 | # Reduced Row-Echelon Form 38 | _, ind = m.rref() 39 | print(len(ind)) 40 | 41 | # determinant of lin dep system != 0 42 | # Reduced Row-Echelon Form 43 | m = Matrix([[x, x, y], [y, y, y]]) 44 | _, ind = m.rref() 45 | # Reduced Row-Echelon Form 46 | print(len(ind)) 47 | 48 | #==================# 49 | #===== Numpy ======# 50 | #==================# 51 | # http://kitchingroup.cheme.cmu.edu/blog/2013/03/01/Determining-linear-independence-of-a-set-of-vectors/ 52 | # Lin Indp of set of numerical vectors 53 | TOLERANCE = 1e-14 54 | v1 = [6, 0, 3, 1, 4, 2] 55 | v2 = [0, -1, 2, 7, 0, 5] 56 | v3 = [12, 3, 0, -19, 8, -11] 57 | 58 | A = np.row_stack([v1, v2, v3]) 59 | 60 | U, s, V = np.linalg.svd(A) 61 | print(s) 62 | print(np.sum(s > TOLERANCE)) 63 | 64 | v1 = [1, 1] 65 | v2 = [4, 4] 66 | 67 | A = np.row_stack([v1, v2]) 68 | U, s, V = np.linalg.svd(A) 69 | print(s) 70 | print(np.sum(s > TOLERANCE)) 71 | 72 | 73 | latex = "\\begin{matrix}1&2\\\\3&4\\end{matrix}" 74 | # math = process_sympy(latex) 75 | print("latex: %s to math: %s" % (latex, 1)) 76 | -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Get relative path of the root directory of the project 4 | rdir=`git rev-parse --git-dir` 5 | rel_path="$(dirname "$rdir")" 6 | # Change to that path and run the file 7 | cd $rel_path 8 | 9 | java -jar antlr-4.11.1-complete.jar PS.g4 -o gen 10 | -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/coverage-ci.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | pytest --doctest-modules --junitxml=junit/test-results.xml --cov-report=xml --cov-config=.coveragerc --cov=latex2sympy tests -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Get relative path of the root directory of the project 4 | rdir=`git rev-parse --git-dir` 5 | rel_path="$(dirname "$rdir")" 6 | # Change to that path and run the file 7 | cd $rel_path 8 | 9 | # Activate virtual environment 10 | echo "activating venv..." 11 | if test -f .env/bin/activate 12 | then source .env/bin/activate && echo "venv activate (bin)" 13 | elif test -f .env/Scripts/activate 14 | then source .env/Scripts/activate && echo "venv activated (Scripts)" 15 | else exit 1 16 | fi 17 | 18 | # Run unit test coverage 19 | echo "starting coverage..." 20 | if pytest --doctest-modules --cov-report=html --cov-config=.coveragerc --cov=latex2sympy tests 21 | then echo "coverage finished" 22 | else exit 1 23 | fi 24 | -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Get relative path of the root directory of the project 4 | rdir=`git rev-parse --git-dir` 5 | rel_path="$(dirname "$rdir")" 6 | 7 | # Change to that path and run the file 8 | cd $rel_path 9 | 10 | echo "pre-commit hook started..." 11 | 12 | # Activate virtual environment 13 | echo "activating venv..." 14 | if test -f .env/bin/activate 15 | then source .env/bin/activate && echo "venv activated." 16 | elif test -f .env/Scripts/activate 17 | then source .env/Scripts/activate && echo "venv activated." 18 | else exit 1 19 | fi 20 | 21 | # Run auto formatting on all staged python files, then add those changes 22 | echo "auto-formatting code..." 23 | if autopep8 --in-place `git diff --name-status --cached | grep '.py' | awk 'match($1, "A|M"){print $2}'` && git add `git diff --name-status --cached | grep '.py' | awk 'match($1, "A|M"){print $2}'` 24 | then echo "code was auto-formatted." 25 | else echo "no code was auto-formatted." 26 | fi 27 | 28 | exit 0 29 | -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/pre-push: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Get relative path of the root directory of the project 4 | rdir=`git rev-parse --git-dir` 5 | rel_path="$(dirname "$rdir")" 6 | 7 | # Change to that path and run the file 8 | cd $rel_path 9 | 10 | echo "pre-push hook started..." 11 | 12 | # Activate virtual environment 13 | echo "activating venv..." 14 | if test -f .env/bin/activate 15 | then source .env/bin/activate && echo "venv activated." 16 | elif test -f .env/Scripts/activate 17 | then source .env/Scripts/activate && echo "venv activated." 18 | else exit 1 19 | fi 20 | 21 | # Run unit tests 22 | echo "starting tests..." 23 | # if pytest tests 24 | # then echo "tests finished." 25 | # else exit 1 26 | # fi 27 | 28 | exit 0 29 | -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/publish.sh: -------------------------------------------------------------------------------- 1 | rm ./dist/* 2 | python3 setup.py bdist_wheel 3 | twine upload dist/* 4 | -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/setup-hooks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | cp scripts/pre-push .git/hooks/ 3 | cp scripts/pre-commit .git/hooks/ -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Get relative path of the root directory of the project 4 | rdir=`git rev-parse --git-dir` 5 | rel_path="$(dirname "$rdir")" 6 | # Change to that path and run the file 7 | cd $rel_path 8 | 9 | echo "creating venv..." 10 | if test -d .env 11 | then echo "venv exists" 12 | else python3 -m venv .env && echo "venv created" 13 | fi 14 | 15 | echo '' 16 | # Activate virtual environment 17 | echo "activating venv..." 18 | if test -f .env/bin/activate 19 | then source .env/bin/activate && echo "venv activate (bin)" 20 | elif test -f .env/Scripts/activate 21 | then source .env/Scripts/activate && echo "venv activated (Scripts)" 22 | else exit 1 23 | fi 24 | 25 | echo '' 26 | echo "installing requirements..." 27 | if pip install -r dev-requirements.txt 28 | then echo "requirements installed" 29 | else exit 1 30 | fi 31 | 32 | echo '' 33 | echo "compiling parser..." 34 | sh scripts/compile.sh 35 | echo "parser compiled" 36 | 37 | echo '' 38 | echo "setup git hooks..." 39 | sh scripts/setup-hooks.sh 40 | echo "git hooks setup" 41 | 42 | exit 0 43 | -------------------------------------------------------------------------------- /utils/latex2sympy/scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Get relative path of the root directory of the project 4 | rdir=`git rev-parse --git-dir` 5 | rel_path="$(dirname "$rdir")" 6 | # Change to that path and run the file 7 | cd $rel_path 8 | 9 | # Activate virtual environment 10 | echo "activating venv..." 11 | if test -f .env/bin/activate 12 | then source .env/bin/activate && echo "venv activate (bin)" 13 | elif test -f .env/Scripts/activate 14 | then source .env/Scripts/activate && echo "venv activated (Scripts)" 15 | else exit 1 16 | fi 17 | 18 | echo '' 19 | echo "compiling parser..." 20 | sh scripts/compile.sh 21 | echo "parser compiled" 22 | 23 | echo '' 24 | # Run unit tests 25 | echo "starting tests..." 26 | if pytest tests 27 | then echo "tests finished" 28 | else exit 1 29 | fi 30 | 31 | exit 0 32 | -------------------------------------------------------------------------------- /utils/latex2sympy/setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | max-line-length = 120 3 | ignore = E501 4 | -------------------------------------------------------------------------------- /utils/latex2sympy/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from codecs import open 3 | from os import path 4 | here = path.abspath(path.dirname(__file__)) 5 | 6 | 7 | setup( 8 | name="latex2sympy2", 9 | version="1.9.0", 10 | description='Convert latex to sympy with ANTLR and support Matrix, Linear Algebra and CAS functions.', 11 | long_description_content_type='text/markdown', 12 | long_description=open(path.join(here, "README.md"), encoding='utf-8').read(), 13 | # The project's main homepage. 14 | url='https://github.com/ZubinGou/latex2sympy', 15 | # Author details 16 | author='ZubinGou', 17 | author_email='zebgou@gmail.com', 18 | # Choose your license 19 | license='MIT', 20 | classifiers=[ 21 | 'Development Status :: 4 - Beta', 22 | 'Intended Audience :: Developers', 23 | 'Intended Audience :: Education', 24 | 'Intended Audience :: Science/Research', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Topic :: Education', 27 | 'Topic :: Scientific/Engineering :: Mathematics', 28 | 'Topic :: Software Development :: Compilers', 29 | 'Topic :: Text Processing :: Markup :: LaTeX', 30 | 'Topic :: Text Processing :: Markup :: Markdown', 31 | 'Programming Language :: Python :: 3', 32 | 'Programming Language :: Python :: 3.3', 33 | 'Programming Language :: Python :: 3.4', 34 | 'Programming Language :: Python :: 3.5', 35 | 'Programming Language :: Python :: 3.6', 36 | 'Programming Language :: Python :: 3.7', 37 | 'Programming Language :: Python :: 3.8', 38 | ], 39 | packages=find_packages(exclude=('tests')), 40 | py_modules=['asciimath_printer', 'latex2sympy2'], 41 | install_requires=[ 42 | 'sympy>=1.4', 43 | 'antlr4-python3-runtime==4.11.1' 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zanette-Labs/efficient-reasoning/e3611e4b37e792a65e78a2d21fa3202a01a38306/utils/latex2sympy/tests/__init__.py -------------------------------------------------------------------------------- /utils/latex2sympy/tests/abs_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal, get_simple_examples 2 | import pytest 3 | from sympy import Abs 4 | 5 | examples = get_simple_examples(Abs) 6 | 7 | delimiter_pairs = { 8 | '|': '|', 9 | '\\vert': '\\vert', 10 | '\\lvert': '\\rvert' 11 | } 12 | 13 | 14 | @pytest.mark.parametrize('input, output, symbolically', examples) 15 | def test_abs(input, output, symbolically): 16 | for left, right in delimiter_pairs.items(): 17 | assert_equal("{left}{input}{right}".format(left=left, right=right, input=input), output, symbolically=symbolically) 18 | assert_equal("\\left{left}{input}\\right{right}".format(left=left, right=right, input=input), output, symbolically=symbolically) 19 | assert_equal("\\mleft{left}{input}\\mright{right}".format(left=left, right=right, input=input), output, symbolically=symbolically) 20 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/all_bad_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal, process_sympy 2 | import pytest 3 | 4 | 5 | def pytest_generate_tests(metafunc): 6 | metafunc.parametrize('s', metafunc.cls.BAD_STRINGS) 7 | 8 | 9 | class TestAllBad(object): 10 | # These bad latex strings should raise an exception when parsed 11 | BAD_STRINGS = [ 12 | "(", 13 | ")", 14 | # "a / b /", 15 | "\\frac{d}{dx}", 16 | "(\\frac{d}{dx})" 17 | "\\sqrt{}", 18 | "\\sqrt", 19 | "{", 20 | "}", 21 | # "1.1.1", 22 | "\\mathit{TEST}" 23 | "\\frac{2}{}", 24 | "\\frac{}{2}", 25 | "\\int", 26 | # "1 +", 27 | # "a +", 28 | "!", 29 | "!0", 30 | "_", 31 | "^", 32 | # "a // b", 33 | # "a \\cdot \\cdot b", 34 | # "a \\div \\div b", 35 | "a\\mod \\begin{matrix}b\\end{matrix}" 36 | "|", 37 | "||x|", 38 | "\\lfloor x", 39 | "\\lfloor a \\rceil", 40 | "\\operatorname{floor}(12.3, 123.4)", 41 | "()", 42 | "((((((((((((((((()))))))))))))))))", 43 | "-", 44 | "\\frac{d}{dx} + \\frac{d}{dt}", 45 | # "f()", 46 | # "f(,", 47 | # "f(x,,y)", 48 | # "f(x,y,", 49 | "\\sin^x", 50 | "\\cos^2", 51 | # "\\cos 1 \\cos", 52 | # "\\gcd(3)", 53 | # "\\lcm(2)", 54 | "@", "#", "$", "%", "&", "*", 55 | "\\", 56 | "~", 57 | "\\frac{(2 + x}{1 - x)}", 58 | "\\lim_{\\pi \\to 3} a", 59 | # because mix of COMMA and SEMICOLON 60 | "\\left\\{\\begin{pmatrix}1\\\\2\\\\3\\end{pmatrix},\\begin{pmatrix}4\\\\3\\\\1\\end{pmatrix};\\begin{pmatrix}1\\\\1\\\\1\\end{pmatrix}\\right\\}", 61 | # percentages without numbers before-hand 62 | "a\\%", 63 | "\\%100", 64 | # dollar signs without numbers after 65 | "\\$" 66 | ] 67 | 68 | def test_bad_string(self, s): 69 | with pytest.raises(Exception): 70 | process_sympy(s) 71 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/atom_expr_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import Symbol, Integer, Pow 4 | 5 | # label, text, symbol_text 6 | symbols = [ 7 | ('letter', 'x', 'x'), 8 | ('greek letter', '\\lambda', 'lambda'), 9 | ('greek letter w/ space', '\\alpha ', 'alpha'), 10 | ('accented letter', '\\overline{x}', 'xbar') 11 | ] 12 | 13 | subscripts = [ 14 | ('2'), 15 | ('{23}'), 16 | ('i'), 17 | ('{ij}'), 18 | ('{i,j}'), 19 | ('{good}'), 20 | ('{x^2}') 21 | ] 22 | 23 | examples = [] 24 | for symbol in symbols: 25 | for subscript in subscripts: 26 | examples.append(tuple(list(symbol) + [subscript])) 27 | 28 | 29 | @pytest.mark.parametrize('label, text, symbol_text, subscript', examples) 30 | def test_with_supexpr(label, text, symbol_text, subscript): 31 | assert_equal(text + '^2', Pow(Symbol(symbol_text, real=True), Integer(2))) 32 | 33 | 34 | @pytest.mark.parametrize('label, text, symbol_text, subscript', examples) 35 | def test_with_subexpr(label, text, symbol_text, subscript): 36 | assert_equal(text + '_' + subscript, Symbol(symbol_text + '_' + subscript, real=True)) 37 | 38 | 39 | @pytest.mark.parametrize('label, text, symbol_text, subscript', examples) 40 | def test_with_subexpr_before_supexpr(label, text, symbol_text, subscript): 41 | assert_equal(text + '_' + subscript + '^2', Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2))) 42 | 43 | 44 | @pytest.mark.parametrize('label, text, symbol_text, subscript', examples) 45 | def test_with_subexpr_before_supexpr_with_braces(label, text, symbol_text, subscript): 46 | wrapped_subscript = subscript if '{' in subscript else '{' + subscript + '}' 47 | assert_equal(text + '_' + wrapped_subscript + '^{2}', Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2))) 48 | 49 | 50 | @pytest.mark.parametrize('label, text, symbol_text, subscript', examples) 51 | def test_with_supexpr_before_subexpr(label, text, symbol_text, subscript): 52 | assert_equal(text + '^2_' + subscript, Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2))) 53 | 54 | 55 | @pytest.mark.parametrize('label, text, symbol_text, subscript', examples) 56 | def test_with_supexpr_before_subexpr_with_braces(label, text, symbol_text, subscript): 57 | wrapped_subscript = subscript if '{' in subscript else '{' + subscript + '}' 58 | assert_equal(text + '^{2}_' + wrapped_subscript, Pow(Symbol(symbol_text + '_' + subscript, real=True), Integer(2))) 59 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/binomial_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal, _Add, _Mul, _Pow 2 | import pytest 3 | from sympy import binomial, Symbol 4 | 5 | x = Symbol('x', real=True) 6 | y = Symbol('y', real=True) 7 | theta = Symbol('theta', real=True) 8 | gamma = Symbol('gamma', real=True) 9 | 10 | 11 | def test_binomial_numeric(): 12 | assert_equal("\\binom{16}{2}", binomial(16, 2)) 13 | 14 | 15 | def test_binomial_symbols(): 16 | assert_equal("\\binom{x}{y}", binomial(x, y)) 17 | 18 | 19 | def test_binomial_greek_symbols(): 20 | assert_equal("\\binom{\\theta}{\\gamma}", binomial(theta, gamma)) 21 | 22 | 23 | def test_binomial_expr(): 24 | assert_equal("\\binom{16+2}{\\frac{4}{2}}", binomial(_Add(16, 2), _Mul(4, _Pow(2, -1)), evaluate=False)) 25 | 26 | 27 | def test_choose_numeric(): 28 | assert_equal("\\choose{16}{2}", binomial(16, 2)) 29 | 30 | 31 | def test_choose_symbols(): 32 | assert_equal("\\choose{x}{y}", binomial(x, y)) 33 | 34 | 35 | def test_choose_greek_symbols(): 36 | assert_equal("\\choose{\\theta}{\\gamma}", binomial(theta, gamma)) 37 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/ceil_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal, get_simple_examples 2 | import pytest 3 | from sympy import ceiling 4 | 5 | examples = get_simple_examples(ceiling) 6 | 7 | 8 | @pytest.mark.parametrize('input, output, symbolically', examples) 9 | def test_ceil_func(input, output, symbolically): 10 | assert_equal("\\ceil({input})".format(input=input), output, symbolically=symbolically) 11 | 12 | 13 | @pytest.mark.parametrize('input, output, symbolically', examples) 14 | def test_ceil_operatorname(input, output, symbolically): 15 | assert_equal("\\operatorname{{ceil}}({input})".format(input=input), output, symbolically=symbolically) 16 | 17 | 18 | @pytest.mark.parametrize('input, output, symbolically', examples) 19 | def test_ceil_cmd(input, output, symbolically): 20 | assert_equal("\\lceil {input}\\rceil".format(input=input), output, symbolically=symbolically) 21 | assert_equal("\\left\\lceil {input}\\right\\rceil".format(input=input), output, symbolically=symbolically) 22 | assert_equal("\\mleft\\lceil {input}\\mright\\rceil".format(input=input), output, symbolically=symbolically) 23 | 24 | 25 | @pytest.mark.parametrize('input, output, symbolically', examples) 26 | def test_ceil_corners(input, output, symbolically): 27 | assert_equal("\\ulcorner {input}\\urcorner".format(input=input), output, symbolically=symbolically) 28 | assert_equal("\\left\\ulcorner {input}\\right\\urcorner".format(input=input), output, symbolically=symbolically) 29 | assert_equal("\\mleft\\ulcorner {input}\\mright\\urcorner".format(input=input), output, symbolically=symbolically) 30 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/complex_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import Sum, I, Symbol, Integer 4 | 5 | a = Symbol('a', real=True) 6 | b = Symbol('b', real=True) 7 | i = Symbol('i', real=True) 8 | n = Symbol('n', real=True) 9 | x = Symbol('x', real=True) 10 | 11 | 12 | def test_complex(): 13 | assert_equal("a+Ib", a + I * b) 14 | 15 | 16 | def test_complex_e(): 17 | assert_equal("e^{I\\pi}", Integer(-1)) 18 | 19 | 20 | def test_complex_sum(): 21 | assert_equal("\\sum_{i=0}^{n} i \\cdot x", Sum(i * x, (i, 0, n))) 22 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/context.py: -------------------------------------------------------------------------------- 1 | from sympy import simplify, srepr, Add, Mul, Pow, Rational, pi, sqrt, Symbol 2 | from latex2sympy.latex2sympy2 import latex2sympy as process_sympy 3 | import sys 4 | import os 5 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 6 | 7 | x = Symbol('x', real=True) 8 | y = Symbol('y', real=True) 9 | 10 | # shorthand definitions 11 | 12 | 13 | def _Add(a, b): 14 | return Add(a, b, evaluate=False) 15 | 16 | 17 | def _Mul(a, b): 18 | return Mul(a, b, evaluate=False) 19 | 20 | 21 | def _Pow(a, b): 22 | return Pow(a, b, evaluate=False) 23 | 24 | 25 | def get_simple_examples(func): 26 | ''' 27 | Returns an array of tuples, containing the string `input`, sympy `output` using the provided sympy `func`, and `symbolically` boolean 28 | for calling `compare`. 29 | ''' 30 | return [ 31 | ("1.1", func(1.1), False), 32 | ("6.9", func(6.9), False), 33 | ("3.5", func(3.5), False), 34 | ("8", func(8), False), 35 | ("0", func(0), False), 36 | ("290348E32", func(Rational('290348E32')), False), 37 | ("1237.293894239480234", func(Rational('1237.293894239480234')), False), 38 | ("8623.4592104E-2", func(Rational('8623.4592104E-2')), False), 39 | ("\\pi ", func(pi), False), 40 | ("\\sqrt{100}", func(sqrt(100)), False), 41 | ("12,123.4", func(Rational('12123.4')), False), 42 | ("-9.4", func(-9.4), False), 43 | ("-35.9825", func(-35.9825), False), 44 | ("-\\sqrt{5}", func(-sqrt(5)), False), 45 | ("-324E-3", func(Rational('-324E-3')), False), 46 | ("-0.23", func(-0.23), False), 47 | ("\\frac{1}{2}", func(Rational('1/2')), False), 48 | ("\\frac{6}{2}", func(Rational('6/2')), False), 49 | ("\\frac{9}{5}", func(Rational('9/5')), False), 50 | ("\\frac{-42}{6}", func(Rational('-42/6')), False), 51 | ("-\\frac{325}{3}", func(Rational('-325/3')), False), 52 | ("\\frac{\\pi }{2}", func(pi / 2), False), 53 | ("(1+6)/3", func(Rational(1 + 6, 3)), False), 54 | ("1+6/3", func(1 + Rational('6/3')), False), 55 | ("7*4/5", func(7 * 4 / 5), False), 56 | ("15-2.3", func(15 - Rational('2.3')), False), 57 | ("x", func(x), True), 58 | ("x + y", func(x + y), True), 59 | ("\\frac{9x}{4}", func(9 * x / 4), True), 60 | ("y\\pi", func(y * pi), True), 61 | ("2y-y-y", func(2 * y - y - y), True) 62 | ] 63 | 64 | 65 | def compare(actual, expected, symbolically=False): 66 | if symbolically: 67 | assert simplify(actual - expected) == 0 68 | else: 69 | actual_exp_tree = srepr(actual) 70 | expected_exp_tree = srepr(expected) 71 | try: 72 | assert actual_exp_tree == expected_exp_tree 73 | except Exception: 74 | if isinstance(actual, int) or isinstance(actual, float) or actual.is_number and isinstance(expected, int) or isinstance(expected, float) or expected.is_number: 75 | assert actual == expected or actual - expected == 0 or simplify(actual - expected) == 0 76 | else: 77 | print('expected_exp_tree = ', expected_exp_tree) 78 | print('actual exp tree = ', actual_exp_tree) 79 | raise 80 | 81 | 82 | def assert_equal(latex, expr, variable_values={}, symbolically=False): 83 | parsed = process_sympy(latex, variable_values) 84 | compare(parsed, expr, symbolically) 85 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/exp_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import exp, sin, Symbol, E 4 | 5 | x = Symbol('x', real=True) 6 | y = Symbol('y', real=True) 7 | 8 | 9 | def test_exp_letter(): 10 | assert_equal("e", E) 11 | assert_equal("e", exp(1)) 12 | 13 | 14 | def test_exp_func(): 15 | assert_equal("\\exp(3)", exp(3)) 16 | 17 | 18 | def test_exp_func_no_delim(): 19 | assert_equal("\\exp3", exp(3)) 20 | 21 | 22 | def test_exp_command_symbol(): 23 | assert_equal("\\exponentialE", E) 24 | assert_equal("\\exponentialE", exp(1)) 25 | 26 | 27 | def test_exp_command_symbol_expression(): 28 | assert_equal("\\exponentialE^{3}", exp(3)) 29 | 30 | 31 | def test_exp_command_symbol_multiplied(): 32 | ''' 33 | \\exponentialE is NOT a function, so using the following notation equates to multiplication 34 | ''' 35 | assert_equal("\\exponentialE (3)", E * 3) 36 | assert_equal("\\exponentialE \\left( 3\\right)", E * 3) 37 | assert_equal("\\exponentialE \\times 3", E * 3) 38 | 39 | 40 | def test_exp_numeric(): 41 | assert_equal("e^3", exp(3)) 42 | 43 | 44 | def test_exp_symbol(): 45 | assert_equal("e^x", exp(x)) 46 | 47 | 48 | def test_exp_symbol_expr(): 49 | assert_equal("e^{x+y}", exp(x + y)) 50 | 51 | 52 | def test_exp_symbol_expr_group(): 53 | assert_equal("e^{(x+y)}", exp(x + y)) 54 | 55 | 56 | def test_exp_expr(): 57 | assert_equal("\\sin(x)*e^x", sin(x) * exp(x)) 58 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/floor_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal, get_simple_examples 2 | import pytest 3 | from sympy import floor 4 | 5 | examples = get_simple_examples(floor) 6 | 7 | 8 | @pytest.mark.parametrize('input, output, symbolically', examples) 9 | def test_floor_func(input, output, symbolically): 10 | assert_equal("\\floor({input})".format(input=input), output, symbolically=symbolically) 11 | 12 | 13 | @pytest.mark.parametrize('input, output, symbolically', examples) 14 | def test_floor_operatorname(input, output, symbolically): 15 | assert_equal("\\operatorname{{floor}}({input})".format(input=input), output, symbolically=symbolically) 16 | 17 | 18 | @pytest.mark.parametrize('input, output, symbolically', examples) 19 | def test_floor_cmd(input, output, symbolically): 20 | assert_equal("\\lfloor {input}\\rfloor".format(input=input), output, symbolically=symbolically) 21 | assert_equal("\\left\\lfloor {input}\\right\\rfloor".format(input=input), output, symbolically=symbolically) 22 | assert_equal("\\mleft\\lfloor {input}\\mright\\rfloor".format(input=input), output, symbolically=symbolically) 23 | 24 | 25 | @pytest.mark.parametrize('input, output, symbolically', examples) 26 | def test_floor_corners(input, output, symbolically): 27 | assert_equal("\\llcorner {input}\\lrcorner".format(input=input), output, symbolically=symbolically) 28 | assert_equal("\\left\\llcorner {input}\\right\\lrcorner".format(input=input), output, symbolically=symbolically) 29 | assert_equal("\\mleft\\llcorner {input}\\mright\\lrcorner".format(input=input), output, symbolically=symbolically) 30 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/greek_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import Symbol 4 | 5 | epsilon_upper = Symbol('char"000190', real=True) 6 | epsilon_lower = Symbol('epsilon', real=True) 7 | varepsilon = Symbol('varepsilon', real=True) 8 | 9 | 10 | def test_greek_epsilon(): 11 | assert_equal("\\epsilon", epsilon_lower) 12 | 13 | 14 | def test_greek_epsilon_upper(): 15 | assert_equal('\\char"000190', epsilon_upper) 16 | 17 | 18 | def test_greek_varepsilon(): 19 | assert_equal('\\varepsilon', varepsilon) 20 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/grouping_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal, _Pow, _Add, _Mul 2 | import pytest 3 | from sympy import Integral, sin, Symbol, Mul, Integer, Pow 4 | from latex2sympy.latex2sympy2 import latex2sympy as process_sympy 5 | 6 | a = Symbol('a', real=True) 7 | b = Symbol('b', real=True) 8 | x = Symbol('x', real=True) 9 | theta = Symbol('theta', real=True) 10 | 11 | 12 | func_arg_examples = [ 13 | ('\\int ', 'x dx', Integral(x, x)), 14 | ('\\sin', '\\theta ', sin(theta)) 15 | ] 16 | 17 | example_groups = [ 18 | ('1+2', '3-4', _Mul(_Add(1, 2), _Add(3, _Mul(-1, 4)))) 19 | ] 20 | 21 | modifiable_delimiter_pairs = { 22 | '(': ')', 23 | '\\lgroup': '\\rgroup', 24 | '\\{': '\\}', 25 | '\\lbrace': '\\rbrace', 26 | '[': ']', 27 | '\\lbrack': '\\rbrack', 28 | } 29 | 30 | 31 | @pytest.mark.parametrize('func, args, output', func_arg_examples) 32 | def test_func_arg_groupings(func, args, output): 33 | # none 34 | assert_equal("{func} {args}".format(func=func, args=args), output) 35 | # normal brace (not modifiable) 36 | assert_equal("{func}{{{args}}}".format(func=func, args=args), output) 37 | # rest of delimiters, with modifications 38 | for left, right in modifiable_delimiter_pairs.items(): 39 | assert_equal("{func}{left}{args}{right}".format(left=left, right=right, func=func, args=args), output) 40 | assert_equal("{func}\\left{left}{args}\\right{right}".format(left=left, right=right, func=func, args=args), output) 41 | assert_equal("{func}\\mleft{left}{args}\\mright{right}".format(left=left, right=right, func=func, args=args), output) 42 | 43 | 44 | @pytest.mark.parametrize('group1, group2, output', example_groups) 45 | def test_delimiter_groupings(group1, group2, output): 46 | # normal brace (not modifiable) 47 | assert_equal("{{{group1}}}{{{group2}}}".format(group1=group1, group2=group2), output) 48 | # rest of delimiters, with modifications 49 | for left, right in modifiable_delimiter_pairs.items(): 50 | assert_equal("{left}{group1}{right}{left}{group2}{right}".format(left=left, right=right, group1=group1, group2=group2), output) 51 | assert_equal("\\left{left}{group1}\\right{right}\\left{left}{group2}\\right{right}".format(left=left, right=right, group1=group1, group2=group2), output) 52 | assert_equal("\\mleft{left}{group1}\\mright{right}\\mleft{left}{group2}\\mright{right}".format(left=left, right=right, group1=group1, group2=group2), output) 53 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/left_right_cdot_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import sin, Symbol 4 | 5 | x = Symbol('x', real=True) 6 | 7 | 8 | def test_left_right_cdot(): 9 | assert_equal("\\sin\\left(x\\right)\\cdot x", sin(x) * x) 10 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/linalg_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import MatMul, Matrix 4 | 5 | 6 | def test_linalg_placeholder(): 7 | assert_equal("\\begin{pmatrix}1&2\\\\3&4\\end{pmatrix}\\cdot\\variable{v}", MatMul(Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), {'v': Matrix([1, 2])}) 8 | 9 | 10 | def test_linalg_placeholder_multiple(): 11 | assert_equal("\\variable{M}\\cdot\\variable{v}", MatMul(Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), {'M': Matrix([[1, 2], [3, 4]]), 'v': Matrix([1, 2])}) 12 | 13 | 14 | def test_linalg_placeholder_multiple_mul(): 15 | assert_equal("\\begin{pmatrix}3&-1\\end{pmatrix}\\cdot\\variable{M}\\cdot\\variable{v}", MatMul(Matrix([[3, -1]]), Matrix([[1, 2], [3, 4]]), Matrix([1, 2])), {'M': Matrix([[1, 2], [3, 4]]), 'v': Matrix([1, 2])}) 16 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/max_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import Symbol, Rational, Float, Max, sqrt, exp, pi, nsimplify 4 | 5 | x = Symbol('x', real=True) 6 | y = Symbol('y', real=True) 7 | z = Symbol('z', real=True) 8 | 9 | 10 | def test_max_usual(): 11 | assert_equal("\\max(1, 5)", Max(1, 5)) 12 | assert_equal("\\max(12, 4)", Max(12, 4)) 13 | assert_equal("\\max(109, 120)", Max(109, 120)) 14 | assert_equal("\\max(3, 3)", Max(3, 3)) 15 | assert_equal("\\max(0, 0)", Max(0, 0)) 16 | assert_equal("\\max(1)", Max(1)) 17 | assert_equal("\\max(1092198374, 290348E32)", Max(1092198374, Rational('290348E32'))) 18 | assert_equal("\\max(5, 2, 17, 4)", Max(5, 2, 17, 4)) 19 | 20 | 21 | def test_max_negative(): 22 | assert_equal("\\max(-9, 4)", Max(-9, 4)) 23 | assert_equal("\\max(4, -9)", Max(4, -9)) 24 | assert_equal("\\max(-7)", Max(-7)) 25 | assert_equal("\\max(-2, -2)", Max(-2, -2)) 26 | assert_equal("\\max(-324E-3, -58)", Max(Rational('-324E-3'), -58)) 27 | assert_equal("\\max(-1, 0, 1, -37, 42)", Max(-1, 0, 1, -37, 42)) 28 | 29 | 30 | def test_max_float(): 31 | assert_equal("\\max(\\pi, 3)", Max(pi, 3)) 32 | assert_equal("\\max(1234.56789, 1234.5678901)", Max(Rational('1234.56789'), Rational('1234.5678901'))) 33 | assert_equal("\\max(12.4, 9.5)", Max(12.4, 9.5)) 34 | assert_equal("\\max(6, 6.2)", Max(6, 6.2)) 35 | assert_equal("\\max(-98.7)", Max(-98.7)) 36 | assert_equal("\\max(7.1, 9)", Max(7.1, 9)) 37 | assert_equal("\\max(-21E-12, 0.00005)", Max(nsimplify(Rational('-21E-12')), Rational('0.00005')), symbolically=True) 38 | assert_equal("\\max(\\sqrt{3}, 0, 1)", Max(sqrt(3), 0, 1)) 39 | 40 | 41 | def test_max_fraction(): 42 | assert_equal("\\max(1/2, 1/4)", Max(Rational('1/2'), Rational('1/4'))) 43 | assert_equal("\\max(6/2, 3)", Max(Rational('6/2'), 3)) 44 | assert_equal("\\max(2/4, 1/2)", Max(Rational('2/4'), Rational('1/2'))) 45 | assert_equal("\\max(-12/5, 6.4)", Max(Rational('-12/5'), Rational('6.4'))) 46 | assert_equal("\\max(1/10)", Max(Rational('1/10'))) 47 | assert_equal("\\max(1.5, \\pi/2)", Max(Rational('1.5'), pi / 2, evaluate=False)) 48 | assert_equal("\\max(-4/3, -2/1, 0/9, -3)", Max(Rational('-4/3'), Rational('-2/1'), Rational('0/9'), -3)) 49 | 50 | 51 | def test_max_expr(): 52 | assert_equal("\\max((1+6)/3, 7)", Max(Rational(1 + 6, 3), 7)) 53 | assert_equal("\\max(58*9)", Max(58 * 9)) 54 | assert_equal("\\max(1+6/3, -5)", Max(1 + Rational('6/3'), -5)) 55 | assert_equal("\\max(7*4/5, 092) * 2", Max(7 * 4 / 5, 92) * 2) 56 | assert_equal("38+\\max(13, 15-2.3)", 38 + Max(13, 15 - Rational('2.3'))) 57 | assert_equal("\\sqrt{\\max(99.9999999999999, 100)}", sqrt(Max(Rational('99.9999999999999'), 100))) 58 | assert_equal("\\max(274/(5+2), \\exp(12.4), 1.4E2)", Max(Rational(274, 5 + 2), exp(Rational('12.4')), Rational('1.4E2'))) 59 | 60 | 61 | def test_max_symbol(): 62 | assert_equal("\\max(x)", Max(x), symbolically=True) 63 | assert_equal("\\max(x, y)", Max(x, y), symbolically=True) 64 | assert_equal("\\max(y, x)", Max(y, x), symbolically=True) 65 | assert_equal("\\max(x+y, y+x)", Max(x + y, y + x), symbolically=True) 66 | assert_equal("\\max(9x/4, z)", Max(9 * x / 4, z), symbolically=True) 67 | assert_equal("\\max(y\\pi, 9)", Max(y * pi, 9), symbolically=True) 68 | assert_equal("\\max(2y-y, y + 1)", Max(2 * y - y, y + 1), symbolically=True) 69 | assert_equal("\\max(z, y, x)", Max(z, y, x), symbolically=True) 70 | 71 | 72 | def test_max_multiarg(): 73 | assert_equal("\\max(1,2)", Max(1, 2)) 74 | assert_equal("\\max(9,876,543)", Max(9, 876, 543)) 75 | assert_equal("\\max(x, y,z)", Max(x, y, z), symbolically=True) 76 | assert_equal("\\max(5.8,7.4, 2.2,-10)", Max(Rational('5.8'), Rational('7.4'), Rational('2.2'), -10)) 77 | assert_equal("\\max(\\pi,12E2,84,\\sqrt{5},12/5)", Max(pi, Rational('12E2'), 84, sqrt(5), Rational('12/5'))) 78 | assert_equal("\\max(823,51)", Max(823, 51)) 79 | assert_equal("\\max(72*4,23, 9)", Max(72 * 4, 23, 9)) 80 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/min_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import Symbol, Rational, Float, Min, sqrt, exp, pi, nsimplify 4 | 5 | x = Symbol('x', real=True) 6 | y = Symbol('y', real=True) 7 | z = Symbol('z', real=True) 8 | 9 | 10 | def test_min_usual(): 11 | assert_equal("\\min(1, 5)", Min(1, 5)) 12 | assert_equal("\\min(12, 4)", Min(12, 4)) 13 | assert_equal("\\min(109, 120)", Min(109, 120)) 14 | assert_equal("\\min(3, 3)", Min(3, 3)) 15 | assert_equal("\\min(0, 0)", Min(0, 0)) 16 | assert_equal("\\min(1)", Min(1)) 17 | assert_equal("\\min(1092198374, 290348E32)", Min(1092198374, Rational('290348E32'))) 18 | assert_equal("\\min(5, 2, 17, 4)", Min(5, 2, 17, 4)) 19 | 20 | 21 | def test_min_negative(): 22 | assert_equal("\\min(-9, 4)", Min(-9, 4)) 23 | assert_equal("\\min(4, -9)", Min(4, -9)) 24 | assert_equal("\\min(-7)", Min(-7)) 25 | assert_equal("\\min(-2, -2)", Min(-2, -2)) 26 | assert_equal("\\min(-324E-3, -58)", Min(Rational('-324E-3'), -58)) 27 | assert_equal("\\min(-1, 0, 1, -37, 42)", Min(-1, 0, 1, -37, 42)) 28 | 29 | 30 | def test_min_float(): 31 | assert_equal("\\min(\\pi, 3)", Min(pi, 3)) 32 | assert_equal("\\min(1234.56789, 1234.5678901)", Min(Rational('1234.56789'), Rational('1234.5678901'))) 33 | assert_equal("\\min(12.4, 9.5)", Min(12.4, 9.5)) 34 | assert_equal("\\min(6, 6.2)", Min(6, 6.2)) 35 | assert_equal("\\min(-98.7)", Min(-98.7)) 36 | assert_equal("\\min(7.1, 9)", Min(7.1, 9)) 37 | assert_equal("\\min(-21E-12, 0.00005)", Min(nsimplify(Rational('-21E-12')), Rational('0.00005')), symbolically=True) 38 | assert_equal("\\min(\\sqrt{3}, 0, 1)", Min(sqrt(3), 0, 1)) 39 | 40 | 41 | def test_min_fraction(): 42 | assert_equal("\\min(1/2, 1/4)", Min(Rational('1/2'), Rational('1/4'))) 43 | assert_equal("\\min(6/2, 3)", Min(Rational('6/2'), 3)) 44 | assert_equal("\\min(2/4, 1/2)", Min(Rational('2/4'), Rational('1/2'))) 45 | assert_equal("\\min(-12/5, 6.4)", Min(Rational('-12/5'), Rational('6.4'))) 46 | assert_equal("\\min(1/10)", Min(Rational('1/10'))) 47 | assert_equal("\\min(1.5, \\pi/2)", Min(Rational('1.5'), pi / 2, evaluate=False)) 48 | assert_equal("\\min(-4/3, -2/1, 0/9, -3)", Min(Rational('-4/3'), Rational('-2/1'), Rational('0/9'), -3)) 49 | 50 | 51 | def test_min_expr(): 52 | assert_equal("\\min((1+6)/3, 7)", Min(Rational(1 + 6, 3), 7)) 53 | assert_equal("\\min(58*9)", Min(58 * 9)) 54 | assert_equal("\\min(1+6/3, -5)", Min(1 + Rational('6/3'), -5)) 55 | assert_equal("\\min(7*4/5, 092) * 2", Min(7 * 4 / 5, 92) * 2) 56 | assert_equal("38+\\min(13, 15-2.3)", 38 + Min(13, 15 - Rational('2.3'))) 57 | assert_equal("\\sqrt{\\min(99.9999999999999, 100)}", sqrt(Min(Rational('99.9999999999999'), 100))) 58 | assert_equal("\\min(274/(5+2), \\exp(12.4), 1.4E2)", Min(Rational(274, 5 + 2), exp(Rational('12.4')), Rational('1.4E2'))) 59 | 60 | 61 | def test_min_symbol(): 62 | assert_equal("\\min(x)", Min(x), symbolically=True) 63 | assert_equal("\\min(x, y)", Min(x, y), symbolically=True) 64 | assert_equal("\\min(y, x)", Min(y, x), symbolically=True) 65 | assert_equal("\\min(x+y, y+x)", Min(x + y, y + x), symbolically=True) 66 | assert_equal("\\min(9x/4, z)", Min(9 * x / 4, z), symbolically=True) 67 | assert_equal("\\min(y\\pi, 9)", Min(y * pi, 9), symbolically=True) 68 | assert_equal("\\min(2y-y, y + 1)", Min(2 * y - y, y + 1), symbolically=True) 69 | assert_equal("\\min(z, y, x)", Min(z, y, x), symbolically=True) 70 | 71 | 72 | def test_min_multiarg(): 73 | assert_equal("\\min(1,2)", Min(1, 2)) 74 | assert_equal("\\min(9,876,543)", Min(9, 876, 543)) 75 | assert_equal("\\min(x, y,z)", Min(x, y, z), symbolically=True) 76 | assert_equal("\\min(5.8,7.4, 2.2,-10)", Min(Rational('5.8'), Rational('7.4'), Rational('2.2'), -10)) 77 | assert_equal("\\min(\\pi,12E2,84,\\sqrt{5},12/5)", Min(pi, Rational('12E2'), 84, sqrt(5), Rational('12/5'))) 78 | assert_equal("\\min(823,51)", Min(823, 51)) 79 | assert_equal("\\min(72*4,23, 9)", Min(72 * 4, 23, 9)) 80 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/mod_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import Symbol, Rational, Mod, sqrt, nsimplify, pi, GoldenRatio 4 | from sympy.physics.units import hbar 5 | 6 | x = Symbol('x', real=True) 7 | y = Symbol('y', real=True) 8 | 9 | 10 | def test_mod_usual(): 11 | assert_equal("128\\mod 3", Mod(128, 3)) 12 | assert_equal("7\\mod 128", Mod(7, 128)) 13 | assert_equal("5\\mod 10", Mod(5, 10)) 14 | assert_equal("5\\mod 5", Mod(5, 5)) 15 | assert_equal("3\\mod 2", Mod(3, 2)) 16 | assert_equal("0 \\mod 6", Mod(0, 6)) 17 | assert_equal("6109\\mod 28", Mod(6109, 28)) 18 | assert_equal("4000000000\\mod 28791", Mod(4000000000, 28791)) 19 | assert_equal("128*10^300\\mod 876123", Mod(Rational('128E300'), 876123)) 20 | assert_equal("876,123\\mod 128E300)", Mod(876123, Rational('128E300'))) 21 | 22 | 23 | def test_mod_negative(): 24 | assert_equal("-1\\mod 2", Mod(-1, 2)) 25 | assert_equal("-3\\mod 3", Mod(-3, 3)) 26 | assert_equal("-12\\mod -12", Mod(-12, -12)) 27 | assert_equal("-128\\mod 4", Mod(-128, 4)) 28 | assert_equal("9\\mod -213", Mod(9, -213)) 29 | assert_equal("123123\\mod -541", Mod(123123, -541)) 30 | assert_equal("-123123\\mod 541", Mod(-123123, 541)) 31 | assert_equal("-97E34\\mod 7", Mod(Rational('-97E34'), 7)) 32 | 33 | 34 | def test_mod_fraction(): 35 | assert_equal("1/2\\mod 3", Mod(Rational(1, 2), 3)) 36 | assert_equal("6/2\\mod 3", Mod(Rational(6, 2), 3)) 37 | assert_equal("-14/2\\mod 5", Mod(Rational(-14, 2), 5)) 38 | assert_equal("123\\mod (42/6)", Mod(123, Rational(42, 6))) 39 | assert_equal("431\\mod (2/123)", Mod(431, Rational(2, 123))) 40 | assert_equal("5/5\\mod (5/5)", Mod(Rational(5, 5), Rational(5, 5))) 41 | assert_equal("849/-21\\mod (092/2)", Mod(Rational(849, -21), Rational(92, 2))) 42 | assert_equal("13*10^9\\mod (21/-2)", Mod(13E9, Rational(21, -2))) 43 | 44 | 45 | def test_mod_float(): 46 | assert_equal("0.41\\mod 2", Mod(Rational('0.41'), 2)) 47 | assert_equal("143E-13\\mod 21", Mod(Rational('143E-13'), 21)) 48 | assert_equal("-9.80665\\mod 9.80665", Mod(-9.80665, 9.80665)) 49 | assert_equal("0.0000923423\\mod -8341.234802909", nsimplify(Mod(0.0000923423, -8341.234802909))) 50 | assert_equal("\\sqrt{5}\\mod \\sqrt{2}", Mod(sqrt(5), sqrt(2))) 51 | assert_equal("987\\mod \\pi", Mod(987, pi)) 52 | assert_equal("\\pi\\mod ((1+\\sqrt{5})/2)", Mod(pi, nsimplify(GoldenRatio)), symbolically=True) 53 | assert_equal("1234\\mod 1E-29", Mod(1234, Rational('1E-29'), evaluate=False)) 54 | 55 | 56 | def test_mod_expr(): 57 | assert_equal("1+1\\mod 2", 1 + Mod(1, 2)) 58 | assert_equal("876123\\mod 128\\times 10^300", Mod(876123, 128) * 1E300) 59 | assert_equal("141\\mod 9/3", Rational(Mod(141, 9) / 3)) 60 | assert_equal("872 / (12\\mod 9 * 4) * 2", Rational(2 * 872, (Mod(12, 9) * 4))) 61 | assert_equal("1E-32 * (1E29\\mod 74)", Rational('1E-32') * Mod(Rational('1E29'), 74)) 62 | assert_equal("299,792,458\\mod 9.81", Mod(299792458, Rational('9.81'))) 63 | 64 | 65 | def test_mod_symbol(): 66 | assert_equal("x\\mod y", Mod(x, y)) 67 | assert_equal("2x\\mod y", Mod(2 * x, y)) 68 | assert_equal("y + 3\\mod 2 / 4", y + Rational(Mod(3, 2), 4), symbolically=True) 69 | assert_equal("0.5x * 2 + \\sqrt{x}\\mod 8y", 0.5 * x * 2 + Mod(sqrt(x), 8 * y), symbolically=True) 70 | assert_equal("6.673E-11 * ((8.85418782E-12\\mod 9x) + 4) / 2y", Rational('6.673E-11') * (Mod(Rational('8.85418782E-12'), 9 * x) + 4) / (2 * y), symbolically=True) 71 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/overline_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import sin, Symbol 4 | 5 | x = Symbol('x', real=True) 6 | 7 | 8 | def test_overline(): 9 | assert_equal("\\frac{\\sin(x)}{\\overline{x}_n}", sin(x) / Symbol('xbar_n', real=True)) 10 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/pi_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal, _Mul, _Pow 2 | import pytest 3 | from sympy import pi, Symbol, acos, cos 4 | 5 | 6 | def test_pi_frac(): 7 | assert_equal("\\frac{\\pi}{3}", _Mul(pi, _Pow(3, -1))) 8 | 9 | 10 | def test_pi_nested(): 11 | assert_equal("\\arccos{\\cos{\\frac{\\pi}{3}}}", acos(cos(_Mul(pi, _Pow(3, -1)), evaluate=False), evaluate=False)) 12 | 13 | 14 | def test_pi_arccos(): 15 | assert_equal("\\arccos{-1}", pi, symbolically=True) 16 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/trig_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | from sympy import asinh, Symbol 4 | 5 | # x = Symbol('x', real=True); 6 | 7 | # latex = "\\sinh(x)" 8 | # math = process_sympy(latex) 9 | # print("latex: %s to math: %s" %(latex,math)) 10 | # 11 | # latex = "\\arcsinh(x)" 12 | # math = process_sympy(latex) 13 | # print("latex: %s to math: %s" %(latex,math)) 14 | # 15 | # latex = "\\arsinh(x)" 16 | # math = process_sympy(latex) 17 | # print("latex: %s to math: %s" %(latex,math)) 18 | 19 | 20 | def test_arcsinh(): 21 | assert_equal("\\operatorname{arcsinh}\\left(1\\right)", asinh(1, evaluate=False)) 22 | -------------------------------------------------------------------------------- /utils/latex2sympy/tests/variable_test.py: -------------------------------------------------------------------------------- 1 | from .context import assert_equal 2 | import pytest 3 | import hashlib 4 | from sympy import UnevaluatedExpr, Symbol, Mul, Pow, Max, Min, gcd, lcm, floor, ceiling 5 | 6 | x = Symbol('x', real=True) 7 | y = Symbol('y', real=True) 8 | 9 | 10 | def test_variable_letter(): 11 | assert_equal("\\variable{x}", Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True)) 12 | 13 | 14 | def test_variable_digit(): 15 | assert_equal("\\variable{1}", Symbol('1' + hashlib.md5('1'.encode()).hexdigest(), real=True)) 16 | 17 | 18 | def test_variable_letter_subscript(): 19 | assert_equal("\\variable{x_y}", Symbol('x_y' + hashlib.md5('x_y'.encode()).hexdigest(), real=True)) 20 | 21 | 22 | def test_variable_letter_comma_subscript(): 23 | assert_equal("\\variable{x_{i,j}}", Symbol('x_{i,j}' + hashlib.md5('x_{i,j}'.encode()).hexdigest(), real=True)) 24 | 25 | 26 | def test_variable_digit_subscript(): 27 | assert_equal("\\variable{x_1}", Symbol('x_1' + hashlib.md5('x_1'.encode()).hexdigest(), real=True)) 28 | 29 | 30 | def test_variable_after_subscript_required(): 31 | with pytest.raises(Exception): 32 | assert_equal("\\variable{x_}", Symbol('x_' + hashlib.md5('x_'.encode()).hexdigest(), real=True)) 33 | 34 | 35 | def test_variable_before_subscript_required(): 36 | with pytest.raises(Exception): 37 | assert_equal("\\variable{_x}", Symbol('_x' + hashlib.md5('_x'.encode()).hexdigest(), real=True)) 38 | 39 | 40 | def test_variable_bad_name(): 41 | with pytest.raises(Exception): 42 | assert_equal("\\variable{\\sin xy}", None) 43 | 44 | 45 | def test_variable_in_expr(): 46 | assert_equal("4\\cdot\\variable{x}", 4 * Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True)) 47 | 48 | 49 | def test_variable_greek_letter(): 50 | assert_equal("\\variable{\\alpha }\\alpha", Symbol('\\alpha ' + hashlib.md5('\\alpha '.encode()).hexdigest(), real=True) * Symbol('alpha', real=True)) 51 | 52 | 53 | def test_variable_greek_letter_subscript(): 54 | assert_equal("\\variable{\\alpha _{\\beta }}\\alpha ", Symbol('\\alpha _{\\beta }' + hashlib.md5('\\alpha _{\\beta }'.encode()).hexdigest(), real=True) * Symbol('alpha', real=True)) 55 | 56 | 57 | def test_variable_bad_unbraced_long_subscript(): 58 | with pytest.raises(Exception): 59 | assert_equal("\\variable{x_yz}", None) 60 | 61 | 62 | def test_variable_bad_unbraced_long_complex_subscript(): 63 | with pytest.raises(Exception): 64 | assert_equal("\\variable{x\\beta 10_y\\alpha 20}", None) 65 | 66 | 67 | def test_variable_braced_subscript(): 68 | assert_equal("\\variable{x\\beta 10_{y\\alpha 20}}", Symbol('x\\beta 10_{y\\alpha 20}' + hashlib.md5('x\\beta 10_{y\\alpha 20}'.encode()).hexdigest(), real=True)) 69 | 70 | 71 | def test_variable_complex_expr(): 72 | assert_equal("4\\cdot\\variable{value1}\\frac{\\variable{value_2}}{\\variable{a}}\\cdot x^2", 4 * Symbol('value1' + hashlib.md5('value1'.encode()).hexdigest(), real=True) * Symbol('value_2' + hashlib.md5('value_2'.encode()).hexdigest(), real=True) / Symbol('a' + hashlib.md5('a'.encode()).hexdigest(), real=True) * x**2) 73 | 74 | 75 | def test_variable_dollars(): 76 | assert_equal("\\$\\variable{x}", Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True)) 77 | 78 | 79 | def test_variable_percentage(): 80 | assert_equal("\\variable{x}\\%", Mul(Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True), Pow(100, -1, evaluate=False), evaluate=False)) 81 | 82 | 83 | def test_variable_single_arg_func(): 84 | assert_equal("\\floor(\\variable{x})", floor(Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True))) 85 | assert_equal("\\ceil(\\variable{x})", ceiling(Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True))) 86 | 87 | 88 | def test_variable_multi_arg_func(): 89 | assert_equal("\\gcd(\\variable{x}, \\variable{y})", UnevaluatedExpr(gcd(Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True), Symbol('y' + hashlib.md5('y'.encode()).hexdigest(), real=True)))) 90 | assert_equal("\\lcm(\\variable{x}, \\variable{y})", UnevaluatedExpr(lcm(Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True), Symbol('y' + hashlib.md5('y'.encode()).hexdigest(), real=True)))) 91 | assert_equal("\\max(\\variable{x}, \\variable{y})", Max(Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True), Symbol('y' + hashlib.md5('y'.encode()).hexdigest(), real=True), evaluate=False)) 92 | assert_equal("\\min(\\variable{x}, \\variable{y})", Min(Symbol('x' + hashlib.md5('x'.encode()).hexdigest(), real=True), Symbol('y' + hashlib.md5('y'.encode()).hexdigest(), real=True), evaluate=False)) 93 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from utils.parser import extract_answer 2 | from utils.grader import math_equal 3 | 4 | DATASET_KEYS = { 5 | 'openai/gsm8k': {'question': 'question', 'answer': 'answer'}, 6 | 'hendrycks/competition_math': {'question': 'problem', 'answer': 'solution'}, 7 | 'datasets/converted_aime_dataset': {'question': 'problem', 'answer': 'solution'}, 8 | 'di-zhang-fdu/MATH500': {'question': 'problem', 'answer': 'solution'}, 9 | 'datasets/compression_dataset': {'question': 'problem', 'answer': 'solution'}, 10 | } 11 | 12 | RESPONSE_EXTRACTOR = { 13 | 'openai/gsm8k': lambda x: extract_answer(x, data_name='gsm8k'), 14 | 'hendrycks/competition_math': lambda x: extract_answer(x, data_name='math'), 15 | 'di-zhang-fdu/MATH500': lambda x: extract_answer(x, data_name='math'), 16 | 'datasets/compression_dataset': lambda x: extract_answer(x, data_name='math'), 17 | 'datasets/converted_aime_dataset': lambda x: extract_answer(x, data_name='math') 18 | } 19 | 20 | RESPONSE_COMPARATOR = { 21 | 'openai/gsm8k': lambda x, y: math_equal(x, y, timeout=True), 22 | 'hendrycks/competition_math': lambda x, y: math_equal(x, y, timeout=True), 23 | 'di-zhang-fdu/MATH500': lambda x, y: math_equal(x, y, timeout=True), 24 | 'datasets/compression_dataset': lambda x, y: math_equal(x, y, timeout=True), 25 | 'datasets/converted_aime_dataset': lambda x, y: math_equal(x, y, timeout=True) 26 | } 27 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.5.4 --------------------------------------------------------------------------------