├── .github └── workflows │ ├── publish.yml │ └── smoketest.yml ├── .gitignore ├── LICENSE ├── NOTICES.md ├── README.md ├── docs └── static │ ├── architecture.png │ ├── image-1.png │ ├── image-10.png │ ├── image-11.png │ ├── image-2.png │ ├── image-3.png │ ├── image-4.png │ ├── image-5.png │ ├── image-6.png │ ├── image-7.png │ ├── image-8.png │ ├── image-9.png │ └── image.png ├── higgsfield ├── __init__.py ├── checkpoint │ ├── __init__.py │ ├── fsdp_checkpoint.py │ └── fsdp_utils.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── openai.py ├── experiment.py ├── internal │ ├── __init__.py │ ├── cfg.py │ ├── ci │ │ ├── cli.py │ │ └── setup.py │ ├── cli.py │ ├── experiment │ │ ├── ast_parser.py │ │ ├── builder.py │ │ ├── decorator.py │ │ └── params.py │ ├── init.py │ ├── launch.py │ ├── main.py │ └── util.py ├── llama │ ├── __init__.py │ ├── llama.py │ └── llama_utils.py ├── loaders │ ├── __init__.py │ └── llama_loader.py ├── mistral │ ├── __init__.py │ ├── mistral.py │ ├── mistral_loader.py │ └── mistral_utils.py ├── path.py ├── rl │ ├── README.md │ └── rl_adventure_2 │ │ ├── 1.actor-critic.ipynb │ │ ├── 2.gae.ipynb │ │ ├── 3.ppo.ipynb │ │ ├── 4.acer.ipynb │ │ ├── 5.ddpg.ipynb │ │ ├── 6.td3.ipynb │ │ ├── 7.soft actor-critic.ipynb │ │ ├── 8.gail.ipynb │ │ ├── 9.her.ipynb │ │ ├── README.md │ │ └── common │ │ ├── __init__.py │ │ └── multiprocessing_env.py ├── static │ ├── project │ │ ├── .gitignore │ │ ├── Dockerfile │ │ ├── env │ │ ├── requirements.txt │ │ └── src │ │ │ ├── alpaca_bf16.py │ │ │ ├── alpaca_fp16.py │ │ │ └── dataset.py │ └── templates │ │ ├── README_md.j2 │ │ ├── config_py.j2 │ │ ├── deploy_action.j2 │ │ ├── experiment_action.j2 │ │ └── kill_action.j2 ├── training │ ├── __init__.py │ ├── grads.py │ └── scaler.py └── utils │ ├── __init__.py │ └── flush.py ├── poetry.lock ├── pyproject.toml ├── setup.md ├── tutorial.md └── tutorials ├── README.md ├── chatgpt.ipynb ├── prompt_completion.ipynb └── text_format.ipynb /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | on: 3 | push: 4 | tags: 5 | - 'v*.*.*' 6 | workflow_dispatch: 7 | inputs: 8 | tag: 9 | description: 'Tag' 10 | required: true 11 | default: 'v0.0.0' 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - name: Build and publish to pypi 19 | uses: JRubics/poetry-publish@v1.17 20 | with: 21 | pypi_token: ${{ secrets.PYPI_TOKEN }} 22 | -------------------------------------------------------------------------------- /.github/workflows/smoketest.yml: -------------------------------------------------------------------------------- 1 | name: Basic Init Smoketest 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | 9 | jobs: 10 | smoketest: 11 | name: Smoketest 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v3 16 | - name: Setup Python 3.8 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.8' 20 | - name: Install higgsfield 21 | run: | 22 | git_hash=$(git rev-parse --short "$GITHUB_SHA") 23 | python -m pip install git+https://github.com/higgsfield/higgsfield.git@$git_hash 24 | - name: Run smoketest 25 | run: | 26 | mkdir -p /tmp/smoketest && cd /tmp/smoketest 27 | 28 | higgsfield init some_project 29 | cd some_project 30 | 31 | git init 32 | git remote add origin git@github.com:user/project.git 33 | 34 | higgsfield build-experiments 35 | 36 | # TODO: compare the generated files. 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | dist/ 11 | build/ 12 | *.egg-info/ 13 | 14 | # Poetry 15 | .poetry/ 16 | venv/ 17 | .idea/ 18 | .vscode/ 19 | 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICES.md: -------------------------------------------------------------------------------- 1 | # Third-Party Licenses 2 | 3 | ### python 4 | - **License:** [Python Software Foundation License](https://docs.python.org/3.8/license.html) 5 | - **Version:** 3.8.13 6 | - **Link to Source:** [https://www.python.org/] 7 | 8 | ### click 9 | - **License:** [BSD 3-Clause License](https://opensource.org/license/bsd-3-clause/) 10 | - **Version:** 8.1.7 11 | - **Link to Source:** [https://github.com/pallets/click] 12 | 13 | ### pyyaml 14 | - **License:** [MIT License](https://opensource.org/licenses/MIT) 15 | - **Version:** 6.0.1 16 | - **Link to Source:** [https://github.com/yaml/pyyaml] 17 | 18 | ### asyncer 19 | - **License:** [MIT License](https://opensource.org/licenses/MIT) 20 | - **Version:** 0.0.2 21 | - **Link to Source:** [https://github.com/tiangolo/asyncer] 22 | 23 | ### jinja2 24 | - **License:** [BSD 3-Clause License](https://opensource.org/licenses/BSD-3-Clause) 25 | - **Version:** 3.1.2 26 | - **Link to Source:** [https://github.com/pallets/jinja] 27 | 28 | ### python-dotenv 29 | - **License:** [BSD 3-Clause License](https://opensource.org/licenses/BSD-3-Clause) 30 | - **Version:** 1.0.0 31 | - **Link to Source:** [https://github.com/theskumar/python-dotenv] 32 | 33 | ### cryptography 34 | - **License:** [BSD 3-Clause License](https://opensource.org/licenses/BSD-3-Clause) 35 | - **Version:** 41.0.4 36 | - **Link to Source:** [https://github.com/pyca/cryptography] 37 | 38 | ### asyncssh 39 | - **License:** [Eclipse Public License 2.0](https://www.eclipse.org/legal/epl-2.0/) 40 | - **Version:** 2.14.0 41 | - **Link to Source:** [https://github.com/ronf/asyncssh] 42 | 43 | ### bcrypt 44 | - **License:** [Apache License 2.0](https://opensource.org/licenses/Apache-2.0) 45 | - **Version:** 3.1.3 46 | - **Link to Source:** [https://github.com/pyca/bcrypt/] 47 | 48 | ### libnacl 49 | - **License:** [Apache License 2.0](https://opensource.org/licenses/Apache-2.0) 50 | - **Version:** 1.4.2 51 | - **Link to Source:** [https://github.com/saltstack/libnacl] 52 | 53 | ### pyopenssl 54 | - **License:** [Apache License 2.0](https://opensource.org/licenses/Apache-2.0) 55 | - **Version:** 23.2.0 56 | - **Link to Source:** [https://github.com/pyca/pyopenssl] 57 | 58 | ### libsodium 59 | - **License:** [MIT License](https://opensource.org/licenses/MIT) 60 | - **Version:** 2.6.1 61 | - **Link to Source:** [https://github.com/libsodiumproject/sodium] 62 | 63 | ### invoker 64 | - **License:** [Apache License 2.0](https://opensource.org/licenses/Apache-2.0) 65 | - **Version:** latest 66 | - **Link to Source:** [https://github.com/higgsfield-ai/invoker] 67 | 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # higgsfield - multi node training without crying 2 | 3 | 4 | Higgsfield is an open-source, fault-tolerant, highly scalable GPU orchestration, and a machine learning framework designed for training models with billions to trillions of parameters, such as Large Language Models (LLMs). 5 | 6 | [![PyPI version](https://badge.fury.io/py/higgsfield.svg)](https://badge.fury.io/py/higgsfield) 7 | 8 | ![architecture](https://raw.githubusercontent.com/higgsfield/higgsfield/main/docs/static/architecture.png) 9 | 10 | Higgsfield serves as a GPU workload manager and machine learning framework with five primary functions: 11 | 12 | 1. Allocating exclusive and non-exclusive access to compute resources (nodes) to users for their training tasks. 13 | 2. Supporting ZeRO-3 deepspeed API and fully sharded data parallel API of PyTorch, enabling efficient sharding for trillion-parameter models. 14 | 3. Offering a framework for initiating, executing, and monitoring the training of large neural networks on allocated nodes. 15 | 4. Managing resource contention by maintaining a queue for running experiments. 16 | 5. Facilitating continuous integration of machine learning development through seamless integration with GitHub and GitHub Actions. 17 | Higgsfield streamlines the process of training massive models and empowers developers with a versatile and robust toolset. 18 | ## Install 19 | 20 | ```bash 21 | $ pip install higgsfield==0.0.3 22 | ``` 23 | 24 | 25 | 26 | ## Train example 27 | 28 | That's all you have to do in order to train LLaMa in a distributed setting: 29 | 30 | ```python 31 | from higgsfield.llama import Llama70b 32 | from higgsfield.loaders import LlamaLoader 33 | from higgsfield.experiment import experiment 34 | 35 | import torch.optim as optim 36 | from alpaca import get_alpaca_data 37 | 38 | @experiment("alpaca") 39 | def train(params): 40 | model = Llama70b(zero_stage=3, fast_attn=False, precision="bf16") 41 | 42 | optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.0) 43 | 44 | dataset = get_alpaca_data(split="train") 45 | train_loader = LlamaLoader(dataset, max_words=2048) 46 | 47 | for batch in train_loader: 48 | optimizer.zero_grad() 49 | loss = model(batch) 50 | loss.backward() 51 | optimizer.step() 52 | 53 | model.push_to_hub('alpaca-70b') 54 | ``` 55 | 56 | ## How it's all done? 57 | 58 | 1. We install all the required tools in your server (Docker, your project's deploy keys, higgsfield binary). 59 | 2. Then we generate deploy & run workflows for your experiments. 60 | 3. As soon as it gets into Github, it will automatically deploy your code on your nodes. 61 | 4. Then you access your experiments' run UI through Github, which will launch experiments and save the checkpoints. 62 | 63 | ## Design 64 | 65 | We follow the standard pytorch workflow. Thus you can incorporate anything besides what we provide, `deepspeed`, `accelerate`, or just implement your custom `pytorch` sharding from scratch. 66 | 67 | **Enviroment hell** 68 | 69 | No more different versions of pytorch, nvidia drivers, data processing libraries. 70 | You can easily orchestrate experiments and their environments, document and track the specific versions and configurations of all dependencies to ensure reproducibility. 71 | 72 | **Config hell** 73 | 74 | No need to define [600 arguments for your experiment](https://github.com/huggingface/transformers/blob/aaccf1844eccbb90cc923378e3c37a6b143d03fb/src/transformers/training_args.py#L161). No more [yaml witchcraft](https://hydra.cc/). 75 | You can use whatever you want, whenever you want. We just introduce a simple interface to define your experiments. We have even taken it further, now you only need to design the way to interact. 76 | 77 | ## Compatibility 78 | 79 | **We need you to have nodes with:** 80 | 81 | - Ubuntu 82 | - SSH access 83 | - Non-root user with sudo privileges (no-password is required) 84 | 85 | **Clouds we have tested on:** 86 | 87 | - Azure 88 | - LambdaLabs 89 | - FluidStack 90 | 91 | Feel free to open an issue if you have any problems with other clouds. 92 | 93 | ## Getting started 94 | 95 | #### [Setup](./setup.md) 96 | 97 | Here you can find the quick start guide on how to setup your nodes and start training. 98 | 99 | - [Initialize the project](https://github.com/higgsfield/higgsfield/blob/main/setup.md#initialize-the-project) 100 | - [Setup the environment](https://github.com/higgsfield/higgsfield/blob/main/setup.md#setup-the-environment) 101 | - [Setup git](https://github.com/higgsfield/higgsfield/blob/main/setup.md#setup-git) 102 | - [Time to setup your nodes!](https://github.com/higgsfield/higgsfield/blob/main/setup.md#time-to-setup-your-nodes) 103 | - [Run your very first experiment](https://github.com/higgsfield/higgsfield/blob/main/setup.md#run-your-very-first-experiment) 104 | - [Fasten your seatbelt, it's time to deploy!](https://github.com/higgsfield/higgsfield/blob/main/setup.md#fasten-your-seatbelt-its-time-to-deploy) 105 | 106 | #### [Tutorial](./tutorial.md) 107 | 108 | API for common tasks in Large Language Models training. 109 | 110 | - [Working with distributed model](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#working-with-distributed-model) 111 | - [Preparing Data](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#preparing-data) 112 | - [Optimizing the Model Parameters](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#optimizing-the-model-parameters) 113 | - [Saving Model](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#saving-model) 114 | - [Training stabilization techniques](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#training-stabilization-techniques) 115 | - [Monitoring](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#monitoring) 116 | 117 | | Platform | Purpose | Estimated Response Time | Support Level | 118 | | ----------------------------------------------------------------- | ----------------------------------------------------------------- | ----------------------- | --------------- | 119 | | [Github Issues](https://github.com/higgsfield/higgsfield/issues/) | Bug reports, feature requests, install issues, usage issues, etc. | < 1 day | Higgsfield Team | 120 | | [Twitter](https://twitter.com/higgsfield_ai/) | For staying up-to-date on new features. | Daily | Higgsfield Team | 121 | | [Website](https://higgsfield.ai/) | Discussion, news. | < 2 days | Higgsfield Team | 122 | 123 | -------------------------------------------------------------------------------- /docs/static/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/architecture.png -------------------------------------------------------------------------------- /docs/static/image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-1.png -------------------------------------------------------------------------------- /docs/static/image-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-10.png -------------------------------------------------------------------------------- /docs/static/image-11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-11.png -------------------------------------------------------------------------------- /docs/static/image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-2.png -------------------------------------------------------------------------------- /docs/static/image-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-3.png -------------------------------------------------------------------------------- /docs/static/image-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-4.png -------------------------------------------------------------------------------- /docs/static/image-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-5.png -------------------------------------------------------------------------------- /docs/static/image-6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-6.png -------------------------------------------------------------------------------- /docs/static/image-7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-7.png -------------------------------------------------------------------------------- /docs/static/image-8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-8.png -------------------------------------------------------------------------------- /docs/static/image-9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image-9.png -------------------------------------------------------------------------------- /docs/static/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/docs/static/image.png -------------------------------------------------------------------------------- /higgsfield/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/higgsfield/__init__.py -------------------------------------------------------------------------------- /higgsfield/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .fsdp_checkpoint import Checkpoint 2 | from .fsdp_utils import fsdp_model_state_dict_rank0 -------------------------------------------------------------------------------- /higgsfield/checkpoint/fsdp_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import time 6 | import json 7 | import torch 8 | import torch.distributed as dist 9 | 10 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 11 | FullyShardedDataParallel as FSDP, 12 | ) 13 | 14 | from optimum.bettertransformer import BetterTransformer 15 | 16 | from .fsdp_utils import fsdp_model_state_dict_rank0 17 | 18 | 19 | DEFAULT_CHECKPOINT_PATH = Path.home() / \ 20 | ".cache" / \ 21 | "higgsfield" / \ 22 | os.environ["PROJECT_NAME"] / \ 23 | "experiments" / \ 24 | os.environ["EXPERIMENT_NAME"] / \ 25 | os.environ["RUN_NAME"] 26 | 27 | class Checkpoint: 28 | ''' 29 | Saving checkpoint to: 30 | ~/.cache/higgsfield/{project_name}/experiments/{experiment_name}/{run_name} 31 | ''' 32 | def __init__( 33 | self, 34 | model, 35 | optimizer=None, 36 | lr_scheduler=None, 37 | scaler=None, 38 | ): 39 | ''' 40 | model: Higgsfield.model 41 | ''' 42 | if os.environ["PROJECT_NAME"] and os.environ["EXPERIMENT_NAME"] and os.environ["RUN_NAME"]: 43 | save_dir = DEFAULT_CHECKPOINT_PATH 44 | else: 45 | raise NotImplementedError("Support single GPU/process not implemeted yet") 46 | 47 | 48 | self.save_dir = save_dir 49 | self.model = model 50 | self.optimizer = optimizer 51 | self.lr_scheduler = lr_scheduler 52 | self.scaler = scaler 53 | 54 | def save(self, epoch, steps=0, metadata={}): 55 | 56 | save_path = Path(self.save_dir) / f"epoch_{epoch}_steps_{steps}" 57 | save_path.mkdir(exist_ok=True, parents=True) 58 | 59 | model_path = save_path / "model.pt" 60 | optimizer_path = save_path / "optimizer.pt" 61 | 62 | t0 = time.perf_counter() 63 | save_distributed_model_rank0(model_path, self.model) 64 | 65 | if self.optimizer: 66 | save_distributed_optimizer_rank0(optimizer_path, self.model, self.optimizer) 67 | t1 = time.perf_counter() 68 | 69 | if int(os.environ["LOCAL_RANK"]) == 0: 70 | print(f"State checkpoint of {steps} steps saved to {save_path}") 71 | print(f"Checkpoint Time = {t1-t0:.4f}\n") 72 | 73 | if self.lr_scheduler: 74 | lr_scheduler_path = save_path / "lr_scheduler.pt" 75 | torch.save(self.lr_scheduler.state_dict(), lr_scheduler_path) 76 | 77 | if self.scaler: 78 | scaler_path = save_path / "scaler.pt" 79 | torch.save(self.grad_scaler.state_dict(), scaler_path) 80 | 81 | metadata_path = save_path / "metadata.json" 82 | metadata["epoch"] = epoch 83 | metadata["steps"] = steps 84 | 85 | with open(metadata_path, "w+") as jsonFile: 86 | json.dump(metadata, jsonFile) 87 | 88 | def save_distributed_model_rank0(checkpoint_path, model): 89 | ''' 90 | model: FSDP 91 | ''' 92 | rank = dist.get_rank() 93 | 94 | cpu_state = fsdp_model_state_dict_rank0(model) 95 | 96 | if rank == 0: 97 | torch.save(cpu_state, checkpoint_path) 98 | 99 | def save_distributed_optimizer_rank0(checkpoint_path, model, optimizer): 100 | ''' 101 | model: FSDP 102 | optimizer: torch.optim 103 | ''' 104 | rank = dist.get_rank() 105 | 106 | optim_state = FSDP.full_optim_state_dict(model, optimizer) 107 | 108 | if rank == 0: 109 | torch.save(optim_state, checkpoint_path) -------------------------------------------------------------------------------- /higgsfield/checkpoint/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 2 | FullyShardedDataParallel as FSDP, 3 | ) 4 | 5 | from torch.distributed.fsdp.api import ( 6 | FullStateDictConfig, 7 | StateDictType, 8 | ) 9 | 10 | fullstate_save_policy = FullStateDictConfig( 11 | offload_to_cpu=True, 12 | rank0_only=True, 13 | ) 14 | 15 | def fsdp_model_state_dict_rank0(model): 16 | with FSDP.state_dict_type( 17 | model, StateDictType.FULL_STATE_DICT, fullstate_save_policy 18 | ): 19 | cpu_state = model.state_dict() 20 | 21 | return cpu_state -------------------------------------------------------------------------------- /higgsfield/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ( 2 | CompletionDataset, 3 | LMDataset, 4 | TorchMultiTurnDataset, 5 | TorchCompletionDataset, 6 | TorchLMDataset, 7 | ) -------------------------------------------------------------------------------- /higgsfield/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | class CompletionDataset(Dataset): 6 | ''' 7 | def __getitem__(self, idx): 8 | ... 9 | return { 10 | "prompt": prompt, 11 | "completion": completion, 12 | } 13 | ''' 14 | pass 15 | 16 | class LMDataset(Dataset): 17 | ''' 18 | def __getitem__(self, idx): 19 | ... 20 | return "whatever sequence you want to return as a string" 21 | ''' 22 | pass 23 | 24 | class TorchMultiTurnDataset(Dataset): 25 | def __init__(self, dataset, tokenizer, max_sequence_length): 26 | self.dataset = dataset 27 | self.tokenizer = tokenizer 28 | self.max_sequence_length = max_sequence_length 29 | 30 | def __len__(self): 31 | return len(self.dataset) 32 | 33 | def __getitem__(self, idx): 34 | items = self.dataset[idx] 35 | 36 | IGNORE_INDEX = -100 37 | 38 | multi_labels = [] 39 | multi_example = [] 40 | for item in items: 41 | prompt = item["prompt"] 42 | completion = item["completion"] 43 | 44 | example = prompt + completion 45 | prompt = torch.tensor( 46 | self.tokenizer.encode(prompt), dtype=torch.int64 47 | ) 48 | example = self.tokenizer.encode(example) 49 | example.append(self.tokenizer.eos_token_id) 50 | example = torch.tensor( 51 | example, dtype=torch.int64 52 | ) 53 | 54 | labels = copy.deepcopy(example) 55 | labels[: len(prompt)] = -1 56 | 57 | multi_example.append(example) 58 | multi_labels.append(labels) 59 | 60 | example = torch.cat(multi_example) 61 | labels = torch.cat(multi_labels) 62 | 63 | padding = self.max_sequence_length - example.shape[0] 64 | if padding > 0: 65 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 66 | labels = torch.cat((labels, torch.zeros(padding, dtype=torch.int64) - 1)) 67 | elif padding < 0: 68 | example = example[: self.max_sequence_length] 69 | labels = labels[: self.max_sequence_length] 70 | 71 | label_mask = labels.ge(0) 72 | labels[~label_mask] = IGNORE_INDEX 73 | 74 | example_mask = example.ge(0) 75 | example[~example_mask] = 0 76 | example_mask = example_mask.float() 77 | 78 | return { 79 | "input_ids": example, 80 | "labels": labels, 81 | "attention_mask": example_mask, 82 | } 83 | 84 | class TorchCompletionDataset(Dataset): 85 | def __init__(self, dataset, tokenizer, max_sequence_length): 86 | self.dataset = dataset 87 | self.tokenizer = tokenizer 88 | self.max_sequence_length = max_sequence_length 89 | 90 | def __len__(self): 91 | return len(self.dataset) 92 | 93 | def __getitem__(self, idx): 94 | item = self.dataset[idx] 95 | 96 | prompt = item["prompt"] 97 | completion = item["completion"] 98 | 99 | IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss 100 | 101 | example = prompt + completion 102 | prompt = torch.tensor( 103 | self.tokenizer.encode(prompt), dtype=torch.int64 104 | ) 105 | example = self.tokenizer.encode(example) 106 | example.append(self.tokenizer.eos_token_id) 107 | example = torch.tensor( 108 | example, dtype=torch.int64 109 | ) 110 | padding = self.max_sequence_length - example.shape[0] 111 | if padding > 0: 112 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 113 | elif padding < 0: 114 | example = example[: self.max_sequence_length] 115 | 116 | labels = copy.deepcopy(example) 117 | labels[: len(prompt)] = -1 118 | label_mask = labels.ge(0) 119 | labels[~label_mask] = IGNORE_INDEX 120 | 121 | example_mask = example.ge(0) 122 | example[~example_mask] = 0 123 | example_mask = example_mask.float() 124 | 125 | return { 126 | "input_ids": example, 127 | "labels": labels, 128 | "attention_mask":example_mask, 129 | } 130 | 131 | class TorchLMDataset(Dataset): 132 | def __init__(self, dataset, tokenizer, max_sequence_length): 133 | self.dataset = dataset 134 | self.tokenizer = tokenizer 135 | self.max_sequence_length = max_sequence_length 136 | 137 | def __len__(self): 138 | return len(self.dataset) 139 | 140 | def __getitem__(self, idx): 141 | x = self.dataset[idx] 142 | 143 | IGNORE_INDEX = -100 144 | 145 | example = self.tokenizer.encode(x, add_special_tokens=False) 146 | example = torch.tensor( 147 | example, dtype=torch.int64 148 | ) 149 | 150 | padding = self.max_sequence_length - example.shape[0] 151 | if padding > 0: 152 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 153 | elif padding < 0: 154 | example = example[: self.max_sequence_length] 155 | 156 | labels = copy.deepcopy(example) 157 | example_mask = example.ge(0) 158 | label_mask = labels.ge(0) 159 | example[~example_mask] = 0 160 | labels[~label_mask] = IGNORE_INDEX 161 | example_mask = example_mask.float() 162 | label_mask = label_mask.float() 163 | 164 | return { 165 | "input_ids": example, 166 | "labels": labels, 167 | "attention_mask":example_mask, 168 | } 169 | 170 | -------------------------------------------------------------------------------- /higgsfield/dataset/openai.py: -------------------------------------------------------------------------------- 1 | from .dataset import CompletionDataset 2 | 3 | def chat_to_prompt(chat): 4 | joined = [] 5 | for message in chat: 6 | joined.append(f"###{message['role'].upper()}: {message['content']}") 7 | 8 | prompt = "\n".join(joined) 9 | prompt += "\n###ASSISTANT: " 10 | 11 | return prompt 12 | 13 | class ChatCompletionDataset(CompletionDataset): 14 | ''' OpenAI's api format: 15 | chats = [ 16 | [ 17 | {"role": "system", "content": "You are a helpful assistant."}, 18 | {"role": "user", "content": "Who won the world series in 2020?"}, 19 | {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, 20 | {"role": "user", "content": "Where was it played?"} 21 | ], 22 | ] 23 | ''' 24 | def __init__(self, chats, chat_to_prompt=chat_to_prompt): 25 | self.chat_to_prompt = chat_to_prompt 26 | 27 | self.chats = chats 28 | 29 | items = [] 30 | for chat in self.chats: 31 | current_chat = [] 32 | last_user = False 33 | 34 | for message in chat: 35 | if message["role"] == "system": 36 | current_chat.append(message) 37 | 38 | elif message["role"] == "user": 39 | last_user = True 40 | current_chat.append(message) 41 | 42 | elif message["role"] == "assistant": 43 | if last_user: 44 | items.append([ 45 | [c for c in current_chat], message["content"] 46 | ]) 47 | current_chat.append(message) 48 | 49 | self.items = items 50 | 51 | def __len__(self): 52 | return len(self.items) 53 | 54 | def __getitem__(self, idx): 55 | chat, completion = self.items[idx] 56 | prompt = self.chat_to_prompt(chat) 57 | 58 | return { 59 | "prompt": prompt, 60 | "completion": completion 61 | } -------------------------------------------------------------------------------- /higgsfield/experiment.py: -------------------------------------------------------------------------------- 1 | from .internal.experiment.decorator import experiment, param 2 | -------------------------------------------------------------------------------- /higgsfield/internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/higgsfield/internal/__init__.py -------------------------------------------------------------------------------- /higgsfield/internal/cfg.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from pathlib import Path 3 | from importlib.machinery import SourceFileLoader 4 | import os 5 | import dotenv 6 | 7 | import subprocess 8 | 9 | 10 | def get_key_from_path_or_key(key_or_path: Optional[str]) -> str: 11 | if key_or_path is None or key_or_path == "": 12 | raise ValueError("SSH_KEY in env is None") 13 | 14 | if ( 15 | key_or_path.startswith("-----BEGIN") 16 | or key_or_path.startswith("ssh-rsa") 17 | or key_or_path.startswith("NOTHING") 18 | ): 19 | return key_or_path 20 | 21 | path = Path(os.path.expanduser(key_or_path)).resolve() 22 | if path.exists(): 23 | return path.read_text() 24 | 25 | return key_or_path 26 | 27 | 28 | class AppConfig: 29 | name: str 30 | github_repo_url: Optional[str] = None 31 | hosts: List[str] 32 | user: str 33 | key: str 34 | port: int 35 | number_of_processes_per_node: int 36 | 37 | def __init__(self, **kwargs): 38 | self.__dict__.update(kwargs) 39 | @classmethod 40 | def from_path(cls, path: Path) -> "AppConfig": 41 | config_path = path / "src" / "config.py" 42 | if not config_path.exists(): 43 | raise ValueError(f"Config file {config_path} not found") 44 | 45 | if not (path / "env").exists(): 46 | raise ValueError(f"Env file {path/ 'env'} not found") 47 | dotenv.load_dotenv(path / "env", verbose=True, override=True) 48 | 49 | try: 50 | module = SourceFileLoader("module.name", str(config_path)).load_module() 51 | except Exception as e: 52 | raise ValueError( 53 | f"Config file {config_path} cannot be loaded since your file doesn't meet requirements" 54 | ) from e 55 | 56 | name = str(module.__dict__["NAME"]) 57 | github_repo_url: str = module.__dict__.get("GITHUB_REPO_URL", None) 58 | hosts = [host.strip() for host in module.__dict__["HOSTS"]] 59 | user = module.__dict__["HOSTS_USER"] 60 | 61 | if user == "root": 62 | raise ValueError("Please don't use root as the user") 63 | 64 | port = module.__dict__["HOSTS_PORT"] 65 | number_of_processes_per_node = module.__dict__["NUMBER_OF_PROCESSES_PER_NODE"] 66 | 67 | key = get_key_from_path_or_key(os.getenv("SSH_KEY")) 68 | 69 | return AppConfig( 70 | name=name, 71 | github_repo_url=github_repo_url, 72 | hosts=hosts, 73 | user=user, 74 | key=key, 75 | port=port, 76 | number_of_processes_per_node=number_of_processes_per_node, 77 | ) 78 | 79 | def get_git_origin_url(self, path) -> Optional[str]: 80 | if self.github_repo_url is not None: 81 | return self.github_repo_url 82 | try: 83 | # Run the Git command to get the remote origin URL 84 | result = subprocess.check_output( 85 | ["git", "config", "--get", "remote.origin.url"], 86 | cwd=path, 87 | universal_newlines=True, 88 | ) 89 | 90 | # Strip any leading/trailing whitespace from the result 91 | origin_url = result.strip() 92 | 93 | return origin_url 94 | except subprocess.CalledProcessError: 95 | return None 96 | 97 | def set_git_origin_url(self, path: Path): 98 | config_path = path / "src" / "config.py" 99 | # remove the GITHUB_REPO_URL line 100 | with open(config_path, "r") as f: 101 | lines = f.readlines() 102 | 103 | # Remove all occurrences of GITHUB_REPO_URL 104 | lines = [line for line in lines if "GITHUB_REPO_URL" not in line] 105 | 106 | with open(config_path, "w") as f: 107 | # Write back the modified lines 108 | f.writelines(lines) 109 | 110 | if lines[-1] != "\n": 111 | f.write("\n") 112 | 113 | # Add GITHUB_REPO_URL at the end without increasing new lines 114 | with open(config_path, "a") as f: 115 | f.write(f'GITHUB_REPO_URL = "{self.github_repo_url}"\n') 116 | 117 | def is_valid(self) -> Optional[str]: 118 | if self.github_repo_url is None: 119 | return "GITHUB_REPO_URL is None" 120 | if self.key is None: 121 | return "SSH_KEY in env is None" 122 | return 123 | -------------------------------------------------------------------------------- /higgsfield/internal/ci/cli.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | 4 | import click 5 | 6 | from higgsfield.internal.util import wd_path, parse_origin_link_or_else 7 | from higgsfield.internal.cfg import AppConfig 8 | from .setup import Setup 9 | 10 | from base64 import b64decode 11 | 12 | 13 | @click.command("get-hosts") 14 | def hosts(): 15 | wd = wd_path() 16 | app_config = AppConfig.from_path(wd) 17 | click.echo(",".join(app_config.hosts)) 18 | 19 | 20 | @click.command("get-nproc-per-node") 21 | def proc_per_node(): 22 | wd = wd_path() 23 | app_config = AppConfig.from_path(wd) 24 | click.echo(str(app_config.number_of_processes_per_node)) 25 | 26 | 27 | @click.command("get-ssh-details") 28 | def ssh_details(): 29 | wd = wd_path() 30 | app_config = AppConfig.from_path(wd) 31 | print( 32 | json.dumps( 33 | { 34 | "key": app_config.key, 35 | "user": app_config.user, 36 | "port": app_config.port, 37 | "hosts": ",".join(app_config.hosts), 38 | }, 39 | indent=2, 40 | ) 41 | ) 42 | 43 | 44 | @click.command("decode-secrets") 45 | @click.argument("env", type=str, required=True) 46 | def decode_secrets(env: str): 47 | env_path = wd_path() / "env" 48 | if env_path.exists(): 49 | raise ValueError("env file already exists") 50 | 51 | env_path.write_text(b64decode(env.encode()).decode()) 52 | 53 | 54 | 55 | 56 | @click.command("setup-nodes") 57 | @click.option('--invoker_tag', default="v0.0.1", help="Tag of the invoker binary to use") 58 | def setup_nodes(invoker_tag: str = "v0.0.1"): 59 | wd = wd_path() 60 | app_config = AppConfig.from_path(wd) 61 | 62 | project_path = wd 63 | 64 | origin_url = app_config.get_git_origin_url(project_path) 65 | 66 | if origin_url is None: 67 | raise ValueError("Have you pushed your project to github?") 68 | 69 | origin_url = parse_origin_link_or_else(origin_url) 70 | 71 | if origin_url is None: 72 | raise ValueError("Please use ssh or https url for github repo.") 73 | 74 | app_config.github_repo_url = origin_url 75 | 76 | app_config.set_git_origin_url(project_path) 77 | 78 | setup = Setup(app_config, project_path, invoker_tag) 79 | 80 | try: 81 | setup.create_ssh_key_file() 82 | asyncio.run(setup.setup_nodes()) 83 | finally: 84 | setup.finish() 85 | -------------------------------------------------------------------------------- /higgsfield/internal/ci/setup.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from pathlib import Path 3 | from typing import List 4 | 5 | from jinja2 import Environment, FileSystemLoader 6 | from higgsfield.internal.cfg import AppConfig 7 | import asyncssh 8 | 9 | from higgsfield.internal.util import templates_path 10 | from higgsfield.internal.experiment.builder import header 11 | 12 | 13 | docker_install_script = '''/bin/bash -c "$(curl -fsSL https://gist.githubusercontent.com/arpanetus/1c1210b9e432a04dcfb494725a407a70/raw/5d47baa19b7100261a2368a43ace610528e0dfa2/install.sh)"''' 14 | 15 | def invoker_install_script(tag: str) -> str: 16 | return f"""wget https://github.com/ml-doom/invoker/releases/download/{tag}/invoker-{tag}-linux-amd64.tar.gz && \ 17 | tar -xvf invoker-{tag}-linux-amd64.tar.gz && \ 18 | sudo mv invoker /usr/bin/invoker && \ 19 | rm invoker-{tag}-linux-amd64.tar.gz""" 20 | 21 | 22 | def deploy_key_script(key: str, project_name: str, deploy_key_string: str): 23 | return f"""sudo mkdir -p ~/.ssh && \ 24 | echo "{key}" > ~/.ssh/{project_name}-github-deploy.key && \ 25 | chmod 600 ~/.ssh/{project_name}-github-deploy.key && \ 26 | sudo touch ~/.ssh/config && \ 27 | sudo chmod 644 ~/.ssh/config && \ 28 | echo "{deploy_key_string}" | sudo tee -a ~/.ssh/config 29 | """ 30 | 31 | 32 | class Setup: 33 | app_config: AppConfig 34 | path: str 35 | deploy_key: str 36 | project_path: Path 37 | invoker_tag: str 38 | 39 | def __init__( 40 | self, 41 | app_config: AppConfig, 42 | project_path: Path, 43 | invoker_tag: str, 44 | ): 45 | self.app_config = app_config 46 | 47 | if reason := self.app_config.is_valid() is not None: 48 | raise ValueError(reason) 49 | 50 | self.project_path = project_path 51 | self.invoker_tag = invoker_tag 52 | 53 | def create_ssh_key_file(self): 54 | if self.app_config.key is None: 55 | raise ValueError("SSH_KEY in env is None") 56 | 57 | with Path.home() / ".ssh" / f"temp-{self.app_config.name}.key" as f: 58 | f.write_text(self.app_config.key) 59 | f.chmod(0o600) 60 | self.path = str(Path.resolve(f.absolute())) 61 | 62 | def finish(self): 63 | Path(self.path).unlink() 64 | 65 | async def establish_connections(self): 66 | if self.app_config.key is None: 67 | raise ValueError("SSH_KEY in env is None") 68 | 69 | self.connections: List[asyncssh.SSHClientConnection] = [] 70 | for host in self.app_config.hosts: 71 | self.connections.append( 72 | await asyncssh.connect( 73 | host, 74 | port=self.app_config.port, 75 | username=self.app_config.user, 76 | client_keys=[self.path], 77 | ) 78 | ) 79 | 80 | def set_deploy_key(self): 81 | with Path.home() / ".ssh" / "higgsfield" / f"{self.app_config.name}-github-deploy.key" as f: 82 | self.deploy_key = f.read_text() 83 | 84 | def _build_deploy_key_string(self): 85 | return f"Host github.com-{self.app_config.name}\n\tHostName github.com\n\tIdentityFile ~/.ssh/{self.app_config.name}-github-deploy.key\n\tIdentitiesOnly yes\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile=/dev/null\n\tLogLevel=ERROR\n" 86 | 87 | async def setup_nodes(self): 88 | await self.establish_connections() 89 | 90 | if len(self.connections) == 0: 91 | print("\n\n\nNO CONNECTIONS!!!\n\n\n") 92 | 93 | print("\n\n\nINSTALLING DOCKER\n\n\n") 94 | async def printer(thing): 95 | print(thing) 96 | to_run = [] 97 | for conn in self.connections: 98 | to_run.append(printer(await conn.run(docker_install_script))) 99 | 100 | await asyncio.gather(*to_run) 101 | 102 | print("\n\n\nINSTALLING INVOKER\n\n\n") 103 | to_run = [] 104 | for conn in self.connections: 105 | to_run.append(printer(await conn.run(invoker_install_script(self.invoker_tag)))) 106 | 107 | await asyncio.gather(*to_run) 108 | 109 | print("\n\n\nSETTING UP DEPLOY KEY\n\n\n") 110 | self.set_deploy_key() 111 | dk_script = deploy_key_script( 112 | self.deploy_key, self.app_config.name, self._build_deploy_key_string() 113 | ) 114 | to_run = [] 115 | for conn in self.connections: 116 | to_run.append(printer(await conn.run(dk_script))) 117 | 118 | await asyncio.gather(*to_run) 119 | 120 | # close connections 121 | for conn in self.connections: 122 | conn.close() 123 | 124 | 125 | # re-establish connections 126 | await self.establish_connections() 127 | 128 | 129 | print("\n\n\nPULLING BASE DOCKER IMAGE\n\n\n") 130 | to_run = [] 131 | for conn in self.connections: 132 | to_run.append(printer(await conn.run(f"docker pull higgsfield/pytorch:latest"))) 133 | 134 | await asyncio.gather(*to_run) 135 | 136 | 137 | print("\n\n\nSeems like everything is done by now. Go run your experiments.\n\n\n") 138 | 139 | -------------------------------------------------------------------------------- /higgsfield/internal/cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | import click 5 | 6 | from .util import wd_path, check_name, parse_origin_link_or_else 7 | from .launch import Launch 8 | 9 | from .cfg import AppConfig 10 | from .experiment.builder import DeployBuilder, KillBuilder, build_all_experiment_actions 11 | 12 | from .ci import cli as ci_cli 13 | from higgsfield.internal.init import init 14 | 15 | import os 16 | 17 | 18 | def setup_environ_flags(rank): 19 | """Environment flags for debugging purposes""" 20 | 21 | os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) 22 | os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) 23 | 24 | 25 | def setup(seed): 26 | import torch 27 | import torch.distributed as dist 28 | 29 | torch.cuda.manual_seed(seed) 30 | torch.manual_seed(seed) 31 | 32 | dist.init_process_group(backend="nccl") 33 | 34 | local_rank = int(os.environ["LOCAL_RANK"]) 35 | rank = int(os.environ["RANK"]) 36 | world_size = int(os.environ["WORLD_SIZE"]) 37 | 38 | if dist.is_initialized(): 39 | torch.cuda.set_device(local_rank) 40 | torch.cuda.empty_cache() 41 | setup_environ_flags(rank) 42 | 43 | 44 | @click.command("run") 45 | @click.option("--experiment_name", type=str, help="experiment name") 46 | @click.option("--run_name", type=str, help="run name") 47 | @click.option("--max_repeats", type=int, help="max repeats") 48 | @click.argument("extra_args", nargs=-1) 49 | def run_experiment( 50 | experiment_name: str, 51 | run_name: str, 52 | max_repeats: int, 53 | extra_args: List[str], 54 | ): 55 | wd = wd_path() 56 | app_config = AppConfig.from_path(wd) 57 | 58 | os.environ["PROJECT_NAME"] = app_config.name 59 | os.environ["EXPERIMENT_NAME"] = experiment_name 60 | os.environ["RUN_NAME"] = run_name 61 | setup(42) 62 | Launch(wd, app_config.name, experiment_name, run_name, max_repeats, extra_args) 63 | 64 | 65 | @click.command("init") 66 | @click.argument("project_name", type=str, required=True) 67 | def init_cmd(project_name: str): 68 | print( 69 | "Initializing {} project at:\n{}/{}".format( 70 | project_name, wd_path(), project_name 71 | ) 72 | ) 73 | init(wd_path(), check_name(project_name)) 74 | 75 | 76 | @click.command("build-experiments") 77 | def build_experiments(): 78 | wd = wd_path() 79 | 80 | app_config = AppConfig.from_path(wd) 81 | 82 | origin_url = app_config.get_git_origin_url(wd) 83 | 84 | if origin_url is None: 85 | raise ValueError("Have you pushed your project to github?") 86 | 87 | origin_url = parse_origin_link_or_else(origin_url) 88 | 89 | if origin_url is None: 90 | raise ValueError("Please use ssh or https url for github repo.") 91 | 92 | app_config.github_repo_url = origin_url 93 | 94 | app_config.set_git_origin_url(wd) 95 | 96 | DeployBuilder(app_config, wd).generate() 97 | KillBuilder(app_config, wd).generate() 98 | build_all_experiment_actions(wd, app_config) 99 | 100 | 101 | @click.command("show-deploy-key") 102 | def show_deploy_key(): 103 | wd = wd_path() 104 | app_config = AppConfig.from_path(wd) 105 | deploy_key_path = ( 106 | Path.home() / ".ssh" / "higgsfield" / f"{app_config.name}-github-deploy.key.pub" 107 | ) 108 | if not deploy_key_path.exists(): 109 | raise ValueError("No deploy key found, file an issue on github") 110 | 111 | with deploy_key_path as f: 112 | click.echo(f.read_text() + "\n") 113 | 114 | 115 | @click.group("ci") 116 | def ci(): 117 | pass 118 | 119 | 120 | ci.add_command(ci_cli.proc_per_node) 121 | ci.add_command(ci_cli.ssh_details) 122 | ci.add_command(ci_cli.decode_secrets) 123 | -------------------------------------------------------------------------------- /higgsfield/internal/experiment/ast_parser.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from typing import Any, List, Dict, Tuple, Optional 4 | 5 | 6 | def func_defs(module: ast.Module) -> List[ast.FunctionDef]: 7 | defs = [] 8 | for node in ast.iter_child_nodes(module): 9 | if isinstance(node, ast.FunctionDef): 10 | defs.append(node) 11 | 12 | return defs 13 | 14 | 15 | def filter_experiment_defs(defs: List[ast.FunctionDef]): 16 | filtered_defs: List[ast.FunctionDef] = [] 17 | for node in defs: 18 | try: 19 | if ( 20 | len(node.decorator_list) >= 1 21 | and node.decorator_list[0].func.id == "experiment" # type: ignore 22 | ): 23 | filtered_defs.append(node) 24 | except Exception: 25 | pass 26 | return filtered_defs 27 | 28 | 29 | class Dec: 30 | name: str 31 | arg_pairs: Dict[str, Any] 32 | allowed_args = dict() 33 | 34 | def __init__( 35 | self, 36 | name: str, 37 | allowed_args: Dict[str, Tuple[type, ...]], 38 | arg_pairs: Optional[Dict[str, Any]] = None, 39 | ): 40 | self.name = name 41 | self.allowed_args = allowed_args 42 | self.arg_pairs = dict() if arg_pairs is None else arg_pairs 43 | 44 | def add_arg_pair(self, left: str, right: Any): 45 | if left not in self.allowed_args: 46 | raise ValueError(f"argument {left} of {self.name} is not allowed") 47 | if type(right) not in self.allowed_args[left]: 48 | raise ValueError( 49 | f"argument {left} of {self.name} has type {type(right)}, need {self.allowed_args[left]}" 50 | ) 51 | 52 | if left in self.arg_pairs: 53 | raise ValueError(f"argument {left} is redefined in {self.name}") 54 | 55 | self.arg_pairs[left] = right 56 | 57 | 58 | noneType = type(None) 59 | 60 | 61 | class Expdec(Dec): 62 | def __init__(self): 63 | super(Expdec, self).__init__( 64 | name="experiment", allowed_args={"name": (str,), "seed": (int, noneType)} 65 | ) 66 | 67 | 68 | class Paramdec(Dec): 69 | def __init__(self): 70 | super(Paramdec, self).__init__( 71 | name="param", 72 | allowed_args={ 73 | "name": (str,), 74 | "default": (str, int, bool, float, noneType), 75 | "description": (str, noneType), 76 | "required": (bool, noneType), 77 | "type": (type,), 78 | "options": (tuple, noneType), 79 | }, 80 | ) 81 | 82 | 83 | builder = {"experiment": Expdec, "param": Paramdec} 84 | 85 | type_dict = { 86 | "str": str, 87 | "int": int, 88 | "float": float, 89 | "bool": bool, 90 | } 91 | 92 | 93 | def build_experiment_def( 94 | node: ast.FunctionDef, 95 | ) -> Optional[Tuple[Expdec, Dict[str, Paramdec]]]: 96 | experiment: Optional[Expdec] = None 97 | params: Dict[str, Paramdec] = dict() 98 | stop = False 99 | 100 | for maybe_decorator in node.decorator_list: 101 | if not isinstance(maybe_decorator, ast.Call): 102 | continue 103 | decorator: ast.Call = maybe_decorator # type: ignore 104 | func: Optional[ast.Name] = getattr(decorator, "func", None) 105 | if func is None: 106 | stop = True 107 | break 108 | dec_to_call = builder.get(func.id, None) 109 | 110 | if dec_to_call is None: 111 | stop = True 112 | break 113 | 114 | dec = dec_to_call() 115 | 116 | if len(decorator.args) > 1: 117 | raise ValueError( 118 | "experiment or param decorators cannot " 119 | + "have other decorators applied or unnamed params other than name" 120 | ) 121 | 122 | if len(decorator.args) == 1: 123 | try: 124 | dec.add_arg_pair("name", decorator.args[0].value) # type: ignore 125 | except Exception: 126 | stop = True 127 | break 128 | 129 | for kw in decorator.keywords: 130 | try: 131 | field_val = None 132 | if isinstance(kw.value, ast.Name): 133 | field_val = type_dict[kw.value.id] 134 | elif isinstance(kw.value, ast.Constant): 135 | field_val = kw.value.value 136 | elif isinstance(kw.value, ast.Tuple) or isinstance(kw.value, ast.List): 137 | vals = [] 138 | for elt in kw.value.elts: 139 | vals.append(elt.value) # type: ignore 140 | field_val = tuple(vals) 141 | 142 | else: 143 | raise ValueError( 144 | f"cannot find the type of kw.value: {type(kw.value)}" 145 | ) 146 | dec.add_arg_pair(kw.arg, field_val) # type: ignore 147 | except Exception as e: 148 | print(e) 149 | stop = True 150 | break 151 | 152 | if type(dec) == Expdec: 153 | if experiment is not None: 154 | raise ValueError("more than one experiment is defined") 155 | else: 156 | experiment = dec 157 | 158 | if type(dec) == Paramdec: 159 | if dec.arg_pairs["name"] in params: 160 | raise ValueError( 161 | f'more than one param with the same name {dec.arg_pairs["name"]} is defined' 162 | ) 163 | params[dec.arg_pairs["name"]] = dec 164 | if stop: 165 | return 166 | 167 | if experiment is None: 168 | return 169 | 170 | return experiment, params 171 | 172 | 173 | def build_experiment_defs( 174 | defs: List[ast.FunctionDef], 175 | ) -> List[Tuple[Expdec, Dict[str, Paramdec]]]: 176 | exps: List[Tuple[Expdec, Dict[str, Paramdec]]] = list() 177 | for node in defs: 178 | ret = build_experiment_def(node) 179 | if ret is not None: 180 | exp, params = ret 181 | exps.append((exp, params)) 182 | 183 | return exps 184 | 185 | 186 | def parse_experiments(filename: str) -> List[Tuple[Expdec, Dict[str, Paramdec]]]: 187 | parsed_code: Optional[ast.Module] = None 188 | with open(filename, "r") as f: 189 | parsed_code = ast.parse(f.read()) 190 | 191 | if parsed_code is None: 192 | return [] 193 | 194 | # search for top level experiment declarations 195 | defs = func_defs(parsed_code) 196 | 197 | # filter out by decorators, top level decorator should be "experiment" 198 | defs = filter_experiment_defs(defs) 199 | 200 | got = build_experiment_defs(defs) 201 | 202 | return got 203 | -------------------------------------------------------------------------------- /higgsfield/internal/experiment/builder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import os 4 | import dotenv 5 | 6 | from higgsfield.internal.cfg import AppConfig 7 | 8 | from .params import Param, build_gh_action_inputs, build_run_params 9 | from jinja2 import Environment, FileSystemLoader, Template 10 | from higgsfield.internal.util import templates_path 11 | from pathlib import Path 12 | from importlib.machinery import SourceFileLoader 13 | from .decorator import ExperimentDecorator 14 | from higgsfield.internal.experiment.ast_parser import parse_experiments 15 | 16 | header = """# THIS FILE WAS GENERATED BY HIGGSFIELD. 17 | # DO NOT EDIT. 18 | # IF YOUR WORKFLOW DOESN'T WORK, CREATE AN ISSUE. 19 | """ 20 | 21 | 22 | class ActionBuilder: 23 | app_config: AppConfig 24 | wf_dir: Path 25 | template: Template 26 | template_name: str 27 | 28 | def __init__(self, app_config: AppConfig, project_path: Path): 29 | environment = Environment(loader=FileSystemLoader(templates_path())) 30 | self.template = environment.get_template(self.template_name) 31 | 32 | self.app_config = app_config 33 | 34 | wf_dir = project_path / ".github" / "workflows" 35 | if not wf_dir.exists(): 36 | wf_dir.mkdir(parents=True, exist_ok=True) 37 | elif not wf_dir.is_dir(): 38 | raise ValueError(f"{wf_dir} is not a directory, delete it and try again") 39 | 40 | self.wf_dir = wf_dir 41 | 42 | 43 | class KillBuilder(ActionBuilder): 44 | template_name: str = "kill_action.j2" 45 | 46 | def generate(self): 47 | (self.wf_dir / "kill.yml").write_text( 48 | self.template.render( 49 | header=header, 50 | project_name=self.app_config.name, 51 | ) 52 | ) 53 | 54 | print("Updated kill action") 55 | 56 | 57 | def as_keyed_repo_url(repo_url: Optional[str], project_name: str) -> str: 58 | if repo_url is None: 59 | raise ValueError("Did you add git remote origin? / push to github?") 60 | 61 | # get the index of "/" after "github.com" 62 | # and replace it with "-project_name:" 63 | return repo_url.replace("github.com:", f"github.com-{project_name}:") 64 | 65 | 66 | def insert_env_line(keys: List[str], indent: str) -> str: 67 | lines = [] 68 | for key in keys: 69 | line = indent + "echo " + key + '="${{ secrets.' + key + ' }}" >> env' 70 | lines.append(line) 71 | 72 | return "\n".join(lines) 73 | 74 | 75 | def env_keys_as_action(path: Path, indent: str) -> str: 76 | keys = list(dotenv.dotenv_values(path).keys()) 77 | keys.pop(keys.index("SSH_KEY")) 78 | 79 | return insert_env_line(keys, indent) 80 | 81 | 82 | echo_indent = " " 83 | 84 | 85 | class DeployBuilder(ActionBuilder): 86 | template_name = "deploy_action.j2" 87 | 88 | def generate(self): 89 | (self.wf_dir / "deploy.yml").write_text( 90 | self.template.render( 91 | header=header, 92 | project_name=self.app_config.name, 93 | keyed_repo_url=as_keyed_repo_url( 94 | self.app_config.github_repo_url, self.app_config.name 95 | ), 96 | env_gen=env_keys_as_action( 97 | self.wf_dir.parent.parent / "env", echo_indent 98 | ), 99 | ) 100 | ) 101 | 102 | print("Updated deploy action") 103 | 104 | 105 | class ExperimentBuilder(ActionBuilder): 106 | template_name = "experiment_action.j2" 107 | 108 | def generate(self, experiment_name: str, params: List[Param]): 109 | (self.wf_dir / f"run_{experiment_name}.yml").write_text( 110 | self.template.render( 111 | header=header, 112 | experiment_name=experiment_name, 113 | project_name=self.app_config.name, 114 | params=build_gh_action_inputs(params), 115 | rest=build_run_params(params), 116 | env_gen=env_keys_as_action( 117 | self.wf_dir.parent.parent / "env", echo_indent 118 | ), 119 | ) 120 | ) 121 | 122 | 123 | print("Updated experiment action", experiment_name) 124 | 125 | 126 | def _source_experiments(base_path: Path): 127 | """ 128 | Only used inside docker to inject experiments into the module. 129 | Do not use outside of docker. But if you do, you will have to have 130 | the same dependencies (aka environment) as the docker container. 131 | """ 132 | for file in base_path.glob("**/*.py"): 133 | module_name = os.path.basename(file).split(".py")[0].split(".py")[0].replace(" ", "_").replace("-", "_") 134 | SourceFileLoader(module_name, str(file)).load_module() 135 | 136 | 137 | def build_all_experiment_actions(wd_path: Path, app_config: AppConfig): 138 | """ 139 | Builds all experiment actions 140 | 141 | Root path should be the root of the project and must contain the src folder under which everything is defined. 142 | Project name should be the name of the project. 143 | 144 | It will parse ast of all files under src folder and search for experiments. 145 | Then it will build actions for each experiment and save them under .github/workflow folder. 146 | Deletes all actions that have name prefix run_experiment_ and header inside if not needed. 147 | """ 148 | exp_params_pairs = [] 149 | for file in (wd_path / "src").glob("**/*.py"): 150 | exp_params_pairs.extend(parse_experiments(str(file.resolve()))) 151 | 152 | if len(exp_params_pairs) == 0: 153 | print("No experiments found") 154 | return 155 | 156 | experiments = ExperimentDecorator.from_ast(exp_params_pairs) 157 | actions_folder = wd_path / ".github" / "workflows" 158 | actions_folder.mkdir(parents=True, exist_ok=True) 159 | 160 | # list all files that have name prefix run_experiment_ and header inside 161 | for i in actions_folder.glob("run_*.yml"): 162 | if i.read_text().startswith(header): 163 | i.unlink() 164 | 165 | for experiment_name, experiment in experiments.items(): 166 | ExperimentBuilder(app_config, wd_path).generate( 167 | experiment_name, experiment.params 168 | ) 169 | -------------------------------------------------------------------------------- /higgsfield/internal/experiment/decorator.py: -------------------------------------------------------------------------------- 1 | from .params import Param 2 | from typing import Callable, Any, List, Tuple, Dict, Optional, Union, Set, Type 3 | from higgsfield.internal.util import check_name 4 | from .ast_parser import Expdec, Paramdec 5 | 6 | 7 | class InnerWrap: 8 | param_set: Set[Param] 9 | 10 | def __init__(self, func: Callable[..., None], param_set: Set[Param]): 11 | self.func = func 12 | self.param_set = set(param_set) 13 | 14 | def add_param(self, param: Param): 15 | self.param_set.add(param) 16 | 17 | 18 | class ExperimentDecorator: 19 | name: str 20 | params: List[Param] 21 | train: Callable[..., None] 22 | 23 | def __init__(self, name: str, *, seed: Optional[int] = None): 24 | """ 25 | Composable experiment decorator. 26 | >> @ExperimentDecorator("my_experiment", seed=42) 27 | >> @ParamDecorator("my_param", default=1, description="My param description") 28 | >> def train(params): 29 | >> pass 30 | 31 | that is equivalent to: 32 | >> def train(params): 33 | >> pass 34 | >> 35 | >> train = ParamDecorator("my_param", default=1, description="My param description")(train) 36 | >> train = ExperimentDecorator("my_experiment")(train) 37 | 38 | Params of the experiment are accessible via train.params. 39 | 40 | 41 | """ 42 | name = check_name(name) 43 | if name in _experiments: 44 | raise ValueError(f"Experiment with name {name} already exists") 45 | 46 | _experiments[name] = self 47 | 48 | self.name = name 49 | self.params = list() 50 | if seed is not None: 51 | self.params.append(Param.from_values(name="seed", default=seed, type=int)) 52 | else: 53 | self.params.append(Param.from_values(name="seed", type=int, required=True, default=42)) 54 | self.train = lambda x: print("this shouldn't have been called at all") 55 | 56 | def __call__(self, func: Callable[..., None]) -> Optional[Callable[..., None]] : 57 | if type(func) == InnerWrap: 58 | # check if seed is in params: 59 | if not any(param.name == "seed" for param in func.param_set): 60 | # add seed param 61 | self.params.extend(list(func.param_set)) 62 | else: 63 | self.params = list(func.param_set) 64 | self.train = func.func 65 | 66 | return 67 | 68 | if callable(func) and len(func.__code__.co_varnames) >= 1: 69 | self.train = func 70 | 71 | return 72 | else: 73 | raise ValueError( 74 | f"Experiment decorator can only be applied to a function that accepts one argument, " 75 | + f"or on top of another param decorator. \n\tHave: \t{self.name} \n\tFunc: \t{self.train}" 76 | ) 77 | 78 | @classmethod 79 | def from_ast( 80 | cls, ast_exps: List[Tuple[Expdec, Dict[str, Paramdec]]] 81 | ) -> Dict[str, "ExperimentDecorator"]: 82 | experiments = {} 83 | for ast_exp, ast_params in ast_exps: 84 | exp = cls(ast_exp.arg_pairs["name"]) 85 | for ast_param in ast_params.values(): 86 | param = ParamDecorator.from_ast(ast_param) 87 | exp.params.append(param.param) 88 | if exp.name in experiments: 89 | raise ValueError(f"Experiment with name {exp.name} already exists") 90 | experiments[exp.name] = exp 91 | 92 | return experiments 93 | 94 | 95 | class ParamDecorator: 96 | param: Param 97 | 98 | def __init__( 99 | self, 100 | name: str, 101 | *, 102 | default: Any = None, 103 | description: Optional[str] = None, 104 | required: bool = False, 105 | type: Optional[Type] = None, 106 | options: Optional[Union[Tuple[Any, ...], List[Any]]] = None, 107 | ): 108 | self.param = Param.from_values( 109 | name=name, 110 | default=default, 111 | description=description, 112 | required=required, 113 | type=type, 114 | options=tuple(options) if options is not None else None, 115 | ) 116 | 117 | def __call__(self, func: Callable[..., None]) -> Callable: 118 | if type(func) == InnerWrap: 119 | func.add_param(self.param) 120 | return func 121 | 122 | if callable(func) and len(func.__code__.co_varnames) >= 1: 123 | params = set() 124 | params.add(self.param) 125 | return InnerWrap(func, params) # type: ignore 126 | else: 127 | raise ValueError( 128 | "Param decorator can only be applied to a function that accepts one argument, or on top of another param decorator." 129 | ) 130 | 131 | @classmethod 132 | def list_from_ast(cls, ast_params: Dict[str, Paramdec]): 133 | params = list() 134 | for ast_param in ast_params.values(): 135 | param = ParamDecorator.from_ast(ast_param) 136 | params.append(param) 137 | return params 138 | 139 | @classmethod 140 | def from_ast(cls, ast_param: Paramdec): 141 | return cls(**ast_param.arg_pairs) 142 | 143 | 144 | _experiments: Dict[str, ExperimentDecorator] = {} 145 | 146 | 147 | experiment = ExperimentDecorator 148 | param = ParamDecorator 149 | -------------------------------------------------------------------------------- /higgsfield/internal/experiment/params.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Tuple, Dict, Optional, Type 2 | from yaml import safe_dump 3 | from higgsfield.internal.util import check_name 4 | 5 | 6 | def build_run_params(params: List["Param"]) -> str: 7 | return " ".join(param.as_run_param() for param in params) 8 | 9 | 10 | def build_gh_action_inputs(params: List["Param"]) -> List[str]: 11 | return [param.as_github_action() for param in params] 12 | 13 | 14 | _arg_type_set = { 15 | int, 16 | float, 17 | str, 18 | bool, 19 | } 20 | 21 | 22 | def reset_values(values: Dict[str, Any]): 23 | if "name" not in values: 24 | raise ValueError("Field name is required in Param") 25 | if "required" not in values: 26 | values["required"] = False 27 | 28 | return values 29 | 30 | 31 | class Param: 32 | name: str 33 | default: Optional[Any] = None 34 | description: Optional[str] = None 35 | required: bool = False 36 | type: Optional[Type] = None 37 | options: Optional[Tuple[Any, ...]] = None 38 | 39 | def __init__(self, **kwargs): 40 | self.__dict__.update(kwargs) 41 | 42 | @classmethod 43 | def from_values(cls, **values) -> "Param": 44 | values = reset_values(values) 45 | values["name"] = check_name(values.get("name") or "") 46 | arg_type = values.get("type") 47 | default = values.get("default") 48 | 49 | if arg_type is None: 50 | if default is None: 51 | arg_type = str 52 | values["type"] = arg_type 53 | else: 54 | arg_type = type(default) 55 | values["type"] = arg_type 56 | 57 | if arg_type not in _arg_type_set: 58 | raise ValueError( 59 | f"Param type {arg_type} not supported." 60 | + "Only primitive {_arg_type_map.keys()} are supported" 61 | ) 62 | 63 | try: 64 | if default is not None: 65 | values["default"] = arg_type(default) 66 | except Exception as e: 67 | raise ValueError( 68 | f"Param default {default} cannot be converted to type {arg_type}" 69 | ) from e 70 | 71 | if options := values.get("options"): 72 | values["options"] = [arg_type(opt) for opt in options] 73 | if default is None: 74 | values["default"] = values["options"][0] 75 | default = values["default"] 76 | 77 | if default is not None and default not in values["options"]: 78 | raise ValueError(f"Param default {default} not in options {options}") 79 | 80 | return Param(**values) 81 | 82 | def check(self, value: str): 83 | try: 84 | if self.type is None: 85 | self.type = str 86 | 87 | value = self.type(value) 88 | except Exception as e: 89 | raise ValueError( 90 | f"Param value {value} cannot be converted to type {self.type}" 91 | ) from e 92 | 93 | if options := self.options: 94 | if value not in options: 95 | raise ValueError(f"Param value {value} not in options {options}") 96 | 97 | def as_github_action(self) -> str: 98 | indent = " " 99 | to_join = [f"{self.name}:"] 100 | if self.description: 101 | # TODO: fix that yaml.safe_dump with some proper encoder, string esc etc. 102 | to_join.append( 103 | f"{indent}description: {remove_trailing_yaml(safe_dump(self.description))}" 104 | ) 105 | to_join.append(f"{indent}required: {self.required}") 106 | if self.default is not None: 107 | d = self.default 108 | if type(self.default) == str: 109 | d = remove_trailing_yaml(safe_dump(self.default)) 110 | to_join.append(f"{indent}default: {d}") 111 | if self.options: 112 | to_join.append(f"{indent}options: {list(self.options)}") 113 | to_join.append(f"{indent}type: choice") 114 | if self.type == bool: 115 | to_join.append(f"{indent}type: boolean") 116 | return "\n".join(to_join) 117 | 118 | def as_run_param(self) -> str: 119 | # field_name="value" 120 | return f'{pfx}{self.name}="{wrap_brackets("github.event.inputs." + self.name)}"' 121 | 122 | class Config: 123 | frozen = True 124 | 125 | 126 | def wrap_brackets(s: str) -> str: 127 | left = "${{ " 128 | right = " }}" 129 | return left + s + right 130 | 131 | 132 | def remove_trailing_yaml(s: str) -> str: 133 | trailing_yaml = "\n...\n" 134 | if s.endswith(trailing_yaml): 135 | return s[: -len(trailing_yaml)] 136 | return s 137 | 138 | 139 | pfx = "hf_action_" 140 | 141 | 142 | class _ToSet: 143 | param: Param 144 | value: Any 145 | 146 | def __init__(self, param: Param, value: Optional[Any] = None): 147 | self.param = param 148 | self.value = value 149 | 150 | 151 | class ArgParams: 152 | pass 153 | 154 | 155 | def parse_kwargs_to_params( 156 | params: List[Param], 157 | kwargs: Dict[str, str], 158 | ): 159 | keys = (key[len(pfx) :] for key in kwargs if key.startswith(pfx)) 160 | 161 | fields = {param.name: param for param in params} 162 | 163 | # remove keys that are not in subclass_fields 164 | keys = [key for key in keys if key in fields] 165 | 166 | # get values from kwargs 167 | values = [kwargs[pfx + key] for key in keys] 168 | 169 | k_v = dict(zip(keys, values)) 170 | 171 | params_to_set = {key: _ToSet(param=param) for key, param in fields.items()} 172 | 173 | for key, to_set in params_to_set.items(): 174 | if key in k_v: 175 | v = k_v[key] 176 | to_set.param.check(v) 177 | to_set.value = v 178 | elif to_set.param.required and to_set.param.default is None: 179 | raise ValueError(f"Required argument {key} not provided") 180 | else: 181 | to_set.value = to_set.param.default 182 | 183 | params_to_set[key] = to_set 184 | 185 | prepare = ArgParams() 186 | for key, to_set in params_to_set.items(): 187 | setattr(prepare, key, to_set.param.type(to_set.value)) 188 | 189 | return prepare 190 | -------------------------------------------------------------------------------- /higgsfield/internal/init.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import io 4 | 5 | from cryptography.hazmat.primitives import serialization, asymmetric 6 | from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey 7 | 8 | from jinja2 import Environment, FileSystemLoader 9 | from .util import templates_path, ROOT_DIR 10 | 11 | from typing import Tuple 12 | 13 | templates = Environment(loader=FileSystemLoader(templates_path())) 14 | 15 | 16 | def init(wd: Path, project_name: str): 17 | if project_name is None or project_name == "": 18 | raise ValueError("Project name cannot be empty") 19 | project_path = wd / project_name 20 | if project_path.exists(): 21 | raise ValueError(f"Project {project_name} already exists") 22 | 23 | project_path.mkdir(parents=True, exist_ok=True) 24 | (project_path / "src").mkdir(parents=True, exist_ok=True) 25 | 26 | config_path = project_path / "src" / "config.py" 27 | if config_path.exists(): 28 | raise ValueError(f"Config file {config_path} already exists") 29 | 30 | readme_path = project_path / "README.md" 31 | if readme_path.exists(): 32 | raise ValueError(f"README.md file {readme_path} already exists") 33 | 34 | source_path = Path(ROOT_DIR) / "static" / "project" 35 | fileset = [ 36 | ".gitignore", 37 | "env", 38 | "Dockerfile", 39 | "src/alpaca_bf16.py", 40 | "src/alpaca_fp16.py", 41 | "src/dataset.py", 42 | "requirements.txt", 43 | ] 44 | 45 | for file in fileset: 46 | (project_path / file).write_bytes((source_path / file).read_bytes()) 47 | 48 | config_path.write_text( 49 | templates.get_template("config_py.j2").render(project_name=project_name) 50 | ) 51 | 52 | readme_path.write_text( 53 | templates.get_template("README_md.j2").render(project_name=project_name) 54 | ) 55 | 56 | hf_deploy_ssh_folder = Path.home() / ".ssh/higgsfield/" 57 | hf_deploy_ssh_folder.mkdir(parents=True, exist_ok=True) 58 | priv, pub = generate_deploy_keys() 59 | (hf_deploy_ssh_folder / f"{project_name}-github-deploy.key").write_bytes(priv) 60 | 61 | # set permissions to 400 62 | os.system(f"chmod 600 {hf_deploy_ssh_folder / f'{project_name}-github-deploy.key'}") 63 | public_key_path = hf_deploy_ssh_folder / f"{project_name}-github-deploy.key.pub" 64 | public_key_path.write_text(pub.decode() + "\n") 65 | 66 | # set permissions to 400 67 | os.system( 68 | f"chmod 600 {hf_deploy_ssh_folder / f'{project_name}-github-deploy.key.pub'}" 69 | ) 70 | 71 | 72 | def generate_deploy_keys() -> Tuple[bytes, bytes]: 73 | private_key = Ed25519PrivateKey.generate() 74 | 75 | public_key = private_key.public_key() 76 | 77 | private_bytes = private_key.private_bytes( 78 | encoding=serialization.Encoding.PEM, 79 | format=serialization.PrivateFormat.OpenSSH, 80 | encryption_algorithm=serialization.NoEncryption(), 81 | ) 82 | 83 | public_bytes = public_key.public_bytes( 84 | encoding=serialization.Encoding.OpenSSH, 85 | format=serialization.PublicFormat.OpenSSH, 86 | ) 87 | 88 | return private_bytes, public_bytes 89 | -------------------------------------------------------------------------------- /higgsfield/internal/launch.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | from pathlib import Path 3 | 4 | import os 5 | from higgsfield.internal.experiment.decorator import _experiments 6 | from higgsfield.internal.experiment.params import parse_kwargs_to_params 7 | from higgsfield.internal.experiment.builder import _source_experiments 8 | 9 | 10 | class Launch: 11 | experiment_name: str 12 | run_name: str 13 | max_repeats: int 14 | kwargs: Dict[str, str] 15 | prepared: Any 16 | 17 | def __init__( 18 | self, 19 | wd: Path, 20 | project_name: str, 21 | experiment_name: str, 22 | run_name: str, 23 | max_repeats: int, 24 | rest: List[str], 25 | ): 26 | if not experiment_name or experiment_name == "": 27 | raise ValueError("Experiment name cannot be empty") 28 | 29 | if not project_name or project_name == "": 30 | raise ValueError("Project name cannot be empty") 31 | 32 | if not run_name or run_name == "": 33 | raise ValueError("Run name cannot be empty") 34 | 35 | if not max_repeats or max_repeats < -1: 36 | raise ValueError("Max repeats cannot be none or less than -1") 37 | 38 | self.wd = wd 39 | self.experiment_name = experiment_name 40 | self.project_name = project_name 41 | self.run_name = run_name 42 | self.max_repeats = max_repeats 43 | self.kwargs = self._parse(rest) 44 | self._find_route() 45 | self.eval_params() 46 | self.apply_train() 47 | 48 | def _parse(self, rest: List[str]): 49 | kwargs = {} 50 | for arg in rest: 51 | if "=" not in arg: 52 | continue 53 | key, value = arg.split("=") 54 | kwargs[key] = value 55 | return kwargs 56 | 57 | def _find_route(self): 58 | _source_experiments(self.wd / "src") 59 | 60 | if self.experiment_name not in _experiments: 61 | raise ValueError(f"Experiment {self.experiment_name} not found") 62 | 63 | experiment = _experiments[self.experiment_name] 64 | self.experiment = experiment 65 | 66 | def eval_params(self): 67 | params = parse_kwargs_to_params(self.experiment.params, self.kwargs) 68 | setattr(params, "experiment_name", self.experiment_name) 69 | setattr(params, "project_name", self.project_name) 70 | setattr(params, "run_name", self.run_name) 71 | setattr(params, "rank", int(os.environ.get("RANK", 0))) 72 | setattr(params, "world_size", int(os.environ.get("WORLD_SIZE", 1))) 73 | setattr(params, "local_rank", int(os.environ.get("LOCAL_RANK", 0))) 74 | 75 | self.prepared = params 76 | 77 | def apply_train(self): 78 | self.experiment.train(self.prepared) 79 | -------------------------------------------------------------------------------- /higgsfield/internal/main.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from higgsfield.internal.cli import ( 4 | init_cmd, 5 | ci, 6 | run_experiment, 7 | show_deploy_key, 8 | build_experiments, 9 | ci_cli, 10 | ) 11 | 12 | 13 | @click.group() 14 | def cli(): 15 | """Higgsfield CLI""" 16 | pass 17 | 18 | 19 | cli.add_command(init_cmd) 20 | cli.add_command(ci) 21 | cli.add_command(run_experiment) 22 | cli.add_command(build_experiments) 23 | cli.add_command(show_deploy_key) 24 | cli.add_command(ci_cli.setup_nodes) 25 | -------------------------------------------------------------------------------- /higgsfield/internal/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | import os 5 | 6 | from typing import Optional 7 | 8 | regex = re.compile("^[a-zA-Z_][a-zA-Z0-9_]*$") 9 | 10 | 11 | def check_name(name: str): 12 | if len(name) < 1 or len(name) > 20: 13 | raise ValueError("Name must be between 1 and 20 characters long") 14 | 15 | if not regex.match(name): 16 | raise ValueError("Name must match regex ^[a-zA-Z_][a-zA-Z0-9_]*$") 17 | 18 | return name 19 | 20 | 21 | def wd_path() -> Path: 22 | return Path.cwd() 23 | 24 | 25 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 26 | 27 | def templates_path() -> Path: 28 | return Path(ROOT_DIR) / "static" / "templates" 29 | 30 | https_repo_url_pattern = re.compile( 31 | r"^https\:\/\/github\.com\/[a-zA-Z0-9\-\_]+\/[a-zA-Z0-9\-\_]+\.git$" 32 | ) 33 | 34 | 35 | def match_https_link(link: str) -> bool: 36 | return https_repo_url_pattern.match(link) is not None 37 | 38 | 39 | def convert_https_to_ssh(link: str) -> str: 40 | gh, user, repo = link[8:-4].split("/") 41 | return f"git@{gh}:{user}/{repo}.git" 42 | 43 | 44 | def parse_origin_link_or_else(link: str) -> Optional[str]: 45 | if match_https_link(link): 46 | return convert_https_to_ssh(link) 47 | if link.startswith("git@github.com:"): 48 | return link 49 | 50 | return None 51 | -------------------------------------------------------------------------------- /higgsfield/llama/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama import ( 2 | Llama, 3 | Llama7b, 4 | Llama13b, 5 | Llama70b, 6 | ) 7 | -------------------------------------------------------------------------------- /higgsfield/llama/llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 9 | FullyShardedDataParallel as FSDP, 10 | CPUOffload, 11 | ) 12 | 13 | from torch.distributed.fsdp import ( 14 | MixedPrecision, 15 | ShardingStrategy, 16 | ) 17 | from torch.distributed.fsdp.wrap import ( 18 | transformer_auto_wrap_policy, 19 | ) 20 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 21 | checkpoint_wrapper, 22 | CheckpointImpl, 23 | apply_activation_checkpointing, 24 | ) 25 | 26 | from transformers import ( 27 | LlamaForCausalLM, 28 | LlamaTokenizer, 29 | LlamaConfig, 30 | default_data_collator, 31 | ) 32 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 33 | from optimum.bettertransformer import BetterTransformer 34 | 35 | from higgsfield.checkpoint.fsdp_checkpoint import ( 36 | save_distributed_model_rank0, 37 | fsdp_model_state_dict_rank0, 38 | ) 39 | 40 | from .llama_utils import ( 41 | load_llama_from_checkpoint, 42 | load_llama_from_config, 43 | ) 44 | 45 | class Llama(FSDP): 46 | def __init__( 47 | self, 48 | model_name, 49 | checkpoint_path=None, 50 | zero_stage=3, 51 | fast_attn=False, 52 | precision="bf16", 53 | cpu_init_rank0=False, 54 | cpu_offload=False, 55 | ): 56 | 57 | rank = dist.get_rank() 58 | 59 | if not checkpoint_path: 60 | if cpu_init_rank0: 61 | if rank == 0: 62 | model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False) 63 | else: 64 | llama_config = LlamaConfig.from_pretrained(model_name, use_cache=False) 65 | 66 | with torch.device('meta'): 67 | model = LlamaForCausalLM(llama_config) 68 | else: 69 | model = LlamaForCausalLM.from_pretrained(model_name, use_cache=False) 70 | else: 71 | if not cpu_init_rank0: 72 | print("Ignoring cpu_init_rank0=False while loading model from checkpoint path") 73 | cpu_init_rank0 = True 74 | 75 | if rank == 0: 76 | model = load_llama_from_checkpoint(model_name, checkpoint_path) 77 | print("LOADED FROM CHECKPOINT") 78 | 79 | else: 80 | llama_config = LlamaConfig.from_pretrained(model_name) 81 | 82 | with torch.device('meta'): 83 | model = LlamaForCausalLM(llama_config) 84 | 85 | if fast_attn: 86 | #raise NotImplementedError("Fast attention is not supported yet") 87 | model = BetterTransformer.transform(model) 88 | 89 | fpSixteen = MixedPrecision( 90 | param_dtype=torch.float16, 91 | reduce_dtype=torch.float16, 92 | buffer_dtype=torch.float16, 93 | ) 94 | 95 | bfSixteen_mixed = MixedPrecision( 96 | param_dtype=torch.float32, 97 | reduce_dtype=torch.bfloat16, 98 | buffer_dtype=torch.bfloat16, 99 | ) 100 | 101 | pure_bf16 = False 102 | if precision == "fp16": 103 | mixed_precision_policy = fpSixteen 104 | 105 | elif precision == "bf16": 106 | mixed_precision_policy = None 107 | pure_bf16 = True 108 | 109 | elif precision == "bf16_mixed": 110 | mixed_precision_policy = bfSixteen_mixed 111 | 112 | else: 113 | mixed_precision_policy = None 114 | 115 | if pure_bf16: 116 | model.to(torch.bfloat16) 117 | 118 | wrapping_policy = functools.partial( 119 | transformer_auto_wrap_policy, 120 | transformer_layer_cls={ 121 | LlamaDecoderLayer, 122 | } 123 | ) 124 | 125 | if zero_stage == 0: 126 | sharding_strategy = ShardingStrategy.NO_SHARD 127 | 128 | elif zero_stage == 1: 129 | raise NotImplementedError("stage 1 is not supported. Only 0 2 3") 130 | 131 | elif zero_stage == 2: 132 | sharding_strategy = ShardingStrategy.SHARD_GRAD_OP 133 | 134 | elif zero_stage == 3: 135 | sharding_strategy = ShardingStrategy.FULL_SHARD 136 | else: 137 | raise NotImplementedError("stage can be only 0 2 3") 138 | 139 | if cpu_init_rank0 and rank != 0: 140 | param_init_fn = lambda module: module.to_empty( 141 | device=torch.device('cuda'), 142 | recurse=False, 143 | ) 144 | else: 145 | param_init_fn = None 146 | 147 | if cpu_offload: 148 | cpu_offload = CPUOffload(offload_params=True) 149 | else: 150 | cpu_offload = None 151 | 152 | super().__init__( 153 | model, 154 | auto_wrap_policy=wrapping_policy, 155 | cpu_offload=cpu_offload, 156 | mixed_precision=mixed_precision_policy, 157 | sharding_strategy=sharding_strategy, 158 | device_id=torch.cuda.current_device(), 159 | limit_all_gathers=True, 160 | sync_module_states=cpu_init_rank0, 161 | param_init_fn=param_init_fn, 162 | ) 163 | 164 | non_reentrant_wrapper = functools.partial( 165 | checkpoint_wrapper, 166 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 167 | ) 168 | 169 | check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) 170 | 171 | apply_activation_checkpointing( 172 | self, 173 | checkpoint_wrapper_fn=non_reentrant_wrapper, 174 | check_fn=check_fn, 175 | ) 176 | 177 | fsdp = True 178 | self.precision = precision 179 | self.fsdp = fsdp 180 | self.model_name = model_name 181 | 182 | def __call__(self, batch): 183 | local_rank = int(os.environ["LOCAL_RANK"]) 184 | 185 | for key in batch.keys(): 186 | batch[key] = batch[key].to(local_rank) 187 | 188 | if self.precision == "fp16": 189 | with torch.cuda.amp.autocast(): 190 | loss = super().__call__(**batch).loss 191 | else: 192 | loss = super().__call__(**batch).loss 193 | 194 | return loss 195 | 196 | def save_model(self, save_path): 197 | ''' 198 | Save model's weight to master node 199 | ~/.cache/higgsfield/{save_path} 200 | ''' 201 | if "/" == save_path[0]: 202 | save_path = save_path[1:] 203 | 204 | head, tail = os.path.split(save_path) 205 | 206 | path = Path.home() / ".cache/higgsfield" / head 207 | path.mkdir(exist_ok=True, parents=True) 208 | 209 | save_distributed_model_rank0(path / tail, self) 210 | 211 | def save_huggingface_model(self, save_path): 212 | ''' 213 | Save model's weight in huggingface format to master node 214 | ~/.cache/higgsfield/{save_path} 215 | ''' 216 | if "/" == save_path[0]: 217 | save_path = save_path[1:] 218 | 219 | head, tail = os.path.split(save_path) 220 | 221 | path = Path.home() / ".cache/higgsfield" / head 222 | path.mkdir(exist_ok=True, parents=True) 223 | cpu_state = fsdp_model_state_dict_rank0(self) 224 | 225 | if dist.get_rank() == 0: 226 | model = load_llama_from_config(self.model_name) 227 | model.load_state_dict(cpu_state) 228 | model.save_pretrained(path / tail) 229 | 230 | def push_to_hub(self, repo_id): 231 | cpu_state = fsdp_model_state_dict_rank0(self) 232 | 233 | if dist.get_rank() == 0: 234 | model = load_llama_from_config(self.model_name) 235 | model.load_state_dict(cpu_state) 236 | model.push_to_hub(repo_id) 237 | 238 | 239 | 240 | class Llama7b(Llama): 241 | def __init__( 242 | self, 243 | checkpoint_path=None, 244 | zero_stage=3, 245 | fast_attn=False, 246 | precision="bf16", 247 | cpu_init_rank0=False, 248 | cpu_offload=False, 249 | ): 250 | model_name = "meta-llama/Llama-2-7b-hf" 251 | super(Llama7b, self).__init__( 252 | model_name, 253 | checkpoint_path, 254 | zero_stage, 255 | fast_attn, 256 | precision, 257 | cpu_init_rank0, 258 | cpu_offload, 259 | ) 260 | 261 | class Llama13b(Llama): 262 | def __init__( 263 | self, 264 | checkpoint_path=None, 265 | zero_stage=3, 266 | fast_attn=False, 267 | precision="bf16", 268 | cpu_init_rank0=False, 269 | cpu_offload=False, 270 | ): 271 | model_name = "meta-llama/Llama-2-13b-hf" 272 | super(Llama13b, self).__init__( 273 | model_name, 274 | checkpoint_path, 275 | zero_stage, 276 | fast_attn, 277 | precision, 278 | cpu_init_rank0, 279 | cpu_offload, 280 | ) 281 | 282 | class Llama70b(Llama): 283 | def __init__( 284 | self, 285 | checkpoint_path=None, 286 | zero_stage=3, 287 | fast_attn=False, 288 | precision="bf16", 289 | cpu_init_rank0=False, 290 | cpu_offload=False, 291 | ): 292 | model_name = "meta-llama/Llama-2-70b-hf" 293 | super(Llama70b, self).__init__( 294 | model_name, 295 | checkpoint_path, 296 | zero_stage, 297 | fast_attn, 298 | precision, 299 | cpu_init_rank0, 300 | cpu_offload, 301 | ) 302 | -------------------------------------------------------------------------------- /higgsfield/llama/llama_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ( 3 | LlamaConfig, 4 | LlamaForCausalLM, 5 | ) 6 | from higgsfield.checkpoint import fsdp_model_state_dict_rank0 7 | 8 | def load_llama_from_config(model_name): 9 | config = LlamaConfig.from_pretrained(model_name) 10 | model = LlamaForCausalLM(config) 11 | return model 12 | 13 | def load_llama_from_checkpoint(model_name, checkpoint_path): 14 | model = load_llama_from_config(model_name) 15 | state_dict = torch.load(checkpoint_path) 16 | model.load_state_dict(state_dict) 17 | return model -------------------------------------------------------------------------------- /higgsfield/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama_loader import LlamaLoader -------------------------------------------------------------------------------- /higgsfield/loaders/llama_loader.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | from torch.utils.data import ( 4 | DistributedSampler, 5 | DataLoader 6 | ) 7 | 8 | from transformers import ( 9 | LlamaTokenizer, 10 | default_data_collator 11 | ) 12 | 13 | from higgsfield.dataset import TorchCompletionDataset 14 | 15 | class HiggsfieldSampler(DistributedSampler): 16 | def __init__( 17 | self, 18 | dataset, 19 | shuffle=True, 20 | seed=0, 21 | drop_last=False 22 | ): 23 | rank=dist.get_rank() 24 | num_replicas=dist.get_world_size() 25 | 26 | super(HiggsfieldSampler, self).__init__( 27 | dataset=dataset, 28 | num_replicas=num_replicas, 29 | rank=rank, 30 | shuffle=shuffle, 31 | seed=seed, 32 | drop_last=drop_last, 33 | ) 34 | 35 | class LlamaLoader(DataLoader): 36 | def __init__( 37 | self, 38 | dataset, 39 | tokenizer=LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf"), 40 | max_sequence_length=2048, 41 | batch_size_per_gpu=1, 42 | shuffle=True, 43 | seed=0, 44 | num_workers=0, 45 | pin_memory=False, 46 | drop_last=False, 47 | timeout=0, 48 | worker_init_fn=None, 49 | multiprocessing_context=None, 50 | *, 51 | prefetch_factor=None, 52 | persistent_workers=False, 53 | pin_memory_device="" 54 | ): 55 | 56 | dataset = TorchCompletionDataset( 57 | dataset, 58 | tokenizer, 59 | max_sequence_length, 60 | ) 61 | 62 | sampler = HiggsfieldSampler(dataset, shuffle=shuffle, seed=seed,) 63 | 64 | super(LlamaLoader, self).__init__( 65 | dataset, 66 | batch_size=batch_size_per_gpu, 67 | sampler=sampler, 68 | num_workers=num_workers, 69 | pin_memory=pin_memory, 70 | drop_last=drop_last, 71 | timeout=timeout, 72 | worker_init_fn=worker_init_fn, 73 | multiprocessing_context=multiprocessing_context, 74 | prefetch_factor=prefetch_factor, 75 | persistent_workers=persistent_workers, 76 | pin_memory_device=pin_memory_device 77 | ) 78 | -------------------------------------------------------------------------------- /higgsfield/mistral/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/higgsfield-ai/higgsfield/d12a36e66024a93d33ec61826a77d5a346c16869/higgsfield/mistral/__init__.py -------------------------------------------------------------------------------- /higgsfield/mistral/mistral.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 9 | FullyShardedDataParallel as FSDP, 10 | CPUOffload, 11 | ) 12 | 13 | from torch.distributed.fsdp import ( 14 | MixedPrecision, 15 | ShardingStrategy, 16 | ) 17 | from torch.distributed.fsdp.wrap import ( 18 | transformer_auto_wrap_policy, 19 | ) 20 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 21 | checkpoint_wrapper, 22 | CheckpointImpl, 23 | apply_activation_checkpointing, 24 | ) 25 | 26 | from transformers import ( 27 | MistralForCausalLM, 28 | MistralConfig, 29 | default_data_collator, 30 | ) 31 | from transformers.models.mistral.modeling_mistral import MistralDecoderLayer 32 | from optimum.bettertransformer import BetterTransformer 33 | 34 | from higgsfield.checkpoint.fsdp_checkpoint import ( 35 | save_distributed_model_rank0, 36 | fsdp_model_state_dict_rank0, 37 | ) 38 | 39 | from higgsfield.mistral.mistral_utils import ( 40 | load_mistral_from_checkpoint, 41 | load_mistral_from_config, 42 | ) 43 | 44 | class Mistral(FSDP): 45 | def __init__( 46 | self, 47 | model_name, 48 | checkpoint_path=None, 49 | zero_stage=3, 50 | fast_attn=False, 51 | precision="bf16", 52 | cpu_init_rank0=False, 53 | cpu_offload=False, 54 | num_embeddings=None, 55 | cache_dir=None, 56 | ): 57 | 58 | rank = dist.get_rank() 59 | 60 | 61 | model = MistralForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) 62 | 63 | if num_embeddings: 64 | model.resize_token_embeddings(num_embeddings) 65 | 66 | 67 | if fast_attn: 68 | #raise NotImplementedError("Fast attention is not supported yet") 69 | model = BetterTransformer.transform(model) 70 | 71 | fpSixteen = MixedPrecision( 72 | param_dtype=torch.float16, 73 | reduce_dtype=torch.float16, 74 | buffer_dtype=torch.float16, 75 | ) 76 | 77 | bfSixteen_mixed = MixedPrecision( 78 | param_dtype=torch.float32, 79 | reduce_dtype=torch.bfloat16, 80 | buffer_dtype=torch.bfloat16, 81 | ) 82 | 83 | pure_bf16 = False 84 | if precision == "fp16": 85 | mixed_precision_policy = fpSixteen 86 | 87 | elif precision == "bf16": 88 | mixed_precision_policy = None 89 | pure_bf16 = True 90 | 91 | elif precision == "bf16_mixed": 92 | mixed_precision_policy = bfSixteen_mixed 93 | 94 | else: 95 | mixed_precision_policy = None 96 | 97 | if pure_bf16: 98 | model.to(torch.bfloat16) 99 | 100 | wrapping_policy = functools.partial( 101 | transformer_auto_wrap_policy, 102 | transformer_layer_cls={ 103 | MistralDecoderLayer, 104 | } 105 | ) 106 | 107 | if zero_stage == 0: 108 | sharding_strategy = ShardingStrategy.NO_SHARD 109 | 110 | elif zero_stage == 1: 111 | raise NotImplementedError("stage 1 is not supported. Only 0 2 3") 112 | 113 | elif zero_stage == 2: 114 | sharding_strategy = ShardingStrategy.SHARD_GRAD_OP 115 | 116 | elif zero_stage == 3: 117 | sharding_strategy = ShardingStrategy.FULL_SHARD 118 | else: 119 | raise NotImplementedError("stage can be only 0 2 3") 120 | 121 | if cpu_init_rank0 and rank != 0: 122 | param_init_fn = lambda module: module.to_empty( 123 | device=torch.device('cuda'), 124 | recurse=False, 125 | ) 126 | else: 127 | param_init_fn = None 128 | 129 | if cpu_offload: 130 | cpu_offload = CPUOffload(offload_params=True) 131 | else: 132 | cpu_offload = None 133 | 134 | super().__init__( 135 | model, 136 | auto_wrap_policy=wrapping_policy, 137 | cpu_offload=cpu_offload, 138 | mixed_precision=mixed_precision_policy, 139 | sharding_strategy=sharding_strategy, 140 | device_id=torch.cuda.current_device(), 141 | limit_all_gathers=True, 142 | sync_module_states=cpu_init_rank0, 143 | param_init_fn=param_init_fn, 144 | ) 145 | 146 | non_reentrant_wrapper = functools.partial( 147 | checkpoint_wrapper, 148 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 149 | ) 150 | 151 | check_fn = lambda submodule: isinstance(submodule, MistralDecoderLayer) 152 | 153 | apply_activation_checkpointing( 154 | self, 155 | checkpoint_wrapper_fn=non_reentrant_wrapper, 156 | check_fn=check_fn, 157 | ) 158 | 159 | fsdp = True 160 | self.precision = precision 161 | self.fsdp = fsdp 162 | self.model_name = model_name 163 | self.num_embeddings = num_embeddings 164 | 165 | def __call__(self, batch): 166 | local_rank = int(os.environ["LOCAL_RANK"]) 167 | 168 | for key in batch.keys(): 169 | batch[key] = batch[key].to(local_rank) 170 | 171 | if self.precision == "fp16": 172 | with torch.cuda.amp.autocast(): 173 | loss = super().__call__(**batch).loss 174 | else: 175 | loss = super().__call__(**batch).loss 176 | 177 | return loss 178 | 179 | def save_model(self, save_path): 180 | ''' 181 | Save model's weight to master node 182 | ~/.cache/higgsfield/{save_path} 183 | ''' 184 | if "/" == save_path[0]: 185 | save_path = save_path[1:] 186 | 187 | head, tail = os.path.split(save_path) 188 | 189 | path = Path.home() / ".cache/higgsfield" / head 190 | path.mkdir(exist_ok=True, parents=True) 191 | 192 | save_distributed_model_rank0(path / tail, self) 193 | 194 | def save_huggingface_model(self, save_path): 195 | ''' 196 | Save model's weight in huggingface format to master node 197 | ~/.cache/higgsfield/{save_path} 198 | ''' 199 | if "/" == save_path[0]: 200 | save_path = save_path[1:] 201 | 202 | head, tail = os.path.split(save_path) 203 | 204 | path = Path.home() / ".cache/higgsfield" / head 205 | path.mkdir(exist_ok=True, parents=True) 206 | cpu_state = fsdp_model_state_dict_rank0(self) 207 | 208 | if dist.get_rank() == 0: 209 | model = load_mistral_from_config(self.model_name, num_embeddings=self.num_embeddings) 210 | model.load_state_dict(cpu_state) 211 | model.save_pretrained(path / tail) 212 | 213 | def push_to_hub(self, repo_id, token): 214 | cpu_state = fsdp_model_state_dict_rank0(self) 215 | 216 | if dist.get_rank() == 0: 217 | model = load_mistral_from_config(self.model_name, num_embeddings=self.num_embeddings) 218 | model.load_state_dict(cpu_state) 219 | model.push_to_hub(repo_id, token=token) -------------------------------------------------------------------------------- /higgsfield/mistral/mistral_loader.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | from torch.utils.data import ( 4 | DistributedSampler, 5 | DataLoader 6 | ) 7 | 8 | from transformers import ( 9 | AutoTokenizer, 10 | default_data_collator 11 | ) 12 | 13 | from higgsfield.dataset import TorchCompletionDataset 14 | 15 | IGNORE_INDEX = -100 16 | DEFAULT_PAD_TOKEN = "<|pad|>" 17 | DEFAULT_EOS_TOKEN = "<|endoftext|>" 18 | DEFAULT_UNK_TOKEN = "<|unk|>" 19 | 20 | def get_tokenizer(model_name, max_length, cache_dir=None): 21 | 22 | tokenizer = AutoTokenizer.from_pretrained( 23 | model_name, 24 | model_max_length=max_length, 25 | padding_side="right", 26 | use_fast=False, 27 | pad_token=DEFAULT_PAD_TOKEN, 28 | trust_remote_code=True, 29 | cache_dir=cache_dir, 30 | ) 31 | 32 | special_tokens_dict = dict() 33 | if tokenizer.pad_token is None: 34 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 35 | if tokenizer.eos_token is None: 36 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 37 | if tokenizer.unk_token is None: 38 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 39 | 40 | tokenizer.add_special_tokens(special_tokens_dict) 41 | 42 | return tokenizer 43 | 44 | class HiggsfieldSampler(DistributedSampler): 45 | def __init__( 46 | self, 47 | dataset, 48 | shuffle=True, 49 | seed=0, 50 | drop_last=False 51 | ): 52 | rank=dist.get_rank() 53 | num_replicas=dist.get_world_size() 54 | 55 | super(HiggsfieldSampler, self).__init__( 56 | dataset=dataset, 57 | num_replicas=num_replicas, 58 | rank=rank, 59 | shuffle=shuffle, 60 | seed=seed, 61 | drop_last=drop_last, 62 | ) 63 | 64 | class MistralLoader(DataLoader): 65 | def __init__( 66 | self, 67 | dataset, 68 | tokenizer=None, 69 | max_sequence_length=2048, 70 | batch_size_per_gpu=1, 71 | shuffle=True, 72 | seed=0, 73 | num_workers=0, 74 | pin_memory=False, 75 | drop_last=False, 76 | timeout=0, 77 | worker_init_fn=None, 78 | multiprocessing_context=None, 79 | *, 80 | prefetch_factor=None, 81 | persistent_workers=False, 82 | pin_memory_device="" 83 | ): 84 | 85 | if not tokenizer: 86 | tokenizer = get_tokenizer("mistralai/Mistral-7B-v0.1", max_sequence_length) 87 | 88 | dataset = TorchCompletionDataset( 89 | dataset, 90 | tokenizer, 91 | max_sequence_length, 92 | ) 93 | 94 | sampler = HiggsfieldSampler(dataset, shuffle=shuffle, seed=seed,) 95 | 96 | super(MistralLoader, self).__init__( 97 | dataset, 98 | batch_size=batch_size_per_gpu, 99 | sampler=sampler, 100 | num_workers=num_workers, 101 | pin_memory=pin_memory, 102 | drop_last=drop_last, 103 | timeout=timeout, 104 | worker_init_fn=worker_init_fn, 105 | multiprocessing_context=multiprocessing_context, 106 | prefetch_factor=prefetch_factor, 107 | persistent_workers=persistent_workers, 108 | pin_memory_device=pin_memory_device 109 | ) 110 | -------------------------------------------------------------------------------- /higgsfield/mistral/mistral_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import ( 3 | MistralConfig, 4 | MistralForCausalLM, 5 | ) 6 | from higgsfield.checkpoint import fsdp_model_state_dict_rank0 7 | 8 | def load_mistral_from_config(model_name, num_embeddings=None): 9 | config = MistralConfig.from_pretrained(model_name) 10 | model = MistralForCausalLM(config) 11 | 12 | if num_embeddings: 13 | model.resize_token_embeddings(num_embeddings) 14 | 15 | return model 16 | 17 | def load_mistral_from_checkpoint(model_name, checkpoint_path, num_embeddings=None): 18 | model = load_mistral_from_config(model_name, num_embeddings=num_embeddings) 19 | state_dict = torch.load(checkpoint_path) 20 | model.load_state_dict(state_dict) 21 | return model -------------------------------------------------------------------------------- /higgsfield/path.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | 4 | class _ProjectCachePath: 5 | project_name: str 6 | path: pathlib.Path 7 | metadata_file: str 8 | _init_path: pathlib.Path 9 | 10 | def _init( 11 | self, 12 | project_name: str, 13 | metadata_file: str, 14 | init_path: pathlib.Path, 15 | verbose: bool = True, 16 | ): 17 | self.project_name = project_name 18 | self.metadata_file = metadata_file 19 | self._init_path = init_path 20 | self.verbose = verbose 21 | self.path = self._mkdir() 22 | 23 | def experiment_path(self, experiment_name: str) -> "ExperimentPath": 24 | return ExperimentPath(self, experiment_name) 25 | 26 | def _mkdir(self) -> pathlib.Path: 27 | home = self._init_path 28 | 29 | path = home / f".cache/{self.project_name}" 30 | 31 | try: 32 | (path / "experiments").mkdir(exist_ok=True, parents=True) 33 | except Exception as e: 34 | if self.verbose: 35 | print("this error shouldn't have been thrown") 36 | print(f"error creating path {path}") 37 | print(e) 38 | 39 | return path 40 | 41 | def metadata_path(self) -> pathlib.Path: 42 | return (self.path / "experiments") / self.metadata_file 43 | 44 | 45 | class ProjectCachePath(_ProjectCachePath): 46 | def __init__( 47 | self, 48 | project_name: str, 49 | metadata_file: str = "metadata.json", 50 | ): 51 | home = pathlib.Path.home() 52 | self._init(project_name, metadata_file, home) 53 | 54 | 55 | class ExperimentPath: 56 | project_path: _ProjectCachePath 57 | experiment_name: str 58 | path: pathlib.Path 59 | metadata_file: str 60 | 61 | def __init__(self, project_path: _ProjectCachePath, experiment_name: str): 62 | self.project_path = project_path 63 | self.experiment_name = experiment_name 64 | 65 | self.path = self._mkdir() 66 | 67 | def run_path(self, run_name: str) -> "RunPath": 68 | return RunPath(self, run_name) 69 | 70 | def _mkdir(self) -> pathlib.Path: 71 | path = self.project_path.path / f"experiments/{self.experiment_name}" 72 | 73 | try: 74 | path.mkdir(exist_ok=True, parents=True) 75 | except Exception as e: 76 | if self.project_path.verbose: 77 | print("this error shouldn't have been thrown") 78 | print(f"error creating path {path}") 79 | print(e) 80 | return path 81 | 82 | 83 | class RunPath: 84 | experiment_path: ExperimentPath 85 | run_name: str 86 | 87 | def __init__(self, experiment_path: ExperimentPath, run_name: str): 88 | self.experiment_path = experiment_path 89 | self.run_name = run_name 90 | 91 | def checkpoint_path(self) -> pathlib.Path: 92 | path = self.experiment_path.path / f"checkpoints/{self.run_name}" 93 | 94 | try: 95 | path.mkdir(exist_ok=True, parents=True) 96 | except Exception as e: 97 | if self.experiment_path.project_path.verbose: 98 | print("this error shouldn't have been thrown") 99 | print(f"error creating path {path}") 100 | print(e) 101 | return path 102 | 103 | def sharded_checkpoint_path(self) -> pathlib.Path: 104 | path = self.experiment_path.path / f"sharded-checkpoints/{self.run_name}" 105 | 106 | try: 107 | path.mkdir(exist_ok=True, parents=True) 108 | except Exception as e: 109 | if self.experiment_path.project_path.verbose: 110 | print("this error shouldn't have been thrown") 111 | print(f"error creating path {path}") 112 | print(e) 113 | return path 114 | 115 | def lr_scheduler_path(self) -> pathlib.Path: 116 | path = self.experiment_path.path / f"lr-schedulers/{self.run_name}" 117 | 118 | try: 119 | path.mkdir(exist_ok=True, parents=True) 120 | except Exception as e: 121 | if self.experiment_path.project_path.verbose: 122 | print("this error shouldn't have been thrown") 123 | print(f"error creating path {path}") 124 | print(e) 125 | 126 | return path 127 | 128 | 129 | def working_directory() -> pathlib.Path: 130 | return pathlib.Path.cwd() 131 | -------------------------------------------------------------------------------- /higgsfield/rl/README.md: -------------------------------------------------------------------------------- 1 | # Congratulations! 2 | 3 | Soon you'll see the most advanced Reinforcement Learning library for Large Language Models! 4 | 5 | Stay tuned! 6 | -------------------------------------------------------------------------------- /higgsfield/rl/rl_adventure_2/1.actor-critic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math\n", 10 | "import random\n", 11 | "\n", 12 | "import gym\n", 13 | "import numpy as np\n", 14 | "\n", 15 | "import torch\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "import torch.nn.functional as F\n", 19 | "from torch.distributions import Categorical" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from IPython.display import clear_output\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "%matplotlib inline" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "

Use CUDA

" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "use_cuda = torch.cuda.is_available()\n", 47 | "device = torch.device(\"cuda\" if use_cuda else \"cpu\")" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "

Create Environments

" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "from common.multiprocessing_env import SubprocVecEnv\n", 64 | "\n", 65 | "num_envs = 16\n", 66 | "env_name = \"CartPole-v0\"\n", 67 | "\n", 68 | "def make_env():\n", 69 | " def _thunk():\n", 70 | " env = gym.make(env_name)\n", 71 | " return env\n", 72 | "\n", 73 | " return _thunk\n", 74 | "\n", 75 | "envs = [make_env() for i in range(num_envs)]\n", 76 | "envs = SubprocVecEnv(envs)\n", 77 | "\n", 78 | "env = gym.make(env_name)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "

Neural Network

" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 19, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "class ActorCritic(nn.Module):\n", 95 | " def __init__(self, num_inputs, num_outputs, hidden_size, std=0.0):\n", 96 | " super(ActorCritic, self).__init__()\n", 97 | " \n", 98 | " self.critic = nn.Sequential(\n", 99 | " nn.Linear(num_inputs, hidden_size),\n", 100 | " nn.ReLU(),\n", 101 | " nn.Linear(hidden_size, 1)\n", 102 | " )\n", 103 | " \n", 104 | " self.actor = nn.Sequential(\n", 105 | " nn.Linear(num_inputs, hidden_size),\n", 106 | " nn.ReLU(),\n", 107 | " nn.Linear(hidden_size, num_outputs),\n", 108 | " nn.Softmax(dim=1),\n", 109 | " )\n", 110 | " \n", 111 | " def forward(self, x):\n", 112 | " value = self.critic(x)\n", 113 | " probs = self.actor(x)\n", 114 | " dist = Categorical(probs)\n", 115 | " return dist, value" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 20, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "def plot(frame_idx, rewards):\n", 125 | " clear_output(True)\n", 126 | " plt.figure(figsize=(20,5))\n", 127 | " plt.subplot(131)\n", 128 | " plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))\n", 129 | " plt.plot(rewards)\n", 130 | " plt.show()\n", 131 | " \n", 132 | "def test_env(vis=False):\n", 133 | " state = env.reset()\n", 134 | " if vis: env.render()\n", 135 | " done = False\n", 136 | " total_reward = 0\n", 137 | " while not done:\n", 138 | " state = torch.FloatTensor(state).unsqueeze(0).to(device)\n", 139 | " dist, _ = model(state)\n", 140 | " next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])\n", 141 | " state = next_state\n", 142 | " if vis: env.render()\n", 143 | " total_reward += reward\n", 144 | " return total_reward" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "

A2C: Synchronous Advantage Actor Critic

\n", 152 | "

OpenAI Blog:

\n", 153 | "

The Asynchronous Advantage Actor Critic method (A3C) has been very influential since the paper was published. The algorithm combines a few key ideas:

\n", 154 | "\n", 155 | "\n", 160 | "\n", 161 | "

After reading the paper, AI researchers wondered whether the asynchrony led to improved performance (e.g. “perhaps the added noise would provide some regularization or exploration?“), or if it was just an implementation detail that allowed for faster training with a CPU-based implementation.

\n", 162 | "\n", 163 | "

As an alternative to the asynchronous implementation, researchers found you can write a synchronous, deterministic implementation that waits for each actor to finish its segment of experience before performing an update, averaging over all of the actors. One advantage of this method is that it can more effectively use of GPUs, which perform best with large batch sizes. This algorithm is naturally called A2C, short for advantage actor critic. (This term has been used in several papers.)

" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 21, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "def compute_returns(next_value, rewards, masks, gamma=0.99):\n", 173 | " R = next_value\n", 174 | " returns = []\n", 175 | " for step in reversed(range(len(rewards))):\n", 176 | " R = rewards[step] + gamma * R * masks[step]\n", 177 | " returns.insert(0, R)\n", 178 | " return returns" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 22, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "num_inputs = envs.observation_space.shape[0]\n", 188 | "num_outputs = envs.action_space.n\n", 189 | "\n", 190 | "#Hyper params:\n", 191 | "hidden_size = 256\n", 192 | "lr = 3e-4\n", 193 | "num_steps = 5\n", 194 | "\n", 195 | "model = ActorCritic(num_inputs, num_outputs, hidden_size).to(device)\n", 196 | "optimizer = optim.Adam(model.parameters())" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": 23, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "max_frames = 20000\n", 206 | "frame_idx = 0\n", 207 | "test_rewards = []" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 17, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "data": { 217 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAE/CAYAAABW/Dj8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8XHW5+PHPk71Zm3TSLUmbpE0L3WnTsrXsAgIKaEVQAQFFLsK9iFcvV/yp93LdrnJV1CuCoKBS2QW9rIJCy9ampU3TPWnTNkubfd8z398fc6ZMQ9JMMss5M3ner9e8MnOWOU/OJE9Ovt/nfL9ijEEppVTki7E7AKWUUsGhCV0ppaKEJnSllIoSmtCVUipKaEJXSqkooQldKaWihCb0KCEi80Vkq4i0i8g/2x2PCi0RqRSRC+yOQzmLJvTo8XXg78aYNGPMfXYH40tE5onIcyJSLyJNIvKyiMwfss1XROSIiLSJyMMikuizLl9E/i4iXSKye2giC2TfiUBEviYiZdYf+wMi8rUh68d9foc51vnWe3RZ7zk7VN+X+jBN6NFjNrBjpJUiEhvGWIaaDDwPzAemARuB57wrReQi4C7gfDzfRyHwHz77rwPeB6YAdwNPiUh2oPuOhYjEjXWfYAjScQW4DsgELgZuE5GrfdYHcn59Y3UBzwD/D8gCSoDHgxC/8pcxRh8R/gBeBwaBHqADmAf8DvgV8ALQCVwAXIrnF7cNOAx8x+c98gED3GCtawZuAVYCpUAL8Ishx70R2GVt+zIw2894s6xjTbFePwZ8z2f9+cAR6/k8oBdI81m/Hrgl0H39iLMS+Dfr++8F4oCZwNNAPXAA+Gdr2ySgG3BZr+8GBoB06/U9wE+t5/58DjcBh4A3reXXAgeBRuu9K4ELxvnzch/w80DP7zDvezPwts/rFOucnGT378hEeegVehQwxpyH55fwNmNMqjFmr7XqM8B3gTRgA57Efh2eK+ZLgX8SkSuGvN2pQBHwaeCneJLHBcBC4CoRORtARC4HvgF8Asi2jr/Oz5DPwpMUGq3XC4FtPuu3AdNEZIq1br8xpn3I+oVB2Ncf1+A5V5MBN/AX6z1y8CS3O0TkImNMD7AJONva72w8CfhMn9dvWM/9+RzOBk4GLhKRBXj+OF+L5w/KFCDXu6GIrBaRFn++GRERYA0f/DcXyPkd6rhtjTGdQAVjO98qAJrQo9tzxpi3jDFuY0yPMeYfxpjt1utSPAn47CH73GNt+wqexLPOGFNnjKnGk7RPsba7Bfi+MWaXMWYA+B6wbLQ2UxHJBX4J3OmzOBVo9XntfZ42zDrv+rQg7OuP+4wxh40x3Xj+W8k2xvynMabPGLMfeBDwNl+8AZxtNZMswXMlfLaIJFn7vgng5+fwHWNMp3XctcBfjTFvGmN68TRpuL0bGmM2GGMm+/n9fAfP7/1vrdeBnN+hgnG+VQA0oUe3w74vRORUq6OqXkRa8SRl15B9jvo87x7mdar1fDbwMxFpsa4Om/C01eaMFIzVLvsK8L/GGN+r+Q4g3ee193n7MOu8671XlIHs6w/fczgbmOn9nq3v+xt4+gXAk9DPAZYD24FX8STq04By738kfn4Ovsed6fvauvJtZIxE5DY8/xlcav1hgMDO71DBON8qAJrQo9vQoTQfw9M5mWeMyQDux5OEx+Mw8CVjzGSfxyRjzNvDbSwimXiS+fPGmO8OWb0DWOrzeilw1EqAO4BCEUkbsn5HEPb1h+85PAwcGPI9pxljLrHWv42n4/dK4A1jzE5gFnAJHzS3gH+fg+9xa4E87wsRScbT7OI3EbkRq3PTGFPlsyqQ8zvUcduKSAowh7GdbxUATegTSxrQZIzpEZFVeNrYx+t+4N9FZCGAiGSIyKeG21BE0vF0mr5ljLlrmE0eBW4SkQUiMhn4Jp5OXaz+gK3At0UkSUSuxNOc8XQQ9h2rjUC7iPybiEwSkVgRWSQiK63jdQGbgS/zQQJ/G88VuG9CH+vn8BRwmdVWngD8J2P43RWRz+JpEvuI1Ux0TCDndxjPAotE5JNWM9O3gFJjzG5/Y1UBsrtXVh/BeQD/AL7g8/p3wH8N2WYtno66duCvwC+AP1jr8vFcFcb5bF8FnOPz+g/AN31eX4unacFbrfHwCLFdb713J55/y72PWT7b3ImneacNT/tuos+6fOv76wb2MKS6Y7z7Ap8FdpzgnFYOc6yZeNq8j+Cp7nl3yHt+3zpWovX6Nut7nzbez8HnHB5imCoXPJ2cHSf4Pg4A/UPO/f1BOr87gM/6vL4A2G291z+AfLt/NybSQ6wPQSmlVITTJhellIoSmtCVUipKaEJXSqkooQldKaWihCZ0pZSKEraMIDeUy+Uy+fn5doehlFKOtHnz5gZjzKijhDoioefn51NSUmJ3GEop5UgictCf7bTJRSmlooQmdKWUihKa0JVSKkpoQldKqSihCV0ppaKEJnSllIoSmtCVUipKjJrQRSTPmi5rp4jsEJF/sZZnicirIrLP+pppLRcRuU9EykWkVESWh/qbUEop5d8V+gDwVWPMAjxzI37ZmoX8LuA1Y0wR8Jr1GuCjeGaNLwJuxjNbuVJKqRAbNaEbY2qNMVus5+3ALjwTAV8OPGJt9ghwhfX8cuBR4/EuMFlEZgQ9cqVUWFW3dFNW3Wp3GBFnYNDNb986wP76jpAfa0xt6CKSD5wCvIdnSq1aa9URPpj5PIfjZyyvYpiZ4EXkZhEpEZGS+vr6MYatlAont9tw86MlXPfwRgbdOsvZWBxu7uY//rKTzQebQ36ssUw0m4pn4tg7jDFtvuuMZx67MX3KxpgHjDHFxpji7OxRx5xRStnoxbIj7Khpo6mzjx01epU+FhV1nivzuVNTQ34svxK6iMTjSeZ/NMY8Yy0+6m1Ksb7WWcurgTyf3XOtZUqpCDQw6ObeV/eQlzUJgPX7GmyOKLJUWE0thdkOSOgiIsBDwC5jzP/4rHoez0zkWF+f81l+nVXtchrQ6tM0o5SKMM++X83++k7uvuRkTpqexgZN6GNSXtdBdloiGZPiQ34sf67QzwSuBc4Tka3W4xLgB8BHRGQfcIH1GuAFYD9QDjwI3Br8sJVS4dA7MMhP/7aPxTkZXLRwOmuKXGw+2Ex336DdoUWMivoO5mSnhOVYo46HbozZAMgIq88fZnsDfDnAuJRSDvD4psNUt3TzvU8sRkRYXZTNg+sP8N6BRs6ZP9Xu8BzPGENFfSeXLQlPoZ/eKaqUGlZ33yA/f72cVQVZnFXkAmBVfhYJsTHa7OKnxs4+Wrv7w9IhCprQlVIjeOSdSurbe/naRfPxdKXBpIRYivMz2VCuCd0f5VaFy5wwdIiCJnSl1DDaevr51T8qOGd+Nivzs45bt7rIxe4j7dS199gUXeTwVrjM0St0pZRdfrP+AK3d/fzrhfM/tG7NXM99I2/pVfqoKuo6mRQfy4z0pLAcTxO6UhGio3eAe/66k22HW0J6nMaOXh5av59LFk9nUU7Gh9YvnJlOZnK81qP7oaK+gzlTU4iJGamuJLg0oSsVAZo7+/jsg+/y0IYD3PTIJmpaukN2rPvfqKC7f5A7PzJv2PUxMcKZc11s2NeAp6hNjaS8riNs7eegCV0pxzvS2sNVv36HXUfa+dZlC+jpd/Ol32+mpz/4teBHWnt45J2DXHlKLnOnpo243ZoiF3Xtvew9GvoBpyJVd98g1S3dmtCVUh4HGztZe//b1LR088gNq7hxdQE//fQyympauevp0qBfIf/89X0YY7jjgqITbre6yNOOvn6fDqw3kv0N4a1wAU3oSjnWrto21t7/Dp29A6y7+TROnzMFgAsWTOPOC+bx5601/Gb9gaAd71BjF49vOszVK2eRl5V8wm1zJk+i0JWi5YsnUFHfCcCcqeG5SxQ0oSvlSJsPNvPpX79DrAhP3nI6S3InH7f+tvPmcsni6Xz/xV28uTc4V8k//dte4mKF28+b69f2q4tcvLe/id4BHQZgOBV1HcQI5E/RhK7UhLV+Xz2f+817ZKUk8OQtpw/bli0i/GjtUuZNS+P2de9T2dAZ0DH3Hm3n2a3VXH96PlP9LLFbPddFd/8gWw6GtuomUpXXd5CXlUxSfGzYjqkJXSkHeXF7LTf+bhOzpyTzxC2nn7DpIyUxjgevK0YEvvhoCR29A+M+7r2v7CElIY5bzp7j9z6nzZlCbIywoVzb0YdTEeYKF9CErpRjPLHpMF9+bAtLcifz+JdOZ2ra6FfKeVnJ/PIzy9nf0Mmdj2/FPY7ZhLYdbuHlHUf54ppCMlMS/N4vPSmeZXmTdVyXYQy6DQcaOsM2yqKXJnSlHOA36/fz9adLWV2Uze9vWjWmsbPPnOvi7ktO5pWdR7nv9X1jPvaPX9lDZnI8N67OH/O+a4pclFa30tLVN+Z9o1lNSze9A269QldqIjHGcO8re/iv/9vFpYtn8JvriklOGHVU6w+54cx8Prk8l5/+bR8vlR3xe7939zeyfl8Dt54zl7SksU/AsKbIhTHwdkXjmPeNZuVhnHbOlyZ0pWzidhu+/fwOfv56OVevzOO+a04hIW58v5IiwnevXMTSvMl89Ymt7D3aPuo+xhh+/PIepqUncu3ps8d13KW5k0lLjNNhAIY4NiiXXqErFf36B93c+cRWHn3nIF86q5Dvf2IxsQGO95EUH8sD164gOTGOLz5aMmozyD/21lNysJnbzysadyVGXGwMp82Zwvp99ToMgI+K+g6yUhLG1CcRDJrQlQqznv5B/ukPm/nz1hq+dtF87vroScfGGw/UtPQk7v/cCmpberh93fsMDLqH3c7t9lyd52VN4qrivGG38deaIhdVzd0cbOwK6H2iSUVd+DtEQRO6UmHV3tPP9Q9v5LXdddxzxSK+fO7coCVzrxWzM7nnioWs39fAD1/aPew2L+04wo6aNr5ywbxxN/N4rZ7rmc1ovd41ekxFfUfY289BE7pSYbPnSDtrf/UOmw8289NPL+Pa08bXbu2PT6+cxfWnz+bB9Qd49v2q49YNuj0dsUVTU7l8WU7AxypwpZAzeRIbdFwXwDMyZmNnX9jbz0ETulIhZ4zhkbcr+dgvNtDY2ctvb1gZlEQ6mm9etoDTCrP4t6e3U1r1wd2cz75fTUV9J1+9cF7A7fbg6ZBdPdfF2xWNIzbxTCR2dYiCJnSlQqqho5cvPFLCt5/fwRlzpvDiv5zFGmukwlCLj43hl59ZTnZqIl/6/Wbq23vpG3Dz07/tZXFOBhctnB60Y60uctHeM0BpdWvQ3jNSaUJXKgq9sbeei3+6nvXlDXznYwv47edXkp2WGNYYpqQm8sB1K2ju6uOf/rCZ3797kKrmbv7VZ+LnYDhzrgsR9K5RPKMsJsTFkJM5KezH1oSuVJD1Dgxyz193cv3DG8lKief5287k82cWBL3z018LZ2bwo7VLKTnYzD1/3cmqgizOKnIF9RhZKQksmpmhCR3PTUWFrpSgNGeN1agJXUQeFpE6ESnzWfa4iGy1HpUistVani8i3T7r7g9l8Eo5TXldO1f88m0e2nCA60+fzfO3reak6el2h8XHls7k1nPmECPwtSBfnXutLnKx5VBzQIOERQPPPKLhb24B8Oce498BvwAe9S4wxnza+1xE7gV8G84qjDHLghWgUpHAGMNjGw9xz193kpwQx0PXF3P+ydPsDus4X7toPjeuLsCVGppmnzVzXfzqHxW8t7/Rcd97uPT0D3K4qSssnd7DGTWhG2PeFJH84daJ58/8VcB5wQ1LqcjR1NnHvz1dyqs7j7KmyMW9n1rq95ji4SQiIUvmACvyM0mKj2H9voYJm9APNnbhNthyUxH4d4V+ImuAo8YY3yHeCkTkfaAN+KYxZn2Ax1DKsd4qb+Arj2+luauPb156MjeeWUCMDW2nTpAYF8uqgikTep5Ruwbl8gq0U/QaYJ3P61pgljHmFOBO4DERGbYBUURuFpESESmpr5+4PwAqMvUNuPn+C7v43EPvkZYUx7O3nskX1hRO2GTutWaui4r6Tmpbu+0OxRbeksVCV4QldBGJAz4BPO5dZozpNcY0Ws83AxXAvOH2N8Y8YIwpNsYUZ2eHpy5XqWDYX9/BJ3/1Nr9+cz/XrJrFX29fw6KcDLvDcoTVVvXMRB19saK+g5zJk5iUEL5p53wFcoV+AbDbGHPsvmIRyRaRWOt5IVAE7A8sRKWcY/PBZi69bwOHm7v49bUr+N6Vi2375XWik6an4UpNnLDli3ZWuIB/ZYvrgHeA+SJSJSI3Wauu5vjmFoCzgFKrjPEp4BZjTFMwA1bKTn9+v5rYGOGlfzkrqHdaRgsRYU2Ri7fKG8Y1HV4kc7uNbaMsevlT5XLNCMs/P8yyp4GnAw9LKWcqrWphcU4G0zOcV8XiFKvnunj2/Wp2HWlj4cyJ0xRV29ZDd/+gbR2ioHeKKuW3vgE3u2rbWZI3cZLUeHjb0Sdas0tFnX1juHhpQldht2FfA/e+ssfuMMZs95E2+gbdLM2dbHcojjYtPYl501LZMMHGR7dzUC4vTegqrIwxfPeFXfz89XLeirBf+G1Vnhuil+TqFfpoVs/NZuOBJnr6B+0OJWwq6jtIT4rDlRreaed8aUJXYVVa1cqu2jZE4Ecv74moeShLD7eQlZJAzuTwj6IXadYUuegdcFNS2Wx3KGFTXuepcLFrEDbQhK7CbN3GQ0yKj+XuS05m6+EWXttVZ3dIfiutamVJboatv7CR4tTCLOJjZULdNVpR38lcG5tbQBO6CqOO3gGe31bDZUtmcP0Z+eRPSebHr+yJiPK2rr4B9tW1s0Tbz/2SnBDH8lmZE+YGo9bufurbe22tQQdN6CqMnt9aQ1ffINecOov42Bi+8pF57D7Szv9tr7U7tFGVVbfhNrBU28/9dta8bHbWttHQ0Wt3KCG33wEdoqAJXYXRnzYd4qTpaZyS57nK/diSmcyflsZPXt3r+LkovXNy6hW6/1bP9ZQvRlrn93hU1HcC9o2y6KUJXYVFWXUrpVWtXL0y71gbdEyMcOeF89jf0Mkz71fbHOGJlVa1MjMjKexTyEWyRTkZZEyKnxD16OV1HcTHCrOykm2NQxO6Cos/bTpEYlwMV56Se9zyCxdMY2luBj/72z56B5xb4lZa1aJX52MUGyOcOXcKG8obIqqaaTwq6jvIn5JCXKy9KVUTugq5rr4B/vx+DZcunkFGcvxx60SEr144n+qWbh7fdNimCE+staufysYuvUN0HFbPzaa2tedYk0S0qqjvsL39HDShqzD467ZaOnoHuObUWcOuX1PkYlVBFj9/vZzuPuddpZdWe9rP9Q7RsVtzbBiA6C1f7B90c6ixizlT7W0/B03oKgzWbTrE3KmpFM/OHHa9iPC1i+ZT397Lo+9UhjU2f5Rad4jqmOdjl5eVzOwpyVE9DMDBxk4G3Eav0FX0232kjfcPtRzXGTqclflZnD0vm1+9UUF7T38YIxzdtsMtFLpSyJgUP/rG6kNWz3XxTkUj/Q6vZBqv8jpPc5Kdoyx6aUJXIfWnjYdJiI3hE8tzR932Xy+cT0tXPw9vqAx9YGPgvUNUjc+aomw6+wZ5/1CL3aGExLFp5/QKXUWznv5BntlSxcWLppOVMvqARYtzM7h44XR+s34/LV19YYhwdHVtPRxp69EKlwCcPmcKMRK97egV9R1MT08iNXHU6SVCThO6CpkXttfS1jPA1avy/N7nzgvn0dE3wP1vOGPmQu8Ii0u1wmXcMibFszRvMuujtB29oq7DER2ioAldhdC6jYfIn5LM6YVT/N5n3rQ0Ll86k9+9fYC69p4QRuef0qoWYmOEBTM0oQdizVwX2w630NrtrP6RQBljHDEol5cmdBUS5XXtbKps5ppVs8Y8OuEdF8yjf9Dwv3+vCFF0/ttW1cq8aWk6EXSAVhdl4zbwTkWj3aEEVV17Lx29A7YPyuWlCV2FxLqNh4mPFT65YvTO0KHyXSlcVZzLY+8dorqlOwTR+ccY47lDVMsVA3bKrMkkJ8Ty7v7oSuhOmHbOlyZ0FXTeztALF0zHlTq+sU9uP68IgPv+ti+YoY3J4aZuWrr69Q7RIIiPjWFOduqxipBo4YRp53xpQldB9/KOIzR39Y+pM3SomZMn8dnTZvHUlioONNhz2/i2Kr1DNJgKs1Ns+yxDpbyug9TEOKalO2PQNk3oKuj+tPEweVmTOHOOK6D3ufWcuSTExvCTV/cGKbKxKa1qISEuhvnT02w5frQpcKVQ3dIdVfOMVtR3Mic7xTGzWGlCV0F1oKGTd/Y3cvXKWcTEBPZDnp2WyA1n5vOX0hp2H2kLUoT+21bVyoIZ6cTbPIJetChwpWAMHGrqsjuUoHHKoFxe+pOqgupPmw4RGyN8ahydocP50llzSE2M495XwnuVPug2lFW36gxFQVTo8iS+/VEy8mJH7wC1rT2OqXABPxK6iDwsInUiUuaz7DsiUi0iW63HJT7r/l1EykVkj4hcFKrAlfP0Dbh5qqSK80+aytT0pKC8Z0ZyPDevKeTVnUfZejh8t45X1HfQ1Teod4gGUb7LM/lDtLSjfzDtnDNuKgL/rtB/B1w8zPKfGGOWWY8XAERkAXA1sNDa539FRAt4J4hXdx6lsbNvxGFyx+uG1QVkpSRw7yt7gvq+J7LN+uOhd4gGT1pSPNlpiRxoiI5KF6dVuIAfCd0Y8ybQ5Of7XQ78yRjTa4w5AJQDqwKIT0WQP206RM7kSZxVlB3U901NjOPWc+awfl9D2OqYt1e3kpoYd6yZQAVHgSt6Kl0q6jqJjRFmT4msK/SR3CYipVaTjHeg6xzAd9qZKmuZinKHGrtYv6+Bq4rziA2wM3Q4nzttNtPSE/nxy3vCMp3ZtqpWFuWkB9yxq45XGE0Jvb6D2VnJJMQ5pytyvJH8CpgDLANqgXvH+gYicrOIlIhISX19dI7CNpE8XnKIGIGrVganM3SopPhYbj+viJKDzfxjb2h/XvoG3OyqadP68xAocKXQ0NEXFWO6lNd1OGLIXF/jSujGmKPGmEFjjBt4kA+aVaoB37tJcq1lw73HA8aYYmNMcXZ2cP9FV+HVP+jmiZIqzp0/lRkZk0J2nKuK88jLmsS9r4T2Kn3PkXb6Bt3aIRoCBS5P80SkX6UPDLqpbOx0zCiLXuNK6CIyw+fllYC3AuZ54GoRSRSRAqAI2BhYiMrpXt9dR317L9esCm5n6FAJcTHccf48yqrbeKnsSMiO471DVCe1CD7vFW2kd4webu6mf9A4ZpRFL3/KFtcB7wDzRaRKRG4C/ltEtotIKXAu8BUAY8wO4AlgJ/AS8GVjTPTcFqaGtW7jIaalJ3LO/ND/p3XFKTkUulL4zYYDITtGaVULWSkJ5GaG7r+NiWpWVjIxAgcivBb92KBcDqpBBxh1ig1jzDXDLH7oBNt/F/huIEGpyFHd0s0be+u5/dy5xIXhjsrYGOGqlXn84MXd7K8PTRumd8o5p9zOHU0S4mLIy0pmf4Q3uRwrWXRYFZRzumdVRHp8k6eo6aqV4x+Ia6w+cUoOsTHCU5urgv7eXX0D7D3aru3nIRQNpYvldR24UhPJSHbWxOGa0NW4DQy6ebLkMGcVZZObmRy2405NT+Lsedk8s6WaQXdwO0d31LThNugY6CHkTejhKD8NFc8YLs7qEAVN6CoAb+ytp7a1h2sCGCZ3vNauyOVIWw8bgjxPpfcOUR0DPXQKXSl09Q1S195rdyjjcmzaOYe1n4MmdBWAdRsP40pN5PyTp4X92OefPJXJyfFBb3YprWplRkYSU9OCMxaN+rCCCB+kq7HTU0fvpFv+vTShq3E50trD67uPclVxri3DyybGxXL50pm8vOMIrV3Bu0mltKpFyxVDrCA7smvRyx1a4QKa0NU4PVFyGLeBT4exM3SoTxXn0Tfg5vnSmqC8X2tXP5WNXdohGmIz0pNIio85NlphpKlw4CiLXprQ1ZgNug2PbzrM6rkuWwcmWjgznZOmpwWt2aW0WqecC4eYGCF/SuRWulTUdTIpPpaZIbwrerw0oasxe29/I9Ut3QHNGRoMIsLaFblsO9zCvqPtAb9faVUrAIu1ySXkInl+0Yr6DgqzUxw5cJsmdDVmb+5rIC5GOHf+VLtD4YpTcogLUk36tsMtFLhSyJjkrNriaFTgSuFQUxf9g267Qxkzp00750sTuhqztysaWD4rk5TEUW80DjlXaiLnnjSVZ96vZiDA5OC9Q1SFXoErlQG3oaq52+5QxqS7b5Dqlm5N6Co6tHT1sb26lTPnuuwO5ZhPrcilvr2XN/eNf1jdurYejrT1aIdomHww6mJkdYzub+jAGBw3yqKXJnQ1Ju/ub8QYOHPuFLtDOebck6YyJSWBJ0vG3+yyzWo/10mhw6PQSuiRVoteYcXrxJuKQBO6GqMN5Q2kJMSyNM85V7LxsTFccUoOf9t1lKbOvnG9R2lVC7ExwsKZmtDDITMlgcnJ8RHXMVpR14EI5Dto2jlfmtDVmLxd3siphVNsuZnoRNauyKV/0PD81mHnUxlVaVUrRVNTmZSgc5qHSyQO0lVe30FeZjJJ8c78OXHWb6VytJqWbvY3dDqq/dzr5BnpLMpJ58lxVLsYYyitatH68zCLxIReUefMQbm8NKErv71lDYTlpPZzX2uX57Kjpo2dNW1j2q+quZvmrn4dkCvM5mSnUtvaQ1ffgN2h+GXQbTjQ0OnYChfQhK7G4K3yBlypCcyflmZ3KMO6fFkO8bFjr0n3TjmnV+jhFWnzi9a0dNM74HZshyhoQld+MsbwVkUjZ8xxOXYmn8yUBC44eRp/3lpN34D/NemlVa0kxMYwz6F/qKJVpCV0Jw/K5aUJXfllX10H9e29rHZg+7mvTxXn0tTZx9/31Pm9z7bDLZw8M52EOP11CCdvpUikzC/6waBcmtBVhNuwz9N+foZD28+9zirKJjst0e+a9EG3oay6VevPbTApIZaZGUkRc4VeUd9BZnI8WSkJdocyIk3oyi9vVzSQPyU5rFPNjUdcbAyfOCWHv++po96PGXH213fQ2Teod4japCA7JWImjK6oc+YsRb40oatRDQy6eXd/E2c4vLnFa+2KXAbdhuf8qEnXO0TtVeDUvRNZAAAgAElEQVRKYX99R0TML+rkQbm8NKGrUW2raqWjd8Dx7edeRdPSWJo3mSdLqkZNFKVVLaQkxFLo8F/UaFXgSqWtZ4DmIM46FQrNnX00dvZpQleR763yBkTg9EJnt5/7+tSKXPYcbaes+sQ16duqWlmUk0GsA8e2nggKI2SQrmMdog4dlMtLE7oa1VvlDSycmU6mgzuDhvrYkpkkxMXw1ObDI27TN+BmV02bo8almWi8pYsVDq90iYQKF/AjoYvIwyJSJyJlPst+JCK7RaRURJ4VkcnW8nwR6RaRrdbj/lAGr0Kvq2+ALYeaOXNOZDS3eGUkx3PRwuk8t62G3oHBYbfZc6SdvkG3joFuo9zMScTHiuMrXSrqO0mIi3F8UYA/V+i/Ay4esuxVYJExZgmwF/h3n3UVxphl1uOW4ISp7LKpspn+QePI8VtGs3ZFLi1d/fxt5/A16XqHqP3iYmOYlZXs+Fr08roOCl0pjm+aGzWhG2PeBJqGLHvFGOMdgOFdIDcEsSkHeKu8gYTYGFbmZ9kdypitnutiRkbSiM0upVUtZCbHk5vpvMl+J5ICV6rjr9D3Hm139B2iXsFoQ78ReNHndYGIvC8ib4jImiC8v7LRW+UNLJ89OSKHlY2NET6xPIc39tZztK3nQ+s9U85NduxQBhNFYXYKBxo7cbudWbrY3NlHVXM3iyJgrPyAErqI3A0MAH+0FtUCs4wxpwB3Ao+JSPoI+94sIiUiUlJfP/6pw1ToNHX2saOmLWLKFYfzyeW5uA08+/7xNeldfQPsPdqu9ecOUOBKoW/ATU2rM+cX3WGN3rkoZ9hU5ijjTugi8nngMuCzxir2Ncb0GmMareebgQpg3nD7G2MeMMYUG2OKs7OzxxuGCqF3KhoBIuaGouEUZqdSPDuTJ0sOH1eTvqOmDbdB7xB1AKcP0lVW47n5LGqv0EXkYuDrwMeNMV0+y7NFJNZ6XggUAfuDEagKvw3lDaQlxrEkx/k/yCeydkUuFfWdbD3ccmzZNuu5VrjYr9DhCX17dSs5kydFRNmuP2WL64B3gPkiUiUiNwG/ANKAV4eUJ54FlIrIVuAp4BZjTNOwb6wc7+2KBk4tnEKcw6abG6tLl8wgKT7muNmMSqtamZ6exNT0JBsjUwDZaYmkJMQ6dsLoHdWtEdHcAhA32gbGmGuGWfzQCNs+DTwdaFDKfoebujjY2MUNZ+TbHUrA0pLi+eiiGfxlWw3fumwBSfGxbK9u1atzhxARCrKdOR1dW08/lY1drF0RGYV8kX3ppULm7QrvdHOR237u61MrcmnvGeDlHUdo7e7nQEOn3iHqIIWuVPY78PZ/73SGCyOk2VETuhrWhvJGpqYlOn64UH+dVjiFnMmTeGpzFdutERb1Ct05ClwpVDV3j3hXr13KqiOnQxQ0oathuN2Gt8sbOHOuc6ebG6uYGOGTK3LZUN7ASztqAViSo1foTlGYnYIxcKixa/SNw6is2tPXkp2WaHcoftGErj5kz9F2Gjv7oqa5xWvt8lyMgXUbD5M/JZmM5Hi7Q1IWb+mi0ya7KKtpi5gOUdCErobxVrm3/Txyhsv1x6wpyZxakMWg22j9ucPkO7B0sbN3gIr6DhZGSHMLaEJXw3irvIHC7BRmZETfGCefKs4DtP3cadKT4nGlJjpqkK5dtW0YA4sjpEMU/ChbVBNL34Cb9w408cnlkVGmNVaXLZnBniNtfHzZTLtDUUMUupxVunisQzSCErpeoavjbKtqoatvMOraz72S4mO5+9IFTE3TG4qcpsDlrAmjt1e34UpNYFp6ZHSIgiZ0NcSGfQ3ERNh0cyo6FGSn0NDRS1uPM+YX3VHjmZ4wkiq9NKFHqNd2HeVpn1vZg+XtigYW52RoBYgKO++YLpUOuErv6R9kX11HxNSfe2kbegTaeKCJL/1+M4PGkJs5iVODdDXd2TvA+4da+OJZhUF5P6XGojDbKl2s77S9CmlXbRuDbhNR7eegV+gRp6alm1v/uJm8rGTyMpP516e20dk7MPqOfth4oIkBt4no8c9V5MrLSiZGnFGLXhZBY6D70oQeQXr6B7n59yX09Lt58LoV/PhTS6lq7ua7L+wKyvtvKG8gIS6GFbMzg/J+So1FYlwsuZnJjqh02VHdyuTkeHImR1bprib0CGGM4a6nS9lR08ZPP72MuVPTWFWQxRfXFPLYe4d4Y2/gsz69Vd7AyvxMkuIjb7o5FR0KXCkccMAgXdurW1kcYR2ioAk9Yjy4fj9/3lrDVz8yjwsWTDu2/M6PzKNoaipff2obrV3jrw5o6Ohl95H2qC1XVJGhwJXCgfrO42aXCrfegUH2Hm2PqDtEvTShR4A39tbzgxd3c8ni6Xz53LnHrUuKj+V/rlpGQ0cf3/nLjnEf421rurkz52hCV/YpzE6hs2+Q+vZe22LYd7SD/kETce3noAnd8SobOrn9sS3Mm5bGj9YuHfZfwMW5Gdx27lyefb+al8pqx3Wct/Y1kJ4UF3G9+iq6OGGQru3WHaKRdMu/lyZ0B+voHeCLj5YQGyM8eF0xKYkjV5nedt5cFuWkc/ezZTR0jO3qxhjDhvIGTp8zhdiYyGozVNHFCRNGl1W3kpYUx6ysZNtiGC9N6A7ldhu+8vhW9jd08svPLCdvlB+u+NgY/ueqZbT3DPCNZ7aPqQ3yUFMX1S3dWq6obDczYxKJcTH2JvSaNhbOTI+4DlHQhO5YP3ttH6/uPMrdl5zMGX4m2nnT0vjqhfN4ZedR/ry12u9jbbCGy/X3OEqFSkyMeMZ0sWnUxf5BN7tq2yKyuQU0oTvSS2W1/Oy1faxdkcsNZ+aPad8vrCmkeHYm33puB7Wt3X7t83Z5IzMyko7deq2UnTyDdNlTulhe10HfgDti+5I0oTvMniPt3PnENpbmTea/rlg05n/7YmOEH39qKQODhq8/VTpq04vbbXi7ooEz5kTPdHMqshW4UjjU2MXAoDvsx/YOmRuJJYugCd1RWrr6+OKjJaQkxvHAtSvGfYNPviuFb1xyEuv3NfDH9w6dcNudtW00d/WzukhHV1TOUOBKYcBtqGr27z/MYCqrbiUlITZi/1vVhO4QA4NubnvsfY609nD/51YwLT2w8bo/d9ps1hS5+N4LuzjYOHJ7pHe6uTO0/lw5hHeQLjs6Rstq2lgwM52YCK320oTuED94cTcbyhv4rysWBWUsFRHhh59cQqwIX3uylEH38E0vb1U0UjQ1NeA/IEoFS4ErFQh/Lfqg27Czpi1im1vAz4QuIg+LSJ2IlPksyxKRV0Vkn/U101ouInKfiJSLSKmILA9V8NHimS1V/GbDAa4/fTZXrcwL2vvOnDyJb398IRsrm3h4w4EPre8dGGTjgUa93V85SmZyPBmT4sM+psv++g66+wcjtsIF/L9C/x1w8ZBldwGvGWOKgNes1wAfBYqsx83ArwIPM3qVVrVw1zPbOa0wi29etiDo7//J5Tl8ZME0fvTKHvYdbT9u3fuHWujpd2tCV44iItYgXeG9Qi+ribw5RIfyK6EbY94EmoYsvhx4xHr+CHCFz/JHjce7wGQRmRGMYKNNXXsPNz+6mezURH75meXExwa/BUxE+N6Vi0lJiOXOJ7bR71M58Fa5Z7q5Uwuzgn5cpQJRaA3SFU5l1W0kxccwJzsyO0QhsDb0acYY78AhRwDvEIA5wGGf7aqsZcpH34CbW/+whZbuPh64bgVTUkM3EW12WiLfvXIx26tb+d+/Vxxb/lZ5A0vzJpOepNPNKWcpzE6hprWH7r7BsB1ze3UrJ89IJy4EF1bhEpTIjafYeUzjXYrIzSJSIiIl9fWBj+UdaZ7aXEXJwWb+e+3SsHTCXLJ4Bpcvm8nPX99HWXUr7T39bKtq1dEVlSN5O0YrT1ChFUxuq0M00uYQHSqQhH7U25Rifa2zllcDvj17uday4xhjHjDGFBtjirOzswMIIzK9u7+RaemJfGxJ+Fqj/uPjC8lKSeDOJ7by5t4GBt1G28+VIx0bdTFMzS4Hm7ro6B2IyCFzfQWS0J8HrreeXw8857P8Oqva5TSg1adpRuEZ3XBTZRPF+VlhvTtzcnICP1y7hL1HO/jGs9tJio9h+Wx7J+NVajj5Ls9gdOGqdPEOmRvJHaLgf9niOuAdYL6IVInITcAPgI+IyD7gAus1wAvAfqAceBC4NehRR7iq5m5qW3tYlR/+zshz50/lmlV5tHb3szI/i8Q4nW5OOU9yQhwzMpLCVou+o7qVhNgYiqamheV4oTLyANs+jDHXjLDq/GG2NcCXAwkq2m2q9BQMrbQhoQPcfekCyus6WLsi15bjK+WPcJYultW0Mn96GglxkdshCn4mdBVcmyqbSUuKY/50e64GUhPjePKWM2w5tlL+KnCl8H/bQ99aa4yhrLqNSxZHfnV1ZP85ilCbKptYMTtTZwdS6gQKXCm0dPXT3NkX0uNUNXfT2t0f8R2ioAk97Jo6+yiv67CtuUWpSOEdpCvU7ejeIXMjvWQRNKGHnbf9fFWBJnSlTqTQqkUPdTv69upW4mLEtibQYNKEHmYllU0kxMWwJDfyrwaUCqXczEnExUjISxfLatoompY27vkHnEQTephtrGxmaW6GlgsqNYq42BhmTUkO6c1Fxhh2VLeyaGbkt5+DJvSw6uobYEd1q7afK+WnwhCXLta29tDY2cfiKPmPWRN6GG091MKA27BS28+V8ou3Ft09wgQtgYr0OUSH0oQeRhsrmxAhKDMSKTURFLhS6R1wU9vWE5L3L6tpI0ZgwQxtclFjtKmyiZOmp+twtUr5yTtIV6jGRi+rbmXu1FQmJURHn5Ym9DDpH3Sz5WALq/L16lwpf30wYXRoKl3Kqlujov7cSxN6mOysaaO7f1Dbz5Uag6lpiaQkxIbk5qK6th7q2nsjfoRFX5rQw8TuAbmUikQiQkF2aCpdomEO0aE0oYfJxgNNzMpKZlp6kt2hKBVRClypoUno1W2IwIIoqUEHTehhYYyh5GCzXp0rNQ4FrhQON3XRN+AefeMxKKtupcCVQmpi9Aw6qwk9DCrqO2nq7GNVgXaIKjVWha4U3AYONQX3Kj3aOkRBE3pYaPu5UuMXivlFGzt6qWntiYohc31pQg+DTQeacKUmHPvBVEr5L99bix7EdvQdNW1AdHWIgib0sNhY2UTx7PBOCK1UtMiYFI8rNSGoCX17lN3y76UJPcRqW7upau7W+nOlAlDgSglqLfqOmlZmZSWTMSm67trWhB5imyqbAVil7edKjVuwJ4wuq25jcZQ1t4Am9JDbdKCJlIRYTp4R+bOhKGWXAlcq9e29tPf0B/xerV39HGrqYmGUdYiCJvSQ21TZxPLZmcTF6qlWary8Y7rstDozA7GjJnrmEB1Ks0wItXb1s+dou5YrKhWg4tmZTElJ4I7Ht3K4qSug94rGW/69NKGH0OZDTRij9edKBWpKaiJ/+MKpdPcPcvUD71Ld0j3u99pe3UbO5ElkpSQEMUJnGHdCF5H5IrLV59EmIneIyHdEpNpn+SXBDDiSbDzQTHyssCxvst2hKBXxTp6Rzh9uOpW2nn6ueeBdjrSOb9KLHdWtLIyi8Vt8jTuhG2P2GGOWGWOWASuALuBZa/VPvOuMMS8EI9BItKmyiUU5GVEzeL5SdluUk8GjN66iqbOPzzz4LnXtY0vq7T397G/ojMoKFwhek8v5QIUx5mCQ3i/i9fQPUlrVouWKSgXZKbMy+e0NKznS1sNnH3yPxo5ev/fdGaV3iHoFK6FfDazzeX2biJSKyMMiMiFHpNp2uIX+QaPt50qFwMr8LH5zfTGHmrr43EMbaenq82u/MiuhR2PJIgQhoYtIAvBx4Elr0a+AOcAyoBa4d4T9bhaREhEpqa+vDzQMx/EOyFWsU84pFRJnzHHx4HXFVNR1cO1DG2ntHr1GfUd1K9PSE5maFp3zEgTjCv2jwBZjzFEAY8xRY8ygMcYNPAisGm4nY8wDxphiY0xxdnZ2EMJwlo2Vzcyblsrk5OjrSVfKKc6al8391y5n95E2Pv/bjXT0Dpxw++1ROGSur2Ak9GvwaW4RkRk+664EyoJwjIgy6DZs0QktlAqL806axs+vWU5pVSs3/HYjXX3DJ/WuvgEq6jtYGKXt5xBgQheRFOAjwDM+i/9bRLaLSClwLvCVQI4RiXbVttHRO8AqHZBLqbC4eNF0fnb1MjYfbOYLj5TQ0z/4oW121bbjNkRthQtAQHMvGWM6gSlDll0bUERRQCe0UCr8Llsyk4FBw1ee2MoXHy3hweuKSYr/oGS4rNp7h2h0doiC3ikaEpsqm8iZPImZkyfZHYpSE8oVp+Tww08sYf2+Bm7945bj5iEtq25lSkoC06N4onZN6EFmjGFTZTMrtbpFKVtctTKP/7piEa/vruP2dVvoH/Qk9bKaNhblZET1RDOa0IPsYGMX9e29OqGFUjb63Gmz+fbHFvDyjqN85fGtdPUNsO9oe1Q3t0CAbejqwzZa7ed6h6hS9rrhzAL6Btx8/8Xd1Lb2MOA2UV2yCJrQg66ksonJyfHMyU61OxSlJrwvnT2HvgE39766F4jeW/69NKEH2abKZopnZxETE73tdEpFktvPL0IE3tzXQG5mdBcqaBt6ENW193CgoZNVBdohqpST3HZeEU986fSo7hAFTehBVWJNCK3150opO2hCD6JNlU0kxcewMMo7XpRSzqQJPYg2VTZxSl4mCXF6WpVS4aeZJ0jae/rZWdOm9edKKdtoQg+SLYdacButP1dK2UcTepCUVDYRGyOcMksnhFZK2UMTepBsPNDEwpnppCRqab9Syh6a0IOgd2CQrYdbtFxRKWUrTehBUFbdSu+AWxO6UspWmtCDYJN1Q5FOCK2UspMm9CDYdKCJwuwUXKmJdoeilJrANKEHyO02lBxs1nJFpZTtNKEHaG9dO63d/dp+rpSynSb0AHnbz1fpHaJKKZtpQg/QpgNNTEtPjPpxlpVSzqcJPQCeCaGbWJmfFfXjLCulnE8TegCqmrupbe3R5hallCNoQg9AyUHPhNDaIaqUcoKABx4RkUqgHRgEBowxxSKSBTwO5AOVwFXGmOZAj+U0Gw80k5YUx7xpaXaHopRSQbtCP9cYs8wYU2y9vgt4zRhTBLxmvY46myqbKJ6dSaxOCK2UcoBQNblcDjxiPX8EuCJEx7FNRX0H5XUdOqGFUsoxgpHQDfCKiGwWkZutZdOMMbXW8yPAtCAcxzHcbsNdT5eSnhTH2hW5doejlFJAENrQgdXGmGoRmQq8KiK7fVcaY4yImKE7Wcn/ZoBZs2YFIYzw+cN7B9lU2cyP1i5halqS3eEopRQQhCt0Y0y19bUOeBZYBRwVkRkA1te6YfZ7wBhTbIwpzs7ODjSMsKlq7uKHL+5mTZFLr86VUo4SUEIXkRQRSfM+By4EyoDngeutza4HngvkOE5hjOEbz5ZhgO9duVhvJlJKOUqgTS7TgGetxBYHPGaMeUlENgFPiMhNwEHgqgCP4whPb6nmzb31/MfHF5KXlWx3OEopdZyAEroxZj+wdJjljcD5gby309S193DPX3dSPDuTa0+bbXc4Sin1IXqnqJ++/dwOuvsH+eHaJcRo3blSyoE0ofvhxe21vFh2hDsuKGJOdqrd4Sil1LA0oY+ipauP//fcDhbOTOeLawrtDkcppUYUjDr0qHbPX3fR3NXHIzeuJD5W//4ppZxLM9QJvLG3nqe3VHHL2YUsnJlhdzhKKXVCmtBH0NE7wDee2c6c7BRuP6/I7nCUUmpU2uQygh+9tJua1m6euuV0kuJj7Q5HKaVGpVfow9hU2cQj7xzk+tPzWTFbR1NUSkUGTehD9PQP8m9PlZKbOYmvXTTf7nCUUspv2uQyxM9e28f+hk5+f9MqUhL19CilIodeofsoq27lgTf3c1VxLmuKImcESKWUAk3ox/QPuvnaU6VkpSRw9yUL7A5HKaXGTNsULL9+o4JdtW38+toVZCTH2x2OUkqNmV6hA+V17dz3WjmXLp7BRQun2x2OUkqNy4RP6INuw9efKiU5MZbvfHyh3eEopdS4TfiE/sjblWw51MK3P7aA7LREu8NRSqlxm9AJ/XBTFz96eQ/nzM/mimU5doejlFIBmbAJ3dvUEiM6P6hSKjpM2IT+s7/t5Z39jXz74wuZOXmS3eEopVTAJmRCf2NvPT//ezlrV+RyVXGe3eEopVRQTLiEXtPSzR1/ep/509K45/JFdoejlFJBM6ESev+gm9se20LfgJtffnY5kxJ0WFylVPSYUHeK/uDF3Ww51MIvPnOKTvaslIo6E+YK/aWyWh7acIDrT5/NZUtm2h2OUkoF3YRI6JUNnXztyVKW5mbwjUtPtjscpZQKiXEndBHJE5G/i8hOEdkhIv9iLf+OiFSLyFbrcUnwwh27nv5Bbv3jFmJihF9+djmJcdpurpSKToG0oQ8AXzXGbBGRNGCziLxqrfuJMebHgYcXuP/4yw521rbx8OeLyc1MtjscpZQKmXEndGNMLVBrPW8XkV2Ao+6ff2ZLFes2HuafzpnDeSdNszscpZQKqaC0oYtIPnAK8J616DYRKRWRh0UkMxjHGKu9R9u5+9kyTi3I4qsfmWdHCEopFVYBJ3QRSQWeBu4wxrQBvwLmAMvwXMHfO8J+N4tIiYiU1NfXBxrGcTp7B/inP2wmJTGOn19zCnGxE6LvVyk1wQWU6UQkHk8y/6Mx5hkAY8xRY8ygMcYNPAisGm5fY8wDxphiY0xxdnbw5u80xvDvz2znQEMn912zjKnpSUF7b6WUcrJAqlwEeAjYZYz5H5/lM3w2uxIoG394Y/eH9w7x/LYa7vzIPM6Y4wrnoZVSylaBVLmcCVwLbBeRrdaybwDXiMgywACVwJcCinAMSqtauOcvOzlnfja3njM3XIdVSilHCKTKZQMw3CDiL4w/nPFr7ern1j9uwZWawE+uWkZMjI5vrpSaWKJiLBdjDF99chtH23p4/Eunk5mSYHdISikVdlFR/vHAm/v5266jfOOSk1k+y5YqSaWUsl3EJ/SNB5r475f3cOniGXz+jHy7w1FKKdtEdEJv6Ojl9nVbmJWVzA8+qfOCKqUmtohO6L94vZyWrn7+97PLSUuKtzscpZSyVUR3iv77JSdx2ZIZnDwj3e5QlFLKdhF9hZ4YF0txfpbdYSillCNEdEJXSin1AU3oSikVJTShK6VUlNCErpRSUUITulJKRQlN6EopFSU0oSulVJTQhK6UUlFCE7pSSkUJTehKKRUlxBhjdwyISD1wcJy7u4CGIIYTbE6OT2MbHyfHBs6OT2Mbn9nGmOzRNnJEQg+EiJQYY4rtjmMkTo5PYxsfJ8cGzo5PYwstbXJRSqkooQldKaWiRDQk9AfsDmAUTo5PYxsfJ8cGzo5PYwuhiG9DV0op5RENV+hKKaWIoIQuIheLyB4RKReRu4ZZnygij1vr3xOR/DDFlScifxeRnSKyQ0T+ZZhtzhGRVhHZaj2+FY7YfI5fKSLbrWOXDLNeROQ+69yVisjyMMU13+ecbBWRNhG5Y8g2YTt3IvKwiNSJSJnPsiwReVVE9llfM0fY93prm30icn2YYvuRiOy2PrNnRWTyCPue8PMPYXzfEZFqn8/ukhH2PeHvdohie9wnrkoR2TrCviE/d0FljHH8A4gFKoBCIAHYBiwYss2twP3W86uBx8MU2wxgufU8Ddg7TGznAH+18fxVAq4TrL8EeBEQ4DTgPZs+4yN46m1tOXfAWcByoMxn2X8Dd1nP7wJ+OMx+WcB+62um9TwzDLFdCMRZz384XGz+fP4hjO87wL/68bmf8Hc7FLENWX8v8C27zl0wH5Fyhb4KKDfG7DfG9AF/Ai4fss3lwCPW86eA80VEQh2YMabWGLPFet4O7AJyQn3cILsceNR4vAtMFpEZYY7hfKDCGDPeG8wCZox5E2gastj35+oR4Iphdr0IeNUY02SMaQZeBS4OdWzGmFeMMQPWy3eB3GAecyxGOHf+8Od3O2SxWTniKmBdMI9pl0hJ6DnAYZ/XVXw4aR7bxvohbwWmhCU6i9XMcwrw3jCrTxeRbSLyoogsDGdcgAFeEZHNInLzMOv9Ob+hdjUj/1LZee6mGWNqredHgGnDbOOE83cjnv+yhjPa5x9Kt1lNQg+P0Fxl97lbAxw1xuwbYb2d527MIiWhO56IpAJPA3cYY9qGrN6CpylhKfBz4M9hDm+1MWY58FHgyyJyVpiPf0IikgB8HHhymNV2n7tjjOd/cMeVhYnI3cAA8McRNrHr8/8VMAdYBtTiadpwmms48dW5o393hoqUhF4N5Pm8zrWWDbuNiMQBGUBjOIITkXg8yfyPxphnhq43xrQZYzqs5y8A8SLiCkds1jGrra91wLN4/s315c/5DaWPAluMMUeHrrD73AFHvc1P1te6Ybax7fyJyOeBy4DPWn9wPsSPzz8kjDFHjTGDxhg38OAIx7Xz3MUBnwAeH2kbu87deEVKQt8EFIlIgXU1dzXw/JBtnge81QVrgddH+gEPJqsN7iFglzHmf0bYZrq3PV9EVuE57+H6Y5MiImne53g60sqGbPY8cJ1V7XIa0OrTzBAOI14l2XnuLL4/V9cDzw2zzcvAhSKSaTUrXGgtCykRuRj4OvBxY0zXCNv48/mHKj7ffpgrRziuP7/boXIBsNsYUzXcSjvP3bjZ3Svr7wNPJcZePD3id1vL/hPPDzNAEp5/2cuBjUBhmOJajeff8FJgq/W4BLgFuMXa5jZgB54e/HeBM8J43gqt426zYvCeO9/4BPildW63A8VhjC8FT4LO8Flmy7nD80elFujH05Z7E55+mNeAfcDfgCxr22LgNz773mj97JUDN4QptnI87c/enztvlddM4IUTff5hiu/31s9TKZ4kPWNofNbrD/1uhzo2a/nvvD9nPtuG/dwF86F3iiqlVJSIlCYXpZRSo5TNNOsAAAAySURBVNCErpRSUUITulJKRQlN6EopFSU0oSulVJTQhK6UUlFCE7pSSkUJTehKKRUl/j/uGOcosdriMAAAAABJRU5ErkJggg==\n", 218 | "text/plain": [ 219 | "
" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "output_type": "display_data" 224 | } 225 | ], 226 | "source": [ 227 | "state = envs.reset()\n", 228 | "\n", 229 | "while frame_idx < max_frames:\n", 230 | "\n", 231 | " log_probs = []\n", 232 | " values = []\n", 233 | " rewards = []\n", 234 | " masks = []\n", 235 | " entropy = 0\n", 236 | "\n", 237 | " for _ in range(num_steps):\n", 238 | " state = torch.FloatTensor(state).to(device)\n", 239 | " dist, value = model(state)\n", 240 | "\n", 241 | " action = dist.sample()\n", 242 | " next_state, reward, done, _ = envs.step(action.cpu().numpy())\n", 243 | "\n", 244 | " log_prob = dist.log_prob(action)\n", 245 | " entropy += dist.entropy().mean()\n", 246 | " \n", 247 | " log_probs.append(log_prob)\n", 248 | " values.append(value)\n", 249 | " rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))\n", 250 | " masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))\n", 251 | " \n", 252 | " state = next_state\n", 253 | " frame_idx += 1\n", 254 | " \n", 255 | " if frame_idx % 1000 == 0:\n", 256 | " test_rewards.append(np.mean([test_env() for _ in range(10)]))\n", 257 | " plot(frame_idx, test_rewards)\n", 258 | " \n", 259 | " next_state = torch.FloatTensor(next_state).to(device)\n", 260 | " _, next_value = model(next_state)\n", 261 | " returns = compute_returns(next_value, rewards, masks)\n", 262 | " \n", 263 | " log_probs = torch.cat(log_probs)\n", 264 | " returns = torch.cat(returns).detach()\n", 265 | " values = torch.cat(values)\n", 266 | "\n", 267 | " advantage = returns - values\n", 268 | "\n", 269 | " actor_loss = -(log_probs * advantage.detach()).mean()\n", 270 | " critic_loss = advantage.pow(2).mean()\n", 271 | "\n", 272 | " loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy\n", 273 | "\n", 274 | " optimizer.zero_grad()\n", 275 | " loss.backward()\n", 276 | " optimizer.step()" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 26, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "text/plain": [ 287 | "200.0" 288 | ] 289 | }, 290 | "execution_count": 26, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "test_env(True)" 297 | ] 298 | } 299 | ], 300 | "metadata": { 301 | "kernelspec": { 302 | "display_name": "Python [conda env:pytorch4]", 303 | "language": "python", 304 | "name": "conda-env-pytorch4-py" 305 | }, 306 | "language_info": { 307 | "codemirror_mode": { 308 | "name": "ipython", 309 | "version": 3 310 | }, 311 | "file_extension": ".py", 312 | "mimetype": "text/x-python", 313 | "name": "python", 314 | "nbconvert_exporter": "python", 315 | "pygments_lexer": "ipython3", 316 | "version": "3.5.5" 317 | } 318 | }, 319 | "nbformat": 4, 320 | "nbformat_minor": 2 321 | } 322 | -------------------------------------------------------------------------------- /higgsfield/rl/rl_adventure_2/README.md: -------------------------------------------------------------------------------- 1 | # RL-Adventure-2: Policy Gradients 2 | 3 | 4 | 5 | 6 | PyTorch tutorial of: actor critic / proximal policy optimization / acer / ddpg / twin dueling ddpg / soft actor critic / generative adversarial imitation learning / hindsight experience replay 7 | 8 | The deep reinforcement learning community has made several improvements to the [policy gradient](http://rll.berkeley.edu/deeprlcourse/f17docs/lecture_4_policy_gradient.pdf) algorithms. This tutorial presents latest extensions in the following order: 9 | 10 | 1. Advantage Actor Critic (A2C) 11 | - [actor-critic.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/1.actor-critic.ipynb) 12 | - [A3C Paper](https://arxiv.org/pdf/1602.01783.pdf) 13 | - [OpenAI blog](https://blog.openai.com/baselines-acktr-a2c/#a2canda3c) 14 | 2. High-Dimensional Continuous Control Using Generalized Advantage Estimation 15 | - [gae.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/2.gae.ipynb) 16 | - [GAE Paper](https://arxiv.org/abs/1506.02438) 17 | 3. Proximal Policy Optimization Algorithms 18 | - [ppo.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/3.ppo.ipynb) 19 | - [PPO Paper](https://arxiv.org/abs/1707.06347) 20 | - [OpenAI blog](https://blog.openai.com/openai-baselines-ppo/) 21 | 4. Sample Efficient Actor-Critic with Experience Replay 22 | - [acer.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/4.acer.ipynb) 23 | - [ACER Paper](https://arxiv.org/abs/1611.01224) 24 | 5. Continuous control with deep reinforcement learning 25 | - [ddpg.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/5.ddpg.ipynb) 26 | - [DDPG Paper](https://arxiv.org/abs/1509.02971) 27 | 6. Addressing Function Approximation Error in Actor-Critic Methods 28 | - [td3.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/6.td3.ipynb) 29 | - [Twin Dueling DDPG Paper](https://arxiv.org/abs/1802.09477) 30 | 7. Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor 31 | - [soft actor-critic.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/7.soft%20actor-critic.ipynb) 32 | - [Soft Actor-Critic Paper](https://arxiv.org/abs/1801.01290) 33 | 8. Generative Adversarial Imitation Learning 34 | - [gail.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/8.gail.ipynb) 35 | - [GAIL Paper](https://arxiv.org/abs/1606.03476) 36 | 9. Hindsight Experience Replay 37 | - [her.ipynb](https://github.com/higgsfield/RL-Adventure-2/blob/master/9.her.ipynb) 38 | - [HER Paper](https://arxiv.org/abs/1707.01495) 39 | - [OpenAI Blog](https://blog.openai.com/ingredients-for-robotics-research/#understandingher) 40 | 41 | # If you get stuck… 42 | - Remember you are not stuck unless you have spent more than a week on a single algorithm. It is perfectly normal if you do not have all the required knowledge of mathematics and CS. 43 | - Carefully go through the paper. Try to see what is the problem the authors are solving. Understand a high-level idea of the approach, then read the code (skipping the proofs), and after go over the mathematical details and proofs. 44 | 45 | # RL Algorithms 46 | Deep Q Learning tutorial: [DQN Adventure: from Zero to State of the Art](https://github.com/higgsfield/RL-Adventure) 47 | [![N|Solid](https://planspace.org/20170830-berkeley_deep_rl_bootcamp/img/annotated.jpg)]() 48 | Awesome RL libs: rlkit [@vitchyr](https://github.com/vitchyr), pytorch-a2c-ppo-acktr [@ikostrikov](https://github.com/ikostrikov), 49 | ACER [@Kaixhin](https://github.com/Kaixhin) 50 | 51 | # Best RL courses 52 | - Berkeley deep RL [link](http://rll.berkeley.edu/deeprlcourse/) 53 | - Deep RL Bootcamp [link](https://sites.google.com/view/deep-rl-bootcamp/lectures) 54 | - David Silver's course [link](http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching.html) 55 | - Practical RL [link](https://github.com/yandexdataschool/Practical_RL) 56 | -------------------------------------------------------------------------------- /higgsfield/rl/rl_adventure_2/common/__init__.py: -------------------------------------------------------------------------------- 1 | import multiprocessing_env -------------------------------------------------------------------------------- /higgsfield/rl/rl_adventure_2/common/multiprocessing_env.py: -------------------------------------------------------------------------------- 1 | #This code is from openai baseline 2 | #https://github.com/openai/baselines/tree/master/baselines/common/vec_env 3 | 4 | import numpy as np 5 | from multiprocessing import Process, Pipe 6 | 7 | def worker(remote, parent_remote, env_fn_wrapper): 8 | parent_remote.close() 9 | env = env_fn_wrapper.x() 10 | while True: 11 | cmd, data = remote.recv() 12 | if cmd == 'step': 13 | ob, reward, done, info = env.step(data) 14 | if done: 15 | ob = env.reset() 16 | remote.send((ob, reward, done, info)) 17 | elif cmd == 'reset': 18 | ob = env.reset() 19 | remote.send(ob) 20 | elif cmd == 'reset_task': 21 | ob = env.reset_task() 22 | remote.send(ob) 23 | elif cmd == 'close': 24 | remote.close() 25 | break 26 | elif cmd == 'get_spaces': 27 | remote.send((env.observation_space, env.action_space)) 28 | else: 29 | raise NotImplementedError 30 | 31 | class VecEnv(object): 32 | """ 33 | An abstract asynchronous, vectorized environment. 34 | """ 35 | def __init__(self, num_envs, observation_space, action_space): 36 | self.num_envs = num_envs 37 | self.observation_space = observation_space 38 | self.action_space = action_space 39 | 40 | def reset(self): 41 | """ 42 | Reset all the environments and return an array of 43 | observations, or a tuple of observation arrays. 44 | If step_async is still doing work, that work will 45 | be cancelled and step_wait() should not be called 46 | until step_async() is invoked again. 47 | """ 48 | pass 49 | 50 | def step_async(self, actions): 51 | """ 52 | Tell all the environments to start taking a step 53 | with the given actions. 54 | Call step_wait() to get the results of the step. 55 | You should not call this if a step_async run is 56 | already pending. 57 | """ 58 | pass 59 | 60 | def step_wait(self): 61 | """ 62 | Wait for the step taken with step_async(). 63 | Returns (obs, rews, dones, infos): 64 | - obs: an array of observations, or a tuple of 65 | arrays of observations. 66 | - rews: an array of rewards 67 | - dones: an array of "episode done" booleans 68 | - infos: a sequence of info objects 69 | """ 70 | pass 71 | 72 | def close(self): 73 | """ 74 | Clean up the environments' resources. 75 | """ 76 | pass 77 | 78 | def step(self, actions): 79 | self.step_async(actions) 80 | return self.step_wait() 81 | 82 | 83 | class CloudpickleWrapper(object): 84 | """ 85 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 86 | """ 87 | def __init__(self, x): 88 | self.x = x 89 | def __getstate__(self): 90 | import cloudpickle 91 | return cloudpickle.dumps(self.x) 92 | def __setstate__(self, ob): 93 | import pickle 94 | self.x = pickle.loads(ob) 95 | 96 | 97 | class SubprocVecEnv(VecEnv): 98 | def __init__(self, env_fns, spaces=None): 99 | """ 100 | envs: list of gym environments to run in subprocesses 101 | """ 102 | self.waiting = False 103 | self.closed = False 104 | nenvs = len(env_fns) 105 | self.nenvs = nenvs 106 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 107 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 108 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 109 | for p in self.ps: 110 | p.daemon = True # if the main process crashes, we should not cause things to hang 111 | p.start() 112 | for remote in self.work_remotes: 113 | remote.close() 114 | 115 | self.remotes[0].send(('get_spaces', None)) 116 | observation_space, action_space = self.remotes[0].recv() 117 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 118 | 119 | def step_async(self, actions): 120 | for remote, action in zip(self.remotes, actions): 121 | remote.send(('step', action)) 122 | self.waiting = True 123 | 124 | def step_wait(self): 125 | results = [remote.recv() for remote in self.remotes] 126 | self.waiting = False 127 | obs, rews, dones, infos = zip(*results) 128 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 129 | 130 | def reset(self): 131 | for remote in self.remotes: 132 | remote.send(('reset', None)) 133 | return np.stack([remote.recv() for remote in self.remotes]) 134 | 135 | def reset_task(self): 136 | for remote in self.remotes: 137 | remote.send(('reset_task', None)) 138 | return np.stack([remote.recv() for remote in self.remotes]) 139 | 140 | def close(self): 141 | if self.closed: 142 | return 143 | if self.waiting: 144 | for remote in self.remotes: 145 | remote.recv() 146 | for remote in self.remotes: 147 | remote.send(('close', None)) 148 | for p in self.ps: 149 | p.join() 150 | self.closed = True 151 | 152 | def __len__(self): 153 | return self.nenvs -------------------------------------------------------------------------------- /higgsfield/static/project/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | env 162 | -------------------------------------------------------------------------------- /higgsfield/static/project/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM higgsfield/pytorch:latest 2 | 3 | COPY requirements.txt requirements.txt 4 | 5 | RUN python3 -m pip install -r requirements.txt 6 | -------------------------------------------------------------------------------- /higgsfield/static/project/env: -------------------------------------------------------------------------------- 1 | WAN_DB_TOKEN="MAYBE:TH4T5H3R34R3N0TTH3D4T4Y0U4R3L00K1NGF0R" 2 | HUGGINGFACE_TOKEN="THIS:IS:THE:TOKEN:YOU:ARE:LOOKING:FOR" 3 | SSH_KEY="/this/better/be/a/path/to/your/ssh/private/key" 4 | -------------------------------------------------------------------------------- /higgsfield/static/project/requirements.txt: -------------------------------------------------------------------------------- 1 | optimum==1.13.2 2 | wandb==0.15.11 3 | polars==0.19.3 4 | peft==0.5.0 5 | datasets==2.14.5 6 | higgsfield==0.0.3 7 | -------------------------------------------------------------------------------- /higgsfield/static/project/src/alpaca_bf16.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.optim.lr_scheduler import StepLR 3 | 4 | from higgsfield.llama import Llama 5 | from higgsfield.loaders import LlamaLoader 6 | from higgsfield.checkpoint import Checkpoint 7 | from higgsfield.training import clip_grad_norm 8 | from higgsfield.experiment import experiment, param 9 | 10 | from src.dataset import AlpacaDataset 11 | 12 | @experiment("alpaca_bf16") 13 | @param("size", options=["7b", "13b", "70b"]) 14 | @param("num_epochs", default=1, description="Number of epochs") 15 | def train(params): 16 | 17 | if params.size == "7b": 18 | model_name = "meta-llama/Llama-2-7b-hf" 19 | elif params.size == "13b": 20 | model_name = "meta-llama/Llama-2-13b-hf" 21 | elif params.size == "70b": 22 | model_name = "meta-llama/Llama-2-70b-hf" 23 | 24 | model = Llama( 25 | model_name=model_name, 26 | zero_stage=3, 27 | cpu_init_rank0=True, 28 | fast_attn=False, 29 | precision="bf16", 30 | cpu_offload=False, 31 | ) 32 | 33 | optimizer = optim.AdamW( 34 | model.parameters(), 35 | lr=1e-5, 36 | weight_decay=0.0, 37 | ) 38 | 39 | lr_scheduler = StepLR( 40 | optimizer, 41 | step_size=1, 42 | gamma=0.85, 43 | ) 44 | 45 | # ~/.cache/{project-name}/experiments/{experiment_name}/{run_name}/ 46 | checkpoint = Checkpoint( 47 | model, 48 | optimizer, 49 | lr_scheduler, 50 | ) 51 | 52 | dataset_name = "tatsu-lab/alpaca" 53 | dataset = AlpacaDataset(dataset_name, split="train") 54 | 55 | train_loader = LlamaLoader( 56 | dataset, 57 | max_sequence_length=2048, 58 | batch_size_per_gpu=1, 59 | ) 60 | 61 | for epoch in range(params.num_epochs): 62 | for i, batch in enumerate(train_loader): 63 | 64 | optimizer.zero_grad() 65 | loss = model(batch) 66 | 67 | if params.rank == 0: 68 | print("Loss: ", loss) 69 | 70 | loss.backward() 71 | 72 | clip_grad_norm(1.0, model, optimizer) 73 | optimizer.step() 74 | 75 | if i % 30 == 0 or i == len(train_loader) - 1: 76 | checkpoint.save(epoch, i) 77 | 78 | lr_scheduler.step() 79 | 80 | model.save_huggingface_model("my-alpaca") 81 | 82 | 83 | -------------------------------------------------------------------------------- /higgsfield/static/project/src/alpaca_fp16.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.optim.lr_scheduler import StepLR 3 | 4 | from higgsfield.llama import Llama 5 | from higgsfield.loaders import LlamaLoader 6 | from higgsfield.checkpoint import Checkpoint 7 | from higgsfield.training import clip_grad_norm, Scaler 8 | from higgsfield.experiment import experiment, param 9 | 10 | from src.dataset import AlpacaDataset 11 | 12 | @experiment("alpaca_fp16") 13 | @param("size", options=["7b", "13b", "70b"]) 14 | @param("num_epochs", default=1, description="Number of epochs") 15 | def train(params): 16 | 17 | if params.size == "7b": 18 | model_name = "meta-llama/Llama-2-7b-hf" 19 | elif params.size == "13b": 20 | model_name = "meta-llama/Llama-2-13b-hf" 21 | elif params.size == "70b": 22 | model_name = "meta-llama/Llama-2-70b-hf" 23 | 24 | model = Llama( 25 | model_name=model_name, 26 | zero_stage=3, 27 | cpu_init_rank0=True, 28 | fast_attn=False, 29 | precision="fp16", 30 | cpu_offload=False, 31 | ) 32 | 33 | optimizer = optim.AdamW( 34 | model.parameters(), 35 | lr=1e-5, 36 | weight_decay=0.0, 37 | ) 38 | 39 | lr_scheduler = StepLR( 40 | optimizer, 41 | step_size=1, 42 | gamma=0.85, 43 | ) 44 | 45 | scaler = Scaler(model) 46 | 47 | # ~/.cache/{project-name}/experiments/{experiment_name}/{run_name}/ 48 | checkpoint = Checkpoint( 49 | model, 50 | optimizer, 51 | lr_scheduler, 52 | scaler, 53 | ) 54 | 55 | dataset_name = "tatsu-lab/alpaca" 56 | dataset = AlpacaDataset(dataset_name, split="train") 57 | 58 | train_loader = LlamaLoader( 59 | dataset, 60 | max_sequence_length=2048, 61 | batch_size_per_gpu=1, 62 | ) 63 | 64 | for epoch in range(params.num_epochs): 65 | for i, batch in enumerate(train_loader): 66 | 67 | optimizer.zero_grad() 68 | loss = model(batch) 69 | 70 | if params.rank == 0: 71 | print("Loss: ", loss) 72 | 73 | scaler.scale(loss).backward() 74 | 75 | clip_grad_norm(1.0, model, optimizer, scaler) 76 | scaler.step(optimizer) 77 | scaler.update() 78 | 79 | if i % 30 == 0 or i == len(train_loader) - 1: 80 | checkpoint.save(epoch, i) 81 | 82 | lr_scheduler.step() 83 | 84 | model.save_huggingface_model("my-alpaca") 85 | -------------------------------------------------------------------------------- /higgsfield/static/project/src/dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | class AlpacaDataset: 4 | def __init__(self, dataset_name, split="train"): 5 | self.dataset = load_dataset(dataset_name, split=split) 6 | 7 | def __len__(self): 8 | return len(self.dataset) 9 | 10 | def __getitem__(self, idx): 11 | item = self.dataset[idx] 12 | 13 | instruction = item["instruction"] 14 | 15 | if "input" in item.keys(): 16 | prompt = ( 17 | "Below is an instruction that describes a task. " 18 | "Write a response that appropriately completes the request.\n\n" 19 | f"### Instruction:\n{instruction}\n\n### Response:" 20 | ) 21 | else: 22 | input = item["input"] 23 | 24 | prompt = ( 25 | "Below is an instruction that describes a task, paired with an input that provides further context. " 26 | "Write a response that appropriately completes the request.\n\n" 27 | f"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 28 | ) 29 | 30 | completion = item["output"] 31 | 32 | return { 33 | "prompt": prompt, 34 | "completion": completion, 35 | } -------------------------------------------------------------------------------- /higgsfield/static/templates/README_md.j2: -------------------------------------------------------------------------------- 1 | # {{ project_name }} 2 | 3 | ## Getting started 4 | 5 | #### [Setup](./setup.md) 6 | Here you can find the quick start guide on how to setup your nodes and start training. 7 | - [Initialize the project](https://github.com/higgsfield/higgsfield/blob/main/setup.md#initialize-the-project) 8 | - [Setup the environment](https://github.com/higgsfield/higgsfield/blob/main/setup.md#setup-the-environment) 9 | - [Setup git](https://github.com/higgsfield/higgsfield/blob/main/setup.md#setup-git) 10 | - [Time to setup your nodes!](https://github.com/higgsfield/higgsfield/blob/main/setup.md#time-to-setup-your-nodes) 11 | - [Run your very first experiment](https://github.com/higgsfield/higgsfield/blob/main/setup.md#run-your-very-first-experiment) 12 | - [Fasten your seatbelt, it's time to deploy!](https://github.com/higgsfield/higgsfield/blob/main/setup.md#fasten-your-seatbelt-its-time-to-deploy) 13 | 14 | #### [Tutorial](./tutorial.md) 15 | API for common tasks in Large Language Models training. 16 | - [Working with distributed model](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#working-with-distributed-model) 17 | - [Preparing Data](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#preparing-data) 18 | - [Optimizing the Model Parameters](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#optimizing-the-model-parameters) 19 | - [Saving Model](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#saving-model) 20 | - [Training stabilization techniques](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#training-stabilization-techniques) 21 | - [Monitoring](https://github.com/higgsfield/higgsfield/blob/main/tutorial.md#monitoring) 22 | 23 | | Platform | Purpose | Estimated Response Time | Support Level | 24 | | -------- | ------- | ----------------------- | ------------- | 25 | | [Github Issues](https://github.com/higgsfield/higgsfield/issues/) | Bug reports, feature requests, install issues, usage issues, etc. | < 1 day | Higgsfield Team | 26 | | [Twitter](https://twitter.com/higgsfield_ai/) | For staying up-to-date on new features. | Daily | Higgsfield Team | 27 | | [Website](https://higgsfield.ai/) | Discussion, news. | < 2 days | Higgsfield Team | 28 | -------------------------------------------------------------------------------- /higgsfield/static/templates/config_py.j2: -------------------------------------------------------------------------------- 1 | # RULES FOR CONFIG FILE: 2 | # 1. If you want something from env file, put it there, and just call the os.getenv, we will load it for you. 3 | # 2. Do not put any imports here. None but 'os'. 4 | # 3. Only simple definitions. 5 | # 4. Always under $folder_of_your_project/src/config.py. 6 | 7 | 8 | import os 9 | 10 | NAME = "{{ project_name }}" 11 | 12 | HOSTS = ["1.2.3.4"] 13 | HOSTS_USER = "ubuntu" 14 | HOSTS_PORT = 22 15 | 16 | NUMBER_OF_PROCESSES_PER_NODE = 2 17 | 18 | -------------------------------------------------------------------------------- /higgsfield/static/templates/deploy_action.j2: -------------------------------------------------------------------------------- 1 | {{ header }} 2 | name: Deploy Experiments 3 | 4 | on: 5 | push: 6 | branches: 7 | - main 8 | 9 | concurrency: 10 | cancel-in-progress: false 11 | group: main 12 | {% raw %} 13 | jobs: 14 | deploy: 15 | name: Deploy Experiments 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v3 20 | - name: Set credentials 21 | run: | 22 | pip install higgsfield==0.0.3 --quiet 23 | 24 | echo "SSH_KEY='NOTHING'" >> env 25 | eval "$(jq -r '["key", "user", "port", "hosts"] as $names | . as $json | $names[] as $name | "export \($name | ascii_upcase | @sh)=\($json[$name] | @sh)"' < <(higgsfield ci get-ssh-details))" 26 | 27 | echo "HOSTS=$HOSTS" >> $GITHUB_ENV 28 | echo "USER=$USER" >> $GITHUB_ENV 29 | echo "PORT=$PORT" >> $GITHUB_ENV 30 | - name: Deploy code 31 | uses: appleboy/ssh-action@v1.0.0 32 | env: 33 | HOSTS: ${{ env.HOSTS }} 34 | USER: ${{ env.USER }} 35 | PORT: ${{ env.PORT }} 36 | with: 37 | host: ${{ env.HOSTS }} 38 | username: ${{ env.USER }} 39 | key: ${{ secrets.SSH_KEY }} 40 | port: ${{ env.PORT }} 41 | sync: true {% endraw %} 42 | script: | 43 | mkdir -p ~/higgsfield/ 44 | cd ~/higgsfield 45 | [ -d ~/higgsfield/{{ project_name }} ] && \ 46 | (cd ~/higgsfield/{{ project_name }} && \ 47 | git fetch --all && \ 48 | git reset --hard origin/main && \ 49 | git pull origin main) || git clone {{ keyed_repo_url }} {{ project_name }} || rm -rf ~/higgsfield/{{ project_name }} && git clone {{ keyed_repo_url }} {{ project_name }} 50 | 51 | echo "SSH_KEY=NOTHING" > env 52 | {{ env_gen }} 53 | -------------------------------------------------------------------------------- /higgsfield/static/templates/experiment_action.j2: -------------------------------------------------------------------------------- 1 | {{ header }} 2 | name: Run {{ experiment_name }} 3 | 4 | on: 5 | {% if params|length == 0 %} 6 | workflow_dispatch: 7 | inputs: 8 | nothing: 9 | description: 'Just run the experiment, no params supplied!' 10 | {% else %} 11 | workflow_dispatch: 12 | inputs: 13 | run_name: 14 | description: 'Name of the run, if not set will be chosen randomly, if exists will be reused' 15 | required: false 16 | {% for param in params %} 17 | {{ param }}{% endfor %} 18 | {% endif %} 19 | 20 | 21 | concurrency: 22 | cancel-in-progress: false 23 | group: main 24 | 25 | {% raw %} 26 | jobs: 27 | run-training: 28 | name: Run experiment 29 | runs-on: ubuntu-latest 30 | steps: 31 | - name: Checkout 32 | uses: actions/checkout@v3 33 | - name: Install invoker 34 | run: | 35 | wget https://github.com/ml-doom/invoker/releases/download/latest/invoker-latest-linux-amd64.tar.gz 36 | tar -xvf invoker-latest-linux-amd64.tar.gz 37 | sudo mv invoker /usr/bin/invoker 38 | rm invoker-latest-linux-amd64.tar.gz 39 | - name: Set Port and Run Name 40 | run: | 41 | pip install higgsfield==0.0.3 --quiet 42 | 43 | echo "SSH_KEY='NOTHING'" >> env 44 | eval "$(jq -r '["key", "user", "port", "hosts"] as $names | . as $json | $names[] as $name | "export \($name | ascii_upcase | @sh)=\($json[$name] | @sh)"' < <(higgsfield ci get-ssh-details))" 45 | 46 | echo "HOSTS=$HOSTS" >> $GITHUB_ENV 47 | echo "USER=$USER" >> $GITHUB_ENV 48 | echo "PORT=$PORT" >> $GITHUB_ENV 49 | 50 | echo "CHOSEN_PORT=$(invoker random-port)" >> $GITHUB_ENV 51 | echo "CHOSEN_RUN_NAME=$(invoker random-name)" >> $GITHUB_ENV 52 | echo "NPROC_PER_NODE=$(higgsfield ci get-nproc-per-node)" >> $GITHUB_ENV 53 | - name: Run experiment 54 | uses: appleboy/ssh-action@v1.0.0 55 | env: 56 | RUN_PORT: ${{ env.CHOSEN_PORT }} 57 | RUN_NAME: ${{ github.event.inputs.run_name || env.CHOSEN_RUN_NAME }} 58 | NPROC_PER_NODE: ${{ env.NPROC_PER_NODE }} 59 | HOSTS: ${{ env.HOSTS }} 60 | USER: ${{ env.USER }} 61 | PORT: ${{ env.PORT }} 62 | with: 63 | host: ${{ env.HOSTS }} 64 | username: ${{ env.USER }} 65 | key: ${{ secrets.SSH_KEY }} 66 | port: ${{ env.PORT }} 67 | script: | 68 | cd ~/higgsfield/{% endraw %}{{ project_name }} 69 | echo "SSH_KEY=NOTHING" > env 70 | {{ env_gen }} 71 | invoker experiment run --project_name {{ project_name }} --experiment_name {{ experiment_name }} {% raw %} --run_name ${{ env.RUN_NAME }} --port ${{ env.RUN_PORT }} --nproc_per_node ${{ env.NPROC_PER_NODE }} --hosts ${{ env.HOSTS }} {% endraw %} {{ rest }} 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /higgsfield/static/templates/kill_action.j2: -------------------------------------------------------------------------------- 1 | {{ header }} 2 | name: Kill training experiments 3 | 4 | on: 5 | workflow_dispatch: 6 | inputs: 7 | experiment_name: 8 | description: 'Kill training on which experiment, if not specified kill all' 9 | required: true 10 | 11 | concurrency: 12 | cancel-in-progress: false 13 | group: main 14 | 15 | jobs: 16 | kill-training: 17 | name: Kill training 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v3 22 | - name: Set credentials 23 | run: | 24 | pip install higgsfield==0.0.3 --quiet 25 | 26 | echo "SSH_KEY='NOTHING'" >> env 27 | eval "$(jq -r '["key", "user", "port", "hosts"] as $names | . as $json | $names[] as $name | "export \($name | ascii_upcase | @sh)=\($json[$name] | @sh)"' < <(higgsfield ci get-ssh-details))" 28 | 29 | echo "HOSTS=$HOSTS" >> $GITHUB_ENV 30 | echo "USER=$USER" >> $GITHUB_ENV 31 | echo "PORT=$PORT" >> $GITHUB_ENV 32 | - name: Send kill signal to invoker 33 | uses: appleboy/ssh-action@v1.0.0 34 | env: {% raw %} 35 | HOSTS: ${{ env.HOSTS }} 36 | USER: ${{ env.USER }} 37 | PORT: ${{ env.PORT }} 38 | with: 39 | host: ${{ env.HOSTS }} 40 | username: ${{ env.USER }} 41 | key: ${{ secrets.SSH_KEY }} 42 | port: ${{ env.PORT }} 43 | script: | 44 | {% endraw %} 45 | cd ~/higgsfield/{{ project_name }} 46 | invoker experiment kill --project_name {{ project_name }} {% raw %} --hosts ${{ env.HOSTS }} --experiment_name ${{ github.event.inputs.experiment_name }} {% endraw %} 47 | 48 | -------------------------------------------------------------------------------- /higgsfield/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .grads import clip_grad_norm 2 | from .scaler import Scaler -------------------------------------------------------------------------------- /higgsfield/training/grads.py: -------------------------------------------------------------------------------- 1 | def clip_grad_norm(max_grad_norm, model, optimizer, scaler=None): 2 | model = model 3 | 4 | if scaler: 5 | scaler.unscale_(optimizer) 6 | 7 | if hasattr(optimizer, 'clip_grad_norm'): 8 | optimizer.clip_grad_norm(max_grad_norm) 9 | 10 | elif hasattr(model.model, 'clip_grad_norm_'): 11 | model.clip_grad_norm_(max_grad_norm) -------------------------------------------------------------------------------- /higgsfield/training/scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 3 | 4 | class Scaler(object): 5 | def __init__(self, model): 6 | if model.fsdp: 7 | self.scaler = ShardedGradScaler() 8 | else: 9 | return torch.cuda.amp.GradScaler() 10 | -------------------------------------------------------------------------------- /higgsfield/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import empty_cache -------------------------------------------------------------------------------- /higgsfield/utils/flush.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | 4 | def get_tensors_from_gc(gpu_only=True): 5 | for obj in gc.get_objects(): 6 | try: 7 | if torch.is_tensor(obj): 8 | tensor = obj 9 | elif hasattr(obj, "data") and torch.is_tensor(obj.data): 10 | tensor = obj.data 11 | else: 12 | continue 13 | 14 | if tensor.is_cuda or not gpu_only: 15 | yield tensor 16 | except Exception: # nosec B112 pylint: disable=broad-exception-caught 17 | continue 18 | 19 | def empty_cache(): 20 | cnt = 0 21 | for obj in get_tensors_from_gc(): 22 | obj.detach() 23 | obj.grad = None 24 | obj.untyped_storage().resize_(0) 25 | cnt += 1 26 | gc.collect() 27 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "higgsfield" 3 | version = "0.0.3" 4 | description = "" 5 | authors = ["Yerzat Dulat ", "Anwar Omar "] 6 | readme = "README.md" 7 | include = ["static/*"] 8 | license = "Apache-2.0" 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.8.13" 12 | click = "^8.1.7" 13 | pyyaml = "^6.0.1" 14 | asyncer = "^0.0.2" 15 | jinja2 = "^3.1.2" 16 | python-dotenv = "^1.0.0" 17 | cryptography = "^41.0.4" 18 | asyncssh = {extras = ["bcrypt", "libnacl", "pyopenssl"], version = "^2.14.0"} 19 | bcrypt = "^4.0.1" 20 | libsodium = "^2.6.1" 21 | pyopenssl = "^23.2.0" 22 | 23 | [tool.poetry.scripts] 24 | higgsfield = "higgsfield.internal.main:cli" 25 | 26 | 27 | [build-system] 28 | requires = ["poetry-core"] 29 | build-backend = "poetry.core.masonry.api" 30 | -------------------------------------------------------------------------------- /setup.md: -------------------------------------------------------------------------------- 1 | ## Setup tutorial 2 | 3 |

"Simplicity is prerequisite for reliability."

— Edsger W. Dijkstra
4 | 5 | ## Initialize the project 6 | 7 | ```bash 8 | $ higgsfield init my_llama_project 9 | ``` 10 | 11 |
It creates a folder named my_llama_project with the following structure: 12 | 13 | ``` 14 | my_llama_project 15 | ├── src 16 | │ ├── __init__.py 17 | │ ├── experiment.py 18 | │ └── config.py 19 | ├── Dockerfile 20 | ├── env 21 | ├── requirements.txt 22 | └── README.md 23 | ``` 24 |
25 | 26 | ## Setup the environment 27 | Get into the project folder: 28 | ```bash 29 | $ cd my_llama_project 30 | ``` 31 | Then start editing the `env` file. It should contain the valid SSH key to your training nodes. Make sure the key exists under the given path in your machine. 32 | For example: 33 | 34 | ```bash 35 | $ cat env 36 | SSH_KEY=~/.ssh/id_rsa 37 | ``` 38 | Great! Now you should edit the `src/config.py` file. It contains your experiments' configuration.
39 | Example 40 | 41 | ```python 42 | import os 43 | 44 | NAME = "my_llama_project" 45 | 46 | # You should fill this place with your training nodes IPs 47 | HOSTS = [ 48 | "1.2.3.4", 49 | ] 50 | 51 | # The user name of your training nodes, 52 | # It should be the same for all nodes. 53 | # And it might be different than 'ubuntu'. 54 | HOSTS_USER = "ubuntu" 55 | 56 | # The port of your training nodes, same for all nodes. 57 | HOSTS_PORT = 22 58 | 59 | # Number of processes per node. Depends on the amount of GPUs you have on each node. 60 | NUM_PROCESSES = 4 61 | 62 | # You can list other environment variables here. 63 | WAN_DB_TOKEN = os.environ.get("WAN_DB_TOKEN", None) 64 | ``` 65 | You should fill those fields with your own configuration. 66 |
67 | 68 | ## Setup git 69 | 70 | You should create [a new git repository](https://github.com/new) in Github. Make sure you won't create any `README`, `.gitignore` or `LICENSE` files. 71 |
72 | Just an empty repository. 73 | 74 | ![Alt text](./docs/static/image.png) 75 | 76 |
77 | 78 | 79 | Then follow the first option in the Github page to push an existing repository from your terminal. 80 | 81 |
82 | Details screen. 83 | 84 | ![Alt text](./docs/static/image-1.png) 85 | 86 |
87 | 88 | ## Time to setup your nodes! 89 | 90 | Now you should setup your nodes. You can do it running: 91 | ```bash 92 | $ higgsfield setup-nodes 93 | ``` 94 | Which will install all the required tools on your nodes. You might need some patience here, don't worry, it's a one time process. Like this: 95 | ``` 96 | $ higgsfield setup-nodes 97 | INSTALLING DOCKER 98 | ... 99 | INSTALLING INVOKER 100 | ... 101 | SETTING UP DEPLOY KEY 102 | ... 103 | PULLING DOCKER IMAGE 104 | ``` 105 | 106 | 107 |
108 | But if you're stuck... 109 | 110 | 111 | But if you're stuck for some reason on this step, because you haven't added your git origin, then you should try to toggle between `SSH | HTTPS` options on top of Github page. Then try to run the `git remote add origin` command again. 112 | If it's not because of that, then you should try to properly setup your SSH key in `env` file along with the config file in `src/config.py`. 113 | 114 | 115 |
116 | 117 | 118 | ## Run your very first experiment 119 | 120 | You're very close to run your first experiment. Take a look at the `src/experiment.py`. 121 | ```python 122 | @experiment("llama") 123 | @param("size", options=["70b", "13b", "7b"]) 124 | def train_llama(params): 125 | print(f"Training llama with size {params.size}") 126 | ... 127 | ``` 128 | That's exactly the way you will be defining experiments further on. No need for `hydra`, `argparse` or any other boilerplate code. Just define your experiment, then run the following command: 129 | ```bash 130 | $ higgsfield build-experiments 131 | ``` 132 | 133 | Notice anything new? It's a new folder named `.github/workflows` with the following structure: 134 | ``` 135 | .github 136 | └── workflows 137 | ├── run_llama.yml 138 | └── deploy.yml 139 | ``` 140 |
141 | Curious about them? 142 | These files were exactly intended to be your entrypoint to the simplified deploy of your experiments. Now you can just push your code to Github, and it will automatically deploy the code on your nodes. Not only that, it will also allow you to run your training experiments and save the checkpoints! 143 |
144 | 145 | 146 | ### Fasten your seatbelt, it's time to deploy! 147 | You should add your `SSH_KEY` contents into Github secrets. To achieve that you should go to your Github repository page, then click on `Settings` tab, then `Secrets` tab, then `New repository secret` button. Then add your `SSH_KEY` contents as a secret with the name `SSH_KEY`. 148 | 149 |
150 | Like this. 151 | 152 | ![Alt text](./docs/static/image-3.png) 153 | ![Alt text](./docs/static/image-4.png) 154 |
155 | 156 | And add your deploy key into deploy keys. You can get it by running the following command: 157 | ```bash 158 | $ higgsfield show-deploy-key 159 | ssh-ed25519 AAAAC3NzaC1lZDI1NTE5A000THERESHOULDBEYOURDEPLOYKEYHEREso83os// 160 | 161 | ``` 162 | Copy the output and add it. You can name it `DEPLOY_KEY`. 163 |
164 | Like this. 165 | 166 | ![Alt text](./docs/static/image-5.png) 167 | 168 |
169 | 170 | 171 | Push your code: 172 | ```bash 173 | git add . 174 | git commit -m "Make the way for the LLAMA!" 175 | git push origin main 176 | ``` 177 | 178 | Now you should go to the `Actions` tab in your Github repository page. You should see something like this: 179 | 180 | ![Alt text](./docs/static/image-6.png) 181 | 182 | As soon as it turns green (which means it's done), you can go to the left side and run the `run_llama.yml` workflow, put any name you like, and click `Run workflow` button: 183 | ![Alt text](./docs/static/image-8.png) 184 | 185 | Run is running... 186 | ![Alt text](./docs/static/image-9.png) 187 | ![Alt text](./docs/static/image-10.png) 188 | 189 | And finished running! Experiment is training on your nodes! 190 | ![Alt text](./docs/static/image-11.png) 191 | -------------------------------------------------------------------------------- /tutorial.md: -------------------------------------------------------------------------------- 1 | ## Tutorial 2 | 3 | This section runs through the API for common tasks in Large Language Models training. 4 | 5 | ### Working with distributed model 6 | Higgsfield provides simple primitives to work with distributed models. 7 | ```python 8 | from higgsfield.llama import Llama70b 9 | from higgsfield.loaders import LlamaLoader 10 | 11 | import torch.optim as optim 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | from datasets import load_dataset 15 | ``` 16 | 17 | `Llama70b` is ready to use sharded class of Llama 70b model. You can control sharding strategy with arguments. 18 | 19 | ```python 20 | model = Llama70b( 21 | zero_stage=3, 22 | fast_attn=False, 23 | precision="bf16", 24 | ) 25 | ``` 26 | - `zero_stage` argument controls what sharding strategy to use. `zero_stage=3` is set to fully shard the model parameters, gradients and optimizer states. This makes the training of some very large models feasible and helps to fit larger models or batch sizes for our training job. This would come with the cost of increased communication volume. `zero_stage=2` shards only optimizer states and gradients reducing the communication overhead. For more information check [Deepspeed](https://arxiv.org/pdf/1910.02054.pdf)'s and [FSDP](https://arxiv.org/pdf/2304.11277.pdf) papers. 27 | 28 | - `precision` argument supports flexible mixed precision training allowing for types such as bf16 or fp16. Former well-suited for deep learning tasks where numerical stability and convergence are essential. But currently bfloat16 is only available on Ampere GPUs, so you need to confirm native support before you use it. 29 | 30 | - `fast_attn` leverages classical techniques (tiling, recomputation) to significantly speed up attention computation and reduce memory usage from quadratic to linear in sequence length. 31 | 32 | ### Preparing Data 33 | 34 | ```python 35 | class AlpacaDataset: 36 | def __init__(self, dataset_name, split="train"): 37 | self.dataset = load_dataset(dataset_name, split=split) 38 | 39 | def __len__(self): 40 | return len(self.dataset) 41 | 42 | def __getitem__(self, idx): 43 | item = self.dataset[idx] 44 | 45 | instruction = item["instruction"] 46 | 47 | if "input" in item.keys(): 48 | prompt = ( 49 | "Below is an instruction that describes a task. " 50 | "Write a response that appropriately completes the request.\n\n" 51 | f"### Instruction:\n{instruction}\n\n### Response:" 52 | ) 53 | else: 54 | input = item["input"] 55 | 56 | prompt = ( 57 | "Below is an instruction that describes a task, paired with an input that provides further context. " 58 | "Write a response that appropriately completes the request.\n\n" 59 | f"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 60 | ) 61 | 62 | completion = item["output"] 63 | 64 | return { 65 | "prompt": prompt, 66 | "completion": completion, 67 | } 68 | ``` 69 | 70 | ```python 71 | dataset = AplacaDataset("tatsu-lab/alpaca", split="train") 72 | 73 | train_loader = LlamaLoader( 74 | alpaca, 75 | max_sequence_length=2048, 76 | batch_size=64*6, 77 | ) 78 | ``` 79 | 80 | ### Optimizing the Model Parameters 81 | Higgsfield's distributed model works with standard PyTorch training flow. 82 | Creating optimizer and learning scheduler. 83 | ```python 84 | optimizer = optim.AdamW( 85 | model.parameters(), 86 | lr=1e-5, 87 | weight_decay=0.0, 88 | ) 89 | 90 | lr_scheduler = StepLR( 91 | optimizer, 92 | step_size=1, 93 | gamma=0.85, 94 | ) 95 | ``` 96 | 97 | In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and backpropagates the prediction error to adjust the model’s parameters. 98 | 99 | ```python 100 | for epoch in range(3): 101 | for i, batch in enumerate(train_loader): 102 | 103 | optimizer.zero_grad() 104 | loss = model(batch) 105 | loss.backward() 106 | optimizer.step() 107 | 108 | lr_scheduler.step() 109 | ``` 110 | 111 | ### Saving Model 112 | Saving pytorch model. 113 | ```python 114 | model.save("alpaca-70b/model.pt") 115 | ``` 116 | 117 | Saving in hugginface format or push it to the hub 118 | ```python 119 | model.save_huggingface_model("alpaca-hf-70b") 120 | ``` 121 | 122 | Or push it the hub 123 | ```python 124 | model.push_to_hub("alpaca-70b") 125 | ``` 126 | 127 | ## Training stabilization techniques 128 | It's easy to use and customize different training techniques because we follow standard PyTorch workflow. 129 | 130 | ### Gradient accumulation 131 | 132 | ```python 133 | 134 | grad_accumulation_steps = 16 135 | 136 | for epoch in range(3): 137 | for i, batch in enumerate(train_loader): 138 | loss = loss / gradient_accumulation_steps 139 | loss.backward() 140 | 141 | if (i + 1) % grad_accumulation_steps == 0 or i == len(train_loader) - 1: 142 | optimizer.step() 143 | optimizer.zero_grad() 144 | ``` 145 | ### Gradient clipping 146 | 147 | ```python 148 | from higgsfield.training import clip_grad_norm 149 | 150 | max_grad_norm = 1.0 151 | 152 | for epoch in range(3): 153 | for i, batch in enumerate(train_loader): 154 | optimizer.zero_grad() 155 | loss.backward() 156 | 157 | if max_grad_norm: 158 | clip_grad_norm(model, optimizer, max_grad_norm) 159 | 160 | optimizer.step() 161 | ``` 162 | 163 | ### FP16 gradient scaling 164 | ```python 165 | from higgsfield.training import Scaler, clip_grad_norm 166 | 167 | scaler = Scaler(model) 168 | 169 | for epoch in range(3): 170 | for i, batch in enumerate(train_loader): 171 | optimizer.zero_grad() 172 | 173 | scaler.scale(loss).backward() 174 | 175 | if max_grad_norm: 176 | clip_grad_norm(max_grad_norm, model, optimizer, scaler) 177 | 178 | scaler.step(optimizer) 179 | scaler.update() 180 | ``` 181 | 182 | ## Monitoring 183 | 184 | ### Wandb support 185 | You can use Wandb logic inside the project, the only exception and requirement would be to place it under the if condition `if params.rank == 0:`. 186 | 187 | ```python 188 | import wandb 189 | 190 | @experiment("alpaca") 191 | def train(params): 192 | ... 193 | 194 | if params.rank == 0: 195 | wandb.init( 196 | project="My Llama2", 197 | ) 198 | 199 | 200 | for epoch in range(1): 201 | for i, batch in enumerate(train_loader): 202 | 203 | optimizer.zero_grad() 204 | loss = model(batch) 205 | 206 | loss.backward() 207 | optimizer.step() 208 | 209 | if params.rank == 0: 210 | wandb.log({ 211 | "train/loss": loss.item(), 212 | }) 213 | ``` -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Dataset formats 2 | 3 | [https://github.com/higgsfield-ai/tutorials/raw/main/dataset_upload_tutorial_1080.mp4](https://github.com/higgsfield-ai/higgsfield/assets/14979358/ab040041-9cf5-498b-a06d-b9d2a5fc02df 4 | ) 5 | 6 | 7 | We support the following dataset formats: 8 | 9 | - **Prompt Completion format** 10 | - **The ChatGPT format** 11 | - **The Plain Text format** 12 | 13 | 14 | Before uploading your dataset to Hugging Face, please make sure that it is in one of the above formats. 15 | 16 | We provide a tutorial on how to convert your dataset to each format. 17 | 18 | - **Prompt Completion format** ([https://github.com/higgsfield-ai/higgsfield/tutorials/prompt_completion.ipynb](https://github.com/higgsfield-ai/higgsfield/blob/main/tutorials/prompt_completion.ipynb)) 19 | - **ChatGPT format** ([https://github.com/higgsfield-ai/higgsfield/tutorials/chatgpt.ipynb](https://github.com/higgsfield-ai/higgsfield/blob/main/tutorials/chatgpt.ipynb)) 20 | - **Plain Text format** ([https://github.com/higgsfield-ai/higgsfield/tutorials/text_format.ipynb](https://github.com/higgsfield-ai/higgsfield/blob/main/tutorials/text_format.ipynb)) 21 | 22 | ### Upload your datasets to Hugging Face: 23 | 24 | ```python 25 | from datasets import Dataset 26 | dataset = Dataset.from_dict(format) 27 | dataset.push_to_hub("", token="") # Example: 'test/test_dataset' 28 | ``` 29 | ### Format: Prompt Completion 30 | ```python 31 | prompt_completion = { 32 | "prompt": [ 33 | "prompt1", 34 | "prompt2", 35 | ], 36 | "completion": [ 37 | "completion1", 38 | "completion2", 39 | ] 40 | } 41 | ``` 42 | 43 | ### Format: ChatGPT 44 | ```python 45 | chatgpt_format = { 46 | "chatgpt": [ 47 | [ 48 | {"role": "system", "content": "You are a human."}, 49 | {"role": "user", "content": "No I am not."}, 50 | {"role": "assistant", "content": "I am a robot."}, 51 | ], 52 | ] 53 | } 54 | ``` 55 | 56 | ### Format: Text 57 | ```python 58 | text_format = { 59 | "text": ["text"] 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /tutorials/chatgpt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "56b08e7a-f016-4783-a4ff-1bc51bf5534b", 6 | "metadata": {}, 7 | "source": [ 8 | "## The ChatGPT format\n", 9 | "```python\n", 10 | "chatgpt_format = {\n", 11 | " \"chatgpt\": [\n", 12 | " [\n", 13 | " {\"role\": \"system\", \"content\": \"You are a human.\"},\n", 14 | " {\"role\": \"user\", \"content\": \"No I am not.\"},\n", 15 | " {\"role\": \"assistant\", \"content\": \"I am a robot.\"},\n", 16 | " ],\n", 17 | " ]\n", 18 | "}\n", 19 | "```\n", 20 | "### Example: converting a dataset from Hugging Face to the ChatGPT format and uploading to Hugging Face repo" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 3, 26 | "id": "c27e49b3-9ce2-481f-9e55-f008abeadc3d", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "{'text': ': What is a panic attack?\\n: Panic attacks come on suddenly and involve intense and often overwhelming fear. They’re accompanied by very challenging physical symptoms, like a racing heartbeat, shortness of breath, or nausea. Unexpected panic attacks occur without an obvious cause. Expected panic attacks are cued by external stressors, like phobias. Panic attacks can happen to anyone, but having more than one may be a sign of panic disorder, a mental health condition characterized by sudden and repeated panic attacks.'} 172\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "from datasets import load_dataset\n", 39 | "data = load_dataset(\"heliosbrahma/mental_health_chatbot_dataset\")[\"train\"]\n", 40 | "print(data[0], len(data))" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 4, 46 | "id": "351d1a81-c1c8-4bb0-9b8c-e14b54ff73f6", 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "[{'role': 'system', 'content': 'You are a mental health assistant.'}, {'role': 'user', 'content': 'What is a panic attack?\\n'}, {'role': 'assistant', 'content': 'Panic attacks come on suddenly and involve intense and often overwhelming fear. They’re accompanied by very challenging physical symptoms, like a racing heartbeat, shortness of breath, or nausea. Unexpected panic attacks occur without an obvious cause. Expected panic attacks are cued by external stressors, like phobias. Panic attacks can happen to anyone, but having more than one may be a sign of panic disorder, a mental health condition characterized by sudden and repeated panic attacks.'}]\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "chatgpt_format = {\n", 59 | " \"chatgpt\": []\n", 60 | "}\n", 61 | "\n", 62 | "SYSTEM_PROMPT = \"You are a mental health assistant.\"\n", 63 | "for d in data:\n", 64 | " text = d[\"text\"]\n", 65 | " assistant_word_i = text.find(\": \"):]\n", 68 | "\n", 69 | " chatgpt_format[\"chatgpt\"].append(\n", 70 | " [\n", 71 | " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", 72 | " {\"role\": \"user\", \"content\": human_text},\n", 73 | " {\"role\": \"assistant\", \"content\": assistant_text}\n", 74 | " ])\n", 75 | "\n", 76 | "print(chatgpt_format[\"chatgpt\"][0])" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "ab541655-c234-41ca-9f7a-2f20f1ad3c01", 82 | "metadata": {}, 83 | "source": [ 84 | "### Publish to Hugging Face repo" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "88fd3ef0-1bc4-42d6-af93-b670b64a44c9", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "dataset = Dataset.from_dict(chatgpt_format)\n", 95 | "dataset.push_to_hub(\"\", token=\"\") # Example of a dataset repo 'test/test_dataset'" 96 | ] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "Python 3 (ipykernel)", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.8.18" 116 | } 117 | }, 118 | "nbformat": 4, 119 | "nbformat_minor": 5 120 | } 121 | -------------------------------------------------------------------------------- /tutorials/prompt_completion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "68c0040c-7c2f-44a3-934e-4969b14d1ebf", 6 | "metadata": {}, 7 | "source": [ 8 | "## The Prompt Completion Format\n", 9 | "```python\n", 10 | " prompt_completion = {\n", 11 | " \"prompt\": [\n", 12 | " \"prompt1\",\n", 13 | " \"prompt2\",\n", 14 | " ],\n", 15 | " \"completion\": [\n", 16 | " \"completion1\",\n", 17 | " \"completion2\",\n", 18 | " ]\n", 19 | " }\n", 20 | "```\n", 21 | "### Example: converting a dataset from Hugging Face to the Prompt Completion format and uploading to Hugging Face repo" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "id": "7c19feb1-28e6-4ecb-9545-af1f61fef7b9", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from datasets import load_dataset\n", 32 | "from datasets import Dataset \n", 33 | "\n", 34 | "data = load_dataset(\"tatsu-lab/alpaca\")[\"train\"]" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "id": "960dce00-cc00-479e-9b1b-6d8e2fd089d6", 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "{'instruction': 'Give three tips for staying healthy.', 'input': '', 'output': '1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.', 'text': 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\nGive three tips for staying healthy.\\n\\n### Response:\\n1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.'}\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "print(data[0])" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 5, 58 | "id": "893c8c1e-ce97-4d13-a718-ecc24fe5a97f", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "###Instruction: Give three tips for staying healthy.\n", 66 | "###Input: \n", 67 | "###Assistant:\n", 68 | "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n", 69 | "2. Exercise regularly to keep your body active and strong. \n", 70 | "3. Get enough sleep and maintain a consistent sleep schedule.\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "prompt_completion = {\n", 76 | " \"prompt\": [],\n", 77 | " \"completion\": []\n", 78 | "}\n", 79 | "\n", 80 | "for d in data:\n", 81 | " if \"input\" in d.keys():\n", 82 | " prompt = f\"\"\"###Instruction: {d['instruction']}\n", 83 | "###Input: {d['input']}\n", 84 | "###Assistant:\"\"\"\n", 85 | " else:\n", 86 | " prompt = f\"\"\"###Instruction: {d['instruction']}\n", 87 | "###Assistant:\"\"\"\n", 88 | " prompt_completion[\"prompt\"].append(prompt)\n", 89 | " prompt_completion[\"completion\"].append(d[\"output\"])\n", 90 | "\n", 91 | "print(prompt_completion[\"prompt\"][0])\n", 92 | "print(prompt_completion[\"completion\"][0])" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "id": "dd5c509d-5fd5-43b5-b119-6029904ec266", 98 | "metadata": {}, 99 | "source": [ 100 | "### Publish to Hugging Face repo" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "6db95a3f-9bb7-403f-882b-b7e428be7a2c", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "dataset = Dataset.from_dict(prompt_completion)\n", 111 | "dataset.push_to_hub(\"\", token=\"\") # Example of a dataset repo 'test/test_dataset'" 112 | ] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "Python 3 (ipykernel)", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.8.18" 132 | } 133 | }, 134 | "nbformat": 4, 135 | "nbformat_minor": 5 136 | } 137 | -------------------------------------------------------------------------------- /tutorials/text_format.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "595308c8-9feb-4b6c-bfea-922ef8b031e8", 6 | "metadata": {}, 7 | "source": [ 8 | "## The Text format\n", 9 | "### Example: converting the TINY SHAKESPEARE dataset into the \"Text\" format" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "id": "546f924a-df45-48ba-849f-51cf919aca9d", 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "name": "stdout", 20 | "output_type": "stream", 21 | "text": [ 22 | "Text len: 1115394\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import urllib.request\n", 28 | "URL = \"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\"\n", 29 | "shakespeare_data = \"\".join([line.decode('utf-8') for line in urllib.request.urlopen(URL)])\n", 30 | "print(\"Text len: \", len(shakespeare_data))" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "3b84ca7b-d344-4ebe-9751-22d3e1482a3b", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "text_format = {\n", 41 | " \"text\": [shakespeare_data]\n", 42 | "}\n", 43 | "\n", 44 | "from datasets import Dataset\n", 45 | "dataset = Dataset.from_dict(text_format)\n", 46 | "dataset.push_to_hub(\"\", token=\"\") # Example of a dataset repo 'test/test_dataset'" 47 | ] 48 | } 49 | ], 50 | "metadata": { 51 | "kernelspec": { 52 | "display_name": "Python 3 (ipykernel)", 53 | "language": "python", 54 | "name": "python3" 55 | }, 56 | "language_info": { 57 | "codemirror_mode": { 58 | "name": "ipython", 59 | "version": 3 60 | }, 61 | "file_extension": ".py", 62 | "mimetype": "text/x-python", 63 | "name": "python", 64 | "nbconvert_exporter": "python", 65 | "pygments_lexer": "ipython3", 66 | "version": "3.8.18" 67 | } 68 | }, 69 | "nbformat": 4, 70 | "nbformat_minor": 5 71 | } 72 | --------------------------------------------------------------------------------