├── .cirun.yml ├── .github └── workflows │ ├── cla.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── LICENSE.txt ├── README.md ├── docs ├── .gitignore ├── Makefile ├── make.bat └── source │ ├── __init__.py │ ├── _static │ ├── css │ │ └── custom.css │ ├── images │ │ ├── cover.png │ │ ├── demo.svg │ │ ├── demo_code.png │ │ ├── favicon.svg │ │ └── logo.svg │ └── js │ │ └── custom.js │ ├── _templates │ ├── apidoc │ │ ├── module.rst_t │ │ └── package.rst_t │ ├── base.html │ ├── metatags.html │ ├── page.html │ └── sidebar │ │ └── brand.html │ ├── conf.py │ ├── docstrings.py │ ├── index.rst │ └── pages │ ├── advanced_usage │ ├── caching_and_saved_outputs.rst │ ├── creating_a_new_datadreamer_... │ │ ├── index.rst │ │ ├── llm.rst │ │ ├── other.rst │ │ ├── step.rst │ │ └── trainer.rst │ ├── index.rst │ ├── parallelization │ │ ├── index.rst │ │ ├── running_models_on_multiple_gpus.rst │ │ ├── running_steps_in_parallel.rst │ │ └── training_models_on_multiple_gpus.rst │ ├── parameter_efficient_training.rst │ └── quantization.rst │ ├── contributing.rst │ └── get_started │ ├── installation.rst │ ├── motivation_and_design.rst │ ├── overview_guide.rst │ └── quick_tour │ ├── abstract_to_tweet.rst │ ├── aligning.rst │ ├── attributed_prompts.rst │ ├── bootstrapping_machine_translation.rst │ ├── dataset_augmentation.rst │ ├── dataset_cleaning.rst │ ├── index.rst │ ├── instruction_tuning.rst │ ├── openai_distillation.rst │ └── self_rewarding.rst ├── pyproject.toml ├── scripts ├── .cluster │ ├── .gitignore │ ├── _boot.sh │ ├── _command.sh │ ├── _lock.sh │ ├── _setup.sh │ ├── args_log.sh │ ├── config_log.sh │ ├── direct │ │ ├── .gitignore │ │ ├── _direct_config.sh │ │ ├── _direct_env.sh │ │ ├── cancel.sh │ │ ├── cancel_all.sh │ │ ├── interactive_submit.sh │ │ ├── reset_venv.sh │ │ ├── status.sh │ │ ├── status_all.sh │ │ └── submit.sh │ ├── full_reset.sh │ ├── install_log.sh │ ├── reset.sh │ ├── reset_venv.sh │ ├── series_log.sh │ ├── sge │ │ ├── .gitignore │ │ ├── _qsub_config.sh │ │ ├── cancel.sh │ │ ├── cancel_all.sh │ │ ├── htop.sh │ │ ├── interactive_submit.sh │ │ ├── reset_venv.sh │ │ ├── shell.sh │ │ ├── status.sh │ │ ├── status_all.sh │ │ └── submit.sh │ ├── slurm │ │ ├── .gitignore │ │ ├── _sbatch_config.sh │ │ ├── cancel.sh │ │ ├── cancel_all.sh │ │ ├── htop.sh │ │ ├── interactive_submit.sh │ │ ├── nvidia-smi.sh │ │ ├── reset_venv.sh │ │ ├── shell.sh │ │ ├── status.sh │ │ ├── status_all.sh │ │ └── submit.sh │ ├── stderr_log.sh │ ├── stdout_log.sh │ ├── submission_log.sh │ └── tag.sh ├── .githooks │ ├── post-checkout │ └── pre-commit ├── .python-version ├── .vscode │ ├── extensions.json │ ├── settings.json │ └── tasks.json ├── docs.sh ├── format.sh ├── lint.sh ├── package.sh ├── package_publish.sh ├── project.env └── run.sh └── src ├── .env ├── .gitignore ├── .secrets.template.env ├── __cli__.py ├── __init__.py ├── __main__.py ├── _cachable ├── __init__.py ├── _cachable.py └── _parallel_cachable.py ├── _patches ├── __init__.py ├── datasets_reset_state_hack.py └── setfit_import_hack.py ├── _stubs ├── .gitkeep └── datasets │ └── __init__.pyi ├── datadreamer.py ├── datasets ├── __init__.py ├── datasets.py └── utils.py ├── embedders ├── __init__.py ├── embedder.py ├── openai_embedder.py ├── parallel_embedder.py ├── sentence_transformers_embedder.py └── together_embedder.py ├── errors ├── __init__.py └── steps │ ├── __init__.py │ └── step.py ├── llms ├── __init__.py ├── _chat_prompt_templates.py ├── _litellm.py ├── _llm_api.py ├── _tokenizers.py ├── ai21.py ├── anthropic.py ├── bedrock.py ├── cohere.py ├── ctransformers.py ├── google_ai_studio.py ├── hf_api_endpoint.py ├── hf_transformers.py ├── llm.py ├── mistral_ai.py ├── openai.py ├── openai_assistant.py ├── parallel_llm.py ├── petals.py ├── together.py ├── vertex_ai.py └── vllm.py ├── logging ├── __init__.py └── logger.py ├── pickling ├── __init__.py └── pickle.py ├── project ├── __init__.py ├── builtin_tasks.py ├── debug.py ├── devices.py ├── environment.py ├── pennnlp.py ├── persistent_storage.py ├── report.py └── serve.py ├── py.typed ├── requirements-accelerator-device.txt ├── requirements-cpu.txt ├── requirements-dev.txt ├── requirements-test.txt ├── requirements.txt ├── retrievers ├── __init__.py ├── embedding_retriever.py ├── parallel_retriever.py └── retriever.py ├── steps ├── __init__.py ├── data_card.py ├── data_sources │ ├── csv_data_source.py │ ├── data_source.py │ ├── hf_dataset_data_source.py │ ├── hf_hub_data_source.py │ ├── json_data_source.py │ └── text_data_source.py ├── prompt │ ├── _prompt_base.py │ ├── data_from_attributed_prompt.py │ ├── data_from_prompt.py │ ├── few_shot_prompt.py │ ├── few_shot_prompt_with_retrieval.py │ ├── filter_with_prompt.py │ ├── judge_generation_pairs_with_prompt.py │ ├── judge_pairs_with_prompt.py │ ├── process_with_prompt.py │ ├── prompt.py │ ├── rag_prompt.py │ └── rank_with_prompt.py ├── step.py ├── step_background.py ├── step_export.py ├── step_operations.py ├── step_output.py └── tasks │ ├── cosine_similarity.py │ ├── embed.py │ ├── retrieve.py │ └── run_task_model.py ├── task_models ├── __init__.py ├── hf_classification_task_model.py ├── parallel_task_model.py └── task_model.py ├── tests ├── __init__.py ├── conftest.py ├── datasets │ ├── __init__.py │ ├── test_datasets.py │ └── test_utils.py ├── embedders │ ├── __init__.py │ └── test_embedders.py ├── llms │ ├── __init__.py │ └── test_llms.py ├── retrievers │ ├── __init__.py │ └── test_retrievers.py ├── steps │ ├── __init__.py │ ├── prompt │ │ ├── __init__.py │ │ └── test_prompt.py │ ├── tasks │ │ ├── __init__.py │ │ └── test_tasks.py │ ├── test_data_sources.py │ ├── test_step.py │ ├── test_step_background.py │ ├── test_step_export.py │ ├── test_step_operations.py │ └── test_step_output.py ├── task_models │ ├── __init__.py │ └── test_task_models.py ├── test_cli.py ├── test_datadreamer.py ├── test_package.py ├── test_utils │ ├── __init__.py │ ├── config.py │ └── fixtures │ │ ├── __init__.py │ │ ├── bitsandbytes_fixture.py │ │ ├── clear_space.py │ │ ├── cli_runner.py │ │ ├── create_datadreamer.py │ │ ├── create_test_step.py │ │ ├── mock_llm.py │ │ └── restore_os_environ.py ├── trainers │ ├── __init__.py │ ├── test_distributed.py │ └── test_trainers.py └── utils │ ├── __init__.py │ └── test_device_utils.py ├── trainers ├── __init__.py ├── _train_hf_base.py ├── _vendored │ ├── __init__.py │ ├── _dpo_helper.py │ ├── _sentence_transformer_helper.py │ ├── _setfit_helper.py │ └── dpo_trainer.py ├── train_hf_classifier.py ├── train_hf_dpo.py ├── train_hf_finetune.py ├── train_hf_ppo.py ├── train_hf_reward_model.py ├── train_openai_finetune.py ├── train_sentence_transformer.py ├── train_setfit_classifier.py └── trainer.py └── utils ├── __init__.py ├── arg_utils.py ├── background_utils.py ├── collection_utils.py ├── device_utils.py ├── distributed_utils.py ├── fingerprint_utils.py ├── fs_utils.py ├── hf_chat_prompt_templates.py ├── hf_hub_utils.py ├── hf_model_utils.py ├── hf_structured_decoding_utils.py ├── hf_training_utils.py ├── import_utils.py ├── ring_utils.py ├── str_utils.py └── time_utils.py /.cirun.yml: -------------------------------------------------------------------------------- 1 | runners: 2 | - name: "aws-runner" 3 | cloud: "aws" 4 | region: us-east-1 5 | instance_type: "t3.xlarge" 6 | # Ubuntu-22.04, ami image -> Created via EC2 Image Builder: https://us-east-1.console.aws.amazon.com/imagebuilder/home?region=us-east-1#/viewPipelines 7 | machine_image: "ami-0dac1a7772e1ccf34" 8 | preemptible: true 9 | labels: 10 | - "cirun-aws-runner" -------------------------------------------------------------------------------- /.github/workflows/cla.yml: -------------------------------------------------------------------------------- 1 | name: "CLA Assistant" 2 | on: 3 | issue_comment: 4 | types: [created] 5 | pull_request_target: 6 | types: [opened,closed,synchronize] 7 | permissions: 8 | actions: write 9 | contents: write 10 | pull-requests: write 11 | statuses: write 12 | jobs: 13 | CLAAssistant: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: "CLA Assistant" 17 | if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' 18 | uses: contributor-assistant/github-action@v2.3.0 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | with: 22 | path-to-signatures: 'signatures/version1/cla.json' 23 | path-to-document: 'https://gist.github.com/AjayP13/b657de111d8d0907f48ba32eababd911' 24 | branch: 'cla_signatures' 25 | allowlist: AjayP13,dependabot[bot] 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.cluster 2 | /.vscode 3 | .ruff_cache 4 | .venv 5 | .venv_dev 6 | 7 | .venv_poetry 8 | dist/ 9 | poetry.lock 10 | 11 | wandb/ 12 | 13 | .DS_Store 14 | .jupyter_ystore.db 15 | .ipynb_checkpoints 16 | .coverage 17 | .coverage.* 18 | coverage.json 19 | coverage.json* 20 | 21 | .mypy_cache 22 | .pytest_cache 23 | /.python-version 24 | 25 | .tests_data 26 | 27 | # This directory is temporarily generated when building docs or publishing to PyPI. Remove this line before running poetry. (See: https://github.com/python-poetry/poetry/issues/5547) 28 | datadreamer -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | ----------- 3 | 4 | Copyright (c) 2024 Ajay Patel 5 | Permission is hereby granted, free of charge, to any person 6 | obtaining a copy of this software and associated documentation 7 | files (the "Software"), to deal in the Software without 8 | restriction, including without limitation the rights to use, 9 | copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the 11 | Software is furnished to do so, subject to the following 12 | conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 19 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 21 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 22 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 23 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 24 | OTHER DEALINGS IN THE SOFTWARE. 25 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | build 5 | source/*.rst 6 | !source/index.rst -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/docs/source/__init__.py -------------------------------------------------------------------------------- /docs/source/_static/images/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/docs/source/_static/images/cover.png -------------------------------------------------------------------------------- /docs/source/_static/images/demo_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/docs/source/_static/images/demo_code.png -------------------------------------------------------------------------------- /docs/source/_static/js/custom.js: -------------------------------------------------------------------------------- 1 | (function() { 2 | const buttons = document.getElementsByClassName("theme-toggle"); 3 | Array.from(buttons).forEach((btn) => { 4 | btn.addEventListener("click", (e) => { 5 | const currentTheme = localStorage.getItem("theme"); 6 | const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches; 7 | if (currentTheme === (prefersDark ? "dark" : "light")) { 8 | // Skip the "auto" theme 9 | document.querySelector('.theme-toggle').click(); 10 | } 11 | }); 12 | }); 13 | })(); -------------------------------------------------------------------------------- /docs/source/_templates/apidoc/module.rst_t: -------------------------------------------------------------------------------- 1 | {%- if show_headings %} 2 | {{- basename.split('.')[1:] | join('.') | e | heading }} 3 | 4 | {% endif -%} 5 | .. automodule:: {{ qualname }} 6 | {%- for option in automodule_options %} 7 | :{{ option }}: 8 | {%- endfor %} 9 | -------------------------------------------------------------------------------- /docs/source/_templates/apidoc/package.rst_t: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | :maxdepth: {{ maxdepth }} 11 | {% for docname in docnames %} 12 | {{ docname }} 13 | {%- endfor %} 14 | {%- endmacro %} 15 | 16 | {%- if is_namespace %} 17 | {{- pkgname.split('.')[1:] | join('.') | e | heading }} 18 | {% else %} 19 | {{- (pkgname.split('.')[1:] if pkgname.split('.')|length > 1 else [pkgname] ) | join('.') | e | heading }} 20 | {% endif %} 21 | 22 | {%- if is_namespace %} 23 | .. py:module:: {{ pkgname }} 24 | {% endif %} 25 | 26 | {%- if modulefirst and not is_namespace %} 27 | {{ automodule(pkgname, automodule_options) }} 28 | {% endif %} 29 | 30 | {%- if subpackages %} 31 | Subpackages 32 | ----------- 33 | 34 | {{ toctree(subpackages) }} 35 | {% endif %} 36 | 37 | {%- if submodules %} 38 | Submodules 39 | ---------- 40 | {% if separatemodules %} 41 | {{ toctree(submodules) }} 42 | {% else %} 43 | {%- for submodule in submodules %} 44 | {% if show_headings %} 45 | {{- submodule.split('.')[1:] | join('.') | e | heading(2) }} 46 | {% endif %} 47 | {{ automodule(submodule, automodule_options) }} 48 | {% endfor %} 49 | {%- endif %} 50 | {%- endif %} 51 | 52 | {%- if not modulefirst and not is_namespace %} 53 | Module contents 54 | --------------- 55 | 56 | {{ automodule(pkgname, automodule_options) }} 57 | {% endif %} -------------------------------------------------------------------------------- /docs/source/_templates/metatags.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 22 | -------------------------------------------------------------------------------- /docs/source/_templates/sidebar/brand.html: -------------------------------------------------------------------------------- 1 | 20 | 23 | -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/caching_and_saved_outputs.rst: -------------------------------------------------------------------------------- 1 | Caching and Saved Outputs 2 | ####################################################### 3 | 4 | DataDreamer aggressively caches and saves its work at multiple levels to avoid re-computing when possible to be as time- and cost-efficient as possible. 5 | 6 | - **Step Outputs**: DataDreamer caches the results of each step run within a session to the output folder. If a session is interrupted and re-run, DataDreamer will automatically load the results of previously completed steps from disk and resume where it left off. 7 | - **Model Generations and Outputs**: DataDreamer caches the results computed by a :py:class:`~datadreamer.llms.LLM`, :py:class:`~datadreamer.embedders.Embedder` model, etc. 8 | - **Training Checkpoints**: DataDreamer will automatically save and resume from checkpoints when training a model with :py:class:`~datadreamer.trainers.Trainer`. 9 | 10 | Output Folder File Structure 11 | =============================== 12 | 13 | :py:class:`~datadreamer.DataDreamer` sessions write to an output folder where all outputs and caches are saved. Below is a brief description of the output folder structure. 14 | 15 | - **Step Folders**: Each :py:class:`~datadreamer.steps.Step` will produce a named folder within the output folder. The name of the folder is the name of the step, and the folder contains the output dataset of the step within a ``_dataset`` folder. ``step.json`` contains metadata about the step. If a step is run within another step, its folder will be nested under the parent step's folder. 16 | - **Trainer Folders**: Each :py:class:`~datadreamer.trainers.Trainer` will produce a named folder within the output folder. The name of the folder is the name of the trainer, and the folder contains saved checkpoints during training to a ``_checkpoints`` folder and the final trained model to a ``_model`` folder. Various JSON files inside the ``_model`` folder like ``training_args.json`` contain metadata about the training configuration. 17 | - **Cache Folder**: The ``.cache`` folder in the output folder holds the SQLite databases that are used to cache the generations and outputs produced by models like :py:class:`~datadreamer.llms.LLM` or :py:class:`~datadreamer.embedders.Embedder`. 18 | - **Backups Folder**: The ``_backups`` folder in the output folder holds backups of step or trainer folders that have since been invalidated by a newer configuration of that step or trainer. They are kept in case a user reverts to a previous configuration of the step or trainer. -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../index.rst: -------------------------------------------------------------------------------- 1 | Creating a new DataDreamer... 2 | ####################################################### 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | Step 8 | LLM 9 | Trainer 10 | Other -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../llm.rst: -------------------------------------------------------------------------------- 1 | Creating a new LLM 2 | ####################################################### 3 | 4 | To create a new DataDreamer LLM class to support a new LLM library or API service, you will want to subclass 5 | the :py:class:`~datadreamer.llms.LLM` class. You can see example implementations of various LLMs by clicking on the 6 | ``[source]`` links on the :doc:`LLMs <../../../datadreamer.llms>` page. These may be helpful as reference implementations. 7 | 8 | Contributing 9 | ============ 10 | 11 | If you would like to contribute the new LLM class you created to DataDreamer for others to use, see the :doc:`Contributing <../../../pages/contributing>` page. 12 | If applicable, please ensure your implementation includes model metadata, such as a link to the model card, the model's license, and the model's citation 13 | information. -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../other.rst: -------------------------------------------------------------------------------- 1 | Creating a new ... 2 | ####################################################### 3 | 4 | To create a new implementation of an embedder, retriever, or other DataDreamer class, you will want to subclass its base class. 5 | You can see example implementations of various classes by clicking on the ``[source]`` links through the :doc:`API Reference <../../../datadreamer>` pages. 6 | These may be helpful as reference implementations. 7 | 8 | Contributing 9 | ============ 10 | 11 | If you would like to contribute the new class you created to DataDreamer for others to use, see the :doc:`Contributing <../../../pages/contributing>` page. 12 | If applicable, please ensure your implementation includes appropriate metadata such as license and citation information. -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../trainer.rst: -------------------------------------------------------------------------------- 1 | Creating a new Trainer 2 | ####################################################### 3 | 4 | To create a new DataDreamer trainer class to support a new training library or training technique, you will want to subclass 5 | the :py:class:`~datadreamer.trainers.Trainer` class. You can see example implementations of various trainers by clicking on the 6 | ``[source]`` links on the :doc:`Trainers <../../../datadreamer.trainers>` page. These may be helpful as reference implementations. 7 | 8 | Contributing 9 | ============ 10 | 11 | If you would like to contribute the new trainer class you created to DataDreamer for others to use, see the :doc:`Contributing <../../../pages/contributing>` page. 12 | If applicable, please ensure your implementation includes appropriate metadata, such as a link to the model card of the model being trained, the model's license, and 13 | the model and training technique's citation information. -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/index.rst: -------------------------------------------------------------------------------- 1 | Advanced Usage 2 | ####################################################### 3 | 4 | Advanced users may want to configure DataDreamer to their needs or learn more about DataDreamer internals. 5 | This section provides a deeper look into DataDreamer and covers various advanced topics. 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | Caching and Saved Outputs 11 | Creating a New DataDreamer ... 12 | Parallelization 13 | Quantization 14 | Parameter-Efficient Training -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/parallelization/index.rst: -------------------------------------------------------------------------------- 1 | Parallelization 2 | ####################################################### 3 | 4 | Parallelization is a core feature of DataDreamer that allows for performant workflows. Advanced users may want to implement 5 | parallelization in different ways, depending on their use case. This section will cover the different ways parallelization 6 | can be implemented in DataDreamer, including device parallelization (multi-GPU inference and training). 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | Running Steps in Parallel 12 | Running Models on Multiple GPUs 13 | Training Models on Multiple GPUs 14 | -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/parallelization/running_models_on_multiple_gpus.rst: -------------------------------------------------------------------------------- 1 | Running Models on Multiple GPUs 2 | ####################################################### 3 | 4 | There are various ways to run models on multiple GPUs in DataDreamer. 5 | 6 | Large LLMs on Multiple GPUs 7 | =========================== 8 | 9 | To split a large model that cannot fit on a single GPU you can set the ``device_map`` parameter of the 10 | :py:class:`~datadreamer.llms.HFTransformers` class to ``'auto'``. This will automatically split the model by layer 11 | onto your available GPUs. You can also manually specify 12 | `how and where the model should be split `_. 13 | 14 | Smaller Models 15 | ============== 16 | 17 | For smaller models, the :py:class:`~datadreamer.llms.ParallelLLM` wrapper takes in multiple :py:class:`~datadreamer.llms.LLM` objects 18 | and behaves like a single unified :py:class:`~datadreamer.llms.LLM` object that can then be passed to a step like :py:class:`~datadreamer.steps.Prompt`. 19 | :py:class:`~datadreamer.llms.ParallelLLM` will run any inputs it recieves against all of the models in parallel. This is useful for running smaller models on multiple GPUs 20 | as each :py:class:`~datadreamer.llms.LLM` passed to the wrapper can be on a different GPU. Your model must be able to fit on a single GPU 21 | for this to work. 22 | 23 | Similarly, we have other parallelization wrappers for other types of models like :py:class:`~datadreamer.embedders.ParallelEmbedder`, 24 | :py:class:`~datadreamer.retrievers.ParallelRetriever`, etc. 25 | -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/parallelization/running_steps_in_parallel.rst: -------------------------------------------------------------------------------- 1 | Running Steps in Parallel 2 | ####################################################### 3 | 4 | There are two ways to run steps in parallel: 5 | 6 | 1. **Running steps in different processes:** Steps can be run asynchronously in a background process. 7 | To run multiple steps in parallel, you can run them all in the background and then wait for them to completed. 8 | 9 | 2. **Running steps in different threads:** You can group steps into Python functions. You can then run these functions 10 | in parallel using :py:func:`~datadreamer.steps.concurrent`. 11 | 12 | Running steps in different processes 13 | ==================================== 14 | You can run steps in the background by passing the ``background=True`` keyword argument to :py:class:`~datadreamer.steps.Step` construction. 15 | This will run the step in its own background process asynchronously. 16 | 17 | Waiting for :py:attr:`~datadreamer.steps.Step.output` 18 | ----------------------------------------------------- 19 | When you run a step in the background, its output may not be immediately ready, and trying to access 20 | :py:attr:`~datadreamer.steps.Step.output` may raise an exception until the step has completed running in 21 | the background. To wait for a step's output to be ready, you can call :py:func:`~datadreamer.steps.wait` 22 | on the step. This will block until the step's output is ready. 23 | 24 | Running steps in different threads 25 | ================================== 26 | 27 | To run multiple steps in parallel, you can group them into Python functions and run these functions in parallel using threads. You can pass 28 | the functions to :py:func:`~datadreamer.steps.concurrent` to run them in parallel using threading. 29 | -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/parameter_efficient_training.rst: -------------------------------------------------------------------------------- 1 | Parameter-Efficient Training 2 | ####################################################### 3 | 4 | DataDreamer makes setting up parameter-efficient training simple. 5 | You can pass a :py:class:`~peft.PeftConfig` to the ``peft_config`` argument of a class 6 | like :py:class:`~datadreamer.trainers.TrainHFFineTune` to enable parameter-efficient training. -------------------------------------------------------------------------------- /docs/source/pages/advanced_usage/quantization.rst: -------------------------------------------------------------------------------- 1 | Quantization 2 | ####################################################### 3 | 4 | DataDreamer makes setting up quantization simple. You can pass a 5 | `quantization config object `_ 6 | to the ``quantization_config`` argument of a class like :py:class:`~datadreamer.llms.HFTransformers` or 7 | :py:class:`~datadreamer.trainers.TrainHFFineTune` to enable quantization. -------------------------------------------------------------------------------- /docs/source/pages/get_started/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ####################################################### 3 | 4 | You can install DataDreamer via `PyPI `_:: 5 | 6 | .. code-block:: bash 7 | 8 | pip3 install datadreamer.dev 9 | 10 | .. 11 | This page is redirected via the sphinx_reredirect extension. -------------------------------------------------------------------------------- /docs/source/pages/get_started/motivation_and_design.rst: -------------------------------------------------------------------------------- 1 | Motivation and Design 2 | ####################################################### 3 | 4 | DataDreamer is an open-source Python library that is made to help streamline and accelerate increasingly common 5 | LLM-related workflows for ML / NLP researchers and users and help increase the rate of research progress through 6 | features encouraging open science and reproducibility. 7 | 8 | Design Principles 9 | ================= 10 | 11 | A few design principles of DataDreamer are: 12 | 13 | - 🔬 **Research-Grade and Production-Grade:** Implementations of techniques that are consistent with the established work and best practices for correctness and efficiency. 14 | - 💪 **Reproducibility and Robustness:** A focus on `reproducibility and robustness <../../pages/get_started/overview_guide.html#reproducibility>`_. 15 | - 🧩 **Simple with Sensible Defaults:** Simple and easy to get started with little configuration through sensible defaults. 16 | - 🛠️ **Adaptable, Extensible, and Customizable:** Selectively overridable advanced configuration and the ability to support :doc:`new techniques <../../pages/advanced_usage/creating_a_new_datadreamer_.../step>` or :doc:`models <../../pages/advanced_usage/creating_a_new_datadreamer_.../llm>`. 17 | - 👥 **Accessible:** :doc:`Aggressive caching and efficiency techniques <../../pages/get_started/overview_guide>` to make both computationally- and financially-expensive LLM-related workflows more accessible to resource-constrained researchers. 18 | - 🤝 **Community-Driven:** Community members can :doc:`contribute <../../pages/contributing>` to extend DataDreamer's abilities. 19 | 20 | For Anyone 21 | ========== 22 | 23 | While DataDreamer was designed *for researchers, by researchers*, it is also meant to be accessible to 24 | anyone who wants to use it. 25 | 26 | Use in Teaching 27 | =============== 28 | 29 | While DataDreamer was built to help researchers and practitioners implement complex LLM-related workflows, it is extremely simple to use making bleeding-edge models, 30 | techniques, and training accessible to reasonably technical students. 31 | 32 | If you are a university professor of a graduate-level NLP or machine learning course 33 | and would like to trial using DataDreamer in your course for instruction, assignments, or projects please reach out to 34 | `Ajay Patel (ajayp@upenn.edu) `_ and 35 | `Professor Chris Callison-Burch (ccb@upenn.edu) `_ at the University of Pennsylvania. -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/abstract_to_tweet.rst: -------------------------------------------------------------------------------- 1 | Training an "Abstract to Tweet Model" with Fully Synthetic Data 2 | ############################################################### 3 | 4 | In this demonstration, we show how to train a small model to generate tweets summarizing ML research paper abstracts. We use DataDreamer to generate a fully synthetic dataset, distill the knowledge to a small T5 model, and publish both the dataset and model. 5 | 6 | .. raw:: html 7 | 8 | See the resulting synthetic dataset and the trained model. 9 | 10 | .. code-block:: python 11 | 12 | from datadreamer import DataDreamer 13 | from datadreamer.llms import OpenAI 14 | from datadreamer.steps import DataFromPrompt, ProcessWithPrompt 15 | from datadreamer.trainers import TrainHFFineTune 16 | from peft import LoraConfig 17 | 18 | with DataDreamer("./output"): 19 | # Load GPT-4 20 | gpt_4 = OpenAI(model_name="gpt-4") 21 | 22 | # Generate synthetic arXiv-style research paper abstracts with GPT-4 23 | arxiv_dataset = DataFromPrompt( 24 | "Generate Research Paper Abstracts", 25 | args={ 26 | "llm": gpt_4, 27 | "n": 1000, 28 | "temperature": 1.2, 29 | "instruction": ( 30 | "Generate an arXiv abstract of an NLP research paper." 31 | " Return just the abstract, no titles." 32 | ), 33 | }, 34 | outputs={"generations": "abstracts"}, 35 | ) 36 | 37 | # Ask GPT-4 to convert the abstracts to tweets 38 | abstracts_and_tweets = ProcessWithPrompt( 39 | "Generate Tweets from Abstracts", 40 | inputs={"inputs": arxiv_dataset.output["abstracts"]}, 41 | args={ 42 | "llm": gpt_4, 43 | "instruction": ( 44 | "Given the abstract, write a tweet to summarize the work." 45 | ), 46 | "top_p": 1.0, 47 | }, 48 | outputs={"inputs": "abstracts", "generations": "tweets"}, 49 | ) 50 | 51 | # Create training data splits 52 | splits = abstracts_and_tweets.splits(train_size=0.90, validation_size=0.10) 53 | 54 | # Train a model to convert research paper abstracts to tweets 55 | # with the synthetic dataset 56 | trainer = TrainHFFineTune( 57 | "Train an Abstract => Tweet Model", 58 | model_name="google/t5-v1_1-base", 59 | peft_config=LoraConfig(), 60 | ) 61 | trainer.train( 62 | train_input=splits["train"].output["abstracts"], 63 | train_output=splits["train"].output["tweets"], 64 | validation_input=splits["validation"].output["abstracts"], 65 | validation_output=splits["validation"].output["tweets"], 66 | epochs=30, 67 | batch_size=8, 68 | ) 69 | 70 | # Publish and share the synthetic dataset 71 | abstracts_and_tweets.publish_to_hf_hub( 72 | "datadreamer-dev/abstracts_and_tweets", 73 | train_size=0.90, 74 | validation_size=0.10, 75 | ) 76 | 77 | # Publish and share the trained model 78 | trainer.publish_to_hf_hub("datadreamer-dev/abstracts_to_tweet_model") -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/aligning.rst: -------------------------------------------------------------------------------- 1 | Aligning a LLM with Human Preferences 2 | ##################################### 3 | 4 | In order to better align the responses :doc:`instruction-tuned LLMs ` generate to what humans would prefer, we can train LLMs against a reward model or a dataset of human preferences in a process known as `RLHF (Reinforcement Learning with Human Feedback) `_. 5 | 6 | DataDreamer makes this process extremely simple and straightforward to accomplish. We demonstrate it below using LoRA to only train 7 | a fraction of the weights with `DPO `_ (a more stable, and efficient alignment method than traditional RLHF). 8 | 9 | .. code-block:: python 10 | 11 | from datadreamer import DataDreamer 12 | from datadreamer.steps import HFHubDataSource 13 | from datadreamer.trainers import TrainHFDPO 14 | from peft import LoraConfig 15 | 16 | with DataDreamer("./output"): 17 | # Get the DPO dataset 18 | dpo_dataset = HFHubDataSource( 19 | "Get DPO Dataset", "Intel/orca_dpo_pairs", split="train" 20 | ) 21 | 22 | # Keep only 1000 examples as a quick demo 23 | dpo_dataset = dpo_dataset.take(1000) 24 | 25 | # Create training data splits 26 | splits = dpo_dataset.splits(train_size=0.90, validation_size=0.10) 27 | 28 | # Align the TinyLlama chat model with human preferences 29 | trainer = TrainHFDPO( 30 | "Align TinyLlama-Chat", 31 | model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", 32 | peft_config=LoraConfig(), 33 | device=["cuda:0", "cuda:1"], 34 | dtype="bfloat16", 35 | ) 36 | trainer.train( 37 | train_prompts=splits["train"].output["question"], 38 | train_chosen=splits["train"].output["chosen"], 39 | train_rejected=splits["train"].output["rejected"], 40 | validation_prompts=splits["validation"].output["question"], 41 | validation_chosen=splits["validation"].output["chosen"], 42 | validation_rejected=splits["validation"].output["rejected"], 43 | epochs=3, 44 | batch_size=1, 45 | gradient_accumulation_steps=32, 46 | ) 47 | -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/attributed_prompts.rst: -------------------------------------------------------------------------------- 1 | Generating Training Data with Attributed Prompts 2 | ################################################ 3 | 4 | By using the `attributed prompt method `_ of generating training data, we can create a diverse dataset that is more representative of real-world data. 5 | We demonstrate this below by generating movie reviews. 6 | 7 | .. raw:: html 8 | 9 | See the resulting synthetic dataset. 10 | 11 | .. code-block:: python 12 | 13 | from datadreamer import DataDreamer 14 | from datadreamer.llms import OpenAI 15 | from datadreamer.steps import ( 16 | Prompt, 17 | DataSource, 18 | DataFromAttributedPrompt, 19 | ) 20 | 21 | with DataDreamer("./output"): 22 | # Load GPT-4 23 | gpt_4 = OpenAI(model_name="gpt-4") 24 | 25 | # Create prompts to generate attributes for movie reviews 26 | attribute_generation_prompts = DataSource( 27 | "Attribute Generation Prompts", 28 | data={ 29 | "prompts": [ 30 | "Generate the names of 10 movies released in theatres in the past, in a comma separated list.", 31 | "Generate 10 elements of a movie a reviewer might consider, in a comma separated list.", 32 | "Generate 10 adjectives that could describe a movie reviewer's style, in a comma separated list.", 33 | ], 34 | }, 35 | ) 36 | 37 | # Generate the attributes for movie reviews 38 | attributes = Prompt( 39 | "Generate Attributes", 40 | inputs={ 41 | "prompts": attribute_generation_prompts.output["prompts"], 42 | }, 43 | args={ 44 | "llm": gpt_4, 45 | }, 46 | ).output["generations"] 47 | 48 | # Generate movie reviews with varied attributes 49 | movie_reviews = ( 50 | DataFromAttributedPrompt( 51 | "Generate Movie Reviews", 52 | args={ 53 | "llm": gpt_4, 54 | "n": 1000, 55 | "instruction": "Generate a few sentence {review_style} movie review about {movie_name} that focuses on {movie_element}.", 56 | "attributes": { 57 | "movie_name": attributes[0].split(","), 58 | "movie_element": attributes[1].split(","), 59 | "review_style": attributes[2].split(","), 60 | }, 61 | }, 62 | outputs={"generations": "reviews"}, 63 | ) 64 | .select_columns(["reviews"]) 65 | .shuffle() 66 | ) 67 | 68 | # Publish and share the synthetic dataset 69 | movie_reviews.publish_to_hf_hub( 70 | "datadreamer-dev/movie_reviews", 71 | ) 72 | -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/dataset_augmentation.rst: -------------------------------------------------------------------------------- 1 | Augmenting an Existing Dataset 2 | ############################## 3 | 4 | DataDreamer can help augment existing datasets using LLMs. We demonstrate this below by augmenting questions from HotpotQA 5 | with a decomposition of what steps a user would need to take to solve the complex question. 6 | 7 | .. raw:: html 8 | 9 | See the resulting synthetic dataset. 10 | 11 | .. code-block:: python 12 | 13 | from datadreamer import DataDreamer 14 | from datadreamer.llms import OpenAI 15 | from datadreamer.steps import ProcessWithPrompt, HFHubDataSource 16 | 17 | with DataDreamer("./output"): 18 | # Load GPT-4 19 | gpt_4 = OpenAI(model_name="gpt-4") 20 | 21 | # Get HotPot QA questions 22 | hotpot_qa_dataset = HFHubDataSource( 23 | "Get Hotpot QA Questions", 24 | "hotpot_qa", 25 | config_name="distractor", 26 | split="train", 27 | ).select_columns(["question"]) 28 | 29 | # Keep only 1000 questions as a quick demo 30 | hotpot_qa_dataset = hotpot_qa_dataset.take(1000) 31 | 32 | # Ask GPT-4 to decompose the question 33 | questions_and_decompositions = ProcessWithPrompt( 34 | "Generate Decompositions", 35 | inputs={"inputs": hotpot_qa_dataset.output["question"]}, 36 | args={ 37 | "llm": gpt_4, 38 | "instruction": ( 39 | "Given the question which requires multiple steps to solve, give a numbered list of intermediate questions required to solve the question." 40 | "Return only the list, nothing else." 41 | ), 42 | }, 43 | outputs={"inputs": "questions", "generations": "decompositions"}, 44 | ).select_columns(["questions", "decompositions"]) 45 | 46 | # Publish and share the synthetic dataset 47 | questions_and_decompositions.publish_to_hf_hub( 48 | "datadreamer-dev/hotpot_qa_augmented", 49 | ) 50 | -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/dataset_cleaning.rst: -------------------------------------------------------------------------------- 1 | Cleaning an Existing Dataset 2 | ############################ 3 | 4 | DataDreamer can help clean or filter existing datasets using LLMs. We demonstrate this below by filtering a dataset of 5 | news articles to only include those that are about sports. 6 | 7 | .. raw:: html 8 | 9 | See the resulting synthetic dataset. 10 | 11 | .. code-block:: python 12 | 13 | from datadreamer import DataDreamer 14 | from datadreamer.llms import OpenAI 15 | from datadreamer.steps import FilterWithPrompt, HFHubDataSource 16 | 17 | with DataDreamer("./output"): 18 | # Load GPT-4 19 | gpt_4 = OpenAI(model_name="gpt-4") 20 | 21 | # Get news articles 22 | news_dataset = HFHubDataSource( 23 | "Get CNN & Daily Mail News Articles", 24 | "cnn_dailymail", 25 | config_name="3.0.0", 26 | split="test", 27 | ) 28 | 29 | # Keep only 1000 articles as a quick demo 30 | news_dataset = news_dataset.take(1000) 31 | 32 | # Ask GPT-4 to filter the dataset 33 | sports_news_dataset = FilterWithPrompt( 34 | "Filter to only keep sports articles", 35 | inputs={"inputs": news_dataset.output["article"]}, 36 | args={ 37 | "llm": gpt_4, 38 | "instruction": "Is the article about sports? Answer 'Yes' or 'No'.", 39 | }, 40 | ) 41 | 42 | # Publish and share the synthetic dataset 43 | sports_news_dataset.publish_to_hf_hub( 44 | "datadreamer-dev/cnn_dailymail_sports", 45 | ) 46 | -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/index.rst: -------------------------------------------------------------------------------- 1 | Quick Tour 2 | ####################################################### 3 | 4 | Below we outline a few examples of using DataDreamer for various use cases to help you get started. It is by no means exhaustive, but should give you a good idea of what 5 | is possible with DataDreamer and how to use it. For more details on the various components of DataDreamer, please see the :doc:`Overview Guide <../overview_guide>`. 6 | 7 | Synthetic Data Generation 8 | ========================= 9 | 10 | - :doc:`Training an "Abstract to Tweet Model" with Fully Synthetic Data ` 11 | - :doc:`Generating Training Data with Attributed Prompts ` 12 | - :doc:`Distilling GPT-4 Capabilities to GPT-3.5 ` 13 | - :doc:`Augmenting an Existing Dataset ` 14 | - :doc:`Cleaning an Existing Dataset ` 15 | - :doc:`Bootstrapping Synthetic Few-Shot Examples ` 16 | 17 | Instruction-Tuning and Aligning Models 18 | ====================================== 19 | 20 | - :doc:`Instruction-Tuning a LLM ` 21 | - :doc:`Aligning a LLM with Human Preferences (RLHF) ` 22 | - :doc:`Training a Self-Improving LLM with Self-Rewarding (RLAIF) ` 23 | 24 | 25 | 26 | .. toctree:: 27 | :hidden: 28 | 29 | Synthetic Data Generation 30 | ../motivation_and_design 31 | abstract_to_tweet 32 | attributed_prompts 33 | openai_distillation 34 | dataset_augmentation 35 | dataset_cleaning 36 | bootstrapping_machine_translation 37 | Instruction-Tuning and Aligning Models 38 | instruction_tuning 39 | aligning 40 | self_rewarding -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/instruction_tuning.rst: -------------------------------------------------------------------------------- 1 | Instruction-Tuning a LLM 2 | ######################## 3 | 4 | When LLMs are pre-trained, they are pre-trained in a self-supervised mannner to simply predict the next word in a sentence. This can yield 5 | a model that can follow human instructions to some degree, but is often not very effective until this "base model" is fined-tuned on a dataset 6 | of example instructions and responses in a process known as `"instruction-tuning" `_, essentially allowing the model to learn to follow natural 7 | language instructions. 8 | 9 | DataDreamer makes this process extremely simple and straightforward to accomplish. We demonstrate it below using LoRA to only train 10 | a fraction of the weights. 11 | 12 | .. code-block:: python 13 | 14 | from datadreamer import DataDreamer 15 | from datadreamer.steps import HFHubDataSource 16 | from datadreamer.trainers import TrainHFFineTune 17 | from peft import LoraConfig 18 | 19 | with DataDreamer("./output"): 20 | # Get the Alpaca instruction-tuning dataset (cleaned version) 21 | instruction_tuning_dataset = HFHubDataSource( 22 | "Get Alpaca Instruction-Tuning Dataset", "yahma/alpaca-cleaned", split="train" 23 | ) 24 | 25 | # Keep only 1000 examples as a quick demo 26 | instruction_tuning_dataset = instruction_tuning_dataset.take(1000) 27 | 28 | # Some examples taken in an "input", we'll format those into the instruction 29 | instruction_tuning_dataset.map( 30 | lambda row: { 31 | "instruction": ( 32 | row["instruction"] 33 | if len(row["input"]) == 0 34 | else f"Input: {row['input']}\n\n{row['instruction']}" 35 | ), 36 | "output": row["output"], 37 | }, 38 | lazy=False, 39 | ) 40 | 41 | # Create training data splits 42 | splits = instruction_tuning_dataset.splits(train_size=0.90, validation_size=0.10) 43 | 44 | # Define what the prompt template should be when instruction-tuning 45 | chat_prompt_template = "### Instruction:\n{{prompt}}\n\n### Response:\n" 46 | 47 | # Instruction-tune the base TinyLlama model to make it follow instructions 48 | trainer = TrainHFFineTune( 49 | "Instruction-Tune TinyLlama", 50 | model_name="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", 51 | chat_prompt_template=chat_prompt_template, 52 | peft_config=LoraConfig(), 53 | device=["cuda:0", "cuda:1"], 54 | dtype="bfloat16", 55 | ) 56 | trainer.train( 57 | train_input=splits["train"].output["instruction"], 58 | train_output=splits["train"].output["output"], 59 | validation_input=splits["validation"].output["instruction"], 60 | validation_output=splits["validation"].output["output"], 61 | epochs=3, 62 | batch_size=1, 63 | gradient_accumulation_steps=32, 64 | ) 65 | -------------------------------------------------------------------------------- /docs/source/pages/get_started/quick_tour/openai_distillation.rst: -------------------------------------------------------------------------------- 1 | Distilling GPT-4 Capabilities to GPT-3.5 2 | ######################################## 3 | 4 | If you want to make GPT-3.5 (a cheaper, smaller, and faster model) more capable, you can use DataDreamer to distill the capabilities of GPT-4 into GPT-3.5. This will allow you to create a more capable model that is cheaper and faster than GPT-4. 5 | 6 | We demonstrate an example below on the `ELI5 ("Explain it like I'm 5") `_ task. 7 | 8 | .. code-block:: python 9 | 10 | from datadreamer import DataDreamer 11 | from datadreamer.llms import OpenAI 12 | from datadreamer.steps import ProcessWithPrompt, HFHubDataSource 13 | from datadreamer.trainers import TrainOpenAIFineTune 14 | 15 | with DataDreamer("./output"): 16 | # Load GPT-4 17 | gpt_4 = OpenAI(model_name="gpt-4") 18 | 19 | # Get ELI5 questions 20 | eli5_dataset = HFHubDataSource( 21 | "Get ELI5 Questions", 22 | "eli5_category", 23 | split="train", 24 | trust_remote_code=True, 25 | ).select_columns(["title"]) 26 | 27 | # Keep only 1000 examples as a quick demo 28 | eli5_dataset = eli5_dataset.take(1000) 29 | 30 | # Ask GPT-4 to ELI5 31 | questions_and_answers = ProcessWithPrompt( 32 | "Generate Explanations", 33 | inputs={"inputs": eli5_dataset.output["title"]}, 34 | args={ 35 | "llm": gpt_4, 36 | "instruction": ( 37 | 'Given the question, give an "Explain it like I\'m 5" answer.' 38 | ), 39 | "top_p": 1.0, 40 | }, 41 | outputs={"inputs": "questions", "generations": "answers"}, 42 | ) 43 | 44 | # Create training data splits 45 | splits = questions_and_answers.splits(train_size=0.90, validation_size=0.10) 46 | 47 | # Train a model to answer questions in ELI5 style 48 | trainer = TrainOpenAIFineTune( 49 | "Distill capabilities to GPT-3.5", 50 | model_name="gpt-3.5-turbo-1106", 51 | ) 52 | trainer.train( 53 | train_input=splits["train"].output["questions"], 54 | train_output=splits["train"].output["answers"], 55 | validation_input=splits["validation"].output["questions"], 56 | validation_output=splits["validation"].output["answers"], 57 | epochs=30, 58 | batch_size=8, 59 | ) 60 | 61 | -------------------------------------------------------------------------------- /scripts/.cluster/.gitignore: -------------------------------------------------------------------------------- 1 | .lock 2 | .bin/ 3 | output 4 | output/ -------------------------------------------------------------------------------- /scripts/.cluster/_boot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run setup 4 | . ./.cluster/_setup.sh 5 | 6 | # Run the python script 7 | if [ "$PROJECT_CLUSTER_TYPE" == "slurm" ] && [ "$PROJECT_INTERACTIVE" != "1" ]; then 8 | srun -u .cluster/_command.sh "$@" 9 | else 10 | .cluster/_command.sh "$@" 11 | fi 12 | -------------------------------------------------------------------------------- /scripts/.cluster/_command.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Write the command arguments to a file 4 | export PROJECT_COMMAND_PATH=$(mktemp /tmp/projectcmd.XXXXXX) 5 | ( 6 | printf '%s' "$COMMAND_PRE" 7 | printf " " 8 | printf '%s' "$COMMAND_ENTRYPOINT" 9 | printf " " 10 | 11 | # For each argument 12 | for ARG in "$@"; do 13 | printf "$'" 14 | # echo the argument, except: 15 | # * Replace backslashes with escaped backslashes 16 | # * Replace single quotes with escaped single quotes 17 | echo -n "$ARG" | sed -e "s/\\\\/\\\\\\\\/g;" | sed -e "s/'/\\\\'/g;" 18 | # echo `'` 19 | printf "' " 20 | done 21 | 22 | printf " " 23 | printf '%s' "$COMMAND_POST" 24 | ) >"$PROJECT_COMMAND_PATH" 25 | chmod +x "$PROJECT_COMMAND_PATH" 26 | 27 | # Run the script 28 | if [ "$PROJECT_CLUSTER" == "1" ]; then 29 | if [ "$PROJECT_INTERACTIVE" == "1" ]; then 30 | exec 2>&4 1>&3 31 | script -efq "$PROJECT_STDOUT_FILE" -c "$PROJECT_COMMAND_PATH 2> >(tee -a $PROJECT_STDERR_FILE >&2)" 32 | else 33 | $PROJECT_COMMAND_PATH 1>>"$PROJECT_STDOUT_FILE" 2>>"$PROJECT_STDERR_FILE" 34 | fi 35 | else 36 | $PROJECT_COMMAND_PATH 37 | fi 38 | -------------------------------------------------------------------------------- /scripts/.cluster/_lock.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Adapted from: https://stackoverflow.com/questions/1715137/what-is-the-best-way-to-ensure-only-one-instance-of-a-bash-script-is-running 4 | 5 | function lockfile_waithold() { 6 | declare -ir time_beg=$(date '+%s') 7 | declare -ir time_max=7200 8 | 9 | while ! ( 10 | set -o noclobber 11 | echo -e "DATE:$(date)\nUSER:$(whoami)\nPID:$$" \ >.cluster/.lock 12 | ) 2>/dev/null; do 13 | if [ $(($(date '+%s') - ${time_beg})) -gt ${time_max} ]; then 14 | echo "Error: waited too long for lock file .cluster/.lock" 1>&2 15 | return 1 16 | fi 17 | sleep 2 18 | done 19 | 20 | return 0 21 | } 22 | 23 | function lockfile_release() { 24 | rm -f .cluster/.lock 25 | } 26 | 27 | if ! lockfile_waithold; then 28 | exit 1 29 | fi 30 | trap lockfile_release EXIT 31 | -------------------------------------------------------------------------------- /scripts/.cluster/args_log.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Get the job name and task ID 7 | if [[ "$1" =~ ':'$ ]]; then 8 | export FIRST_ARGUMENT="$1" 9 | else 10 | export FIRST_ARGUMENT="$1:" 11 | fi 12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:) 13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:) 14 | 15 | # Fetch the args log 16 | if [ -z "$JOB_NAME" ]; then 17 | echo "Fetching the args log of the last job..." 1>&2; 18 | if [ -f "./output/named/_latest/job/.array" ]; then 19 | if [ -z "$TASK_ID" ]; then 20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2; 21 | exit 1 22 | else 23 | export LOG_PATH=./output/named/_latest/$TASK_ID/args.json 24 | fi 25 | else 26 | export LOG_PATH=./output/named/_latest/args.json 27 | fi 28 | if [ -f "$LOG_PATH" ]; then 29 | cat $LOG_PATH 30 | else 31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2; 32 | exit 1 33 | fi 34 | else 35 | echo "Fetching the args log of the '$JOB_NAME' job..." 1>&2; 36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then 37 | if [ -z "$TASK_ID" ]; then 38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2; 39 | exit 1 40 | else 41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/args.json 42 | fi 43 | else 44 | export LOG_PATH=./output/named/$JOB_NAME/args.json 45 | fi 46 | if [ -f "$LOG_PATH" ]; then 47 | cat "$LOG_PATH" 48 | else 49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2; 50 | exit 1 51 | fi 52 | fi -------------------------------------------------------------------------------- /scripts/.cluster/config_log.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Get the job name and task ID 7 | if [[ "$1" =~ ':'$ ]]; then 8 | export FIRST_ARGUMENT="$1" 9 | else 10 | export FIRST_ARGUMENT="$1:" 11 | fi 12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:) 13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:) 14 | 15 | # Fetch the config log 16 | if [ -z "$JOB_NAME" ]; then 17 | echo "Fetching the config log of the last job..." 1>&2 18 | if [ -f "./output/named/_latest/job/.array" ]; then 19 | if [ -z "$TASK_ID" ]; then 20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 21 | exit 1 22 | else 23 | export LOG_PATH=./output/named/_latest/$TASK_ID/config.json 24 | fi 25 | else 26 | export LOG_PATH=./output/named/_latest/config.json 27 | fi 28 | if [ -f "$LOG_PATH" ]; then 29 | cat $LOG_PATH 30 | else 31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 32 | exit 1 33 | fi 34 | else 35 | echo "Fetching the config log of the '$JOB_NAME' job..." 1>&2 36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then 37 | if [ -z "$TASK_ID" ]; then 38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 39 | exit 1 40 | else 41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/config.json 42 | fi 43 | else 44 | export LOG_PATH=./output/named/$JOB_NAME/config.json 45 | fi 46 | if [ -f "$LOG_PATH" ]; then 47 | cat "$LOG_PATH" 48 | else 49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 50 | exit 1 51 | fi 52 | fi 53 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/.gitignore: -------------------------------------------------------------------------------- 1 | .last_job/ -------------------------------------------------------------------------------- /scripts/.cluster/direct/_direct_config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ################################################################ 4 | # To run a job array define array task IDs here, 5 | # example: ARRAY_TASK_IDS=( 10 20 30 40 50) 6 | # if only one value is provided, this job will not be considered a job array 7 | ARRAY_TASK_IDS=(0) 8 | MIN_MEMORY_PER_TASK=10G 9 | MAX_PARALLEL_TASKS=8 10 | ################################################################ 11 | 12 | # Load environment variables for direct running 13 | source project.env 14 | 15 | # Ensure direct environment variables are provided 16 | source .cluster/direct/_direct_env.sh 17 | 18 | # Redirect to submission log 19 | if [ "$PROJECT_INTERACTIVE" != "1" ]; then 20 | exec 3>&1 4>&2 >.cluster/direct/.last_job/submission.out 2>&1 21 | fi 22 | 23 | # Define task runner 24 | TASK_RUNNER() { 25 | TASK_ID=$1 26 | ARRAY_TASK_IDS_LENGTH=$2 27 | 28 | # Source the user's bashrc 29 | # shellcheck disable=SC1090 30 | source ~/.bashrc 31 | 32 | # Mark that we are using direct 33 | export PROJECT_CLUSTER=1 34 | export PROJECT_CLUSTER_TYPE=direct 35 | 36 | # Set direct dependent environment variables 37 | export PROJECT_VENV=.venv/direct 38 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache 39 | if [ "$ARRAY_TASK_IDS_LENGTH" == "1" ]; then 40 | # Only one task ID detected, therefore this is not a job array 41 | : 42 | else 43 | export PROJECT_TASK_ID=$TASK_ID 44 | fi 45 | 46 | # Store the direct last job information 47 | cp .cluster/direct/_direct_config.sh .cluster/direct/.last_job/resources 48 | echo $PROJECT_CLUSTER_TYPE >.cluster/direct/.last_job/type 49 | echo "$PROJECT_JOB_NAME" >.cluster/direct/.last_job/job_name 50 | echo $$ >.cluster/direct/.last_job/job_id 51 | hostname >.cluster/direct/.last_job/hostname 52 | echo "$PROJECT_CURRENT_DATE" >.cluster/direct/.last_job/date 53 | echo "$PROJECT_CURRENT_COMMIT" >.cluster/direct/.last_job/commit 54 | 55 | # Run the boot script 56 | shift 57 | shift 58 | .cluster/_boot.sh "$@" 59 | } 60 | export -f TASK_RUNNER 61 | 62 | # Run 63 | if [ "$PROJECT_INTERACTIVE" != "1" ]; then 64 | # shellcheck disable=SC1083 65 | parallel --memfree $MIN_MEMORY_PER_TASK --jobs $MAX_PARALLEL_TASKS TASK_RUNNER {.} "${#ARRAY_TASK_IDS[@]}" "$@" ::: "${ARRAY_TASK_IDS[@]}" 66 | else 67 | # shellcheck disable=SC1083 68 | parallel --tty --memfree $MIN_MEMORY_PER_TASK --jobs $MAX_PARALLEL_TASKS TASK_RUNNER {.} "${#ARRAY_TASK_IDS[@]}" "$@" ::: "${ARRAY_TASK_IDS[@]}" 69 | fi 70 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/_direct_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check for the existence of environment variables 4 | if [ -z "$PROJECT_DATA" ]; then 5 | echo "You must define the 'PROJECT_DATA' environment variable in project.env. Set it to a directory where output files will be written and large data can be saved while running." 1>&2 6 | exit 1 7 | fi 8 | if [ -z "$PROJECT_ACCELERATOR_TYPE" ]; then 9 | if [ "$(uname)" != "Darwin" ] && (env -0 | cut -z -f1 -d= | tr '\0' '\n' | grep -q "TPU_"); then 10 | export PROJECT_ACCELERATOR_TYPE=tpu 11 | export PROJECT_VISIBLE_ACCELERATOR_DEVICES=all 12 | echo "Detected an environment with TPUs. For this reason, accelerator device dependencies will not be installed to a virtual environment." 1>&2 13 | export PROJECT_DISABLE_ACCELERATOR_REQUIREMENTS=1 14 | elif [ -x "$(command -v nvidia-smi)" ]; then 15 | export PROJECT_ACCELERATOR_TYPE=cuda 16 | echo "Detected an environment with CUDA GPUs." 1>&2 17 | export PROJECT_VISIBLE_ACCELERATOR_DEVICES=$(seq --separator="," 0 $(($(nvidia-smi --list-gpus | wc -l) - 1))) 18 | fi 19 | fi 20 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/cancel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Confirm 7 | read -r -p "Are you sure? [y/N] " response 8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm again 15 | read -r -p "Are you really sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | : 18 | else 19 | exit 1 20 | fi 21 | 22 | # Fetch the status 23 | if [ -z "$1" ]; then 24 | echo "Canceling the last job..." 1>&2 25 | if [ -f "../output/named/_latest/job/job_id" ]; then 26 | JOB_PID=$(ps -o sid= -p "$(cat ../output/named/_latest/job/job_id)") 27 | kill -9 "$JOB_PID" 28 | else 29 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 30 | exit 1 31 | fi 32 | else 33 | echo "Canceling the '$1' job..." 1>&2 34 | if [ -f "../output/named/$1/job/job_id" ]; then 35 | JOB_PID=$(ps -o sid= -p "$(cat "../output/named/$1/job/job_id")") 36 | kill -9 "$JOB_PID" 37 | else 38 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 39 | exit 1 40 | fi 41 | fi 42 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/cancel_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Confirm 7 | read -r -p "Are you sure? [y/N] " response 8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm again 15 | read -r -p "Are you really sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | echo "Canceling..." 18 | else 19 | exit 1 20 | fi 21 | 22 | # Cancel all jobs 23 | for JOB_FOLDER in ../output/dated/*/; do 24 | if [ -f "$JOB_FOLDER"/job/type ] && [ -f "$JOB_FOLDER"/job/job_id ]; then 25 | if grep -q "direct" "$JOB_FOLDER"/job/type; then 26 | JOB_PID=$(ps -o sid= -p "$(cat "$JOB_FOLDER"/job/job_id)") 27 | if [ -n "$JOB_PID" ]; then 28 | kill -9 "$JOB_PID" 29 | fi 30 | fi 31 | fi 32 | done 33 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/interactive_submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Change directory to project root location 7 | cd ../../ 8 | 9 | # Check if there is already a job running 10 | if ( (.cluster/direct/status_all.sh | grep -q "$USER") 2>/dev/null); then 11 | echo -e "WARNING: There is already a job running!\n" 12 | .cluster/direct/status_all.sh 13 | 14 | # Confirm 15 | echo "" 16 | read -r -p "Are you sure you want to submit another job? [y/N] " response 17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 18 | : 19 | else 20 | exit 1 21 | fi 22 | elif ( 23 | # shellcheck disable=SC2009 24 | ps -aux | grep -v "grep" | grep -q .cluster/direct/_direct_config.sh 25 | ); then 26 | echo -e "WARNING: There is already a job running! It is still initializing...\n" 27 | 28 | # Confirm 29 | echo "" 30 | read -r -p "Are you sure you want to submit another job? [y/N] " response 31 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 32 | : 33 | else 34 | exit 1 35 | fi 36 | fi 37 | 38 | # Get the job name 39 | if [ -z "$1" ]; then 40 | echo "You must submit with a job name as the first argument to this script." 41 | exit 1 42 | else 43 | export JOB_NAME="$1" 44 | fi 45 | 46 | # Submit 47 | mkdir -p .cluster/direct/.last_job 48 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T) 49 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited") 50 | export PROJECT_JOB_NAME="$JOB_NAME" 51 | rm -rf .cluster/.lock 52 | export PROJECT_INTERACTIVE=1 53 | touch .cluster/direct/.last_job/.interactive 54 | echo -n "" >.cluster/direct/.last_job/submission.out 55 | shift 56 | .cluster/direct/_direct_config.sh "$@" 57 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/reset_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Reset the virtual env 7 | rm -rf ../../.venv/direct/* 8 | rm -rf ../../.venv/direct/ 9 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/status.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the status 7 | if [ -z "$1" ]; then 8 | echo "Fetching the status of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/job_id" ]; then 10 | JOB_PID=$(ps -o sid= -p "$(cat ../output/named/_latest/job/job_id)") 11 | if [ -n "$JOB_PID" ]; then 12 | # shellcheck disable=SC2086 13 | ps --forest -wwo user,sid,pid,stat,start,time,%cpu,%mem,args -g $JOB_PID 2>/dev/null 14 | fi 15 | else 16 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 17 | exit 1 18 | fi 19 | else 20 | echo "Fetching the status of the '$1' job..." 1>&2 21 | if [ -f "../output/named/$1/job/job_id" ]; then 22 | JOB_PID=$(ps -o sid= -p "$(cat "../output/named/$1/job/job_id")") 23 | if [ -n "$JOB_PID" ]; then 24 | # shellcheck disable=SC2086 25 | ps --forest -wwo user,sid,pid,stat,start,time,%cpu,%mem,args -g $JOB_PID 2>/dev/null 26 | fi 27 | else 28 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 29 | exit 1 30 | fi 31 | fi 32 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/status_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the status of all jobs 7 | echo "Fetching the status of all jobs..." 1>&2 8 | for JOB_FOLDER in ../output/dated/*/; do 9 | if [ -f "$JOB_FOLDER"/job/type ] && [ -f "$JOB_FOLDER"/job/job_id ] && [ -f "$JOB_FOLDER"/job/job_name ]; then 10 | if grep -q "direct" "$JOB_FOLDER"/job/type; then 11 | JOB_PID=$(ps -o sid= -p "$(cat "$JOB_FOLDER"/job/job_id)") 12 | if [ -n "$JOB_PID" ]; then 13 | echo "------------------------------------------------------------------------" 14 | echo "Job: $(cat "$JOB_FOLDER"/job/job_name) ($(cat "$JOB_FOLDER"/job/date))" 15 | echo "------------------------------------------------------------------------" 16 | # shellcheck disable=SC2086 17 | ps --forest -wwo user,sid,pid,stat,start,time,%cpu,%mem,args -g $JOB_PID 2>/dev/null 18 | fi 19 | fi 20 | fi 21 | done 22 | -------------------------------------------------------------------------------- /scripts/.cluster/direct/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Change directory to project root location 7 | cd ../../ 8 | 9 | # Check if there is already a job running 10 | if ( (.cluster/direct/status_all.sh | grep -q "$USER") 2>/dev/null); then 11 | echo -e "WARNING: There is already a job running!\n" 12 | .cluster/direct/status_all.sh 13 | 14 | # Confirm 15 | echo "" 16 | read -r -p "Are you sure you want to submit another job? [y/N] " response 17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 18 | : 19 | else 20 | exit 1 21 | fi 22 | elif ( 23 | # shellcheck disable=SC2009 24 | ps -aux | grep -v "grep" | grep -q .cluster/direct/_direct_config.sh 25 | ); then 26 | echo -e "WARNING: There is already a job running! It is still initializing...\n" 27 | 28 | # Confirm 29 | echo "" 30 | read -r -p "Are you sure you want to submit another job? [y/N] " response 31 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 32 | : 33 | else 34 | exit 1 35 | fi 36 | fi 37 | 38 | # Get the job name 39 | if [ -z "$1" ]; then 40 | echo "You must submit with a job name as the first argument to this script." 41 | exit 1 42 | else 43 | export JOB_NAME="$1" 44 | fi 45 | 46 | # Submit 47 | rm -rf .cluster/direct/.last_job/* 48 | mkdir -p .cluster/direct/.last_job 49 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T) 50 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited") 51 | export PROJECT_JOB_NAME="$JOB_NAME" 52 | rm -rf .cluster/.lock 53 | shift 54 | setsid .cluster/direct/_direct_config.sh "$@" & 55 | -------------------------------------------------------------------------------- /scripts/.cluster/full_reset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Warn 7 | read -r -p $'Are you sure? You will lose ALL data, logs, experiments, etc. associated with this project.\nType "delete all my data" to confirm: ' response 8 | if [[ "$response" =~ "delete all my data" ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm 15 | read -r -p "Are you sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | : 18 | else 19 | exit 1 20 | fi 21 | 22 | # Confirm again 23 | read -r -p "Are you really sure? [y/N] " response 24 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 25 | echo "Fully resetting..." 26 | else 27 | exit 1 28 | fi 29 | 30 | # Reset all data 31 | rm -rf ../.cluster/.lock 32 | rm -rf ../.cluster/output/*/* || true 33 | rm -rf ../.cluster/output/ || true 34 | rm ../.cluster/output 2>/dev/null || true 35 | rm -rf ../.cluster/*/.last_job || true 36 | rm -rf ../.venv/*/* || true 37 | rm -rf ../.venv/ || true 38 | rm -rf ../.venv_dev/*/* || true 39 | rm -rf ../.venv_dev/ || true 40 | -------------------------------------------------------------------------------- /scripts/.cluster/install_log.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Get the job name and task ID 7 | if [[ "$1" =~ ':'$ ]]; then 8 | export FIRST_ARGUMENT="$1" 9 | else 10 | export FIRST_ARGUMENT="$1:" 11 | fi 12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:) 13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:) 14 | 15 | # Fetch the install log 16 | if [ -z "$JOB_NAME" ]; then 17 | echo "Fetching the install log of the last job..." 1>&2 18 | if [ -f "./output/named/_latest/job/.array" ]; then 19 | if [ -z "$TASK_ID" ]; then 20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 21 | exit 1 22 | else 23 | export LOG_PATH=./output/named/_latest/$TASK_ID/install.out 24 | fi 25 | else 26 | export LOG_PATH=./output/named/_latest/install.out 27 | fi 28 | if [ -f "$LOG_PATH" ]; then 29 | cat $LOG_PATH 30 | else 31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 32 | exit 1 33 | fi 34 | else 35 | echo "Fetching the install log of the '$JOB_NAME' job..." 1>&2 36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then 37 | if [ -z "$TASK_ID" ]; then 38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 39 | exit 1 40 | else 41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/install.out 42 | fi 43 | else 44 | export LOG_PATH=./output/named/$JOB_NAME/install.out 45 | fi 46 | if [ -f "$LOG_PATH" ]; then 47 | cat "$LOG_PATH" 48 | else 49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 50 | exit 1 51 | fi 52 | fi 53 | -------------------------------------------------------------------------------- /scripts/.cluster/reset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Warn 7 | read -r -p $'Are you sure? You will lose all data, logs, etc. associated with this project that are not tagged as an experiment.\nType "i only want experiments" to confirm: ' response 8 | if [[ "$response" =~ "i only want experiments" ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm 15 | read -r -p "Are you sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | : 18 | else 19 | exit 1 20 | fi 21 | 22 | # Confirm again 23 | read -r -p "Are you really sure? [y/N] " response 24 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 25 | echo "Resetting..." 26 | else 27 | exit 1 28 | fi 29 | 30 | # Reset all data 31 | rm -rf ../.cluster/output/committed || true 32 | rm -rf ../.cluster/output/named || true 33 | ( 34 | cd ../.cluster/output/dated || exit 35 | for x in *; do readlink -f ../experiments/* | grep -q "$x" || rm -rf "$x"; done 36 | ) 37 | ( 38 | cd ../.cluster/output/persistent_data || exit 39 | for x in *; do readlink -f ../dated/*/data/* | grep -q "$x" || rm -rf "$x"; done 40 | ) 41 | rm -rf ../.cluster/*/.last_job || true 42 | -------------------------------------------------------------------------------- /scripts/.cluster/reset_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Change directory to project root location 7 | cd ../../ 8 | 9 | # Reset the virtual env 10 | rm -rf ../.venv/*/* || true 11 | rm -rf ../.venv/ || true 12 | rm -rf ../.venv_dev/*/* || true 13 | rm -rf ../.venv_dev/ || true 14 | -------------------------------------------------------------------------------- /scripts/.cluster/series_log.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Get the job name, task ID, and series 7 | if [[ "$1" =~ '/'$ ]]; then 8 | export FIRST_ARGUMENT="$1" 9 | else 10 | export FIRST_ARGUMENT="$1/" 11 | fi 12 | export FIRST_ARGUMENT_=$(echo "$FIRST_ARGUMENT" | cut -f1 -d/) 13 | if [[ "$FIRST_ARGUMENT_" =~ ':'$ ]]; then 14 | export FIRST_ARGUMENT_="$FIRST_ARGUMENT_" 15 | else 16 | export FIRST_ARGUMENT_="$FIRST_ARGUMENT_:" 17 | fi 18 | export SERIES_NAME=$(echo "$FIRST_ARGUMENT" | cut -f2 -d/) 19 | export JOB_NAME=$(echo "$FIRST_ARGUMENT_" | cut -f1 -d:) 20 | export TASK_ID=$(echo "$FIRST_ARGUMENT_" | cut -f2 -d:) 21 | 22 | # Make sure a series name was provided 23 | if [ -z "$SERIES_NAME" ]; then 24 | echo "You must provide the series name with a slash (/)." 1>&2 25 | exit 1 26 | fi 27 | 28 | # Fetch the $SERIES_NAME series log 29 | if [ -z "$JOB_NAME" ]; then 30 | echo "Fetching the $SERIES_NAME series log of the last job..." 1>&2 31 | if [ -f "./output/named/_latest/job/.array" ]; then 32 | if [ -z "$TASK_ID" ]; then 33 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 34 | exit 1 35 | else 36 | export LOG_PATH=./output/named/_latest/$TASK_ID/series/$SERIES_NAME.csv 37 | fi 38 | else 39 | export LOG_PATH=./output/named/_latest/series/$SERIES_NAME.csv 40 | fi 41 | if [ -f "$LOG_PATH" ]; then 42 | cat "$LOG_PATH" 43 | else 44 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 45 | exit 1 46 | fi 47 | else 48 | echo "Fetching the $SERIES_NAME series log of the '$JOB_NAME' job..." 1>&2 49 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then 50 | if [ -z "$TASK_ID" ]; then 51 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 52 | exit 1 53 | else 54 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/series/$SERIES_NAME.csv 55 | fi 56 | else 57 | export LOG_PATH=./output/named/$JOB_NAME/series/$SERIES_NAME.csv 58 | fi 59 | if [ -f "$LOG_PATH" ]; then 60 | perl -pe 's/((?<=,)|(?<=^)),/ ,/g;' <"$LOG_PATH" | column -t -s, | less -S 61 | else 62 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 63 | exit 1 64 | fi 65 | fi 66 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/.gitignore: -------------------------------------------------------------------------------- 1 | .last_job/ -------------------------------------------------------------------------------- /scripts/.cluster/sge/_qsub_config.sh: -------------------------------------------------------------------------------- 1 | #$ -l h_rt=60:00:00 2 | #$ -o .cluster/sge/.last_job/submission.out 3 | #$ -e .cluster/sge/.last_job/submission.out 4 | #$ -cwd 5 | #$ -l mem=1G 6 | #$ -pe parallel-onenode 4 7 | #$ -l h=nlpgrid12 8 | 9 | # Source the user's bashrc 10 | # shellcheck disable=SC1090 11 | source ~/.bashrc 12 | 13 | # Mark that we are using sge 14 | export PROJECT_CLUSTER=1 15 | export PROJECT_CLUSTER_TYPE=sge 16 | 17 | # Set sge dependent environment variables 18 | export PROJECT_VENV=.venv/sge 19 | export PROJECT_DATA=/nlp/data/$USER 20 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache 21 | export PROJECT_JOB_NAME=$JOB_NAME 22 | if [ "$SGE_TASK_ID" != "undefined" ]; then 23 | export PROJECT_TASK_ID=$SGE_TASK_ID 24 | fi 25 | 26 | # Set up global cache directories 27 | if [ ! -e "$PROJECT_CACHE_DIR/huggingface_cache" ]; then 28 | ln -s /nlp/data/huggingface_cache "$PROJECT_CACHE_DIR/huggingface_cache" 29 | fi 30 | if [ ! -e "$PROJECT_CACHE_DIR/sentence_transformers_cache" ]; then 31 | ln -s /nlp/data/huggingface_cache/sentence_transformers "$PROJECT_CACHE_DIR/sentence_transformers_cache" 32 | fi 33 | 34 | # Change directory to submit location 35 | cd "$SGE_O_WORKDIR" || exit 36 | 37 | # Store the sge last job information 38 | cp .cluster/sge/_qsub_config.sh .cluster/sge/.last_job/resources 39 | echo $PROJECT_CLUSTER_TYPE >.cluster/sge/.last_job/type 40 | echo "$PROJECT_JOB_NAME" >.cluster/sge/.last_job/job_name 41 | if [ -z "$PROJECT_TASK_ID" ]; then 42 | echo "$JOB_ID" >.cluster/sge/.last_job/job_id 43 | else 44 | echo "$JOB_ID" >.cluster/sge/.last_job/job_id 45 | fi 46 | echo "$HOSTNAME" >.cluster/sge/.last_job/nodelist 47 | echo "$PROJECT_CURRENT_DATE" >.cluster/sge/.last_job/date 48 | echo "$PROJECT_CURRENT_COMMIT" >.cluster/sge/.last_job/commit 49 | 50 | # Run the boot script 51 | .cluster/_boot.sh "$@" 52 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/cancel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Confirm 7 | read -r -p "Are you sure? [y/N] " response 8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm again 15 | read -r -p "Are you really sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | : 18 | else 19 | exit 1 20 | fi 21 | 22 | # Fetch the status 23 | if [ -z "$1" ]; then 24 | echo "Canceling the last job..." 1>&2 25 | if [ -f "../output/named/_latest/job/job_id" ]; then 26 | qdel "$(cat ../output/named/_latest/job/job_id)" 27 | else 28 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 29 | exit 1 30 | fi 31 | else 32 | echo "Canceling the '$1' job..." 1>&2 33 | if [ -f "../output/named/$1/job/job_id" ]; then 34 | qdel "$(cat ../output/named/"$1"/job/job_id)" 35 | else 36 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 37 | exit 1 38 | fi 39 | fi 40 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/cancel_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Confirm 7 | read -r -p "Are you sure? [y/N] " response 8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm again 15 | read -r -p "Are you really sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | echo "Canceling..." 18 | else 19 | exit 1 20 | fi 21 | 22 | qdel -u "$USER" 23 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/htop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the nodelist 7 | if [ -z "$1" ]; then 8 | echo "Fetching the compute environment of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/nodelist" ]; then 10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist) 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the compute environment of the '$1' job..." 1>&2 17 | if [ -f "../output/named/$1/job/nodelist" ]; then 18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist) 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | 25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 1 hour, this is a time limit to prevent hanging resources)" 1>&2 26 | sleep 2 27 | echo -e "\n" 1>&2 28 | qrsh -now no -cwd -pty y -V -N shell -l h_rt=1:00:00 -pe parallel-onenode 1 -l mem=100M -l h="$NODELIST" htop -u "$USER" 29 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/interactive_submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Change directory to project root location 7 | cd ../../ 8 | 9 | # Check if there is already a job running 10 | if ( (.cluster/sge/status_all.sh | grep -q "$USER") 2>/dev/null); then 11 | echo -e "WARNING: There is already a job running!\n" 12 | .cluster/sge/status_all.sh 13 | 14 | # Confirm 15 | echo "" 16 | read -r -p "Are you sure you want to submit another job? [y/N] " response 17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 18 | : 19 | else 20 | exit 1 21 | fi 22 | fi 23 | 24 | # Get the job name 25 | if [ -z "$1" ]; then 26 | echo "You must submit with a job name as the first argument to this script." 27 | exit 1 28 | else 29 | export JOB_NAME="$1" 30 | fi 31 | 32 | # Submit 33 | mkdir -p .cluster/sge/.last_job 34 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T) 35 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited") 36 | rm -rf .cluster/.lock 37 | export QSUB_ARGS=$(grep -E '^#\$' .cluster/sge/_qsub_config.sh | sed -E 's/\s*#\$\s*//g' | grep -v '^-o' | grep -v '^-e' | tr '\n' ' ') 38 | export PROJECT_INTERACTIVE=1 39 | touch .cluster/sge/.last_job/.interactive 40 | echo -n "" >.cluster/sge/.last_job/submission.out 41 | shift 42 | # shellcheck disable=SC2086 43 | qrsh -now no -cwd -pty y -V -N "$JOB_NAME" $QSUB_ARGS .cluster/sge/_qsub_config.sh "$@" 44 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/reset_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Reset the virtual env 7 | rm -rf ../../.venv/sge/* 8 | rm -rf ../../.venv/sge/ 9 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/shell.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the nodelist 7 | if [ -z "$1" ]; then 8 | echo "Fetching the compute environment of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/nodelist" ]; then 10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist) 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the compute environment of the '$1' job..." 1>&2 17 | if [ -f "../output/named/$1/job/nodelist" ]; then 18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist) 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | 25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 3 hours, this is a time limit to prevent hanging resources)" 1>&2 26 | sleep 2 27 | qrsh -now no -cwd -pty y -N shell -l h_rt=3:00:00 -pe parallel-onenode 1 -l mem=100M -l h="$NODELIST" /bin/bash -c "echo ""; cd ../../; /bin/bash -i" 28 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/status.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the status 7 | if [ -z "$1" ]; then 8 | echo "Fetching the status of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/job_id" ]; then 10 | qstat | grep --color=never -E "job-ID|----------------------------------|$(cat ../output/named/_latest/job/job_id)" 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the status of the '$1' job..." 1>&2 17 | if [ -f "../output/named/$1/job/job_id" ]; then 18 | qstat | grep --color=never -E "job-ID|----------------------------------|$(cat ../output/named/"$1"/job/job_id)" 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/status_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the status of all jobs 7 | echo "Fetching the status of all jobs..." 1>&2 8 | qstat -u "$USER" 9 | -------------------------------------------------------------------------------- /scripts/.cluster/sge/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Change directory to project root location 7 | cd ../../ 8 | 9 | # Check if there is already a job running 10 | if ( (.cluster/sge/status_all.sh | grep -q "$USER") 2>/dev/null); then 11 | echo -e "WARNING: There is already a job running!\n" 12 | .cluster/sge/status_all.sh 13 | 14 | # Confirm 15 | echo "" 16 | read -r -p "Are you sure you want to submit another job? [y/N] " response 17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 18 | : 19 | else 20 | exit 1 21 | fi 22 | fi 23 | 24 | # Get the job name 25 | if [ -z "$1" ]; then 26 | echo "You must submit with a job name as the first argument to this script." 27 | exit 1 28 | else 29 | export JOB_NAME="$1" 30 | fi 31 | 32 | # Submit 33 | rm -rf .cluster/sge/.last_job/* 34 | mkdir -p .cluster/sge/.last_job 35 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T) 36 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited") 37 | rm -rf .cluster/.lock 38 | shift 39 | qsub -N "$JOB_NAME" -V .cluster/sge/_qsub_config.sh "$@" 40 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/.gitignore: -------------------------------------------------------------------------------- 1 | .last_job/ -------------------------------------------------------------------------------- /scripts/.cluster/slurm/_sbatch_config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | #SBATCH --time=2:00:00 4 | #SBATCH --partition=interactive 5 | #SBATCH --output=.cluster/slurm/.last_job/submission.out 6 | #SBATCH --ntasks 1 7 | #SBATCH --cpus-per-task 16 8 | #SBATCH --mem=30G 9 | #SBATCH --gpus=2 10 | 11 | # Source the user's bashrc 12 | # shellcheck disable=SC1090 13 | source ~/.bashrc 14 | 15 | # Mark that we are using slurm 16 | export PROJECT_CLUSTER=1 17 | export PROJECT_CLUSTER_TYPE=slurm 18 | 19 | # Set slurm dependent environment variables 20 | export PROJECT_VENV=.venv/slurm 21 | export PROJECT_DATA=/nlp/data/$USER 22 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache 23 | export PROJECT_JOB_NAME=$SLURM_JOB_NAME 24 | export PROJECT_TASK_ID=$SLURM_ARRAY_TASK_ID 25 | 26 | # Set up global cache directories 27 | if [ ! -e "$PROJECT_CACHE_DIR/huggingface_cache" ]; then 28 | ln -s /nlp/data/huggingface_cache "$PROJECT_CACHE_DIR/huggingface_cache" 29 | fi 30 | if [ ! -e "$PROJECT_CACHE_DIR/sentence_transformers_cache" ]; then 31 | ln -s /nlp/data/huggingface_cache/sentence_transformers "$PROJECT_CACHE_DIR/sentence_transformers_cache" 32 | fi 33 | 34 | # Change directory to submit location 35 | cd "$SLURM_SUBMIT_DIR" || exit 36 | 37 | # Store the slurm last job information 38 | cp .cluster/slurm/_sbatch_config.sh .cluster/slurm/.last_job/resources 39 | echo $PROJECT_CLUSTER_TYPE >.cluster/slurm/.last_job/type 40 | echo "$PROJECT_JOB_NAME" >.cluster/slurm/.last_job/job_name 41 | if [ -z "$PROJECT_TASK_ID" ]; then 42 | echo "$SLURM_JOBID" >.cluster/slurm/.last_job/job_id 43 | else 44 | echo "$SLURM_ARRAY_JOB_ID" >.cluster/slurm/.last_job/job_id 45 | fi 46 | echo "$SLURM_JOB_NODELIST" >.cluster/slurm/.last_job/nodelist 47 | echo "$PROJECT_CURRENT_DATE" >.cluster/slurm/.last_job/date 48 | echo "$PROJECT_CURRENT_COMMIT" >.cluster/slurm/.last_job/commit 49 | 50 | # Run the boot script 51 | .cluster/_boot.sh "$@" 52 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/cancel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Confirm 7 | read -r -p "Are you sure? [y/N] " response 8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm again 15 | read -r -p "Are you really sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | : 18 | else 19 | exit 1 20 | fi 21 | 22 | # Fetch the status 23 | if [ -z "$1" ]; then 24 | echo "Canceling the last job..." 1>&2 25 | if [ -f "../output/named/_latest/job/job_id" ]; then 26 | scancel "$(cat ../output/named/_latest/job/job_id)" 27 | else 28 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 29 | exit 1 30 | fi 31 | else 32 | echo "Canceling the '$1' job..." 1>&2 33 | if [ -f "../output/named/$1/job/job_id" ]; then 34 | scancel "$(cat ../output/named/"$1"/job/job_id)" 35 | else 36 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 37 | exit 1 38 | fi 39 | fi 40 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/cancel_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Confirm 7 | read -r -p "Are you sure? [y/N] " response 8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 9 | : 10 | else 11 | exit 1 12 | fi 13 | 14 | # Confirm again 15 | read -r -p "Are you really sure? [y/N] " response 16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 17 | echo "Canceling..." 18 | else 19 | exit 1 20 | fi 21 | 22 | scancel -u "$USER" 23 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/htop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the nodelist 7 | if [ -z "$1" ]; then 8 | echo "Fetching the compute environment of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/nodelist" ]; then 10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist) 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the compute environment of the '$1' job..." 1>&2 17 | if [ -f "../output/named/$1/job/nodelist" ]; then 18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist) 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | 25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 1 hour, this is a time limit to prevent hanging resources)" 1>&2 26 | sleep 2 27 | echo -e "\n" 1>&2 28 | srun -u --job-name=shell --time=1:00:00 --ntasks 1 --cpus-per-task 1 --mem=100M --nodelist="$NODELIST" --pty htop -u "$USER" 29 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/interactive_submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Change directory to project root location 7 | cd ../../ 8 | 9 | # Check if there is already a job running 10 | if ( (.cluster/slurm/status_all.sh | grep -q "$USER") 2>/dev/null); then 11 | echo -e "WARNING: There is already a job running!\n" 12 | .cluster/slurm/status_all.sh 13 | 14 | # Confirm 15 | echo "" 16 | read -r -p "Are you sure you want to submit another job? [y/N] " response 17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 18 | : 19 | else 20 | exit 1 21 | fi 22 | fi 23 | 24 | # Get the job name 25 | if [ -z "$1" ]; then 26 | echo "You must submit with a job name as the first argument to this script." 27 | exit 1 28 | else 29 | export JOB_NAME="$1" 30 | fi 31 | 32 | # Submit 33 | mkdir -p .cluster/slurm/.last_job 34 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T) 35 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited") 36 | rm -rf .cluster/.lock 37 | export SBATCH_ARGS=$(grep -E '^#SBATCH' .cluster/slurm/_sbatch_config.sh | sed -E 's/\s*#SBATCH\s*//g' | grep -v '^--output' | tr '\n' ' ') 38 | export PROJECT_INTERACTIVE=1 39 | touch .cluster/slurm/.last_job/.interactive 40 | echo -n "" >.cluster/slurm/.last_job/submission.out 41 | shift 42 | # shellcheck disable=SC2086 43 | srun -u --job-name="$JOB_NAME" $SBATCH_ARGS --pty .cluster/slurm/_sbatch_config.sh "$@" 44 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/nvidia-smi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the nodelist 7 | if [ -z "$1" ]; then 8 | echo "Fetching the compute environment of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/nodelist" ]; then 10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist) 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the compute environment of the '$1' job..." 1>&2 17 | if [ -f "../output/named/$1/job/nodelist" ]; then 18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist) 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | 25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 1 hour, this is a time limit to prevent hanging resources)" 1>&2 26 | sleep 2 27 | echo -e "\n" 1>&2 28 | srun -u --job-name=shell --time=1:00:00 --ntasks 1 --cpus-per-task 1 --mem=100M --nodelist="$NODELIST" --pty watch -n0.5 nvidia-smi 29 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/reset_venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Reset the virtual env 7 | rm -rf ../../.venv/slurm/* 8 | rm -rf ../../.venv/slurm/ 9 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/shell.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the nodelist 7 | if [ -z "$1" ]; then 8 | echo "Fetching the compute environment of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/nodelist" ]; then 10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist) 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the compute environment of the '$1' job..." 1>&2 17 | if [ -f "../output/named/$1/job/nodelist" ]; then 18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist) 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | 25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 3 hours, this is a time limit to prevent hanging resources)" 1>&2 26 | sleep 2 27 | echo -e "\n" 1>&2 28 | srun -u --job-name=shell --time=3:00:00 --ntasks 1 --cpus-per-task 1 --mem=100M --nodelist="$NODELIST" --pty /bin/bash -c "cd ../../; /bin/bash" 29 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/status.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the status 7 | if [ -z "$1" ]; then 8 | echo "Fetching the status of the last job..." 1>&2 9 | if [ -f "../output/named/_latest/job/job_id" ]; then 10 | squeue -j "$(cat ../output/named/_latest/job/job_id)" 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the status of the '$1' job..." 1>&2 17 | if [ -f "../output/named/$1/job/job_id" ]; then 18 | squeue -j "$(cat ../output/named/"$1"/job/job_id)" 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/status_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the status of all jobs 7 | echo "Fetching the status of all jobs..." 1>&2 8 | squeue -u "$USER" 9 | -------------------------------------------------------------------------------- /scripts/.cluster/slurm/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Change directory to project root location 7 | cd ../../ 8 | 9 | # Check if there is already a job running 10 | if ( (.cluster/slurm/status_all.sh | grep -q "$USER") 2>/dev/null); then 11 | echo -e "WARNING: There is already a job running!\n" 12 | .cluster/slurm/status_all.sh 13 | 14 | # Confirm 15 | echo "" 16 | read -r -p "Are you sure you want to submit another job? [y/N] " response 17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 18 | : 19 | else 20 | exit 1 21 | fi 22 | fi 23 | 24 | # Get the job name 25 | if [ -z "$1" ]; then 26 | echo "You must submit with a job name as the first argument to this script." 27 | exit 1 28 | else 29 | export JOB_NAME="$1" 30 | fi 31 | 32 | # Submit 33 | rm -rf .cluster/slurm/.last_job/* 34 | mkdir -p .cluster/slurm/.last_job 35 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T) 36 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited") 37 | rm -rf .cluster/.lock 38 | shift 39 | sbatch --job-name="$JOB_NAME" .cluster/slurm/_sbatch_config.sh "$@" 40 | -------------------------------------------------------------------------------- /scripts/.cluster/stderr_log.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Get the job name and task ID 7 | if [[ "$1" =~ ':'$ ]]; then 8 | export FIRST_ARGUMENT="$1" 9 | else 10 | export FIRST_ARGUMENT="$1:" 11 | fi 12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:) 13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:) 14 | 15 | # Fetch the stderr log 16 | if [ -z "$JOB_NAME" ]; then 17 | echo "Fetching the stderr log of the last job..." 1>&2 18 | if [ -f "./output/named/_latest/job/.array" ]; then 19 | if [ -z "$TASK_ID" ]; then 20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 21 | exit 1 22 | else 23 | export LOG_PATH=./output/named/_latest/$TASK_ID/stderr.out 24 | fi 25 | else 26 | export LOG_PATH=./output/named/_latest/stderr.out 27 | fi 28 | if [ -f "$LOG_PATH" ]; then 29 | cat "$LOG_PATH" 30 | else 31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 32 | exit 1 33 | fi 34 | else 35 | echo "Fetching the stderr log of the '$JOB_NAME' job..." 1>&2 36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then 37 | if [ -z "$TASK_ID" ]; then 38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 39 | exit 1 40 | else 41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/stderr.out 42 | fi 43 | else 44 | export LOG_PATH=./output/named/$JOB_NAME/stderr.out 45 | fi 46 | if [ -f "$LOG_PATH" ]; then 47 | cat "$LOG_PATH" 48 | else 49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 50 | exit 1 51 | fi 52 | fi 53 | -------------------------------------------------------------------------------- /scripts/.cluster/stdout_log.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Get the job name and task ID 7 | if [[ "$1" =~ ':'$ ]]; then 8 | export FIRST_ARGUMENT="$1" 9 | else 10 | export FIRST_ARGUMENT="$1:" 11 | fi 12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:) 13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:) 14 | 15 | # Fetch the stdout log 16 | if [ -z "$JOB_NAME" ]; then 17 | echo "Fetching the stdout log of the last job..." 1>&2 18 | if [ -f "./output/named/_latest/job/.array" ]; then 19 | if [ -z "$TASK_ID" ]; then 20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 21 | exit 1 22 | else 23 | export LOG_PATH=./output/named/_latest/$TASK_ID/stdout.out 24 | fi 25 | else 26 | export LOG_PATH=./output/named/_latest/stdout.out 27 | fi 28 | if [ -f "$LOG_PATH" ]; then 29 | cat $LOG_PATH 30 | else 31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 32 | exit 1 33 | fi 34 | else 35 | echo "Fetching the stdout log of the '$JOB_NAME' job..." 1>&2 36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then 37 | if [ -z "$TASK_ID" ]; then 38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2 39 | exit 1 40 | else 41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/stdout.out 42 | fi 43 | else 44 | export LOG_PATH=./output/named/$JOB_NAME/stdout.out 45 | fi 46 | if [ -f "$LOG_PATH" ]; then 47 | cat "$LOG_PATH" 48 | else 49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 50 | exit 1 51 | fi 52 | fi 53 | -------------------------------------------------------------------------------- /scripts/.cluster/submission_log.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the submission log 7 | if [ -z "$1" ]; then 8 | echo "Fetching the submission log of the last job..." 1>&2 9 | if [ -f "./output/named/_latest/job/submission.out" ]; then 10 | cat ./output/named/_latest/job/submission.out 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the submission log of the '$1' job..." 1>&2 17 | if [ -f "./output/named/$1/job/submission.out" ]; then 18 | cat ./output/named/"$1"/job/submission.out 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/.cluster/tag.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")" || exit 5 | 6 | # Fetch the run dir 7 | if [ -z "$1" ]; then 8 | echo "Fetching the run output of the last job..." 1>&2 9 | if [ -e "./output/named/_latest" ]; then 10 | export RUN_DIR=$(readlink -f ./output/named/_latest) 11 | else 12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 13 | exit 1 14 | fi 15 | else 16 | echo "Fetching the run output of the '$1' job..." 1>&2 17 | if [ -e "./output/named/$1" ]; then 18 | export RUN_DIR=$(readlink -f ./output/named/"$1") 19 | else 20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2 21 | exit 1 22 | fi 23 | fi 24 | 25 | # Fetch the experiment dir 26 | export JOB_NAME=$(cat "$RUN_DIR"/job/job_name) 27 | read -r -p "What do you want to name the experiment (default: '$JOB_NAME')? " EXPERIMENT_NAME 28 | if [ -z "$EXPERIMENT_NAME" ]; then 29 | export EXPERIMENT_NAME=$JOB_NAME 30 | fi 31 | export EXPERIMENT_DIR=./output/experiments/$EXPERIMENT_NAME 32 | 33 | # Store the run as an experiment 34 | if [ -e "$EXPERIMENT_DIR" ]; then 35 | echo "This '$EXPERIMENT_NAME' experiment already exists. Do you want to replace it?" 36 | 37 | # Confirm 38 | read -r -p "Are you sure? [y/N] " response 39 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 40 | : 41 | else 42 | exit 1 43 | fi 44 | 45 | # Confirm again 46 | read -r -p "Are you really sure? [y/N] " response 47 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then 48 | : 49 | else 50 | exit 1 51 | fi 52 | 53 | rm "$EXPERIMENT_DIR" 54 | fi 55 | ln -s "$RUN_DIR" "$EXPERIMENT_DIR" && echo -e "\nDone. You can find the experiment at: $( 56 | cd "$EXPERIMENT_DIR" || exit 57 | pwd 58 | )" 59 | -------------------------------------------------------------------------------- /scripts/.githooks/post-checkout: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Change directory to script location 5 | cd "$(dirname "$0")/../../" || exit 6 | 7 | # Symlink .cluster and .vscode at the top-level 8 | if [ ! -L ./.cluster ] || [ ! -e ./.cluster ]; then 9 | rm ./.cluster 1>/dev/null 2>/dev/null || true 10 | rm -rf ./.cluster 1>/dev/null 2>/dev/null || true 11 | ln -s "$(realpath ./scripts/.cluster)" ./.cluster 12 | fi 13 | if [ ! -L ./.vscode ] || [ ! -e ./.vscode ]; then 14 | rm ./.vscode 1>/dev/null 2>/dev/null || true 15 | rm -rf ./.vscode 1>/dev/null 2>/dev/null || true 16 | ln -s "$(realpath ./scripts/.vscode)" ./.vscode 17 | fi 18 | 19 | # Don't track local changes to project.env 20 | git update-index --skip-worktree ./scripts/project.env -------------------------------------------------------------------------------- /scripts/.githooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")/../../" || exit 5 | 6 | # Format and lint 7 | ./scripts/format.sh 8 | 9 | # Check if lint failed 10 | retval=$? 11 | if [ $retval -ne 0 ]; then 12 | echo "Lint failed. Please fix lint errors. Commit aborted." 13 | exit $retval 14 | fi 15 | -------------------------------------------------------------------------------- /scripts/.python-version: -------------------------------------------------------------------------------- 1 | 3.10.9 2 | -------------------------------------------------------------------------------- /scripts/.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-vscode-remote.remote-ssh", 4 | "mechatroner.rainbow-csv", 5 | "ms-python.python", 6 | "ms-python.vscode-pylance", 7 | "charliermarsh.ruff", 8 | "ms-python.mypy-type-checker", 9 | "visualstudioexptteam.vscodeintellicode", 10 | "njpwerner.autodocstring", 11 | "iliazeus.vscode-ansi", 12 | "timonwong.shellcheck", 13 | "foxundermoon.shell-format", 14 | "mikestead.dotenv", 15 | "bungcip.better-toml", 16 | "github.copilot" 17 | ] 18 | } -------------------------------------------------------------------------------- /scripts/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "terminal.integrated.scrollback": 100000, 3 | "[python]": { 4 | "editor.tabSize": 4, 5 | "editor.insertSpaces": true, 6 | "editor.rulers": [88], 7 | "editor.defaultFormatter": "charliermarsh.ruff", 8 | "editor.formatOnSave": true, 9 | "editor.codeActionsOnSave": { 10 | "source.organizeImports": "explicit", 11 | "source.fixAll": "explicit" 12 | }, 13 | }, 14 | "[shellscript]": { 15 | "editor.tabSize": 4, 16 | "editor.insertSpaces": true, 17 | "editor.formatOnSave": true, 18 | "editor.formatOnPaste": true, 19 | "editor.rulers": [88], 20 | }, 21 | "python.terminal.activateEnvironment": false, 22 | "python.linting.flake8Enabled": false, 23 | "python.analysis.extraPaths": [ 24 | "./src", 25 | "./.venv/slurm/lib64/python3.10/site-packages", 26 | "./.venv/sge/lib64/python3.10/site-packages", 27 | "./.venv/direct/lib64/python3.10/site-packages" 28 | ], 29 | "autoDocstring.docstringFormat": "google", 30 | "autoDocstring.guessTypes": false, 31 | "autoDocstring.quoteStyle": "\"\"\"", 32 | "shellcheck.exclude": ["2155", "1091", "2004"], 33 | "search.exclude": { 34 | "**/.git": true, 35 | "**/.cluster": true, 36 | "**/.venv": true, 37 | "**/.venv_dev": true, 38 | "**/wandb/**": true, 39 | "**/docs/build": true, 40 | "**/docs/source/!(index.rst)": true, 41 | }, 42 | "files.watcherExclude": { 43 | "**/.git/**": true, 44 | "**/.cluster/**": true, 45 | "**/.venv/**": true, 46 | "**/.venv_dev/**": true, 47 | "**/wandb/**": true, 48 | "**/docs/build/**": true, 49 | "**/docs/source/!(index.rst)": true, 50 | } 51 | } -------------------------------------------------------------------------------- /scripts/.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "Format Project", 8 | "type": "shell", 9 | "command": "./scripts/format.sh", 10 | "problemMatcher": [] 11 | } 12 | ] 13 | } -------------------------------------------------------------------------------- /scripts/format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")/../" || exit 5 | 6 | # Create and activate a virtual environment for dev requirements 7 | export PROJECT_VENV=.venv_dev/dev 8 | mkdir -p $PROJECT_VENV 9 | python3 -m venv $PROJECT_VENV 10 | source $PROJECT_VENV/bin/activate 11 | pip3 install -q -U pip 12 | pip3 install -q -r src/requirements-dev.txt 13 | 14 | # Ruff format: https://github.com/astral-sh/ruff/issues/8232 15 | ruff check --select I --fix . 16 | python3 -m ruff format src/ 17 | 18 | # Lint 19 | python3 -m ruff check src/ --fix 20 | -------------------------------------------------------------------------------- /scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")/../" || exit 5 | 6 | # Create and activate a virtual environment for dev requirements 7 | export PROJECT_VENV=.venv_dev/dev 8 | mkdir -p $PROJECT_VENV 9 | python3 -m venv $PROJECT_VENV 10 | source $PROJECT_VENV/bin/activate 11 | pip3 install -q -U pip 12 | pip3 install -q -r src/requirements-dev.txt 13 | 14 | # Lint 15 | python3 -m ruff check src/ --fix 16 | -------------------------------------------------------------------------------- /scripts/package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Change directory to script location 5 | cd "$(dirname "$0")/../" || exit 6 | 7 | # Create and activate a virtual environment for the build 8 | export PROJECT_VENV=.venv_poetry 9 | mkdir -p $PROJECT_VENV 10 | python3 -m venv $PROJECT_VENV 11 | source $PROJECT_VENV/bin/activate 12 | pip3 install -q -U pip 13 | 14 | export PACKAGE_NAME=$(grep "include = " pyproject.toml | head -n1 | cut -d'"' -f 2 | awk '{print tolower($0)}') 15 | export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring 16 | export POETRY_HOME=/nlp/data/ajayp/.cache/pypoetry 17 | mkdir -p $POETRY_HOME 18 | export POETRY_CACHE_DIR=/nlp/data/ajayp/.cache/pypoetry/cache 19 | mkdir -p $POETRY_CACHE_DIR 20 | 21 | echo "Setting up..." 22 | cp pyproject.toml pyproject.toml.bak 23 | poetry add $(cat src/requirements.txt | grep -v "pandas-stubs>=") 24 | poetry add $(cat src/requirements-cpu.txt | sed 's/+cpu//' | grep -v "find-links" | grep -v "torchvision=" | grep -v "torchaudio=" | grep -v "tensorflow=") 25 | poetry add --group dev $(cat src/requirements-dev.txt) 26 | mv src/ $PACKAGE_NAME/ 27 | 28 | echo "Building 'dist' folder..." 29 | rm -rf dist || true 30 | cp .gitignore .gitignore.bak 31 | cat .gitignore.bak | grep -v '^\s*datadreamer\s*$' >.gitignore 32 | poetry build 33 | mv .gitignore.bak .gitignore 34 | 35 | echo "Cleaning up..." 36 | mv pyproject.toml.bak pyproject.toml 37 | mv $PACKAGE_NAME/ src/ 38 | if [[ $* != *--keep-venv* ]]; then 39 | rm -rf ./.venv_poetry 40 | fi 41 | rm -rf poetry.lock 42 | -------------------------------------------------------------------------------- /scripts/package_publish.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")/../" || exit 5 | 6 | poetry publish "$@" 7 | -------------------------------------------------------------------------------- /scripts/project.env: -------------------------------------------------------------------------------- 1 | export PROJECT_JOB_NAME="test" # This makes run.sh run with `pytest` instead of `python3` 2 | export PROJECT_DATA=~/.datadreamer_dev/ # Where project dependencies will be installed and stored 3 | export PROJECT_DISABLE_TUNNEL=1 # Disables certain dependencies that are not required 4 | 5 | # API Keys and Tokens 6 | # export HUGGING_FACE_HUB_TOKEN="your huggingface_hub token" # (optional) Some tests require a Hugging Face Hub token 7 | # export OPENAI_API_KEY="your_openai_api_key" # (optional) Some tests OpenAI API key 8 | 9 | # You can un-comment the line below to make subsequent runs faster 10 | # after project dependencies have been installed. 11 | # export PROJECT_SKIP_INSTALL_REQS=1 # Skip installing reqs -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change directory to script location 4 | cd "$(dirname "$0")/../" || exit 5 | 6 | # Set environment variables 7 | export PROJECT_VENV=.venv/prod 8 | 9 | # Load user-specified project environment variables 10 | source ./scripts/project.env 11 | 12 | # Ensure direct environment variables are provided 13 | source .cluster/direct/_direct_env.sh 14 | 15 | # Load user-specified project environment variables 16 | source ./scripts/project.env 17 | 18 | # Set environment variables 19 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache 20 | 21 | # Run the boot script 22 | .cluster/_boot.sh "$@" 23 | -------------------------------------------------------------------------------- /src/.env: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # If you have secret environment variables (API keys or passwords), do NOT 3 | # place them in this file. 4 | # Instead, create a file called ".secrets.env". 5 | # ------------------------------------------------------------------------ 6 | # Define any environment variables you want set when running below. 7 | # Example: 8 | # export FOO="bar" 9 | # ------------------------------------------------------------------------ 10 | 11 | # Output in unbuffered mode 12 | export PYTHONUNBUFFERED=1 13 | 14 | # Control log level 15 | export LOGURU_LEVEL="TRACE" 16 | 17 | # Disable TensorFlow from allocating all GPU memory on startup 18 | export TF_FORCE_GPU_ALLOW_GROWTH=true 19 | 20 | # Set cache directories 21 | export NLTK_DATA=$PROJECT_CACHE_DIR/nltk 22 | mkdir -p $NLTK_DATA 23 | export HF_HOME=$PROJECT_CACHE_DIR/huggingface_cache_datadreamer 24 | mkdir -p $HF_HOME 25 | export SENTENCE_TRANSFORMERS_HOME=$PROJECT_CACHE_DIR/sentence_transformers_cache_datadreamer 26 | mkdir -p $SENTENCE_TRANSFORMERS_HOME -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .secrets.env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ -------------------------------------------------------------------------------- /src/.secrets.template.env: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # If you have secret environment variables (API keys or passwords), place 3 | # them in this file and rename it to ".secrets.env" instead of 4 | # ".secrets.template.env". 5 | # ------------------------------------------------------------------------ 6 | # Define any environment variables you want set when running below. 7 | # Example: 8 | # export FOO="bar" 9 | # ------------------------------------------------------------------------ 10 | 11 | # ngrok 12 | export NGROK_AUTHTOKEN=your_api_key 13 | 14 | # Weights & Biases 15 | export WANDB_ENTITY="username" 16 | export WANDB_MODE="online" # or disabled 17 | export WANDB_API_KEY=your_api_key 18 | 19 | # Hugging Face Hub 20 | export HUGGINGFACEHUB_API_TOKEN=your_api_key 21 | 22 | # OpenAI 23 | export OPENAI_API_KEY=your_api_key 24 | -------------------------------------------------------------------------------- /src/__cli__.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from . import project 4 | 5 | 6 | # Register main 7 | @click.group() 8 | @click.pass_context 9 | def _main(*args, **kwargs): # pragma: no cover 10 | # Run init 11 | project.init() 12 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | DataDreamer Sessions 4 | ==================== 5 | 6 | You can run prompting, synthetic data generation, and training workflows within 7 | a DataDreamer session using a context manager like so: 8 | 9 | .. code-block:: python 10 | 11 | from datadreamer import DataDreamer 12 | 13 | with DataDreamer('./output/'): 14 | # ... run steps or trainers here ... 15 | 16 | Inside the ``with`` block, you can run any :py:class:`~datadreamer.steps.Step` or 17 | :py:class:`~datadreamer.trainers.Trainer` you want. DataDreamer will automatically 18 | organize, cache, and save the results of each step run within a session to the output 19 | folder. 20 | 21 | In-Memory Sessions 22 | ------------------------------------------------ 23 | 24 | Optionally, you can run DataDreamer fully in-memory, without it saving anything to disk, 25 | by passing ``':memory:'`` as the ``output_folder_path`` argument like 26 | ``with DataDreamer(':memory:'):``. 27 | 28 | Sessions in Interactive Environments 29 | ------------------------------------------------ 30 | 31 | As an alternative to using a Python context manager (``with`` block), you can also 32 | structure your code with :py:meth:`~DataDreamer.start` and :py:meth:`~DataDreamer.stop` 33 | to achieve the same result. Using the context manager, however, is recommended and 34 | preferred. Using :py:meth:`~DataDreamer.start` and :py:meth:`~DataDreamer.stop` may be 35 | useful if you want to run DataDreamer in a Jupyter or Google Colab notebook or 36 | other interactive environments. 37 | 38 | .. code-block:: python 39 | 40 | from datadreamer import DataDreamer 41 | 42 | dd = DataDreamer('./output/') 43 | dd.start() 44 | # ... run steps or trainers here ... 45 | dd.stop() 46 | 47 | Caching 48 | ======= 49 | 50 | DataDreamer caches the results of each step or trainer run within a session to the 51 | output folder. If a session is interrupted and re-run, DataDreamer will automatically 52 | load the results of previously completed steps from disk and resume where it left off. 53 | 54 | 55 | Attributes: 56 | __version__ (str): The version of DataDreamer installed. 57 | """ 58 | 59 | from .utils import import_utils # isort: skip # noqa: F401 60 | 61 | import importlib.metadata 62 | import os 63 | 64 | from .datadreamer import DataDreamer 65 | 66 | try: 67 | project_root_dir = os.path.dirname(os.path.dirname(__file__)) 68 | with open(os.path.join(project_root_dir, "./pyproject.toml")) as pyproject_fp: 69 | version_line = [ 70 | line.strip() for line in pyproject_fp if line.startswith("version") 71 | ][0] 72 | __version__ = version_line[version_line.find('"') + 1 : version_line.rfind('"')] 73 | except FileNotFoundError: # pragma: no cover 74 | __version__ = importlib.metadata.version( 75 | os.path.basename(os.path.dirname(__file__)) + "-dev" 76 | ) 77 | 78 | __all__ = ["__version__", "DataDreamer"] 79 | -------------------------------------------------------------------------------- /src/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from . import project 4 | 5 | try: 6 | from . import __entry__ # type: ignore[attr-defined] # noqa: F401 7 | except ImportError: 8 | pass 9 | from .__cli__ import _main 10 | 11 | if __name__ == "__main__": # pragma: no cover 12 | # Set the initial cwd 13 | project.INITIAL_CWD = os.path.abspath(os.getcwd()) 14 | 15 | _main() 16 | -------------------------------------------------------------------------------- /src/_cachable/__init__.py: -------------------------------------------------------------------------------- 1 | from ._cachable import _Cachable, _default_batch_scheduler_buffer_size 2 | from ._parallel_cachable import _ParallelCachable 3 | 4 | __all__ = ["_Cachable", "_ParallelCachable", "_default_batch_scheduler_buffer_size"] 5 | -------------------------------------------------------------------------------- /src/_patches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/_patches/__init__.py -------------------------------------------------------------------------------- /src/_patches/datasets_reset_state_hack.py: -------------------------------------------------------------------------------- 1 | # An update in datasets 2.20.0 adding state_dict to IterableDataset seems to have 2 | # broken IterableDataset. This patch is a temporary fix until the issue is resolved. 3 | 4 | import contextlib 5 | from unittest.mock import patch 6 | 7 | from datasets.iterable_dataset import ( 8 | ArrowExamplesIterable, 9 | ExamplesIterable, 10 | TypedExamplesIterable, 11 | ) 12 | 13 | __original_init_state_dict = TypedExamplesIterable._init_state_dict 14 | __original_examples__iter__ = ExamplesIterable.__iter__ 15 | __original_arrowexamples__iter__ = ArrowExamplesIterable.__iter__ 16 | _should_reset_state_dict = False 17 | 18 | 19 | def patched_examples__iter__(self): 20 | global _should_reset_state_dict 21 | if _should_reset_state_dict: 22 | self._init_state_dict() 23 | return __original_examples__iter__(self) 24 | 25 | 26 | def patched_arrowexamples__iter__(self): 27 | global _should_reset_state_dict 28 | if _should_reset_state_dict: 29 | self._init_state_dict() 30 | return __original_arrowexamples__iter__(self) 31 | 32 | 33 | ExamplesIterable.__iter__ = patched_examples__iter__ 34 | ArrowExamplesIterable.__iter__ = patched_arrowexamples__iter__ 35 | 36 | 37 | @contextlib.contextmanager 38 | def apply_datasets_reset_state_hack(): 39 | def patched_init_state_dict(self): 40 | self._state_dict = None # Set to None to ensure it is reset 41 | return __original_init_state_dict(self) 42 | 43 | with patch( 44 | "datasets.iterable_dataset.TypedExamplesIterable._init_state_dict", 45 | patched_init_state_dict, 46 | ): 47 | yield None 48 | 49 | 50 | def start_datasets_reset_state_hack(): 51 | global _should_reset_state_dict 52 | _should_reset_state_dict = True 53 | 54 | 55 | def stop_datasets_reset_state_hack(): 56 | global _should_reset_state_dict 57 | _should_reset_state_dict = False 58 | -------------------------------------------------------------------------------- /src/_patches/setfit_import_hack.py: -------------------------------------------------------------------------------- 1 | # SetFit is out-of-date with huggingface_hub and throws an error when trying to import 2 | # from it 3 | # like this: ImportError: cannot import name 'DatasetFilter' from 'huggingface_hub' 4 | 5 | # To fix this, we need to monkey patch huggingface_hub to prevent the import error 6 | 7 | from ..utils.import_utils import ignore_pydantic_warnings 8 | 9 | 10 | def apply_setfit_import_hack(): 11 | with ignore_pydantic_warnings(): 12 | import huggingface_hub 13 | 14 | huggingface_hub.DatasetFilter = None 15 | -------------------------------------------------------------------------------- /src/_stubs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/_stubs/.gitkeep -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | :py:class:`OutputDataset` and :py:class:`OutputIterableDataset` dataset objects are 3 | returned as outputs from :py:class:`~datadreamer.steps.Step` objects under the 4 | :py:attr:`~datadreamer.steps.Step.output` attribute. 5 | 6 | .. tip:: 7 | 8 | You never need to construct a dataset object yourself. They are returned as 9 | :py:attr:`~datadreamer.steps.Step.output` from 10 | :py:class:`~datadreamer.steps.Step` objects. If you need to convert in-memory Python 11 | data or data in files to a DataDreamer dataset object, see the 12 | `DataSource steps <./datadreamer.steps.html#types-of-steps>`_ 13 | available in :py:mod:`datadreamer.steps`. 14 | 15 | Accessing Columns 16 | ================= 17 | 18 | To access a column on the dataset objects you can use the ``__getitem__`` operator like 19 | so: ``step.output['column_name']``. This will return a :py:class:`OutputDatasetColumn` 20 | or :py:class:`OutputIterableDatasetColumn` column object that can be passed as an input 21 | to the ``inputs`` argument of a :py:class:`~datadreamer.steps.Step`. 22 | """ 23 | 24 | from .datasets import ( 25 | OutputDataset, 26 | OutputDatasetColumn, 27 | OutputIterableDataset, 28 | OutputIterableDatasetColumn, 29 | ) 30 | 31 | __all__ = [ 32 | "OutputDataset", 33 | "OutputDatasetColumn", 34 | "OutputIterableDataset", 35 | "OutputIterableDatasetColumn", 36 | ] 37 | -------------------------------------------------------------------------------- /src/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from functools import partial 3 | from itertools import chain 4 | from typing import Any, Iterator, cast 5 | 6 | from datasets import Dataset, IterableDataset 7 | from datasets.features.features import Features 8 | 9 | from .. import DataDreamer 10 | 11 | 12 | def get_column_names(dataset: Dataset | IterableDataset) -> list[str]: 13 | column_names = cast(None | list[str], dataset.column_names) 14 | if column_names: 15 | return column_names 16 | else: 17 | try: 18 | first_row = next(iter(dataset)) 19 | except StopIteration: 20 | return [] 21 | return list(first_row.keys()) 22 | 23 | 24 | def drop_unsupported_features(dataset: Dataset | IterableDataset): 25 | if isinstance(dataset, Dataset): 26 | dataset.reset_format() 27 | for index_name in dataset.list_indexes(): 28 | dataset.drop_index(index_name) # pragma: no cover 29 | 30 | 31 | def dataset_zip( 32 | *datasets: Dataset, 33 | writer_batch_size: None | int = 1000, 34 | num_proc: None | int = None, 35 | ) -> Dataset: 36 | if len(datasets) == 0: 37 | raise ValueError("You must provide at least one dataset to zip.") 38 | datasets = tuple([deepcopy(d) for d in datasets]) 39 | for d in datasets: 40 | drop_unsupported_features(d) 41 | smallest_dataset = min(datasets, key=lambda d: len(d)) 42 | 43 | def merge_rows(datasets, x, idx): 44 | result_row = {} 45 | for d in datasets: 46 | result_row.update(d[idx]) 47 | return result_row 48 | 49 | DataDreamer._enable_hf_datasets_logging() 50 | zip_results = smallest_dataset.map( 51 | partial(merge_rows, datasets), 52 | with_indices=True, 53 | desc="Zipping datasets together", 54 | writer_batch_size=writer_batch_size, 55 | num_proc=num_proc, 56 | ) 57 | DataDreamer._disable_hf_datasets_logging() 58 | return zip_results 59 | 60 | 61 | def iterable_dataset_zip(*datasets: Dataset | IterableDataset) -> IterableDataset: 62 | if len(datasets) == 0: 63 | raise ValueError("You must provide at least one dataset to zip.") 64 | datasets = tuple([deepcopy(d) for d in datasets]) 65 | for d in datasets: 66 | drop_unsupported_features(d) 67 | 68 | def merged_generator(datasets): 69 | iters: list[Iterator[dict[str, Any]]] = [iter(d) for d in datasets] 70 | for row_dicts in zip(*iters): 71 | row = {} 72 | for d in row_dicts: 73 | for k, v in d.items(): 74 | row[k] = v 75 | yield row 76 | 77 | column_names: list[str] = list( 78 | chain.from_iterable([get_column_names(d) for d in datasets]) 79 | ) 80 | features = Features([(n, None) for n in column_names]) 81 | return IterableDataset.from_generator( 82 | partial(merged_generator, datasets), features=features 83 | ) 84 | -------------------------------------------------------------------------------- /src/embedders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | :py:class:`Embedder` objects help convert texts to :wikipedia:`embeddings `. 3 | All embedders derive from the :py:class:`Embedder` base class. 4 | 5 | .. tip:: 6 | 7 | Instead of using :py:meth:`~Embedder.run` directly, use a 8 | :py:class:`step ` that takes an :py:class:`Embedder` as an ``args`` 9 | argument such as :py:class:`~datadreamer.steps.Embed` or construct an 10 | :py:class:`~datadreamer.retrievers.EmbeddingRetriever` with the embedder and then use 11 | a retrieval step such as :py:class:`~datadreamer.steps.Retrieve`. 12 | 13 | Caching 14 | ======= 15 | Embedders internally perform caching to disk, so if you embed the same text multiple 16 | times, the embedder will only embed the text once and then cache the results for 17 | future runs. 18 | """ 19 | 20 | from .embedder import Embedder 21 | from .openai_embedder import OpenAIEmbedder 22 | from .parallel_embedder import ParallelEmbedder 23 | from .sentence_transformers_embedder import SentenceTransformersEmbedder 24 | from .together_embedder import TogetherEmbedder 25 | 26 | __all__ = [ 27 | "Embedder", 28 | "OpenAIEmbedder", 29 | "SentenceTransformersEmbedder", 30 | "TogetherEmbedder", 31 | "ParallelEmbedder", 32 | ] 33 | -------------------------------------------------------------------------------- /src/embedders/parallel_embedder.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | from ..task_models.parallel_task_model import ParallelTaskModel 4 | from .embedder import Embedder 5 | 6 | 7 | class ParallelEmbedder(ParallelTaskModel, Embedder): 8 | def __init__(self, *embedders: Embedder): 9 | """ 10 | Creates an embedder that will run multiple embedders in parallel. See 11 | :doc:`running models in parallel 12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>` 13 | for more details. 14 | 15 | Args: 16 | *embedders: The embedders to run in parallel. 17 | """ 18 | super().__init__(*embedders) 19 | self.embedders = cast(list[Embedder], self.cachables) 20 | 21 | @property 22 | def model_max_length(self) -> int: 23 | return self.embedders[0].model_max_length 24 | 25 | @property 26 | def dims(self) -> int: 27 | return self.embedders[0].dims 28 | 29 | 30 | __all__ = ["ParallelEmbedder"] 31 | -------------------------------------------------------------------------------- /src/errors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various exceptions that may be raised when using DataDreamer. 3 | """ 4 | 5 | from .steps.step import StepOutputError, StepOutputTypeError 6 | 7 | __all__ = ["StepOutputError", "StepOutputTypeError"] 8 | -------------------------------------------------------------------------------- /src/errors/steps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/errors/steps/__init__.py -------------------------------------------------------------------------------- /src/errors/steps/step.py: -------------------------------------------------------------------------------- 1 | class StepOutputError(Exception): 2 | """Raised when a :py:class:`~datadreamer.steps.Step` is constructing its 3 | :py:attr:`~datadreamer.steps.Step.output` and an error occurs.""" 4 | 5 | pass 6 | 7 | 8 | class StepOutputTypeError(TypeError, StepOutputError): 9 | """Raised when a :py:class:`~datadreamer.steps.Step` is constructing its 10 | :py:attr:`~datadreamer.steps.Step.output` and a type error occurs.""" 11 | 12 | def __init__(self, message: str): 13 | if message: 14 | super().__init__( 15 | "Error processing dataset, make sure all values for each output of the" 16 | " dataset are of the same Python type/shape. If you need more" 17 | " flexibility you can pickle your data using the .pickle() method on" 18 | " a Step object. Data will automatically be un-pickled when read." 19 | f" Detailed error: {message.replace('struct', 'dict')}" 20 | ) 21 | -------------------------------------------------------------------------------- /src/llms/_tokenizers.py: -------------------------------------------------------------------------------- 1 | from functools import cache 2 | 3 | TOGETHER_TOKENIZERS = { 4 | "togethercomputer/Pythia-Chat-Base-7B-v0.16": "togethercomputer/Pythia-Chat-Base-7B", 5 | "togethercomputer/Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", 6 | "togethercomputer/Koala-13B": "meta-llama/Llama-2-13b-hf", 7 | "togethercomputer/llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", 8 | "togethercomputer/llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", 9 | "togethercomputer/llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", 10 | "togethercomputer/CodeLlama-7b-Instruct": "codellama/CodeLlama-7b-Instruct-hf", 11 | "togethercomputer/CodeLlama-13b-Instruct": "codellama/CodeLlama-13b-Instruct-hf", 12 | "togethercomputer/CodeLlama-34b-Instruct": "codellama/CodeLlama-34b-Instruct-hf", 13 | "togethercomputer/mpt-7b-chat": "mosaicml/mpt-7b-chat", 14 | "togethercomputer/mpt-30b-chat": "mosaicml/mpt-30b-chat", 15 | "togethercomputer/alpaca-7b": "huggyllama/llama-7b", 16 | "togethercomputer/falcon-7b-instruct": "tiiuae/falcon-7b-instruct", 17 | "togethercomputer/falcon-40b-instruct": "tiiuae/falcon-40b-instruct", 18 | "togethercomputer/guanaco-7b": "huggyllama/llama-7b", 19 | "togethercomputer/guanaco-13b": "huggyllama/llama-13b", 20 | "togethercomputer/guanaco-33b": "huggyllama/llama-30b", 21 | "togethercomputer/guanaco-65b": "huggyllama/llama-65b", 22 | "togethercomputer/Qwen-7B": "Qwen/Qwen-7B", 23 | "togethercomputer/llama-2-7b": "meta-llama/Llama-2-7b-hf", 24 | "togethercomputer/llama-2-13b": "meta-llama/Llama-2-13b-hf", 25 | "togethercomputer/llama-2-70b": "meta-llama/Llama-2-70b-hf", 26 | "togethercomputer/mpt-30b": "mosaicml/mpt-30b", 27 | "togethercomputer/mpt-30b-instruct": "mosaicml/mpt-30b-instruct", 28 | "togethercomputer/falcon-7b": "tiiuae/falcon-7b", 29 | "togethercomputer/falcon-40b": "tiiuae/falcon-40b", 30 | "togethercomputer/codegen2-7B": "Salesforce/codegen2-7B", 31 | "togethercomputer/codegen2-16B": "Salesforce/codegen2-16B", 32 | "togethercomputer/CodeLlama-7b": "codellama/CodeLlama-7b-hf", 33 | "togethercomputer/CodeLlama-13b": "codellama/CodeLlama-13b-hf", 34 | "togethercomputer/CodeLlama-34b": "codellama/CodeLlama-34b-hf", 35 | "togethercomputer/CodeLlama-7b-Python": "codellama/CodeLlama-7b-Python-hf", 36 | "togethercomputer/CodeLlama-13b-Python": "codellama/CodeLlama-13b-Python-hf", 37 | "togethercomputer/CodeLlama-34b-Python": "codellama/CodeLlama-34b-Python-hf", 38 | } 39 | 40 | 41 | @cache 42 | def _model_name_to_tokenizer_model_name(model_name: str) -> str: # pragma: no cover 43 | model_name_lower = model_name.lower() 44 | if all(fragment in model_name_lower for fragment in ["llama-", "-2-", "-chat"]): 45 | return "meta-llama/Llama-2-7b-chat-hf" 46 | return "gpt2" 47 | 48 | 49 | __all__ = ["TOGETHER_TOKENIZERS"] 50 | -------------------------------------------------------------------------------- /src/llms/ai21.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | 3 | from ._litellm import LiteLLM 4 | 5 | 6 | class AI21(LiteLLM): 7 | def __init__( 8 | self, 9 | model_name: str, 10 | api_key: None | str = None, 11 | retry_on_fail: bool = True, 12 | cache_folder_path: None | str = None, 13 | **kwargs, 14 | ): 15 | super().__init__( 16 | model_name=model_name, 17 | api_key=api_key, 18 | retry_on_fail=retry_on_fail, 19 | cache_folder_path=cache_folder_path, 20 | **kwargs, 21 | ) 22 | self._model_name_prefix = "" 23 | 24 | @cached_property 25 | def model_card(self) -> None | str: 26 | return "https://www.ai21.com/blog/introducing-j2" 27 | 28 | @cached_property 29 | def license(self) -> None | str: 30 | return "https://www.ai21.com/terms-of-use" 31 | 32 | 33 | __all__ = ["AI21"] 34 | -------------------------------------------------------------------------------- /src/llms/anthropic.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Callable 3 | 4 | from ._litellm import LiteLLM 5 | from .llm import DEFAULT_BATCH_SIZE 6 | 7 | 8 | class Anthropic(LiteLLM): 9 | def __init__( 10 | self, 11 | model_name: str, 12 | api_key: None | str = None, 13 | retry_on_fail: bool = True, 14 | cache_folder_path: None | str = None, 15 | **kwargs, 16 | ): 17 | super().__init__( 18 | model_name=model_name, 19 | api_key=api_key, 20 | retry_on_fail=retry_on_fail, 21 | cache_folder_path=cache_folder_path, 22 | **kwargs, 23 | ) 24 | self._model_name_prefix = "" 25 | 26 | def _run_batch( 27 | self, 28 | max_length_func: Callable[[list[str]], int], 29 | inputs: list[str], 30 | max_new_tokens: None | int = None, 31 | temperature: float = 1.0, 32 | top_p: float = 0.0, 33 | n: int = 1, 34 | stop: None | str | list[str] = None, 35 | repetition_penalty: None | float = None, 36 | logit_bias: None | dict[int, float] = None, 37 | batch_size: int = DEFAULT_BATCH_SIZE, 38 | seed: None | int = None, 39 | **kwargs, 40 | ) -> list[str] | list[list[str]]: 41 | assert ( 42 | repetition_penalty is None 43 | ), f"`repetition_penalty` is not supported for {type(self).__name__}" 44 | assert n == 1, f"Only `n` = 1 is supported for {type(self).__name__}" 45 | return super()._run_batch( 46 | max_length_func=max_length_func, 47 | inputs=inputs, 48 | max_new_tokens=max_new_tokens, 49 | temperature=temperature, 50 | top_p=top_p, 51 | n=n, 52 | stop=stop, 53 | repetition_penalty=repetition_penalty, 54 | logit_bias=logit_bias, 55 | batch_size=batch_size, 56 | seed=seed, 57 | **kwargs, 58 | ) 59 | 60 | @cached_property 61 | def model_card(self) -> None | str: 62 | return "https://www.ai21.com/blog/introducing-j2" 63 | 64 | @cached_property 65 | def license(self) -> None | str: 66 | return "https://console.anthropic.com/legal/terms" 67 | 68 | 69 | __all__ = ["Anthropic"] 70 | -------------------------------------------------------------------------------- /src/llms/bedrock.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Callable 3 | 4 | from ._litellm import LiteLLM 5 | from .llm import DEFAULT_BATCH_SIZE 6 | 7 | 8 | class Bedrock(LiteLLM): 9 | def __init__( 10 | self, 11 | model_name: str, 12 | aws_access_key_id: None | str = None, 13 | aws_secret_access_key: None | str = None, 14 | aws_region_name: None | str = None, 15 | retry_on_fail: bool = True, 16 | cache_folder_path: None | str = None, 17 | **kwargs, 18 | ): 19 | super().__init__( 20 | model_name=model_name, 21 | aws_access_key_id=aws_access_key_id, 22 | aws_secret_access_key=aws_secret_access_key, 23 | aws_region_name=aws_region_name, 24 | retry_on_fail=retry_on_fail, 25 | cache_folder_path=cache_folder_path, 26 | **kwargs, 27 | ) 28 | self._model_name_prefix = "" 29 | 30 | def _run_batch( 31 | self, 32 | max_length_func: Callable[[list[str]], int], 33 | inputs: list[str], 34 | max_new_tokens: None | int = None, 35 | temperature: float = 1.0, 36 | top_p: float = 0.0, 37 | n: int = 1, 38 | stop: None | str | list[str] = None, 39 | repetition_penalty: None | float = None, 40 | logit_bias: None | dict[int, float] = None, 41 | batch_size: int = DEFAULT_BATCH_SIZE, 42 | seed: None | int = None, 43 | **kwargs, 44 | ) -> list[str] | list[list[str]]: 45 | assert ( 46 | repetition_penalty is None 47 | ), f"`repetition_penalty` is not supported for {type(self).__name__}" 48 | assert n == 1, f"Only `n` = 1 is supported for {type(self).__name__}" 49 | return super()._run_batch( 50 | max_length_func=max_length_func, 51 | inputs=inputs, 52 | max_new_tokens=max_new_tokens, 53 | temperature=temperature, 54 | top_p=top_p, 55 | n=n, 56 | stop=stop, 57 | repetition_penalty=repetition_penalty, 58 | logit_bias=logit_bias, 59 | batch_size=batch_size, 60 | seed=seed, 61 | **kwargs, 62 | ) 63 | 64 | @cached_property 65 | def model_card(self) -> None | str: 66 | return "https://aws.amazon.com/bedrock/" 67 | 68 | @cached_property 69 | def license(self) -> None | str: 70 | return "https://aws.amazon.com/terms/" 71 | 72 | 73 | __all__ = ["Bedrock"] 74 | -------------------------------------------------------------------------------- /src/llms/cohere.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | 3 | from ._litellm import LiteLLM 4 | 5 | 6 | class Cohere(LiteLLM): 7 | def __init__( 8 | self, 9 | model_name: str, 10 | api_key: None | str = None, 11 | retry_on_fail: bool = True, 12 | cache_folder_path: None | str = None, 13 | **kwargs, 14 | ): 15 | super().__init__( 16 | model_name=model_name, 17 | api_key=api_key, 18 | retry_on_fail=retry_on_fail, 19 | cache_folder_path=cache_folder_path, 20 | **kwargs, 21 | ) 22 | self._model_name_prefix = "" 23 | 24 | @cached_property 25 | def model_card(self) -> None | str: 26 | return "https://docs.cohere.com/docs/models" 27 | 28 | @cached_property 29 | def license(self) -> None | str: 30 | return "https://cohere.com/saas-agreement" 31 | 32 | 33 | __all__ = ["Cohere"] 34 | -------------------------------------------------------------------------------- /src/llms/google_ai_studio.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Callable 3 | 4 | from ._litellm import LiteLLM 5 | from .llm import DEFAULT_BATCH_SIZE 6 | 7 | 8 | class GoogleAIStudio(LiteLLM): 9 | def __init__( 10 | self, 11 | model_name: str, 12 | api_key: None | str = None, 13 | retry_on_fail: bool = True, 14 | cache_folder_path: None | str = None, 15 | **kwargs, 16 | ): 17 | super().__init__( 18 | model_name=model_name, 19 | api_key=api_key, 20 | retry_on_fail=retry_on_fail, 21 | cache_folder_path=cache_folder_path, 22 | **kwargs, 23 | ) 24 | self._model_name_prefix = "gemini/" 25 | 26 | def _run_batch( 27 | self, 28 | max_length_func: Callable[[list[str]], int], 29 | inputs: list[str], 30 | max_new_tokens: None | int = None, 31 | temperature: float = 1.0, 32 | top_p: float = 0.0, 33 | n: int = 1, 34 | stop: None | str | list[str] = None, 35 | repetition_penalty: None | float = None, 36 | logit_bias: None | dict[int, float] = None, 37 | batch_size: int = DEFAULT_BATCH_SIZE, 38 | seed: None | int = None, 39 | **kwargs, 40 | ) -> list[str] | list[list[str]]: # pragma: no cover 41 | assert ( 42 | repetition_penalty is None 43 | ), f"`repetition_penalty` is not supported for {type(self).__name__}" 44 | return super()._run_batch( 45 | max_length_func=max_length_func, 46 | inputs=inputs, 47 | max_new_tokens=max_new_tokens, 48 | temperature=temperature, 49 | top_p=top_p, 50 | n=n, 51 | stop=stop, 52 | repetition_penalty=repetition_penalty, 53 | logit_bias=logit_bias, 54 | batch_size=batch_size, 55 | seed=seed, 56 | **kwargs, 57 | ) 58 | 59 | @cached_property 60 | def model_card(self) -> None | str: 61 | return "https://arxiv.org/abs/2312.11805" 62 | 63 | @cached_property 64 | def license(self) -> None | str: 65 | return "https://ai.google.dev/gemini-api/terms" 66 | 67 | @cached_property 68 | def citation(self) -> None | list[str]: 69 | citations = [] 70 | citations.append( 71 | """@article{anil2023gemini, 72 | title={Gemini: A family of highly capable multimodal models}, 73 | author={Anil, Rohan and Borgeaud, Sebastian and Wu, Yonghui and Alayrac, Jean-Baptiste and Yu, Jiahui and Soricut, Radu and Schalkwyk, Johan and Dai, Andrew M and Hauth, Anja and Millican, Katie and others}, 74 | journal={arXiv preprint arXiv:2312.11805}, 75 | volume={1}, 76 | year={2023} 77 | }""".strip() 78 | ) 79 | return citations 80 | 81 | 82 | __all__ = ["GoogleAIStudio"] 83 | -------------------------------------------------------------------------------- /src/llms/parallel_llm.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, Iterable, cast 2 | 3 | from .._cachable import _ParallelCachable 4 | from .llm import DEFAULT_BATCH_SIZE, LLM 5 | 6 | 7 | class ParallelLLM(_ParallelCachable, LLM): 8 | def __init__(self, *llms: LLM): 9 | """ 10 | Creates a LLM that will run multiple LLMs in parallel. See 11 | :doc:`running models in parallel 12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>` 13 | for more details. 14 | 15 | Args: 16 | *llms: The LLMs to run in parallel. 17 | """ 18 | super().__init__(*llms, cls=LLM) 19 | self.llms = cast(list[LLM], self.cachables) 20 | 21 | def count_tokens(self, value: str) -> int: 22 | """Counts the number of tokens in a string. 23 | 24 | Args: 25 | value: The string to count tokens for. 26 | 27 | Returns: 28 | The number of tokens in the string. 29 | """ 30 | pass 31 | return self.llms[0].count_tokens(value=value) 32 | 33 | def get_max_context_length(self, max_new_tokens: int) -> int: 34 | """Gets the maximum context length for the model. When ``max_new_tokens`` is 35 | greater than 0, the maximum number of tokens that can be used for the prompt 36 | context is returned. 37 | 38 | Args: 39 | max_new_tokens: The maximum number of tokens that can be generated. 40 | 41 | Returns: 42 | The maximum context length. 43 | """ 44 | return self.llms[0].get_max_context_length(max_new_tokens=max_new_tokens) 45 | 46 | def format_prompt( # noqa: C901 47 | self, 48 | max_new_tokens: None | int = None, 49 | beg_instruction: None | str = None, 50 | in_context_examples: None | list[str] = None, 51 | end_instruction: None | str = None, 52 | sep="\n", 53 | min_in_context_examples: None | int = None, 54 | max_in_context_examples: None | int = None, 55 | ) -> str: 56 | return self.llms[0].format_prompt( 57 | max_new_tokens=max_new_tokens, 58 | beg_instruction=beg_instruction, 59 | in_context_examples=in_context_examples, 60 | end_instruction=end_instruction, 61 | sep=sep, 62 | min_in_context_examples=min_in_context_examples, 63 | max_in_context_examples=max_in_context_examples, 64 | ) 65 | 66 | def run( 67 | self, prompts: Iterable[str], *args, **kwargs 68 | ) -> Generator[str | list[str], None, None] | list[str | list[str]]: 69 | kwargs["batch_size"] = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE) 70 | results_generator = self._run_in_parallel(prompts, *args, **kwargs) 71 | if not kwargs.get("return_generator", False): 72 | return list(results_generator) 73 | else: 74 | return results_generator 75 | 76 | def unload_model(self): 77 | for llm in self.llms: 78 | llm.unload_model() 79 | 80 | 81 | __all__ = ["ParallelLLM"] 82 | -------------------------------------------------------------------------------- /src/llms/vertex_ai.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Callable 3 | 4 | from ._litellm import LiteLLM 5 | from .llm import DEFAULT_BATCH_SIZE 6 | 7 | 8 | class VertexAI(LiteLLM): 9 | def __init__( 10 | self, 11 | model_name: str, 12 | vertex_project: None | str = None, 13 | vertex_location: None | str = None, 14 | retry_on_fail: bool = True, 15 | cache_folder_path: None | str = None, 16 | **kwargs, 17 | ): 18 | super().__init__( 19 | model_name=model_name, 20 | vertex_project=vertex_project, 21 | vertex_location=vertex_location, 22 | retry_on_fail=retry_on_fail, 23 | cache_folder_path=cache_folder_path, 24 | **kwargs, 25 | ) 26 | self._model_name_prefix = "" 27 | 28 | def _run_batch( 29 | self, 30 | max_length_func: Callable[[list[str]], int], 31 | inputs: list[str], 32 | max_new_tokens: None | int = None, 33 | temperature: float = 1.0, 34 | top_p: float = 0.0, 35 | n: int = 1, 36 | stop: None | str | list[str] = None, 37 | repetition_penalty: None | float = None, 38 | logit_bias: None | dict[int, float] = None, 39 | batch_size: int = DEFAULT_BATCH_SIZE, 40 | seed: None | int = None, 41 | **kwargs, 42 | ) -> list[str] | list[list[str]]: 43 | assert stop is None, f"`stop` is not supported for {type(self).__name__}" 44 | assert ( 45 | repetition_penalty is None 46 | ), f"`repetition_penalty` is not supported for {type(self).__name__}" 47 | assert n == 1, f"Only `n` = 1 is supported for {type(self).__name__}" 48 | return super()._run_batch( 49 | max_length_func=max_length_func, 50 | inputs=inputs, 51 | max_new_tokens=max_new_tokens, 52 | temperature=temperature, 53 | top_p=top_p, 54 | n=n, 55 | stop=stop, 56 | repetition_penalty=repetition_penalty, 57 | logit_bias=logit_bias, 58 | batch_size=batch_size, 59 | seed=seed, 60 | **kwargs, 61 | ) 62 | 63 | @cached_property 64 | def model_card(self) -> None | str: 65 | return "https://arxiv.org/abs/2312.11805" 66 | 67 | @cached_property 68 | def license(self) -> None | str: 69 | return "https://ai.google.dev/gemini-api/terms" 70 | 71 | @cached_property 72 | def citation(self) -> None | list[str]: 73 | citations = [] 74 | citations.append( 75 | """@article{anil2023gemini, 76 | title={Gemini: A family of highly capable multimodal models}, 77 | author={Anil, Rohan and Borgeaud, Sebastian and Wu, Yonghui and Alayrac, Jean-Baptiste and Yu, Jiahui and Soricut, Radu and Schalkwyk, Johan and Dai, Andrew M and Hauth, Anja and Millican, Katie and others}, 78 | journal={arXiv preprint arXiv:2312.11805}, 79 | volume={1}, 80 | year={2023} 81 | }""".strip() 82 | ) 83 | return citations 84 | 85 | 86 | __all__ = ["VertexAI"] 87 | -------------------------------------------------------------------------------- /src/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import DATEFMT, DATETIME_FORMAT, STANDARD_FORMAT, logger 2 | 3 | __all__ = ["logger", "DATEFMT", "DATETIME_FORMAT", "STANDARD_FORMAT"] 4 | -------------------------------------------------------------------------------- /src/logging/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging import Logger 3 | 4 | from ..project.environment import RUNNING_IN_PYTEST 5 | 6 | DATEFMT: str = "[%Y-%m-%d %H:%M:%S %z]" 7 | STANDARD_FORMAT: str = "[ \N{ESC}[35m🤖 Data\N{ESC}[33mDr\N{ESC}[31mea\N{ESC}[35mmer\u001b[0m 💤 ] %(message)s" # noqa: B950 8 | DATETIME_FORMAT: str = "%(asctime)s [ \N{ESC}[35m🤖 Data\N{ESC}[33mDr\N{ESC}[31mea\N{ESC}[35mmer\u001b[0m 💤 ] %(message)s" # noqa: B950 9 | 10 | # stderr Handler 11 | STDERR_HANDLER = logging.StreamHandler() 12 | STDERR_HANDLER.setLevel(logging.DEBUG) 13 | 14 | # Logger 15 | logger: Logger = logging.getLogger("datadreamer") 16 | if RUNNING_IN_PYTEST: 17 | logger.propagate = True 18 | else: 19 | logger.propagate = False # pragma: no cover 20 | formatter = logging.Formatter( 21 | STANDARD_FORMAT, datefmt="[%Y-%m-%d %H:%M:%S %z]", validate=False 22 | ) 23 | STDERR_HANDLER.setFormatter(formatter) 24 | logger.addHandler(STDERR_HANDLER) 25 | logger.setLevel(logging.CRITICAL + 1) 26 | -------------------------------------------------------------------------------- /src/pickling/__init__.py: -------------------------------------------------------------------------------- 1 | from .pickle import unpickle, unpickle_transform 2 | 3 | __all__ = ["unpickle", "unpickle_transform"] 4 | -------------------------------------------------------------------------------- /src/pickling/pickle.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any 3 | 4 | from datasets.features.features import Features, Value 5 | from dill import dumps, loads 6 | 7 | _INTERNAL_PICKLE_KEY = "__DataDreamer__pickle_internal__" 8 | _PICKLE_KEY = "__DataDreamer__pickle__" 9 | __FEATURES_DEFAULT = Features() 10 | 11 | 12 | def _pickle(value: Any, *args: Any, **kwargs: Any) -> bytes: 13 | if _INTERNAL_PICKLE_KEY not in kwargs: 14 | warnings.warn( 15 | "Do not call pickle() directly. You should instead use the .pickle()" 16 | " method on a Step object.", 17 | stacklevel=2, 18 | ) 19 | else: 20 | del kwargs[_INTERNAL_PICKLE_KEY] 21 | return dumps({_PICKLE_KEY: value}, *args, **kwargs) 22 | 23 | 24 | def unpickle(value: bytes) -> Any: 25 | return loads(value)[_PICKLE_KEY] 26 | 27 | 28 | def _unpickle_transform_value(value): 29 | if ( 30 | isinstance(value, bytes) 31 | and len(value) >= 2 32 | and value[0] == 128 33 | and value[-1] == 46 34 | and _PICKLE_KEY.encode("utf8") in value[:100] 35 | ): 36 | return unpickle(value) 37 | else: 38 | return value 39 | 40 | 41 | def unpickle_transform(batch, features=__FEATURES_DEFAULT, batched=False): 42 | for column in batch: 43 | feature = features.get(column, None) 44 | if not isinstance(feature, Value) or feature.dtype != "binary": 45 | continue 46 | if batched: 47 | for i in range(len(batch[column])): 48 | batch[column][i] = _unpickle_transform_value(batch[column][i]) 49 | else: 50 | batch[column] = _unpickle_transform_value(batch[column]) 51 | return batch 52 | -------------------------------------------------------------------------------- /src/project/__init__.py: -------------------------------------------------------------------------------- 1 | """``project`` provides project-wide helpers and utilities useful in machine learning projects. 2 | 3 | Attributes: 4 | INITIAL_CWD (None | str): The initial current working directory path. 5 | context (dict): A dictionary to use to store global context. 6 | RUNNING_IN_PYTEST (bool): Whether or not the project is running in ``pytest``. 7 | RUNNING_IN_CLUSTER (bool): Whether or not the project is running on a cluster. 8 | """ 9 | 10 | import json 11 | import os 12 | import sys 13 | 14 | from loguru import logger 15 | 16 | from .debug import bash, context, debugger 17 | from .devices import ( 18 | get_jax_cpu_device, 19 | get_jax_device, 20 | get_jax_devices, 21 | get_tf_cpu_device, 22 | get_tf_device, 23 | get_tf_devices, 24 | get_torch_cpu_device, 25 | get_torch_device, 26 | get_torch_devices, 27 | ) 28 | from .environment import RUNNING_IN_CLUSTER, RUNNING_IN_PYTEST 29 | from .persistent_storage import get_persistent_dir 30 | from .report import reporter # type:ignore[attr-defined] 31 | from .serve import run_ngrok 32 | 33 | # Initial cwd (defined in __main__.py) 34 | INITIAL_CWD: None | str = None 35 | 36 | # Make sure CUDA/NVIDIA_VISIBLE_DEVICES is set if it is needed 37 | if os.environ.get("PROJECT_ACCELERATOR_TYPE", None) == "cuda": 38 | if "PROJECT_VISIBLE_ACCELERATOR_DEVICES" in os.environ: 39 | os.environ["NVIDIA_VISIBLE_DEVICES"] = os.environ[ 40 | "PROJECT_VISIBLE_ACCELERATOR_DEVICES" 41 | ] 42 | os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[ 43 | "PROJECT_VISIBLE_ACCELERATOR_DEVICES" 44 | ] 45 | 46 | # Make sure if CUDA/NVIDIA_VISIBLE_DEVICES is set, PROJECT_*_ACCELERATOR_* is set 47 | if ( 48 | "CUDA_VISIBLE_DEVICES" in os.environ 49 | and os.environ.get("PROJECT_ACCELERATOR_TYPE", None) is None 50 | ): 51 | os.environ["PROJECT_ACCELERATOR_TYPE"] = "cuda" 52 | os.environ["PROJECT_VISIBLE_ACCELERATOR_DEVICES"] = os.environ[ 53 | "CUDA_VISIBLE_DEVICES" 54 | ] 55 | elif ( 56 | "NVIDIA_VISIBLE_DEVICES" in os.environ 57 | and os.environ.get("PROJECT_ACCELERATOR_TYPE", None) is None 58 | ): 59 | os.environ["PROJECT_ACCELERATOR_TYPE"] = "cuda" 60 | os.environ["PROJECT_VISIBLE_ACCELERATOR_DEVICES"] = os.environ[ 61 | "NVIDIA_VISIBLE_DEVICES" 62 | ] 63 | 64 | 65 | def init(): 66 | """Initializes the project. Adds logging and does any other project setup.""" 67 | # Setup logger 68 | logger.remove() 69 | logger.add( 70 | sys.stderr, 71 | colorize=True, 72 | format="[{process}] | {time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", # noqa: B950 73 | ) 74 | 75 | # Write args 76 | if RUNNING_IN_CLUSTER: 77 | with open(os.environ["PROJECT_ARGS_FILE"], "w+") as f: 78 | f.write(json.dumps(sys.argv, indent=2)) 79 | 80 | 81 | __all__ = [ 82 | "RUNNING_IN_CLUSTER", 83 | "RUNNING_IN_PYTEST", 84 | "bash", 85 | "context", 86 | "debugger", 87 | "get_jax_cpu_device", 88 | "get_jax_device", 89 | "get_jax_devices", 90 | "get_tf_cpu_device", 91 | "get_tf_device", 92 | "get_tf_devices", 93 | "get_torch_cpu_device", 94 | "get_torch_device", 95 | "get_torch_devices", 96 | "get_persistent_dir", 97 | "reporter", 98 | "run_ngrok", 99 | "init", 100 | ] 101 | -------------------------------------------------------------------------------- /src/project/builtin_tasks.py: -------------------------------------------------------------------------------- 1 | import click 2 | from loguru import logger 3 | 4 | from .serve import ( 5 | run_cloudflared, 6 | run_http_server, 7 | run_jupyter, 8 | run_ngrok, 9 | sleep_infinity, 10 | ) 11 | 12 | 13 | @click.option( 14 | "--tunnel", 15 | "-t", 16 | default="cloudflare", 17 | type=click.Choice(["cloudflare", "ngrok"]), 18 | help="The tunneling service to use.", 19 | ) 20 | @click.option( 21 | "--hostname", "-h", default=None, type=str, help="The hostname to serve at." 22 | ) 23 | @click.option("--password", "-p", default=None, type=str, help="The password to use.") 24 | def jupyter(ctx, tunnel, hostname, password): 25 | """This command runs Jupyter Lab.""" 26 | logger.info("Running Jupyter Lab...") 27 | port = run_jupyter(password=password) 28 | if tunnel == "cloudflare": 29 | url = run_cloudflared(port, hostname=hostname) 30 | else: 31 | url = run_ngrok(port, hostname=hostname) 32 | logger.info(f"Jupyter Lab is available at URL: {url}") 33 | sleep_infinity() 34 | 35 | 36 | @click.option( 37 | "--tunnel", 38 | "-t", 39 | default="cloudflare", 40 | type=click.Choice(["cloudflare", "ngrok"]), 41 | help="The tunneling service to use.", 42 | ) 43 | @click.option( 44 | "--hostname", "-h", default=None, type=str, help="The hostname to serve at." 45 | ) 46 | def http_server(ctx, tunnel, hostname): 47 | """This command runs a HTTP server.""" 48 | logger.info("Running HTTP server...") 49 | port = run_http_server() 50 | if tunnel == "cloudflare": 51 | url = run_cloudflared(port, hostname=hostname) 52 | else: 53 | url = run_ngrok(port, hostname=hostname) 54 | logger.info(f"HTTP server is available at URL: {url}") 55 | sleep_infinity() 56 | 57 | 58 | def register_builtin_tasks(_main): 59 | _main.command(hidden=True)(click.pass_context(jupyter)) 60 | _main.command(hidden=True)(click.pass_context(http_server)) 61 | -------------------------------------------------------------------------------- /src/project/debug.py: -------------------------------------------------------------------------------- 1 | import code 2 | import inspect 3 | import os 4 | import pty 5 | import tempfile 6 | from time import sleep 7 | 8 | from .report import _deep_defaultdict # type:ignore[attr-defined] 9 | 10 | 11 | def _get_callers_locals_and_globals(): 12 | """Gets the local and global variables from the caller's frame. 13 | 14 | Returns: 15 | tuple[dict, dict]: A tuple of a dictionary of local variables and global 16 | variables. 17 | """ 18 | frame = inspect.currentframe() 19 | if frame and frame.f_back and frame.f_back.f_back: 20 | try: 21 | return frame.f_back.f_back.f_locals, frame.f_back.f_back.f_globals 22 | finally: 23 | del frame 24 | 25 | 26 | def debugger(rank=None, launch_on_rank=0): 27 | """Pauses execution and opens an interactive REPL with access to local and global 28 | variables. 29 | 30 | Args: 31 | rank (Any, optional): The current rank. Defaults to None (will always 32 | launch the debugger). 33 | launch_on_rank (Any, optional): What rank the debugger should be launched on. 34 | Defaults to 0. 35 | """ 36 | if "PROJECT_INTERACTIVE" in os.environ: 37 | from filelock import FileLock 38 | 39 | lock = FileLock( 40 | os.path.join( 41 | os.path.dirname(tempfile.mkdtemp()), 42 | f"{os.environ['PROJECT_NAME']}-debugger.lock", 43 | ) 44 | ) 45 | if rank is None or rank == launch_on_rank: 46 | with lock.acquire(): 47 | ls, gs = _get_callers_locals_and_globals() 48 | all_items = list(ls.items()) + list(gs.items()) 49 | code.interact( 50 | banner="Opening Python REPL (press 'Ctrl-D' to exit the shell)...", 51 | local=dict(all_items), 52 | ) 53 | else: 54 | sleep(5) 55 | while True: 56 | with lock.acquire(): 57 | break 58 | 59 | 60 | def bash(rank=None, launch_on_rank=0): 61 | """Pauses execution and opens a bash shell. 62 | 63 | Args: 64 | rank (Any, optional): The current rank. Defaults to None (will always 65 | launch bash). 66 | launch_on_rank (Any, optional): What rank bash should be launched on. 67 | Defaults to 0. 68 | """ 69 | if "PROJECT_INTERACTIVE" in os.environ: 70 | from filelock import FileLock 71 | 72 | lock = FileLock( 73 | os.path.join( 74 | os.path.dirname(tempfile.mkdtemp()), 75 | f"{os.environ['PROJECT_NAME']}-bash.lock", 76 | ) 77 | ) 78 | if rank is None or rank == launch_on_rank: 79 | with lock.acquire(): 80 | print("Opening bash shell (type 'exit' to exit the shell)...") 81 | pty.spawn("/bin/bash") 82 | else: 83 | sleep(5) 84 | while True: 85 | with lock.acquire(): 86 | break 87 | 88 | 89 | # Create a context to help store global context when debugging 90 | context = _deep_defaultdict() 91 | -------------------------------------------------------------------------------- /src/project/environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Detect environments 5 | RUNNING_IN_PYTEST = os.path.basename(sys.argv[0]) == "pytest" 6 | RUNNING_IN_CLUSTER = "PROJECT_CLUSTER" in os.environ 7 | -------------------------------------------------------------------------------- /src/project/pennnlp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | 5 | from loguru import logger 6 | 7 | """This file contains PennNLP cluster specific utilities.""" 8 | 9 | NLPDATA_PATH = os.path.join("/nlp/data/", os.environ["USER"]) 10 | 11 | SCRATCH_PATH = os.path.join("/scratch/", os.environ["USER"], os.environ["PROJECT_NAME"]) 12 | 13 | 14 | def detect_pennnlp(): 15 | """Detect if running on PennNLP's cluster. 16 | 17 | Returns: 18 | bool: Whether or not we are running on PennNLP's cluster. 19 | """ 20 | return os.path.exists(NLPDATA_PATH) 21 | 22 | 23 | def copy_file(src, dest): 24 | """Copies a file from the source path to the destination path, but skips the copying 25 | if the file already exists (determined by last modified time or file size). 26 | 27 | Args: 28 | src (str): The source path. 29 | dest (str): The destination path. 30 | """ 31 | if ( 32 | (not os.path.exists(dest)) 33 | or (os.stat(src).st_mtime - os.stat(dest).st_mtime > 1) 34 | or (os.stat(src).st_size != os.stat(dest).st_size) 35 | ): 36 | shutil.copy2(src, dest) 37 | 38 | 39 | def copy_files_to_ssd(*paths, subfolder=None): 40 | if detect_pennnlp(): 41 | # Create scratch dir for SSD disk speed on PennNLP cluster 42 | scratch_path = SCRATCH_PATH 43 | if subfolder is True: 44 | scratch_path = os.path.join(SCRATCH_PATH, sys.argv[1]) 45 | elif subfolder: 46 | scratch_path = os.path.join(SCRATCH_PATH, subfolder) 47 | os.makedirs(scratch_path, exist_ok=True) 48 | 49 | # Copy files to scratch dir for SSD disk speed on PennNLP cluster 50 | new_paths = [] 51 | for path in paths: 52 | path = os.path.normpath(os.path.abspath(path)) 53 | basename = os.path.basename(path) 54 | new_path = os.path.join(scratch_path, basename) 55 | new_paths.append(new_path) 56 | logger.debug(f"Copying file {path} to PennNLP scratch path: {new_path}...") 57 | copy_file(path, new_path) 58 | logger.debug("Done copying file.") 59 | return new_paths 60 | else: 61 | return paths 62 | -------------------------------------------------------------------------------- /src/project/persistent_storage.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | 5 | from .report import reporter # type:ignore[attr-defined] 6 | 7 | 8 | def _dict_hash(dictionary): 9 | """Returns the MD5 hash of a dictionary. 10 | 11 | Args: 12 | dictionary (dict[str, Any]): The dictionary to hash. 13 | 14 | Returns: 15 | str: The MD5 hash. 16 | """ 17 | dhash = hashlib.md5() 18 | # We need to sort arguments so {'a': 1, 'b': 2} is 19 | # the same as {'b': 2, 'a': 1} 20 | encoded = json.dumps(dictionary, sort_keys=True).encode() 21 | dhash.update(encoded) 22 | return dhash.hexdigest() 23 | 24 | 25 | def get_persistent_dir(name, config_path): 26 | """Returns the path to a persistent directory that will be usable across jobs. Any 27 | future jobs 28 | 29 | Args: 30 | name (str): [description] 31 | config_path (str): [description] 32 | """ 33 | config_hash = _dict_hash(reporter.get(config_path)) 34 | persistent_dir = os.path.join( 35 | os.environ["PROJECT_DATA_OUTPUT_PERSISTENT_DATA"], config_hash 36 | ) 37 | local_persistent_dir = os.path.join(os.environ["PROJECT_WRITE_DIR"], name) 38 | os.makedirs(persistent_dir, exist_ok=True) 39 | try: 40 | os.symlink(persistent_dir, local_persistent_dir, target_is_directory=True) 41 | except FileExistsError: 42 | pass 43 | return local_persistent_dir 44 | -------------------------------------------------------------------------------- /src/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/py.typed -------------------------------------------------------------------------------- /src/requirements-accelerator-device.txt: -------------------------------------------------------------------------------- 1 | torch>=2.5.1,<3.0.0 -------------------------------------------------------------------------------- /src/requirements-cpu.txt: -------------------------------------------------------------------------------- 1 | torch>=2.5.1,<3.0.0 -------------------------------------------------------------------------------- /src/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | ruff==0.1.9 2 | Sphinx==7.0.1 3 | furo==2023.5.20 4 | sphinx-inline-tabs==2023.4.21 5 | sphinx-autobuild==2021.3.14 6 | sphinx-copybutton==0.5.2 7 | sphinx-click==4.4.0 8 | sphinx-sitemap==2.5.0 9 | sphinx-autodoc-typehints==1.23.0 10 | sphinx-toolbox==3.4.0 11 | sphinx_design==0.5.0 12 | sphinx-reredirects==0.1.3 -------------------------------------------------------------------------------- /src/requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest==7.3.1 2 | pytest-cov==4.1.0 3 | flaky==3.7.0 4 | mypy==1.4.1 5 | pytest-mock==3.11.1 6 | pytest-timeout==2.2.0 7 | pytest-order==1.2.0 -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | click>=8.1.3 2 | loguru>=0.7.0,<1.0.0 3 | filelock>=3.13.1,<4.0.0 4 | jsonlines>=4.0.0,<7.0.0 5 | numpy>=1.26.4,<2.0.0 6 | sortedcontainers>=2.4.0,<3.0.0 7 | sqlitedict>=2.1.0,<3.0.0 8 | pandas>=1.5.3,<2.0.0 9 | pandas-stubs>=1.5.3.230321,<2.0.0 10 | tenacity>=9.0.0 11 | dill>=0.3.8,<1.0.0 12 | ring>=0.10.1,<1.0.0 13 | psutil>=6.1.1 14 | faiss-cpu>=1.9.0.post1,<2.0.0 15 | evaluate>=0.4.3,<1.0.0 16 | tiktoken>=0.7.0,<1.0.0 17 | sentence-transformers>=3.4.0,<4.0.0 18 | setfit>=1.1.1,<2.0.0 19 | openai>=1.59.6,<2.0.0 20 | datasets>=3.2.0,<4.0.0 21 | peft>=0.14.0,<1.0.0 22 | bitsandbytes>=0.45.0,<1.0.0 23 | huggingface-hub>=0.27.1,<1.0.0 24 | optimum>=1.21.2,<2.0.0 25 | accelerate>=1.3.0,<2.0.0 26 | transformers>=4.48.1,<4.50.0 27 | ctransformers>=0.2.27,<1.0.0 28 | outlines-core>=0.1.26 29 | Pyro5>=5.15 30 | litellm==1.57.8 31 | trl==0.9.6 32 | -------------------------------------------------------------------------------- /src/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | :py:class:`Retriever` objects help retrieve texts based on a set of queries. 3 | All retrievers derive from the :py:class:`Retriever` base class. 4 | 5 | .. tip:: 6 | 7 | Instead of using :py:meth:`~Retriever.run` directly, use a 8 | :py:class:`step ` that takes a :py:class:`Retriever` as an ``args`` 9 | argument such as :py:class:`~datadreamer.steps.Retrieve` and 10 | :py:class:`~datadreamer.steps.RAGPrompt`. Some other steps like and 11 | :py:class:`~datadreamer.steps.FewShotPromptWithRetrieval` 12 | use retrievers internally. 13 | 14 | Caching 15 | ======= 16 | Retrievers typically initially build an index once and cache the index to disk. 17 | Retrievers additionally internally perform caching to disk, so if you retrieve results 18 | for the same query multiple times, the retriever will only retrieve results for the 19 | query once and then cache the results for future runs. 20 | """ 21 | 22 | from .embedding_retriever import EmbeddingRetriever 23 | from .parallel_retriever import ParallelRetriever 24 | from .retriever import Retriever 25 | 26 | __all__ = ["Retriever", "EmbeddingRetriever", "ParallelRetriever"] 27 | -------------------------------------------------------------------------------- /src/retrievers/parallel_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, Iterable, cast 2 | 3 | from .._cachable import _ParallelCachable 4 | from .retriever import DEFAULT_BATCH_SIZE, Retriever 5 | 6 | 7 | class ParallelRetriever(_ParallelCachable, Retriever): 8 | def __init__(self, *retrievers: Retriever): 9 | """ 10 | Creates a retriever that will run multiple retrievers in parallel. See 11 | :doc:`running models in parallel 12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>` 13 | for more details. 14 | 15 | Args: 16 | *retrievers: The retrievers to run in parallel. 17 | """ 18 | super().__init__(*retrievers, cls=Retriever) 19 | self.retrievers = cast(list[Retriever], self.cachables) 20 | 21 | @property 22 | def index(self): # pragma: no cover 23 | return self.retrievers[0].index 24 | 25 | def run( 26 | self, queries: Iterable[str], *args, **kwargs 27 | ) -> Generator[str | list[str], None, None] | list[str | list[str]]: 28 | kwargs["batch_size"] = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE) 29 | results_generator = self._run_in_parallel(queries, *args, **kwargs) 30 | if not kwargs.get("return_generator", False): 31 | return list(results_generator) 32 | else: 33 | return results_generator 34 | 35 | def unload_model(self): 36 | for llm in self.retrievers: 37 | llm.unload_model() 38 | 39 | 40 | __all__ = ["ParallelRetriever"] 41 | -------------------------------------------------------------------------------- /src/steps/data_card.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from ..utils.collection_utils import sort_keys 4 | 5 | 6 | class DataCardType: 7 | """The types of data card entries.""" 8 | 9 | DATETIME = "Date & Time" 10 | MODEL_NAME = "Model Name" 11 | DATASET_NAME = "Dataset Name" 12 | LICENSE = "License Information" 13 | CITATION = "Citation Information" 14 | DATASET_CARD = "Dataset Card" 15 | MODEL_CARD = "Model Card" 16 | URL = "URL" 17 | 18 | 19 | def sort_data_card(data_card: dict[str, list[Any]]) -> dict[str, list[Any]]: 20 | return sort_keys( 21 | data_card, 22 | key_order=[ 23 | DataCardType.DATETIME, 24 | DataCardType.DATASET_NAME, 25 | DataCardType.MODEL_NAME, 26 | DataCardType.URL, 27 | DataCardType.DATASET_CARD, 28 | DataCardType.MODEL_CARD, 29 | DataCardType.LICENSE, 30 | DataCardType.CITATION, 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /src/steps/data_sources/csv_data_source.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Sequence 3 | 4 | from datasets import DatasetDict, load_dataset 5 | from datasets.fingerprint import Hasher 6 | 7 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY 8 | from .data_source import DataSource 9 | 10 | 11 | class CSVDataSource(DataSource): 12 | """Loads a CSV dataset from a local path. See :py:func:`datasets.load_dataset` for 13 | more details. 14 | 15 | Args: 16 | name: The name of the step. 17 | data_folder: The path to the dataset folder. 18 | data_files: The name of files from the folder to load. 19 | progress_interval: How often to log progress in seconds. 20 | force: Whether to force run the step (ignore saved results). 21 | verbose: Whether or not to print verbose logs. 22 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.). 23 | save_num_proc: The number of processes to use if saving to disk. 24 | save_num_shards: The number of shards on disk to save the dataset into. 25 | background: Whether to run the operation in the background. 26 | **config_kwargs: Additional keyword arguments to pass to 27 | :py:func:`datasets.load_dataset`. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name: str, 33 | data_folder: None | str = None, 34 | data_files: None | str | Sequence[str] = None, 35 | sep: str = ",", 36 | progress_interval: None | int = 60, 37 | force: bool = False, 38 | verbose: None | bool = None, 39 | log_level: None | int = None, 40 | save_num_proc: None | int = None, 41 | save_num_shards: None | int = None, 42 | background: bool = False, 43 | **config_kwargs, 44 | ): 45 | self.data_folder = data_folder 46 | self.data_files = data_files 47 | self.sep = sep 48 | self.config_kwargs = config_kwargs 49 | super().__init__( 50 | name, 51 | data=None, # type: ignore[arg-type] 52 | progress_interval=progress_interval, 53 | force=force, 54 | verbose=verbose, 55 | log_level=log_level, 56 | save_num_proc=save_num_proc, 57 | save_num_shards=save_num_shards, 58 | background=background, 59 | ) 60 | 61 | def setup(self): 62 | pass 63 | 64 | def run(self): 65 | if isinstance(self.data_files, dict): 66 | raise ValueError( 67 | "You supplied a dict to data_files, multiple splits are not supported." 68 | ) 69 | result = load_dataset( 70 | "csv", 71 | data_dir=self.data_folder, 72 | data_files=self.data_files, 73 | num_proc=self.save_num_proc, 74 | sep=self.sep, 75 | **self.config_kwargs, 76 | ) 77 | if isinstance(result, DatasetDict): 78 | result = result["train"] 79 | return result 80 | 81 | @cached_property 82 | def fingerprint(self) -> str: 83 | return Hasher.hash( 84 | [super().fingerprint, self.data_folder, self.data_files, self.config_kwargs] 85 | ) 86 | 87 | 88 | setattr(CSVDataSource, _INTERNAL_STEP_OPERATION_KEY, True) 89 | 90 | __all__ = ["CSVDataSource"] 91 | -------------------------------------------------------------------------------- /src/steps/data_sources/hf_dataset_data_source.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | 3 | from datasets import Dataset 4 | from datasets.fingerprint import Hasher 5 | 6 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY 7 | from .data_source import DataSource 8 | 9 | 10 | class HFDatasetDataSource(DataSource): 11 | """Loads a Hugging Face :py:class:`~datasets.Dataset` from a local path. See 12 | :py:func:`datasets.load_from_disk` for more details. 13 | 14 | Args: 15 | name: The name of the step. 16 | dataset_path: The path to the :py:class:`datasets.Dataset` folder. 17 | progress_interval: How often to log progress in seconds. 18 | force: Whether to force run the step (ignore saved results). 19 | verbose: Whether or not to print verbose logs. 20 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.). 21 | save_num_proc: The number of processes to use if saving to disk. 22 | save_num_shards: The number of shards on disk to save the dataset into. 23 | background: Whether to run the operation in the background. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | name: str, 29 | dataset_path: str, 30 | progress_interval: None | int = 60, 31 | force: bool = False, 32 | verbose: None | bool = None, 33 | log_level: None | int = None, 34 | save_num_proc: None | int = None, 35 | save_num_shards: None | int = None, 36 | background: bool = False, 37 | ): 38 | self.path_to_dataset = dataset_path 39 | super().__init__( 40 | name, 41 | data=None, # type: ignore[arg-type] 42 | progress_interval=progress_interval, 43 | force=force, 44 | verbose=verbose, 45 | log_level=log_level, 46 | save_num_proc=save_num_proc, 47 | save_num_shards=save_num_shards, 48 | background=background, 49 | ) 50 | 51 | def setup(self): 52 | pass 53 | 54 | def run(self): 55 | return Dataset.load_from_disk(self.path_to_dataset) 56 | 57 | @cached_property 58 | def fingerprint(self) -> str: 59 | return Hasher.hash([super().fingerprint, self.path_to_dataset]) 60 | 61 | 62 | setattr(HFDatasetDataSource, _INTERNAL_STEP_OPERATION_KEY, True) 63 | 64 | __all__ = ["HFDatasetDataSource"] 65 | -------------------------------------------------------------------------------- /src/steps/data_sources/json_data_source.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Sequence 3 | 4 | from datasets import DatasetDict, load_dataset 5 | from datasets.fingerprint import Hasher 6 | 7 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY 8 | from .data_source import DataSource 9 | 10 | 11 | class JSONDataSource(DataSource): 12 | """Loads a JSON dataset from a local path. See :py:func:`datasets.load_dataset` for 13 | more details. 14 | 15 | Args: 16 | name: The name of the step. 17 | data_folder: The path to the dataset folder. 18 | data_files: The name of files from the folder to load. 19 | progress_interval: How often to log progress in seconds. 20 | force: Whether to force run the step (ignore saved results). 21 | verbose: Whether or not to print verbose logs. 22 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.). 23 | save_num_proc: The number of processes to use if saving to disk. 24 | save_num_shards: The number of shards on disk to save the dataset into. 25 | background: Whether to run the operation in the background. 26 | **config_kwargs: Additional keyword arguments to pass to 27 | :py:func:`datasets.load_dataset`. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name: str, 33 | data_folder: None | str = None, 34 | data_files: None | str | Sequence[str] = None, 35 | progress_interval: None | int = 60, 36 | force: bool = False, 37 | verbose: None | bool = None, 38 | log_level: None | int = None, 39 | save_num_proc: None | int = None, 40 | save_num_shards: None | int = None, 41 | background: bool = False, 42 | **config_kwargs, 43 | ): 44 | self.data_folder = data_folder 45 | self.data_files = data_files 46 | self.config_kwargs = config_kwargs 47 | super().__init__( 48 | name, 49 | data=None, # type: ignore[arg-type] 50 | progress_interval=progress_interval, 51 | force=force, 52 | verbose=verbose, 53 | log_level=log_level, 54 | save_num_proc=save_num_proc, 55 | save_num_shards=save_num_shards, 56 | background=background, 57 | ) 58 | 59 | def setup(self): 60 | pass 61 | 62 | def run(self): 63 | if isinstance(self.data_files, dict): 64 | raise ValueError( 65 | "You supplied a dict to data_files, multiple splits are not supported." 66 | ) 67 | result = load_dataset( 68 | "json", 69 | data_dir=self.data_folder, 70 | data_files=self.data_files, 71 | num_proc=self.save_num_proc, 72 | **self.config_kwargs, 73 | ) 74 | if isinstance(result, DatasetDict): 75 | result = result["train"] 76 | return result 77 | 78 | @cached_property 79 | def fingerprint(self) -> str: 80 | return Hasher.hash( 81 | [super().fingerprint, self.data_folder, self.data_files, self.config_kwargs] 82 | ) 83 | 84 | 85 | setattr(JSONDataSource, _INTERNAL_STEP_OPERATION_KEY, True) 86 | 87 | __all__ = ["JSONDataSource"] 88 | -------------------------------------------------------------------------------- /src/steps/data_sources/text_data_source.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Sequence 3 | 4 | from datasets import DatasetDict, load_dataset 5 | from datasets.fingerprint import Hasher 6 | 7 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY 8 | from .data_source import DataSource 9 | 10 | 11 | class TextDataSource(DataSource): 12 | """Loads a text dataset from a local path. See :py:func:`datasets.load_dataset` for 13 | more details. 14 | 15 | Args: 16 | name: The name of the step. 17 | data_folder: The path to the dataset folder. 18 | data_files: The name of files from the folder to load. 19 | progress_interval: How often to log progress in seconds. 20 | force: Whether to force run the step (ignore saved results). 21 | verbose: Whether or not to print verbose logs. 22 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.). 23 | save_num_proc: The number of processes to use if saving to disk. 24 | save_num_shards: The number of shards on disk to save the dataset into. 25 | background: Whether to run the operation in the background. 26 | **config_kwargs: Additional keyword arguments to pass to 27 | :py:func:`datasets.load_dataset`. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name: str, 33 | data_folder: None | str = None, 34 | data_files: None | str | Sequence[str] = None, 35 | progress_interval: None | int = 60, 36 | force: bool = False, 37 | verbose: None | bool = None, 38 | log_level: None | int = None, 39 | save_num_proc: None | int = None, 40 | save_num_shards: None | int = None, 41 | background: bool = False, 42 | **config_kwargs, 43 | ): 44 | self.data_folder = data_folder 45 | self.data_files = data_files 46 | self.config_kwargs = config_kwargs 47 | super().__init__( 48 | name, 49 | data=None, # type: ignore[arg-type] 50 | progress_interval=progress_interval, 51 | force=force, 52 | verbose=verbose, 53 | log_level=log_level, 54 | save_num_proc=save_num_proc, 55 | save_num_shards=save_num_shards, 56 | background=background, 57 | ) 58 | 59 | def setup(self): 60 | pass 61 | 62 | def run(self): 63 | if isinstance(self.data_files, dict): 64 | raise ValueError( 65 | "You supplied a dict to data_files, multiple splits are not supported." 66 | ) 67 | result = load_dataset( 68 | "text", 69 | data_dir=self.data_folder, 70 | data_files=self.data_files, 71 | num_proc=self.save_num_proc, 72 | **self.config_kwargs, 73 | ) 74 | if isinstance(result, DatasetDict): 75 | result = result["train"] 76 | return result 77 | 78 | @cached_property 79 | def fingerprint(self) -> str: 80 | return Hasher.hash( 81 | [super().fingerprint, self.data_folder, self.data_files, self.config_kwargs] 82 | ) 83 | 84 | 85 | setattr(TextDataSource, _INTERNAL_STEP_OPERATION_KEY, True) 86 | 87 | __all__ = ["TextDataSource"] 88 | -------------------------------------------------------------------------------- /src/steps/prompt/data_from_prompt.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from ..._cachable._cachable import _StrWithSeed 4 | from ._prompt_base import _PromptBase 5 | 6 | 7 | class DataFromPrompt(_PromptBase): 8 | """Generates ``n`` rows of data using an instruction with a 9 | :py:class:`~datadreamer.llms.LLM`.""" 10 | 11 | def setup(self): 12 | self._prompt_input_type = "none" 13 | self._register_prompt_args() 14 | self.register_arg( 15 | "instruction", 16 | required=True, 17 | help="The instruction to use to generate data.", 18 | ) 19 | self.register_arg( 20 | "n", required=True, help="The number of rows to generate from the prompt." 21 | ) 22 | self.register_arg( 23 | "temperature", 24 | required=False, 25 | default=1.0, 26 | help="The temperature to use when generating data.", 27 | ) 28 | self.register_arg( 29 | "top_p", 30 | required=False, 31 | default=1.0, 32 | help="The top_p to use when generating data.", 33 | ) 34 | self._register_prompt_optional_args() 35 | self._register_prompt_outputs() 36 | 37 | def run(self): 38 | # Get inputs and arguments 39 | args = self.args 40 | instruction = args.pop("instruction") 41 | n = args.pop("n") 42 | _seed = args.pop("_seed", None) 43 | 44 | def create_prompts(instruction, n, seed): 45 | for prompt_idx in range(n): 46 | yield _StrWithSeed( 47 | instruction, 48 | seed=((_seed, prompt_idx) if _seed is not None else prompt_idx), 49 | ) 50 | 51 | return self._run_prompts( 52 | args=args, 53 | prompts=partial(create_prompts, instruction, n, _seed), 54 | total_num_prompts=n, 55 | ) 56 | 57 | 58 | __all__ = ["DataFromPrompt"] 59 | -------------------------------------------------------------------------------- /src/steps/prompt/process_with_prompt.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from ._prompt_base import _PromptBase 4 | 5 | 6 | class ProcessWithPrompt(_PromptBase): 7 | """Processes a set of inputs using an instruction with a 8 | :py:class:`~datadreamer.llms.LLM`.""" 9 | 10 | def setup(self): 11 | self._register_prompt_inputs(prompt_input_type="input") 12 | self._register_prompt_args() 13 | self.register_arg( 14 | "instruction", 15 | required=True, 16 | help="An instruction that describes how to process the input.", 17 | ) 18 | self.register_arg( 19 | "input_label", 20 | required=False, 21 | default="Input:", 22 | help="The label to use for inputs.", 23 | ) 24 | self.register_arg( 25 | "instruction_label", 26 | required=False, 27 | default="Instruction:", 28 | help="The label to use for the instruction.", 29 | ) 30 | self.register_arg( 31 | "max_new_tokens", 32 | required=False, 33 | help="The maximum number of tokens to generate.", 34 | ) 35 | self.register_arg( 36 | "sep", 37 | required=False, 38 | default="\n\n", 39 | help="The separator to use between instructions and the input.", 40 | ) 41 | self._register_prompt_optional_args() 42 | self._register_prompt_outputs() 43 | 44 | def run(self): 45 | # Get inputs and arguments 46 | args = self.args 47 | llm = args["llm"] 48 | inputs = self.inputs["inputs"] 49 | input_label = args.pop("input_label") 50 | instruction_label = args.pop("instruction_label") 51 | max_new_tokens = args["max_new_tokens"] 52 | format_prompt_args = dict( 53 | max_new_tokens=max_new_tokens, 54 | end_instruction=( 55 | f"{instruction_label} {args.pop('instruction')}" 56 | if instruction_label 57 | else args.pop("instruction") 58 | ), 59 | sep=args.pop("sep"), 60 | ) 61 | 62 | def create_process_input_with_instruction_prompts( 63 | llm, inputs, input_label, format_prompt_args 64 | ): 65 | for input in inputs: 66 | beg_instruction = f"{input_label} {input}" if input_label else input 67 | yield llm.format_prompt( 68 | beg_instruction=beg_instruction, **format_prompt_args 69 | ) 70 | 71 | # Generate 72 | return self._run_prompts( 73 | args=args, 74 | prompts=partial( 75 | create_process_input_with_instruction_prompts, 76 | llm, 77 | inputs, 78 | input_label, 79 | format_prompt_args, 80 | ), 81 | ) 82 | 83 | 84 | __all__ = ["ProcessWithPrompt"] 85 | -------------------------------------------------------------------------------- /src/steps/prompt/prompt.py: -------------------------------------------------------------------------------- 1 | from ._prompt_base import _PromptBase 2 | 3 | 4 | class Prompt(_PromptBase): 5 | "Runs a set of prompts against a :py:class:`~datadreamer.llms.LLM`." 6 | 7 | def setup(self): 8 | self._register_prompt_inputs() 9 | self._register_prompt_args() 10 | self._register_prompt_optional_args() 11 | self._register_prompt_outputs() 12 | 13 | def run(self): 14 | return self._run_prompts(args=self.args) 15 | 16 | 17 | __all__ = ["Prompt"] 18 | -------------------------------------------------------------------------------- /src/steps/step_background.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from multiprocessing.dummy import Pool as ThreadPool 3 | from time import sleep 4 | from typing import TYPE_CHECKING, Callable 5 | 6 | from .. import DataDreamer 7 | from ..errors import StepOutputError 8 | from ..utils.background_utils import get_thread_id, run_in_background_thread 9 | 10 | if TYPE_CHECKING: # pragma: no cover 11 | from .step import Step 12 | 13 | 14 | def _check_step_output(step: "Step") -> bool: 15 | try: 16 | step.output # noqa: B018 17 | return True 18 | except StepOutputError: 19 | return False 20 | 21 | 22 | def _waiter(steps, poll_interval=1.0): 23 | while len(steps) > 0: 24 | step = steps[-1] 25 | if _check_step_output(step): 26 | steps.pop() 27 | else: 28 | sleep(poll_interval) 29 | 30 | 31 | def wait(*steps: "Step", poll_interval=1.0): 32 | """Wait for all steps to complete if they are running in the background. 33 | 34 | Args: 35 | poll_interval: How often to poll in seconds. 36 | """ 37 | from ..steps import Step 38 | 39 | if not all([isinstance(s, Step) for s in steps]): 40 | raise TypeError("All arguments to wait() must be of type Step.") 41 | if all([_check_step_output(step) for step in steps]): 42 | return 43 | steps_list = list(steps) 44 | wait_thread = run_in_background_thread( 45 | _waiter, steps_list, poll_interval=poll_interval 46 | ) 47 | wait_thread.join() 48 | 49 | 50 | def concurrent(*funcs: Callable): 51 | """Run a set of functions (which run steps) concurrently. 52 | 53 | Args: 54 | *funcs: The functions to run concurrently. 55 | """ 56 | parent_thread_id = get_thread_id() 57 | 58 | def wrapper_func(parent_thread_id, func): 59 | if DataDreamer.initialized(): 60 | DataDreamer._register_child_thread(parent_thread_id) 61 | return func() 62 | 63 | if not all([callable(f) for f in funcs]): 64 | raise TypeError("All arguments to concurrent() must be functions.") 65 | thread_pool = ThreadPool(len(funcs)) 66 | results = thread_pool.map(partial(wrapper_func, parent_thread_id), funcs) 67 | thread_pool.close() 68 | thread_pool.join() 69 | return results 70 | -------------------------------------------------------------------------------- /src/steps/step_export.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | from datasets import DatasetDict 5 | 6 | from ..datasets import OutputDataset 7 | from ..pickling import unpickle_transform 8 | 9 | 10 | def _path_to_split_paths(path: str, dataset_dict: DatasetDict) -> dict[str, str]: 11 | os.makedirs(os.path.dirname(path), exist_ok=True) 12 | base, extension = os.path.splitext(path) 13 | paths: dict[str, str] = {} 14 | for split_name in dataset_dict: 15 | if split_name == "validation": 16 | path_split_name = "val" 17 | else: 18 | path_split_name = split_name 19 | split_path = f"{base}.{path_split_name}{extension}" 20 | paths[split_name] = split_path 21 | return paths 22 | 23 | 24 | def _unpickle_export(export: DatasetDict | list | dict, output_dataset: OutputDataset): 25 | if output_dataset._pickled: 26 | if isinstance(export, DatasetDict): 27 | export.set_transform( 28 | partial( 29 | unpickle_transform, features=output_dataset._features, batched=True 30 | ) 31 | ) 32 | return export 33 | elif isinstance(export, list): 34 | return [ 35 | unpickle_transform( 36 | row, features=output_dataset._features, batched=False 37 | ) 38 | for row in export 39 | ] 40 | else: 41 | return unpickle_transform( 42 | export, features=output_dataset._features, batched=True 43 | ) 44 | else: 45 | return export 46 | 47 | 48 | __all__ = ["_path_to_split_paths", "_unpickle_export"] 49 | -------------------------------------------------------------------------------- /src/steps/tasks/embed.py: -------------------------------------------------------------------------------- 1 | from itertools import tee 2 | 3 | from ..data_card import DataCardType 4 | from ..step import Step 5 | from ..step_output import LazyRows 6 | 7 | 8 | class Embed(Step): 9 | "Embeds a set of texts with an :py:class:`~datadreamer.embedders.Embedder`." 10 | 11 | def setup(self): 12 | self.register_input("texts", help="The texts to embed.") 13 | self.register_arg("embedder", help="The Embedder to use.") 14 | self.register_arg( 15 | "truncate", 16 | required=False, 17 | default=False, 18 | help="Whether or not to truncate inputs.", 19 | ) 20 | self.register_arg( 21 | "instruction", 22 | required=False, 23 | help="The instruction to prefix inputs to the embedding model with.", 24 | ) 25 | self.register_arg( 26 | "lazy", required=False, default=False, help="Whether to run lazily or not." 27 | ) 28 | self.register_arg( 29 | "**kwargs", 30 | required=False, 31 | help="Any other arguments you want to pass to the .run() method of the Embedder.", 32 | ) 33 | self.register_output("texts", help="The texts that were embedded.") 34 | self.register_output("embeddings", help="The embeddings by the Embedder.") 35 | 36 | def run(self): 37 | args = self.args 38 | 39 | # Get inputs and arguments 40 | embedder = args.pop("embedder") 41 | lazy = args.pop("lazy") 42 | 43 | # Register trace info from the Embedder model 44 | if hasattr(embedder, "model_name"): 45 | self.register_data_card(DataCardType.MODEL_NAME, embedder.model_name) 46 | self.register_data_card(DataCardType.MODEL_CARD, embedder.model_card) 47 | self.register_data_card(DataCardType.LICENSE, embedder.license) 48 | for citation in embedder.citation or []: 49 | self.register_data_card(DataCardType.CITATION, citation) 50 | 51 | # Get the total number of texts 52 | texts = self.inputs["texts"] 53 | total_num_texts = texts.num_rows 54 | 55 | # Define a function that yields embeddings 56 | def get_embeddings(): 57 | # Get an iterator over texts 58 | texts_iter_1, texts_iter_2 = tee(iter(texts), 2) 59 | 60 | # Generate 61 | embeddings_iter = iter( 62 | embedder.run( 63 | texts=texts_iter_1, 64 | progress_interval=self.progress_interval, 65 | total_num_texts=total_num_texts, 66 | return_generator=True, 67 | _step=self, 68 | **args, 69 | ) 70 | ) 71 | 72 | yield from zip(texts_iter_2, embeddings_iter) 73 | 74 | # Return embeddings 75 | return LazyRows( 76 | get_embeddings, 77 | total_num_rows=total_num_texts, 78 | auto_progress=False, 79 | save=(not lazy), 80 | ) 81 | 82 | 83 | __all__ = ["Embed"] 84 | -------------------------------------------------------------------------------- /src/steps/tasks/run_task_model.py: -------------------------------------------------------------------------------- 1 | from itertools import tee 2 | 3 | from ..data_card import DataCardType 4 | from ..step import Step 5 | from ..step_output import LazyRows 6 | 7 | 8 | class RunTaskModel(Step): 9 | "Runs a set of texts against a :py:class:`~datadreamer.task_models.TaskModel`." 10 | 11 | def setup(self): 12 | self.register_input("texts", help="The texts to process with the TaskModel.") 13 | self.register_arg("model", help="The TaskModel to use.") 14 | self.register_arg( 15 | "truncate", 16 | required=False, 17 | default=False, 18 | help="Whether or not to truncate inputs.", 19 | ) 20 | self.register_arg( 21 | "lazy", required=False, default=False, help="Whether to run lazily or not." 22 | ) 23 | self.register_arg( 24 | "**kwargs", 25 | required=False, 26 | help="Any other arguments you want to pass to the .run() method of the TaskModel.", 27 | ) 28 | self.register_output("texts", help="The texts processed with the TaskModel.") 29 | self.register_output("results", help="The results from the TaskModel.") 30 | 31 | def run(self): 32 | args = self.args 33 | 34 | # Get inputs and arguments 35 | model = args.pop("model") 36 | lazy = args.pop("lazy") 37 | 38 | # Register trace info from the TaskModel model 39 | if hasattr(model, "model_name"): 40 | self.register_data_card(DataCardType.MODEL_NAME, model.model_name) 41 | self.register_data_card(DataCardType.MODEL_CARD, model.model_card) 42 | self.register_data_card(DataCardType.LICENSE, model.license) 43 | for citation in model.citation or []: 44 | self.register_data_card(DataCardType.CITATION, citation) 45 | 46 | # Get the total number of texts 47 | texts = self.inputs["texts"] 48 | total_num_texts = texts.num_rows 49 | 50 | # Define a function that yields results 51 | def get_results(): 52 | # Get an iterator over texts 53 | texts_iter_1, texts_iter_2 = tee(iter(texts), 2) 54 | 55 | # Generate 56 | results_iter = iter( 57 | model.run( 58 | texts=texts_iter_1, 59 | progress_interval=self.progress_interval, 60 | total_num_texts=total_num_texts, 61 | return_generator=True, 62 | _step=self, 63 | **args, 64 | ) 65 | ) 66 | 67 | yield from zip(texts_iter_2, results_iter) 68 | 69 | # Return results 70 | return LazyRows( 71 | get_results, 72 | total_num_rows=total_num_texts, 73 | auto_progress=False, 74 | save=(not lazy), 75 | ) 76 | 77 | 78 | __all__ = ["RunTaskModel"] 79 | -------------------------------------------------------------------------------- /src/task_models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | :py:class:`TaskModel` objects help perform some sort of arbitrary NLP task 3 | (classification, etc.). 4 | All task models derive from the :py:class:`TaskModel` base class. 5 | 6 | .. tip:: 7 | 8 | Instead of using :py:meth:`~TaskModel.run` directly, use a 9 | :py:class:`step ` that takes a :py:class:`TaskModel` as an 10 | ``args`` argument such as :py:class:`~datadreamer.steps.RunTaskModel`. 11 | 12 | Caching 13 | ======= 14 | Task models internally perform caching to disk, so if you run the same text multiple 15 | times, the task model will only run once and then cache the results for future runs. 16 | """ 17 | 18 | from .hf_classification_task_model import HFClassificationTaskModel 19 | from .parallel_task_model import ParallelTaskModel 20 | from .task_model import TaskModel 21 | 22 | __all__ = ["TaskModel", "HFClassificationTaskModel", "ParallelTaskModel"] 23 | -------------------------------------------------------------------------------- /src/task_models/parallel_task_model.py: -------------------------------------------------------------------------------- 1 | from typing import Generator, Iterable, cast 2 | 3 | from .._cachable import _ParallelCachable 4 | from .task_model import DEFAULT_BATCH_SIZE, TaskModel 5 | 6 | 7 | class ParallelTaskModel(_ParallelCachable, TaskModel): 8 | def __init__(self, *task_models: TaskModel): 9 | """ 10 | Creates a task model that will run multiple task models in parallel. See 11 | :doc:`running models in parallel 12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>` 13 | for more details. 14 | 15 | Args: 16 | *task_models: The task models to run in parallel. 17 | """ 18 | super().__init__(*task_models, cls=TaskModel) 19 | self.task_models = cast(list[TaskModel], self.cachables) 20 | 21 | def count_tokens(self, value: str) -> int: 22 | """Counts the number of tokens in a string. 23 | 24 | Args: 25 | value: The string to count tokens for. 26 | 27 | Returns: 28 | The number of tokens in the string. 29 | """ 30 | pass 31 | return self.task_models[0].count_tokens(value=value) 32 | 33 | @property 34 | def model_max_length(self) -> int: # pragma: no cover 35 | return self.task_models[0].model_max_length 36 | 37 | def run( # type:ignore[override] 38 | self, texts: Iterable[str], *args, **kwargs 39 | ) -> Generator[str | list[str], None, None] | list[str | list[str]]: 40 | kwargs["batch_size"] = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE) 41 | results_generator = self._run_in_parallel(texts, *args, **kwargs) 42 | if not kwargs.get("return_generator", False): 43 | return list(results_generator) 44 | else: 45 | return results_generator 46 | 47 | def unload_model(self): 48 | for llm in self.task_models: 49 | llm.unload_model() 50 | 51 | 52 | __all__ = ["ParallelTaskModel"] 53 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from .. import project 5 | from ..utils.fs_utils import clear_dir 6 | 7 | 8 | # Register pytest fixtures 9 | def refactor(string: str) -> str: 10 | return string.replace("/", ".").replace("\\", ".").replace(".py", "") 11 | 12 | 13 | pytest_plugins = [ 14 | refactor(fixture) 15 | for fixture in glob("src/tests/test_utils/fixtures/*.py") 16 | if "__" not in fixture 17 | ] 18 | 19 | # Set the initial cwd 20 | project.INITIAL_CWD = os.path.abspath(os.getcwd()) 21 | 22 | # Clear the tests data directory 23 | try: 24 | clear_dir("./.tests_data") 25 | except FileNotFoundError: 26 | pass 27 | -------------------------------------------------------------------------------- /src/tests/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/datasets/__init__.py -------------------------------------------------------------------------------- /src/tests/embedders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/embedders/__init__.py -------------------------------------------------------------------------------- /src/tests/llms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/llms/__init__.py -------------------------------------------------------------------------------- /src/tests/retrievers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/retrievers/__init__.py -------------------------------------------------------------------------------- /src/tests/steps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/steps/__init__.py -------------------------------------------------------------------------------- /src/tests/steps/prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/steps/prompt/__init__.py -------------------------------------------------------------------------------- /src/tests/steps/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/steps/tasks/__init__.py -------------------------------------------------------------------------------- /src/tests/task_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/task_models/__init__.py -------------------------------------------------------------------------------- /src/tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from ..__cli__ import _main 2 | 3 | 4 | class TestCli: 5 | def test_help(self, cli_runner): 6 | result = cli_runner.invoke(_main, ["--help"]) 7 | assert result.output.count("Show this message and exit.") == 1 8 | -------------------------------------------------------------------------------- /src/tests/test_package.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | 4 | import pytest 5 | from mypy import api 6 | 7 | from .. import __version__ 8 | from ..project import RUNNING_IN_PYTEST 9 | 10 | 11 | class TestPackage: 12 | def test_version(self): 13 | assert len(__version__.split(".")) == 3 14 | 15 | def test_running_in_pytest(self): 16 | assert RUNNING_IN_PYTEST 17 | 18 | @pytest.mark.skipif( 19 | "GITHUB_ACTIONS" not in os.environ and "PROJECT_CLUSTER" not in os.environ, 20 | reason="only run on CI", 21 | ) 22 | def test_python_version(self): 23 | with open("./scripts/.python-version", "r") as f: 24 | python_version = f.read().strip() 25 | assert python_version == platform.python_version() 26 | 27 | def test_mypy(self): 28 | result = api.run(["src/", "--sqlite-cache", "--explicit-package-bases"]) 29 | if result[0]: 30 | print("\nType checking report:\n") 31 | print(result[0]) 32 | if result[1]: 33 | print("\nError report:\n") 34 | print(result[1]) 35 | assert result[2] == 0 36 | -------------------------------------------------------------------------------- /src/tests/test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/test_utils/__init__.py -------------------------------------------------------------------------------- /src/tests/test_utils/config.py: -------------------------------------------------------------------------------- 1 | TEST_DIR = "./.tests_data" 2 | -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/test_utils/fixtures/__init__.py -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/bitsandbytes_fixture.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | 5 | imported_banb = False 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def reset_bandb_import(): 10 | """We need to import banb as it throws a warning on first import""" 11 | global imported_banb 12 | 13 | # Code that will run before your test 14 | with warnings.catch_warnings(): 15 | warnings.filterwarnings( 16 | "ignore", 17 | category=UserWarning, 18 | message="The installed version of bitsandbytes was compiled without GPU.*", 19 | module="bitsandbytes.cextension", 20 | ) 21 | if not imported_banb: 22 | print( 23 | "\nDataDreamer test suite is importing bitsandbyes," 24 | " ignore any warnings below this...\n" 25 | ) 26 | import bitsandbytes # noqa: F401 27 | 28 | if not imported_banb: 29 | print("\nDataDreamer test suite is done importing bitsandbyes.\n") 30 | 31 | imported_banb = True 32 | 33 | yield 34 | # Code that will run after your test 35 | -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/clear_space.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from ....utils.fs_utils import clear_dir 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def clear_github_space(): 10 | yield 11 | if "GITHUB_ACTIONS" in os.environ: 12 | # Clear the tests data directory to make more disk space available 13 | try: 14 | clear_dir("./.tests_data") 15 | os.system("rm -rf ~/.cache/huggingface/") 16 | except FileNotFoundError: 17 | pass 18 | -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/cli_runner.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from click.testing import CliRunner 3 | 4 | 5 | @pytest.fixture 6 | def cli_runner(): 7 | return CliRunner() 8 | -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/create_datadreamer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | from typing import Callable 4 | 5 | import pytest 6 | 7 | from .... import DataDreamer 8 | from ..config import TEST_DIR 9 | 10 | 11 | @pytest.fixture 12 | def create_datadreamer() -> Callable[..., DataDreamer]: 13 | def _create_datadreamer(path: None | str = None, **kwargs) -> DataDreamer: 14 | if path is None: 15 | path = uuid.uuid4().hex[0:10] 16 | if path == ":memory:": 17 | return DataDreamer(path, **kwargs) 18 | else: 19 | return DataDreamer(os.path.join(TEST_DIR, path), **kwargs) 20 | 21 | return _create_datadreamer 22 | -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/create_test_step.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import pytest 4 | 5 | from ....steps import Step 6 | from ....steps.step import _INTERNAL_TEST_KEY 7 | 8 | 9 | @pytest.fixture 10 | def create_test_step() -> Callable[..., Step]: 11 | def _create_test_step( 12 | name="my-step", 13 | inputs=None, 14 | args=None, 15 | outputs=None, 16 | output_names=None, 17 | setup=None, 18 | **kwargs, 19 | ) -> Step: 20 | if output_names is None: 21 | output_names = [] 22 | 23 | class TestStep(Step): 24 | def setup(self): 25 | if isinstance(output_names, str): 26 | self.register_output(output_names) 27 | else: 28 | for o in output_names: 29 | self.register_output(o) 30 | if setup is not None: 31 | setup(self) 32 | 33 | setattr(TestStep, _INTERNAL_TEST_KEY, True) 34 | 35 | return TestStep(name, inputs=inputs, args=args, outputs=outputs, **kwargs) 36 | 37 | return _create_test_step 38 | -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/mock_llm.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import pytest 4 | 5 | from ....llms import LLM 6 | 7 | 8 | @pytest.fixture 9 | def mock_llm( 10 | allowed_kwargs=frozenset( 11 | { 12 | "inputs", 13 | "batch_size", 14 | "max_new_tokens", 15 | "temperature", 16 | "top_p", 17 | "n", 18 | "stop", 19 | "repetition_penalty", 20 | "logit_bias", 21 | "seed", 22 | "max_length_func", 23 | "cached_tokenizer", 24 | } 25 | ), 26 | ) -> Callable[..., LLM]: 27 | def _mock_llm(llm: LLM, responses: dict[str, str]) -> LLM: 28 | def _run_batch_mocked(**kwargs): 29 | for kwarg in kwargs: 30 | assert kwarg in allowed_kwargs, f"LLM got unexpected keyword: {kwarg}" 31 | return [responses[prompt] for prompt in kwargs["inputs"]] 32 | 33 | llm._run_batch = _run_batch_mocked # type: ignore[attr-defined] 34 | 35 | return llm 36 | 37 | return _mock_llm 38 | -------------------------------------------------------------------------------- /src/tests/test_utils/fixtures/restore_os_environ.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(autouse=True) 7 | def restore_os_environ(): 8 | orig_environ = os.environ.copy() 9 | yield 10 | os.environ.clear() 11 | os.environ.update(orig_environ) 12 | -------------------------------------------------------------------------------- /src/tests/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/trainers/__init__.py -------------------------------------------------------------------------------- /src/tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/utils/__init__.py -------------------------------------------------------------------------------- /src/trainers/_vendored/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/trainers/_vendored/__init__.py -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/arg_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | 4 | class Default: 5 | def __init__(self, name: str): 6 | self.name = name 7 | 8 | def __repr__(self): # pragma: no cover 9 | return self.name 10 | 11 | 12 | DEFAULT = Default("DEFAULT") 13 | AUTO = Default("AUTO") 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | def default_to(val: T | Default, default_val: T) -> T: 19 | if isinstance(val, Default): 20 | return default_val 21 | else: 22 | return val 23 | -------------------------------------------------------------------------------- /src/utils/collection_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Iterable 3 | 4 | 5 | def uniq_str(collection: Iterable) -> str: 6 | seen = set() 7 | uniqed = tuple([x for x in collection if not (x in seen or seen.add(x))]) # type: ignore[func-returns-value] 8 | return re.sub(r",}$", "}", f"{{{str(uniqed)[1:-1]}}}") if uniqed else "{}" 9 | 10 | 11 | def sort_keys( 12 | d: dict[Any, Any], key_order: list[Any] 13 | ) -> dict[Any, Any]: # pragma: no cover 14 | d_copy = dict(d) 15 | all_keys = set(d_copy.keys()) 16 | other_keys = all_keys.difference(set(key_order)) 17 | d.clear() 18 | for key in key_order: 19 | if key in d_copy: 20 | d[key] = d_copy[key] 21 | for key in other_keys: 22 | d[key] = d_copy[key] 23 | return d 24 | -------------------------------------------------------------------------------- /src/utils/fs_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from typing import Any 4 | 5 | 6 | def mkdir(path: str): 7 | try: 8 | os.makedirs(path, exist_ok=True) 9 | except FileExistsError: # pragma: no cover 10 | pass 11 | 12 | 13 | def safe_fn(value: str, allow_slashes=False, to_lower=False) -> str: 14 | if allow_slashes: 15 | value = value.replace(" / ", "/") 16 | value = value.replace(" ", "-") 17 | if not allow_slashes: 18 | value = value.replace("/", "-") 19 | safe_chars: Any = ("-", "_") 20 | if allow_slashes: 21 | safe_chars = ("-", "_", "/") 22 | strip_chars = "".join(c for c in value if c.isalnum() or c in safe_chars).strip() 23 | if to_lower: 24 | return strip_chars.lower() 25 | else: 26 | return strip_chars 27 | 28 | 29 | def rm_dir(path: str): 30 | if os.path.isdir(path): 31 | shutil.rmtree(path, ignore_errors=True) 32 | 33 | 34 | def clear_dir(path: str): 35 | if os.path.isdir(path): 36 | shutil.rmtree(path, ignore_errors=True) 37 | mkdir(path) 38 | 39 | 40 | def move_dir(src_path: str, dst_path: str): 41 | mkdir(src_path) 42 | clear_dir(dst_path) 43 | shutil.copytree(src_path, dst_path, dirs_exist_ok=True) 44 | clear_dir(src_path) 45 | 46 | 47 | def dir_size(path: str) -> int: # pragma: no cover 48 | total_size = 0 49 | for dirpath, _, filenames in os.walk(path): 50 | for f in filenames: 51 | fp = os.path.join(dirpath, f) 52 | if not os.path.islink(fp): 53 | total_size += os.path.getsize(fp) 54 | return total_size 55 | -------------------------------------------------------------------------------- /src/utils/ring_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import ring 4 | 5 | 6 | def lru(*args, **kwargs): 7 | if "SPHINX_BUILD" in os.environ: # pragma: no cover 8 | 9 | def noop_decorator(func): 10 | return func 11 | 12 | return noop_decorator 13 | else: 14 | return ring.lru(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /src/utils/str_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from itertools import chain 3 | 4 | 5 | def replace_many(text, substitions): 6 | pattern = re.compile("|".join(map(re.escape, substitions.keys()))) 7 | return pattern.sub(lambda match: substitions[match.group(0)], text) 8 | 9 | 10 | def get_templated_var_names(templated_str): 11 | escaped_components = re.split(r"{{|}}", templated_str) # Ignore {{ and }} 12 | template_var_pattern = r"{([a-zA-Z0-9:_\.]+)}" 13 | return list( 14 | chain.from_iterable( 15 | [ 16 | re.findall(template_var_pattern, component) 17 | for component in escaped_components 18 | ] 19 | ) 20 | ) 21 | 22 | 23 | def replace_templated_vars(templated_str, var_name_to_values): 24 | return replace_many( 25 | templated_str, {"{" + k + "}": str(v) for k, v in var_name_to_values.items()} 26 | ) 27 | -------------------------------------------------------------------------------- /src/utils/time_utils.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | TIME_DURATION_UNITS = ( 4 | ("week", 60 * 60 * 24 * 7), 5 | ("day", 60 * 60 * 24), 6 | ("hour", 60 * 60), 7 | ("min", 60), 8 | ("sec", 1), 9 | ) 10 | 11 | 12 | def human_time_duration(seconds: float) -> str: # pragma: no cover 13 | parts = [] 14 | for unit, div in TIME_DURATION_UNITS: 15 | amount, seconds = divmod(int(seconds), div) 16 | if amount > 0: 17 | parts.append("{} {}{}".format(amount, unit, "" if amount == 1 else "s")) 18 | if len(parts) == 0: 19 | return "0 secs" 20 | return ", ".join(parts) 21 | 22 | 23 | def progress_eta(progress: float, start_time: float) -> str: 24 | elapsed_time = time() - start_time 25 | eta = ( 26 | human_time_duration((elapsed_time / progress) - elapsed_time) 27 | if progress > 0 28 | else "calculating..." 29 | ) 30 | return f"(Estimated time left: {eta})" 31 | --------------------------------------------------------------------------------