├── .cirun.yml
├── .github
└── workflows
│ ├── cla.yml
│ ├── release.yml
│ └── tests.yml
├── .gitignore
├── LICENSE.txt
├── README.md
├── docs
├── .gitignore
├── Makefile
├── make.bat
└── source
│ ├── __init__.py
│ ├── _static
│ ├── css
│ │ └── custom.css
│ ├── images
│ │ ├── cover.png
│ │ ├── demo.svg
│ │ ├── demo_code.png
│ │ ├── favicon.svg
│ │ └── logo.svg
│ └── js
│ │ └── custom.js
│ ├── _templates
│ ├── apidoc
│ │ ├── module.rst_t
│ │ └── package.rst_t
│ ├── base.html
│ ├── metatags.html
│ ├── page.html
│ └── sidebar
│ │ └── brand.html
│ ├── conf.py
│ ├── docstrings.py
│ ├── index.rst
│ └── pages
│ ├── advanced_usage
│ ├── caching_and_saved_outputs.rst
│ ├── creating_a_new_datadreamer_...
│ │ ├── index.rst
│ │ ├── llm.rst
│ │ ├── other.rst
│ │ ├── step.rst
│ │ └── trainer.rst
│ ├── index.rst
│ ├── parallelization
│ │ ├── index.rst
│ │ ├── running_models_on_multiple_gpus.rst
│ │ ├── running_steps_in_parallel.rst
│ │ └── training_models_on_multiple_gpus.rst
│ ├── parameter_efficient_training.rst
│ └── quantization.rst
│ ├── contributing.rst
│ └── get_started
│ ├── installation.rst
│ ├── motivation_and_design.rst
│ ├── overview_guide.rst
│ └── quick_tour
│ ├── abstract_to_tweet.rst
│ ├── aligning.rst
│ ├── attributed_prompts.rst
│ ├── bootstrapping_machine_translation.rst
│ ├── dataset_augmentation.rst
│ ├── dataset_cleaning.rst
│ ├── index.rst
│ ├── instruction_tuning.rst
│ ├── openai_distillation.rst
│ └── self_rewarding.rst
├── pyproject.toml
├── scripts
├── .cluster
│ ├── .gitignore
│ ├── _boot.sh
│ ├── _command.sh
│ ├── _lock.sh
│ ├── _setup.sh
│ ├── args_log.sh
│ ├── config_log.sh
│ ├── direct
│ │ ├── .gitignore
│ │ ├── _direct_config.sh
│ │ ├── _direct_env.sh
│ │ ├── cancel.sh
│ │ ├── cancel_all.sh
│ │ ├── interactive_submit.sh
│ │ ├── reset_venv.sh
│ │ ├── status.sh
│ │ ├── status_all.sh
│ │ └── submit.sh
│ ├── full_reset.sh
│ ├── install_log.sh
│ ├── reset.sh
│ ├── reset_venv.sh
│ ├── series_log.sh
│ ├── sge
│ │ ├── .gitignore
│ │ ├── _qsub_config.sh
│ │ ├── cancel.sh
│ │ ├── cancel_all.sh
│ │ ├── htop.sh
│ │ ├── interactive_submit.sh
│ │ ├── reset_venv.sh
│ │ ├── shell.sh
│ │ ├── status.sh
│ │ ├── status_all.sh
│ │ └── submit.sh
│ ├── slurm
│ │ ├── .gitignore
│ │ ├── _sbatch_config.sh
│ │ ├── cancel.sh
│ │ ├── cancel_all.sh
│ │ ├── htop.sh
│ │ ├── interactive_submit.sh
│ │ ├── nvidia-smi.sh
│ │ ├── reset_venv.sh
│ │ ├── shell.sh
│ │ ├── status.sh
│ │ ├── status_all.sh
│ │ └── submit.sh
│ ├── stderr_log.sh
│ ├── stdout_log.sh
│ ├── submission_log.sh
│ └── tag.sh
├── .githooks
│ ├── post-checkout
│ └── pre-commit
├── .python-version
├── .vscode
│ ├── extensions.json
│ ├── settings.json
│ └── tasks.json
├── docs.sh
├── format.sh
├── lint.sh
├── package.sh
├── package_publish.sh
├── project.env
└── run.sh
└── src
├── .env
├── .gitignore
├── .secrets.template.env
├── __cli__.py
├── __init__.py
├── __main__.py
├── _cachable
├── __init__.py
├── _cachable.py
└── _parallel_cachable.py
├── _patches
├── __init__.py
├── datasets_reset_state_hack.py
└── setfit_import_hack.py
├── _stubs
├── .gitkeep
└── datasets
│ └── __init__.pyi
├── datadreamer.py
├── datasets
├── __init__.py
├── datasets.py
└── utils.py
├── embedders
├── __init__.py
├── embedder.py
├── openai_embedder.py
├── parallel_embedder.py
├── sentence_transformers_embedder.py
└── together_embedder.py
├── errors
├── __init__.py
└── steps
│ ├── __init__.py
│ └── step.py
├── llms
├── __init__.py
├── _chat_prompt_templates.py
├── _litellm.py
├── _llm_api.py
├── _tokenizers.py
├── ai21.py
├── anthropic.py
├── bedrock.py
├── cohere.py
├── ctransformers.py
├── google_ai_studio.py
├── hf_api_endpoint.py
├── hf_transformers.py
├── llm.py
├── mistral_ai.py
├── openai.py
├── openai_assistant.py
├── parallel_llm.py
├── petals.py
├── together.py
├── vertex_ai.py
└── vllm.py
├── logging
├── __init__.py
└── logger.py
├── pickling
├── __init__.py
└── pickle.py
├── project
├── __init__.py
├── builtin_tasks.py
├── debug.py
├── devices.py
├── environment.py
├── pennnlp.py
├── persistent_storage.py
├── report.py
└── serve.py
├── py.typed
├── requirements-accelerator-device.txt
├── requirements-cpu.txt
├── requirements-dev.txt
├── requirements-test.txt
├── requirements.txt
├── retrievers
├── __init__.py
├── embedding_retriever.py
├── parallel_retriever.py
└── retriever.py
├── steps
├── __init__.py
├── data_card.py
├── data_sources
│ ├── csv_data_source.py
│ ├── data_source.py
│ ├── hf_dataset_data_source.py
│ ├── hf_hub_data_source.py
│ ├── json_data_source.py
│ └── text_data_source.py
├── prompt
│ ├── _prompt_base.py
│ ├── data_from_attributed_prompt.py
│ ├── data_from_prompt.py
│ ├── few_shot_prompt.py
│ ├── few_shot_prompt_with_retrieval.py
│ ├── filter_with_prompt.py
│ ├── judge_generation_pairs_with_prompt.py
│ ├── judge_pairs_with_prompt.py
│ ├── process_with_prompt.py
│ ├── prompt.py
│ ├── rag_prompt.py
│ └── rank_with_prompt.py
├── step.py
├── step_background.py
├── step_export.py
├── step_operations.py
├── step_output.py
└── tasks
│ ├── cosine_similarity.py
│ ├── embed.py
│ ├── retrieve.py
│ └── run_task_model.py
├── task_models
├── __init__.py
├── hf_classification_task_model.py
├── parallel_task_model.py
└── task_model.py
├── tests
├── __init__.py
├── conftest.py
├── datasets
│ ├── __init__.py
│ ├── test_datasets.py
│ └── test_utils.py
├── embedders
│ ├── __init__.py
│ └── test_embedders.py
├── llms
│ ├── __init__.py
│ └── test_llms.py
├── retrievers
│ ├── __init__.py
│ └── test_retrievers.py
├── steps
│ ├── __init__.py
│ ├── prompt
│ │ ├── __init__.py
│ │ └── test_prompt.py
│ ├── tasks
│ │ ├── __init__.py
│ │ └── test_tasks.py
│ ├── test_data_sources.py
│ ├── test_step.py
│ ├── test_step_background.py
│ ├── test_step_export.py
│ ├── test_step_operations.py
│ └── test_step_output.py
├── task_models
│ ├── __init__.py
│ └── test_task_models.py
├── test_cli.py
├── test_datadreamer.py
├── test_package.py
├── test_utils
│ ├── __init__.py
│ ├── config.py
│ └── fixtures
│ │ ├── __init__.py
│ │ ├── bitsandbytes_fixture.py
│ │ ├── clear_space.py
│ │ ├── cli_runner.py
│ │ ├── create_datadreamer.py
│ │ ├── create_test_step.py
│ │ ├── mock_llm.py
│ │ └── restore_os_environ.py
├── trainers
│ ├── __init__.py
│ ├── test_distributed.py
│ └── test_trainers.py
└── utils
│ ├── __init__.py
│ └── test_device_utils.py
├── trainers
├── __init__.py
├── _train_hf_base.py
├── _vendored
│ ├── __init__.py
│ ├── _dpo_helper.py
│ ├── _sentence_transformer_helper.py
│ ├── _setfit_helper.py
│ └── dpo_trainer.py
├── train_hf_classifier.py
├── train_hf_dpo.py
├── train_hf_finetune.py
├── train_hf_ppo.py
├── train_hf_reward_model.py
├── train_openai_finetune.py
├── train_sentence_transformer.py
├── train_setfit_classifier.py
└── trainer.py
└── utils
├── __init__.py
├── arg_utils.py
├── background_utils.py
├── collection_utils.py
├── device_utils.py
├── distributed_utils.py
├── fingerprint_utils.py
├── fs_utils.py
├── hf_chat_prompt_templates.py
├── hf_hub_utils.py
├── hf_model_utils.py
├── hf_structured_decoding_utils.py
├── hf_training_utils.py
├── import_utils.py
├── ring_utils.py
├── str_utils.py
└── time_utils.py
/.cirun.yml:
--------------------------------------------------------------------------------
1 | runners:
2 | - name: "aws-runner"
3 | cloud: "aws"
4 | region: us-east-1
5 | instance_type: "t3.xlarge"
6 | # Ubuntu-22.04, ami image -> Created via EC2 Image Builder: https://us-east-1.console.aws.amazon.com/imagebuilder/home?region=us-east-1#/viewPipelines
7 | machine_image: "ami-0dac1a7772e1ccf34"
8 | preemptible: true
9 | labels:
10 | - "cirun-aws-runner"
--------------------------------------------------------------------------------
/.github/workflows/cla.yml:
--------------------------------------------------------------------------------
1 | name: "CLA Assistant"
2 | on:
3 | issue_comment:
4 | types: [created]
5 | pull_request_target:
6 | types: [opened,closed,synchronize]
7 | permissions:
8 | actions: write
9 | contents: write
10 | pull-requests: write
11 | statuses: write
12 | jobs:
13 | CLAAssistant:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: "CLA Assistant"
17 | if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
18 | uses: contributor-assistant/github-action@v2.3.0
19 | env:
20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21 | with:
22 | path-to-signatures: 'signatures/version1/cla.json'
23 | path-to-document: 'https://gist.github.com/AjayP13/b657de111d8d0907f48ba32eababd911'
24 | branch: 'cla_signatures'
25 | allowlist: AjayP13,dependabot[bot]
26 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.cluster
2 | /.vscode
3 | .ruff_cache
4 | .venv
5 | .venv_dev
6 |
7 | .venv_poetry
8 | dist/
9 | poetry.lock
10 |
11 | wandb/
12 |
13 | .DS_Store
14 | .jupyter_ystore.db
15 | .ipynb_checkpoints
16 | .coverage
17 | .coverage.*
18 | coverage.json
19 | coverage.json*
20 |
21 | .mypy_cache
22 | .pytest_cache
23 | /.python-version
24 |
25 | .tests_data
26 |
27 | # This directory is temporarily generated when building docs or publishing to PyPI. Remove this line before running poetry. (See: https://github.com/python-poetry/poetry/issues/5547)
28 | datadreamer
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 | -----------
3 |
4 | Copyright (c) 2024 Ajay Patel
5 | Permission is hereby granted, free of charge, to any person
6 | obtaining a copy of this software and associated documentation
7 | files (the "Software"), to deal in the Software without
8 | restriction, including without limitation the rights to use,
9 | copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the
11 | Software is furnished to do so, subject to the following
12 | conditions:
13 |
14 | The above copyright notice and this permission notice shall be
15 | included in all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
19 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
21 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
22 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24 | OTHER DEALINGS IN THE SOFTWARE.
25 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 |
4 | build
5 | source/*.rst
6 | !source/index.rst
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/docs/source/__init__.py
--------------------------------------------------------------------------------
/docs/source/_static/images/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/docs/source/_static/images/cover.png
--------------------------------------------------------------------------------
/docs/source/_static/images/demo_code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/docs/source/_static/images/demo_code.png
--------------------------------------------------------------------------------
/docs/source/_static/js/custom.js:
--------------------------------------------------------------------------------
1 | (function() {
2 | const buttons = document.getElementsByClassName("theme-toggle");
3 | Array.from(buttons).forEach((btn) => {
4 | btn.addEventListener("click", (e) => {
5 | const currentTheme = localStorage.getItem("theme");
6 | const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
7 | if (currentTheme === (prefersDark ? "dark" : "light")) {
8 | // Skip the "auto" theme
9 | document.querySelector('.theme-toggle').click();
10 | }
11 | });
12 | });
13 | })();
--------------------------------------------------------------------------------
/docs/source/_templates/apidoc/module.rst_t:
--------------------------------------------------------------------------------
1 | {%- if show_headings %}
2 | {{- basename.split('.')[1:] | join('.') | e | heading }}
3 |
4 | {% endif -%}
5 | .. automodule:: {{ qualname }}
6 | {%- for option in automodule_options %}
7 | :{{ option }}:
8 | {%- endfor %}
9 |
--------------------------------------------------------------------------------
/docs/source/_templates/apidoc/package.rst_t:
--------------------------------------------------------------------------------
1 | {%- macro automodule(modname, options) -%}
2 | .. automodule:: {{ modname }}
3 | {%- for option in options %}
4 | :{{ option }}:
5 | {%- endfor %}
6 | {%- endmacro %}
7 |
8 | {%- macro toctree(docnames) -%}
9 | .. toctree::
10 | :maxdepth: {{ maxdepth }}
11 | {% for docname in docnames %}
12 | {{ docname }}
13 | {%- endfor %}
14 | {%- endmacro %}
15 |
16 | {%- if is_namespace %}
17 | {{- pkgname.split('.')[1:] | join('.') | e | heading }}
18 | {% else %}
19 | {{- (pkgname.split('.')[1:] if pkgname.split('.')|length > 1 else [pkgname] ) | join('.') | e | heading }}
20 | {% endif %}
21 |
22 | {%- if is_namespace %}
23 | .. py:module:: {{ pkgname }}
24 | {% endif %}
25 |
26 | {%- if modulefirst and not is_namespace %}
27 | {{ automodule(pkgname, automodule_options) }}
28 | {% endif %}
29 |
30 | {%- if subpackages %}
31 | Subpackages
32 | -----------
33 |
34 | {{ toctree(subpackages) }}
35 | {% endif %}
36 |
37 | {%- if submodules %}
38 | Submodules
39 | ----------
40 | {% if separatemodules %}
41 | {{ toctree(submodules) }}
42 | {% else %}
43 | {%- for submodule in submodules %}
44 | {% if show_headings %}
45 | {{- submodule.split('.')[1:] | join('.') | e | heading(2) }}
46 | {% endif %}
47 | {{ automodule(submodule, automodule_options) }}
48 | {% endfor %}
49 | {%- endif %}
50 | {%- endif %}
51 |
52 | {%- if not modulefirst and not is_namespace %}
53 | Module contents
54 | ---------------
55 |
56 | {{ automodule(pkgname, automodule_options) }}
57 | {% endif %}
--------------------------------------------------------------------------------
/docs/source/_templates/metatags.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
22 |
--------------------------------------------------------------------------------
/docs/source/_templates/sidebar/brand.html:
--------------------------------------------------------------------------------
1 |
2 | {% block brand_content %}
3 | {%- if logo_url %}
4 |
7 | {%- endif %}
8 | {%- if theme_light_logo and theme_dark_logo %}
9 |
13 | {%- endif %}
14 | {% if (not theme_sidebar_hide_name) and (not logo_url) %}
15 | {{ docstitle if docstitle else project }}
16 | {%- endif %}
17 | v{{ release }}
18 | {% endblock brand_content %}
19 |
20 |
23 |
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/caching_and_saved_outputs.rst:
--------------------------------------------------------------------------------
1 | Caching and Saved Outputs
2 | #######################################################
3 |
4 | DataDreamer aggressively caches and saves its work at multiple levels to avoid re-computing when possible to be as time- and cost-efficient as possible.
5 |
6 | - **Step Outputs**: DataDreamer caches the results of each step run within a session to the output folder. If a session is interrupted and re-run, DataDreamer will automatically load the results of previously completed steps from disk and resume where it left off.
7 | - **Model Generations and Outputs**: DataDreamer caches the results computed by a :py:class:`~datadreamer.llms.LLM`, :py:class:`~datadreamer.embedders.Embedder` model, etc.
8 | - **Training Checkpoints**: DataDreamer will automatically save and resume from checkpoints when training a model with :py:class:`~datadreamer.trainers.Trainer`.
9 |
10 | Output Folder File Structure
11 | ===============================
12 |
13 | :py:class:`~datadreamer.DataDreamer` sessions write to an output folder where all outputs and caches are saved. Below is a brief description of the output folder structure.
14 |
15 | - **Step Folders**: Each :py:class:`~datadreamer.steps.Step` will produce a named folder within the output folder. The name of the folder is the name of the step, and the folder contains the output dataset of the step within a ``_dataset`` folder. ``step.json`` contains metadata about the step. If a step is run within another step, its folder will be nested under the parent step's folder.
16 | - **Trainer Folders**: Each :py:class:`~datadreamer.trainers.Trainer` will produce a named folder within the output folder. The name of the folder is the name of the trainer, and the folder contains saved checkpoints during training to a ``_checkpoints`` folder and the final trained model to a ``_model`` folder. Various JSON files inside the ``_model`` folder like ``training_args.json`` contain metadata about the training configuration.
17 | - **Cache Folder**: The ``.cache`` folder in the output folder holds the SQLite databases that are used to cache the generations and outputs produced by models like :py:class:`~datadreamer.llms.LLM` or :py:class:`~datadreamer.embedders.Embedder`.
18 | - **Backups Folder**: The ``_backups`` folder in the output folder holds backups of step or trainer folders that have since been invalidated by a newer configuration of that step or trainer. They are kept in case a user reverts to a previous configuration of the step or trainer.
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../index.rst:
--------------------------------------------------------------------------------
1 | Creating a new DataDreamer...
2 | #######################################################
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 |
7 | Step
8 | LLM
9 | Trainer
10 | Other
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../llm.rst:
--------------------------------------------------------------------------------
1 | Creating a new LLM
2 | #######################################################
3 |
4 | To create a new DataDreamer LLM class to support a new LLM library or API service, you will want to subclass
5 | the :py:class:`~datadreamer.llms.LLM` class. You can see example implementations of various LLMs by clicking on the
6 | ``[source]`` links on the :doc:`LLMs <../../../datadreamer.llms>` page. These may be helpful as reference implementations.
7 |
8 | Contributing
9 | ============
10 |
11 | If you would like to contribute the new LLM class you created to DataDreamer for others to use, see the :doc:`Contributing <../../../pages/contributing>` page.
12 | If applicable, please ensure your implementation includes model metadata, such as a link to the model card, the model's license, and the model's citation
13 | information.
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../other.rst:
--------------------------------------------------------------------------------
1 | Creating a new ...
2 | #######################################################
3 |
4 | To create a new implementation of an embedder, retriever, or other DataDreamer class, you will want to subclass its base class.
5 | You can see example implementations of various classes by clicking on the ``[source]`` links through the :doc:`API Reference <../../../datadreamer>` pages.
6 | These may be helpful as reference implementations.
7 |
8 | Contributing
9 | ============
10 |
11 | If you would like to contribute the new class you created to DataDreamer for others to use, see the :doc:`Contributing <../../../pages/contributing>` page.
12 | If applicable, please ensure your implementation includes appropriate metadata such as license and citation information.
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/creating_a_new_datadreamer_.../trainer.rst:
--------------------------------------------------------------------------------
1 | Creating a new Trainer
2 | #######################################################
3 |
4 | To create a new DataDreamer trainer class to support a new training library or training technique, you will want to subclass
5 | the :py:class:`~datadreamer.trainers.Trainer` class. You can see example implementations of various trainers by clicking on the
6 | ``[source]`` links on the :doc:`Trainers <../../../datadreamer.trainers>` page. These may be helpful as reference implementations.
7 |
8 | Contributing
9 | ============
10 |
11 | If you would like to contribute the new trainer class you created to DataDreamer for others to use, see the :doc:`Contributing <../../../pages/contributing>` page.
12 | If applicable, please ensure your implementation includes appropriate metadata, such as a link to the model card of the model being trained, the model's license, and
13 | the model and training technique's citation information.
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/index.rst:
--------------------------------------------------------------------------------
1 | Advanced Usage
2 | #######################################################
3 |
4 | Advanced users may want to configure DataDreamer to their needs or learn more about DataDreamer internals.
5 | This section provides a deeper look into DataDreamer and covers various advanced topics.
6 |
7 | .. toctree::
8 | :maxdepth: 1
9 |
10 | Caching and Saved Outputs
11 | Creating a New DataDreamer ...
12 | Parallelization
13 | Quantization
14 | Parameter-Efficient Training
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/parallelization/index.rst:
--------------------------------------------------------------------------------
1 | Parallelization
2 | #######################################################
3 |
4 | Parallelization is a core feature of DataDreamer that allows for performant workflows. Advanced users may want to implement
5 | parallelization in different ways, depending on their use case. This section will cover the different ways parallelization
6 | can be implemented in DataDreamer, including device parallelization (multi-GPU inference and training).
7 |
8 | .. toctree::
9 | :maxdepth: 1
10 |
11 | Running Steps in Parallel
12 | Running Models on Multiple GPUs
13 | Training Models on Multiple GPUs
14 |
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/parallelization/running_models_on_multiple_gpus.rst:
--------------------------------------------------------------------------------
1 | Running Models on Multiple GPUs
2 | #######################################################
3 |
4 | There are various ways to run models on multiple GPUs in DataDreamer.
5 |
6 | Large LLMs on Multiple GPUs
7 | ===========================
8 |
9 | To split a large model that cannot fit on a single GPU you can set the ``device_map`` parameter of the
10 | :py:class:`~datadreamer.llms.HFTransformers` class to ``'auto'``. This will automatically split the model by layer
11 | onto your available GPUs. You can also manually specify
12 | `how and where the model should be split `_.
13 |
14 | Smaller Models
15 | ==============
16 |
17 | For smaller models, the :py:class:`~datadreamer.llms.ParallelLLM` wrapper takes in multiple :py:class:`~datadreamer.llms.LLM` objects
18 | and behaves like a single unified :py:class:`~datadreamer.llms.LLM` object that can then be passed to a step like :py:class:`~datadreamer.steps.Prompt`.
19 | :py:class:`~datadreamer.llms.ParallelLLM` will run any inputs it recieves against all of the models in parallel. This is useful for running smaller models on multiple GPUs
20 | as each :py:class:`~datadreamer.llms.LLM` passed to the wrapper can be on a different GPU. Your model must be able to fit on a single GPU
21 | for this to work.
22 |
23 | Similarly, we have other parallelization wrappers for other types of models like :py:class:`~datadreamer.embedders.ParallelEmbedder`,
24 | :py:class:`~datadreamer.retrievers.ParallelRetriever`, etc.
25 |
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/parallelization/running_steps_in_parallel.rst:
--------------------------------------------------------------------------------
1 | Running Steps in Parallel
2 | #######################################################
3 |
4 | There are two ways to run steps in parallel:
5 |
6 | 1. **Running steps in different processes:** Steps can be run asynchronously in a background process.
7 | To run multiple steps in parallel, you can run them all in the background and then wait for them to completed.
8 |
9 | 2. **Running steps in different threads:** You can group steps into Python functions. You can then run these functions
10 | in parallel using :py:func:`~datadreamer.steps.concurrent`.
11 |
12 | Running steps in different processes
13 | ====================================
14 | You can run steps in the background by passing the ``background=True`` keyword argument to :py:class:`~datadreamer.steps.Step` construction.
15 | This will run the step in its own background process asynchronously.
16 |
17 | Waiting for :py:attr:`~datadreamer.steps.Step.output`
18 | -----------------------------------------------------
19 | When you run a step in the background, its output may not be immediately ready, and trying to access
20 | :py:attr:`~datadreamer.steps.Step.output` may raise an exception until the step has completed running in
21 | the background. To wait for a step's output to be ready, you can call :py:func:`~datadreamer.steps.wait`
22 | on the step. This will block until the step's output is ready.
23 |
24 | Running steps in different threads
25 | ==================================
26 |
27 | To run multiple steps in parallel, you can group them into Python functions and run these functions in parallel using threads. You can pass
28 | the functions to :py:func:`~datadreamer.steps.concurrent` to run them in parallel using threading.
29 |
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/parameter_efficient_training.rst:
--------------------------------------------------------------------------------
1 | Parameter-Efficient Training
2 | #######################################################
3 |
4 | DataDreamer makes setting up parameter-efficient training simple.
5 | You can pass a :py:class:`~peft.PeftConfig` to the ``peft_config`` argument of a class
6 | like :py:class:`~datadreamer.trainers.TrainHFFineTune` to enable parameter-efficient training.
--------------------------------------------------------------------------------
/docs/source/pages/advanced_usage/quantization.rst:
--------------------------------------------------------------------------------
1 | Quantization
2 | #######################################################
3 |
4 | DataDreamer makes setting up quantization simple. You can pass a
5 | `quantization config object `_
6 | to the ``quantization_config`` argument of a class like :py:class:`~datadreamer.llms.HFTransformers` or
7 | :py:class:`~datadreamer.trainers.TrainHFFineTune` to enable quantization.
--------------------------------------------------------------------------------
/docs/source/pages/get_started/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | #######################################################
3 |
4 | You can install DataDreamer via `PyPI `_::
5 |
6 | .. code-block:: bash
7 |
8 | pip3 install datadreamer.dev
9 |
10 | ..
11 | This page is redirected via the sphinx_reredirect extension.
--------------------------------------------------------------------------------
/docs/source/pages/get_started/motivation_and_design.rst:
--------------------------------------------------------------------------------
1 | Motivation and Design
2 | #######################################################
3 |
4 | DataDreamer is an open-source Python library that is made to help streamline and accelerate increasingly common
5 | LLM-related workflows for ML / NLP researchers and users and help increase the rate of research progress through
6 | features encouraging open science and reproducibility.
7 |
8 | Design Principles
9 | =================
10 |
11 | A few design principles of DataDreamer are:
12 |
13 | - 🔬 **Research-Grade and Production-Grade:** Implementations of techniques that are consistent with the established work and best practices for correctness and efficiency.
14 | - 💪 **Reproducibility and Robustness:** A focus on `reproducibility and robustness <../../pages/get_started/overview_guide.html#reproducibility>`_.
15 | - 🧩 **Simple with Sensible Defaults:** Simple and easy to get started with little configuration through sensible defaults.
16 | - 🛠️ **Adaptable, Extensible, and Customizable:** Selectively overridable advanced configuration and the ability to support :doc:`new techniques <../../pages/advanced_usage/creating_a_new_datadreamer_.../step>` or :doc:`models <../../pages/advanced_usage/creating_a_new_datadreamer_.../llm>`.
17 | - 👥 **Accessible:** :doc:`Aggressive caching and efficiency techniques <../../pages/get_started/overview_guide>` to make both computationally- and financially-expensive LLM-related workflows more accessible to resource-constrained researchers.
18 | - 🤝 **Community-Driven:** Community members can :doc:`contribute <../../pages/contributing>` to extend DataDreamer's abilities.
19 |
20 | For Anyone
21 | ==========
22 |
23 | While DataDreamer was designed *for researchers, by researchers*, it is also meant to be accessible to
24 | anyone who wants to use it.
25 |
26 | Use in Teaching
27 | ===============
28 |
29 | While DataDreamer was built to help researchers and practitioners implement complex LLM-related workflows, it is extremely simple to use making bleeding-edge models,
30 | techniques, and training accessible to reasonably technical students.
31 |
32 | If you are a university professor of a graduate-level NLP or machine learning course
33 | and would like to trial using DataDreamer in your course for instruction, assignments, or projects please reach out to
34 | `Ajay Patel (ajayp@upenn.edu) `_ and
35 | `Professor Chris Callison-Burch (ccb@upenn.edu) `_ at the University of Pennsylvania.
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/abstract_to_tweet.rst:
--------------------------------------------------------------------------------
1 | Training an "Abstract to Tweet Model" with Fully Synthetic Data
2 | ###############################################################
3 |
4 | In this demonstration, we show how to train a small model to generate tweets summarizing ML research paper abstracts. We use DataDreamer to generate a fully synthetic dataset, distill the knowledge to a small T5 model, and publish both the dataset and model.
5 |
6 | .. raw:: html
7 |
8 | See the resulting synthetic dataset and the trained model.
9 |
10 | .. code-block:: python
11 |
12 | from datadreamer import DataDreamer
13 | from datadreamer.llms import OpenAI
14 | from datadreamer.steps import DataFromPrompt, ProcessWithPrompt
15 | from datadreamer.trainers import TrainHFFineTune
16 | from peft import LoraConfig
17 |
18 | with DataDreamer("./output"):
19 | # Load GPT-4
20 | gpt_4 = OpenAI(model_name="gpt-4")
21 |
22 | # Generate synthetic arXiv-style research paper abstracts with GPT-4
23 | arxiv_dataset = DataFromPrompt(
24 | "Generate Research Paper Abstracts",
25 | args={
26 | "llm": gpt_4,
27 | "n": 1000,
28 | "temperature": 1.2,
29 | "instruction": (
30 | "Generate an arXiv abstract of an NLP research paper."
31 | " Return just the abstract, no titles."
32 | ),
33 | },
34 | outputs={"generations": "abstracts"},
35 | )
36 |
37 | # Ask GPT-4 to convert the abstracts to tweets
38 | abstracts_and_tweets = ProcessWithPrompt(
39 | "Generate Tweets from Abstracts",
40 | inputs={"inputs": arxiv_dataset.output["abstracts"]},
41 | args={
42 | "llm": gpt_4,
43 | "instruction": (
44 | "Given the abstract, write a tweet to summarize the work."
45 | ),
46 | "top_p": 1.0,
47 | },
48 | outputs={"inputs": "abstracts", "generations": "tweets"},
49 | )
50 |
51 | # Create training data splits
52 | splits = abstracts_and_tweets.splits(train_size=0.90, validation_size=0.10)
53 |
54 | # Train a model to convert research paper abstracts to tweets
55 | # with the synthetic dataset
56 | trainer = TrainHFFineTune(
57 | "Train an Abstract => Tweet Model",
58 | model_name="google/t5-v1_1-base",
59 | peft_config=LoraConfig(),
60 | )
61 | trainer.train(
62 | train_input=splits["train"].output["abstracts"],
63 | train_output=splits["train"].output["tweets"],
64 | validation_input=splits["validation"].output["abstracts"],
65 | validation_output=splits["validation"].output["tweets"],
66 | epochs=30,
67 | batch_size=8,
68 | )
69 |
70 | # Publish and share the synthetic dataset
71 | abstracts_and_tweets.publish_to_hf_hub(
72 | "datadreamer-dev/abstracts_and_tweets",
73 | train_size=0.90,
74 | validation_size=0.10,
75 | )
76 |
77 | # Publish and share the trained model
78 | trainer.publish_to_hf_hub("datadreamer-dev/abstracts_to_tweet_model")
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/aligning.rst:
--------------------------------------------------------------------------------
1 | Aligning a LLM with Human Preferences
2 | #####################################
3 |
4 | In order to better align the responses :doc:`instruction-tuned LLMs ` generate to what humans would prefer, we can train LLMs against a reward model or a dataset of human preferences in a process known as `RLHF (Reinforcement Learning with Human Feedback) `_.
5 |
6 | DataDreamer makes this process extremely simple and straightforward to accomplish. We demonstrate it below using LoRA to only train
7 | a fraction of the weights with `DPO `_ (a more stable, and efficient alignment method than traditional RLHF).
8 |
9 | .. code-block:: python
10 |
11 | from datadreamer import DataDreamer
12 | from datadreamer.steps import HFHubDataSource
13 | from datadreamer.trainers import TrainHFDPO
14 | from peft import LoraConfig
15 |
16 | with DataDreamer("./output"):
17 | # Get the DPO dataset
18 | dpo_dataset = HFHubDataSource(
19 | "Get DPO Dataset", "Intel/orca_dpo_pairs", split="train"
20 | )
21 |
22 | # Keep only 1000 examples as a quick demo
23 | dpo_dataset = dpo_dataset.take(1000)
24 |
25 | # Create training data splits
26 | splits = dpo_dataset.splits(train_size=0.90, validation_size=0.10)
27 |
28 | # Align the TinyLlama chat model with human preferences
29 | trainer = TrainHFDPO(
30 | "Align TinyLlama-Chat",
31 | model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
32 | peft_config=LoraConfig(),
33 | device=["cuda:0", "cuda:1"],
34 | dtype="bfloat16",
35 | )
36 | trainer.train(
37 | train_prompts=splits["train"].output["question"],
38 | train_chosen=splits["train"].output["chosen"],
39 | train_rejected=splits["train"].output["rejected"],
40 | validation_prompts=splits["validation"].output["question"],
41 | validation_chosen=splits["validation"].output["chosen"],
42 | validation_rejected=splits["validation"].output["rejected"],
43 | epochs=3,
44 | batch_size=1,
45 | gradient_accumulation_steps=32,
46 | )
47 |
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/attributed_prompts.rst:
--------------------------------------------------------------------------------
1 | Generating Training Data with Attributed Prompts
2 | ################################################
3 |
4 | By using the `attributed prompt method `_ of generating training data, we can create a diverse dataset that is more representative of real-world data.
5 | We demonstrate this below by generating movie reviews.
6 |
7 | .. raw:: html
8 |
9 | See the resulting synthetic dataset.
10 |
11 | .. code-block:: python
12 |
13 | from datadreamer import DataDreamer
14 | from datadreamer.llms import OpenAI
15 | from datadreamer.steps import (
16 | Prompt,
17 | DataSource,
18 | DataFromAttributedPrompt,
19 | )
20 |
21 | with DataDreamer("./output"):
22 | # Load GPT-4
23 | gpt_4 = OpenAI(model_name="gpt-4")
24 |
25 | # Create prompts to generate attributes for movie reviews
26 | attribute_generation_prompts = DataSource(
27 | "Attribute Generation Prompts",
28 | data={
29 | "prompts": [
30 | "Generate the names of 10 movies released in theatres in the past, in a comma separated list.",
31 | "Generate 10 elements of a movie a reviewer might consider, in a comma separated list.",
32 | "Generate 10 adjectives that could describe a movie reviewer's style, in a comma separated list.",
33 | ],
34 | },
35 | )
36 |
37 | # Generate the attributes for movie reviews
38 | attributes = Prompt(
39 | "Generate Attributes",
40 | inputs={
41 | "prompts": attribute_generation_prompts.output["prompts"],
42 | },
43 | args={
44 | "llm": gpt_4,
45 | },
46 | ).output["generations"]
47 |
48 | # Generate movie reviews with varied attributes
49 | movie_reviews = (
50 | DataFromAttributedPrompt(
51 | "Generate Movie Reviews",
52 | args={
53 | "llm": gpt_4,
54 | "n": 1000,
55 | "instruction": "Generate a few sentence {review_style} movie review about {movie_name} that focuses on {movie_element}.",
56 | "attributes": {
57 | "movie_name": attributes[0].split(","),
58 | "movie_element": attributes[1].split(","),
59 | "review_style": attributes[2].split(","),
60 | },
61 | },
62 | outputs={"generations": "reviews"},
63 | )
64 | .select_columns(["reviews"])
65 | .shuffle()
66 | )
67 |
68 | # Publish and share the synthetic dataset
69 | movie_reviews.publish_to_hf_hub(
70 | "datadreamer-dev/movie_reviews",
71 | )
72 |
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/dataset_augmentation.rst:
--------------------------------------------------------------------------------
1 | Augmenting an Existing Dataset
2 | ##############################
3 |
4 | DataDreamer can help augment existing datasets using LLMs. We demonstrate this below by augmenting questions from HotpotQA
5 | with a decomposition of what steps a user would need to take to solve the complex question.
6 |
7 | .. raw:: html
8 |
9 | See the resulting synthetic dataset.
10 |
11 | .. code-block:: python
12 |
13 | from datadreamer import DataDreamer
14 | from datadreamer.llms import OpenAI
15 | from datadreamer.steps import ProcessWithPrompt, HFHubDataSource
16 |
17 | with DataDreamer("./output"):
18 | # Load GPT-4
19 | gpt_4 = OpenAI(model_name="gpt-4")
20 |
21 | # Get HotPot QA questions
22 | hotpot_qa_dataset = HFHubDataSource(
23 | "Get Hotpot QA Questions",
24 | "hotpot_qa",
25 | config_name="distractor",
26 | split="train",
27 | ).select_columns(["question"])
28 |
29 | # Keep only 1000 questions as a quick demo
30 | hotpot_qa_dataset = hotpot_qa_dataset.take(1000)
31 |
32 | # Ask GPT-4 to decompose the question
33 | questions_and_decompositions = ProcessWithPrompt(
34 | "Generate Decompositions",
35 | inputs={"inputs": hotpot_qa_dataset.output["question"]},
36 | args={
37 | "llm": gpt_4,
38 | "instruction": (
39 | "Given the question which requires multiple steps to solve, give a numbered list of intermediate questions required to solve the question."
40 | "Return only the list, nothing else."
41 | ),
42 | },
43 | outputs={"inputs": "questions", "generations": "decompositions"},
44 | ).select_columns(["questions", "decompositions"])
45 |
46 | # Publish and share the synthetic dataset
47 | questions_and_decompositions.publish_to_hf_hub(
48 | "datadreamer-dev/hotpot_qa_augmented",
49 | )
50 |
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/dataset_cleaning.rst:
--------------------------------------------------------------------------------
1 | Cleaning an Existing Dataset
2 | ############################
3 |
4 | DataDreamer can help clean or filter existing datasets using LLMs. We demonstrate this below by filtering a dataset of
5 | news articles to only include those that are about sports.
6 |
7 | .. raw:: html
8 |
9 | See the resulting synthetic dataset.
10 |
11 | .. code-block:: python
12 |
13 | from datadreamer import DataDreamer
14 | from datadreamer.llms import OpenAI
15 | from datadreamer.steps import FilterWithPrompt, HFHubDataSource
16 |
17 | with DataDreamer("./output"):
18 | # Load GPT-4
19 | gpt_4 = OpenAI(model_name="gpt-4")
20 |
21 | # Get news articles
22 | news_dataset = HFHubDataSource(
23 | "Get CNN & Daily Mail News Articles",
24 | "cnn_dailymail",
25 | config_name="3.0.0",
26 | split="test",
27 | )
28 |
29 | # Keep only 1000 articles as a quick demo
30 | news_dataset = news_dataset.take(1000)
31 |
32 | # Ask GPT-4 to filter the dataset
33 | sports_news_dataset = FilterWithPrompt(
34 | "Filter to only keep sports articles",
35 | inputs={"inputs": news_dataset.output["article"]},
36 | args={
37 | "llm": gpt_4,
38 | "instruction": "Is the article about sports? Answer 'Yes' or 'No'.",
39 | },
40 | )
41 |
42 | # Publish and share the synthetic dataset
43 | sports_news_dataset.publish_to_hf_hub(
44 | "datadreamer-dev/cnn_dailymail_sports",
45 | )
46 |
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/index.rst:
--------------------------------------------------------------------------------
1 | Quick Tour
2 | #######################################################
3 |
4 | Below we outline a few examples of using DataDreamer for various use cases to help you get started. It is by no means exhaustive, but should give you a good idea of what
5 | is possible with DataDreamer and how to use it. For more details on the various components of DataDreamer, please see the :doc:`Overview Guide <../overview_guide>`.
6 |
7 | Synthetic Data Generation
8 | =========================
9 |
10 | - :doc:`Training an "Abstract to Tweet Model" with Fully Synthetic Data `
11 | - :doc:`Generating Training Data with Attributed Prompts `
12 | - :doc:`Distilling GPT-4 Capabilities to GPT-3.5 `
13 | - :doc:`Augmenting an Existing Dataset `
14 | - :doc:`Cleaning an Existing Dataset `
15 | - :doc:`Bootstrapping Synthetic Few-Shot Examples `
16 |
17 | Instruction-Tuning and Aligning Models
18 | ======================================
19 |
20 | - :doc:`Instruction-Tuning a LLM `
21 | - :doc:`Aligning a LLM with Human Preferences (RLHF) `
22 | - :doc:`Training a Self-Improving LLM with Self-Rewarding (RLAIF) `
23 |
24 |
25 |
26 | .. toctree::
27 | :hidden:
28 |
29 | Synthetic Data Generation
30 | ../motivation_and_design
31 | abstract_to_tweet
32 | attributed_prompts
33 | openai_distillation
34 | dataset_augmentation
35 | dataset_cleaning
36 | bootstrapping_machine_translation
37 | Instruction-Tuning and Aligning Models
38 | instruction_tuning
39 | aligning
40 | self_rewarding
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/instruction_tuning.rst:
--------------------------------------------------------------------------------
1 | Instruction-Tuning a LLM
2 | ########################
3 |
4 | When LLMs are pre-trained, they are pre-trained in a self-supervised mannner to simply predict the next word in a sentence. This can yield
5 | a model that can follow human instructions to some degree, but is often not very effective until this "base model" is fined-tuned on a dataset
6 | of example instructions and responses in a process known as `"instruction-tuning" `_, essentially allowing the model to learn to follow natural
7 | language instructions.
8 |
9 | DataDreamer makes this process extremely simple and straightforward to accomplish. We demonstrate it below using LoRA to only train
10 | a fraction of the weights.
11 |
12 | .. code-block:: python
13 |
14 | from datadreamer import DataDreamer
15 | from datadreamer.steps import HFHubDataSource
16 | from datadreamer.trainers import TrainHFFineTune
17 | from peft import LoraConfig
18 |
19 | with DataDreamer("./output"):
20 | # Get the Alpaca instruction-tuning dataset (cleaned version)
21 | instruction_tuning_dataset = HFHubDataSource(
22 | "Get Alpaca Instruction-Tuning Dataset", "yahma/alpaca-cleaned", split="train"
23 | )
24 |
25 | # Keep only 1000 examples as a quick demo
26 | instruction_tuning_dataset = instruction_tuning_dataset.take(1000)
27 |
28 | # Some examples taken in an "input", we'll format those into the instruction
29 | instruction_tuning_dataset.map(
30 | lambda row: {
31 | "instruction": (
32 | row["instruction"]
33 | if len(row["input"]) == 0
34 | else f"Input: {row['input']}\n\n{row['instruction']}"
35 | ),
36 | "output": row["output"],
37 | },
38 | lazy=False,
39 | )
40 |
41 | # Create training data splits
42 | splits = instruction_tuning_dataset.splits(train_size=0.90, validation_size=0.10)
43 |
44 | # Define what the prompt template should be when instruction-tuning
45 | chat_prompt_template = "### Instruction:\n{{prompt}}\n\n### Response:\n"
46 |
47 | # Instruction-tune the base TinyLlama model to make it follow instructions
48 | trainer = TrainHFFineTune(
49 | "Instruction-Tune TinyLlama",
50 | model_name="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
51 | chat_prompt_template=chat_prompt_template,
52 | peft_config=LoraConfig(),
53 | device=["cuda:0", "cuda:1"],
54 | dtype="bfloat16",
55 | )
56 | trainer.train(
57 | train_input=splits["train"].output["instruction"],
58 | train_output=splits["train"].output["output"],
59 | validation_input=splits["validation"].output["instruction"],
60 | validation_output=splits["validation"].output["output"],
61 | epochs=3,
62 | batch_size=1,
63 | gradient_accumulation_steps=32,
64 | )
65 |
--------------------------------------------------------------------------------
/docs/source/pages/get_started/quick_tour/openai_distillation.rst:
--------------------------------------------------------------------------------
1 | Distilling GPT-4 Capabilities to GPT-3.5
2 | ########################################
3 |
4 | If you want to make GPT-3.5 (a cheaper, smaller, and faster model) more capable, you can use DataDreamer to distill the capabilities of GPT-4 into GPT-3.5. This will allow you to create a more capable model that is cheaper and faster than GPT-4.
5 |
6 | We demonstrate an example below on the `ELI5 ("Explain it like I'm 5") `_ task.
7 |
8 | .. code-block:: python
9 |
10 | from datadreamer import DataDreamer
11 | from datadreamer.llms import OpenAI
12 | from datadreamer.steps import ProcessWithPrompt, HFHubDataSource
13 | from datadreamer.trainers import TrainOpenAIFineTune
14 |
15 | with DataDreamer("./output"):
16 | # Load GPT-4
17 | gpt_4 = OpenAI(model_name="gpt-4")
18 |
19 | # Get ELI5 questions
20 | eli5_dataset = HFHubDataSource(
21 | "Get ELI5 Questions",
22 | "eli5_category",
23 | split="train",
24 | trust_remote_code=True,
25 | ).select_columns(["title"])
26 |
27 | # Keep only 1000 examples as a quick demo
28 | eli5_dataset = eli5_dataset.take(1000)
29 |
30 | # Ask GPT-4 to ELI5
31 | questions_and_answers = ProcessWithPrompt(
32 | "Generate Explanations",
33 | inputs={"inputs": eli5_dataset.output["title"]},
34 | args={
35 | "llm": gpt_4,
36 | "instruction": (
37 | 'Given the question, give an "Explain it like I\'m 5" answer.'
38 | ),
39 | "top_p": 1.0,
40 | },
41 | outputs={"inputs": "questions", "generations": "answers"},
42 | )
43 |
44 | # Create training data splits
45 | splits = questions_and_answers.splits(train_size=0.90, validation_size=0.10)
46 |
47 | # Train a model to answer questions in ELI5 style
48 | trainer = TrainOpenAIFineTune(
49 | "Distill capabilities to GPT-3.5",
50 | model_name="gpt-3.5-turbo-1106",
51 | )
52 | trainer.train(
53 | train_input=splits["train"].output["questions"],
54 | train_output=splits["train"].output["answers"],
55 | validation_input=splits["validation"].output["questions"],
56 | validation_output=splits["validation"].output["answers"],
57 | epochs=30,
58 | batch_size=8,
59 | )
60 |
61 |
--------------------------------------------------------------------------------
/scripts/.cluster/.gitignore:
--------------------------------------------------------------------------------
1 | .lock
2 | .bin/
3 | output
4 | output/
--------------------------------------------------------------------------------
/scripts/.cluster/_boot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Run setup
4 | . ./.cluster/_setup.sh
5 |
6 | # Run the python script
7 | if [ "$PROJECT_CLUSTER_TYPE" == "slurm" ] && [ "$PROJECT_INTERACTIVE" != "1" ]; then
8 | srun -u .cluster/_command.sh "$@"
9 | else
10 | .cluster/_command.sh "$@"
11 | fi
12 |
--------------------------------------------------------------------------------
/scripts/.cluster/_command.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Write the command arguments to a file
4 | export PROJECT_COMMAND_PATH=$(mktemp /tmp/projectcmd.XXXXXX)
5 | (
6 | printf '%s' "$COMMAND_PRE"
7 | printf " "
8 | printf '%s' "$COMMAND_ENTRYPOINT"
9 | printf " "
10 |
11 | # For each argument
12 | for ARG in "$@"; do
13 | printf "$'"
14 | # echo the argument, except:
15 | # * Replace backslashes with escaped backslashes
16 | # * Replace single quotes with escaped single quotes
17 | echo -n "$ARG" | sed -e "s/\\\\/\\\\\\\\/g;" | sed -e "s/'/\\\\'/g;"
18 | # echo `'`
19 | printf "' "
20 | done
21 |
22 | printf " "
23 | printf '%s' "$COMMAND_POST"
24 | ) >"$PROJECT_COMMAND_PATH"
25 | chmod +x "$PROJECT_COMMAND_PATH"
26 |
27 | # Run the script
28 | if [ "$PROJECT_CLUSTER" == "1" ]; then
29 | if [ "$PROJECT_INTERACTIVE" == "1" ]; then
30 | exec 2>&4 1>&3
31 | script -efq "$PROJECT_STDOUT_FILE" -c "$PROJECT_COMMAND_PATH 2> >(tee -a $PROJECT_STDERR_FILE >&2)"
32 | else
33 | $PROJECT_COMMAND_PATH 1>>"$PROJECT_STDOUT_FILE" 2>>"$PROJECT_STDERR_FILE"
34 | fi
35 | else
36 | $PROJECT_COMMAND_PATH
37 | fi
38 |
--------------------------------------------------------------------------------
/scripts/.cluster/_lock.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Adapted from: https://stackoverflow.com/questions/1715137/what-is-the-best-way-to-ensure-only-one-instance-of-a-bash-script-is-running
4 |
5 | function lockfile_waithold() {
6 | declare -ir time_beg=$(date '+%s')
7 | declare -ir time_max=7200
8 |
9 | while ! (
10 | set -o noclobber
11 | echo -e "DATE:$(date)\nUSER:$(whoami)\nPID:$$" \ >.cluster/.lock
12 | ) 2>/dev/null; do
13 | if [ $(($(date '+%s') - ${time_beg})) -gt ${time_max} ]; then
14 | echo "Error: waited too long for lock file .cluster/.lock" 1>&2
15 | return 1
16 | fi
17 | sleep 2
18 | done
19 |
20 | return 0
21 | }
22 |
23 | function lockfile_release() {
24 | rm -f .cluster/.lock
25 | }
26 |
27 | if ! lockfile_waithold; then
28 | exit 1
29 | fi
30 | trap lockfile_release EXIT
31 |
--------------------------------------------------------------------------------
/scripts/.cluster/args_log.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Get the job name and task ID
7 | if [[ "$1" =~ ':'$ ]]; then
8 | export FIRST_ARGUMENT="$1"
9 | else
10 | export FIRST_ARGUMENT="$1:"
11 | fi
12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:)
13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:)
14 |
15 | # Fetch the args log
16 | if [ -z "$JOB_NAME" ]; then
17 | echo "Fetching the args log of the last job..." 1>&2;
18 | if [ -f "./output/named/_latest/job/.array" ]; then
19 | if [ -z "$TASK_ID" ]; then
20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2;
21 | exit 1
22 | else
23 | export LOG_PATH=./output/named/_latest/$TASK_ID/args.json
24 | fi
25 | else
26 | export LOG_PATH=./output/named/_latest/args.json
27 | fi
28 | if [ -f "$LOG_PATH" ]; then
29 | cat $LOG_PATH
30 | else
31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2;
32 | exit 1
33 | fi
34 | else
35 | echo "Fetching the args log of the '$JOB_NAME' job..." 1>&2;
36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then
37 | if [ -z "$TASK_ID" ]; then
38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2;
39 | exit 1
40 | else
41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/args.json
42 | fi
43 | else
44 | export LOG_PATH=./output/named/$JOB_NAME/args.json
45 | fi
46 | if [ -f "$LOG_PATH" ]; then
47 | cat "$LOG_PATH"
48 | else
49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2;
50 | exit 1
51 | fi
52 | fi
--------------------------------------------------------------------------------
/scripts/.cluster/config_log.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Get the job name and task ID
7 | if [[ "$1" =~ ':'$ ]]; then
8 | export FIRST_ARGUMENT="$1"
9 | else
10 | export FIRST_ARGUMENT="$1:"
11 | fi
12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:)
13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:)
14 |
15 | # Fetch the config log
16 | if [ -z "$JOB_NAME" ]; then
17 | echo "Fetching the config log of the last job..." 1>&2
18 | if [ -f "./output/named/_latest/job/.array" ]; then
19 | if [ -z "$TASK_ID" ]; then
20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
21 | exit 1
22 | else
23 | export LOG_PATH=./output/named/_latest/$TASK_ID/config.json
24 | fi
25 | else
26 | export LOG_PATH=./output/named/_latest/config.json
27 | fi
28 | if [ -f "$LOG_PATH" ]; then
29 | cat $LOG_PATH
30 | else
31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
32 | exit 1
33 | fi
34 | else
35 | echo "Fetching the config log of the '$JOB_NAME' job..." 1>&2
36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then
37 | if [ -z "$TASK_ID" ]; then
38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
39 | exit 1
40 | else
41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/config.json
42 | fi
43 | else
44 | export LOG_PATH=./output/named/$JOB_NAME/config.json
45 | fi
46 | if [ -f "$LOG_PATH" ]; then
47 | cat "$LOG_PATH"
48 | else
49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
50 | exit 1
51 | fi
52 | fi
53 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/.gitignore:
--------------------------------------------------------------------------------
1 | .last_job/
--------------------------------------------------------------------------------
/scripts/.cluster/direct/_direct_config.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ################################################################
4 | # To run a job array define array task IDs here,
5 | # example: ARRAY_TASK_IDS=( 10 20 30 40 50)
6 | # if only one value is provided, this job will not be considered a job array
7 | ARRAY_TASK_IDS=(0)
8 | MIN_MEMORY_PER_TASK=10G
9 | MAX_PARALLEL_TASKS=8
10 | ################################################################
11 |
12 | # Load environment variables for direct running
13 | source project.env
14 |
15 | # Ensure direct environment variables are provided
16 | source .cluster/direct/_direct_env.sh
17 |
18 | # Redirect to submission log
19 | if [ "$PROJECT_INTERACTIVE" != "1" ]; then
20 | exec 3>&1 4>&2 >.cluster/direct/.last_job/submission.out 2>&1
21 | fi
22 |
23 | # Define task runner
24 | TASK_RUNNER() {
25 | TASK_ID=$1
26 | ARRAY_TASK_IDS_LENGTH=$2
27 |
28 | # Source the user's bashrc
29 | # shellcheck disable=SC1090
30 | source ~/.bashrc
31 |
32 | # Mark that we are using direct
33 | export PROJECT_CLUSTER=1
34 | export PROJECT_CLUSTER_TYPE=direct
35 |
36 | # Set direct dependent environment variables
37 | export PROJECT_VENV=.venv/direct
38 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache
39 | if [ "$ARRAY_TASK_IDS_LENGTH" == "1" ]; then
40 | # Only one task ID detected, therefore this is not a job array
41 | :
42 | else
43 | export PROJECT_TASK_ID=$TASK_ID
44 | fi
45 |
46 | # Store the direct last job information
47 | cp .cluster/direct/_direct_config.sh .cluster/direct/.last_job/resources
48 | echo $PROJECT_CLUSTER_TYPE >.cluster/direct/.last_job/type
49 | echo "$PROJECT_JOB_NAME" >.cluster/direct/.last_job/job_name
50 | echo $$ >.cluster/direct/.last_job/job_id
51 | hostname >.cluster/direct/.last_job/hostname
52 | echo "$PROJECT_CURRENT_DATE" >.cluster/direct/.last_job/date
53 | echo "$PROJECT_CURRENT_COMMIT" >.cluster/direct/.last_job/commit
54 |
55 | # Run the boot script
56 | shift
57 | shift
58 | .cluster/_boot.sh "$@"
59 | }
60 | export -f TASK_RUNNER
61 |
62 | # Run
63 | if [ "$PROJECT_INTERACTIVE" != "1" ]; then
64 | # shellcheck disable=SC1083
65 | parallel --memfree $MIN_MEMORY_PER_TASK --jobs $MAX_PARALLEL_TASKS TASK_RUNNER {.} "${#ARRAY_TASK_IDS[@]}" "$@" ::: "${ARRAY_TASK_IDS[@]}"
66 | else
67 | # shellcheck disable=SC1083
68 | parallel --tty --memfree $MIN_MEMORY_PER_TASK --jobs $MAX_PARALLEL_TASKS TASK_RUNNER {.} "${#ARRAY_TASK_IDS[@]}" "$@" ::: "${ARRAY_TASK_IDS[@]}"
69 | fi
70 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/_direct_env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check for the existence of environment variables
4 | if [ -z "$PROJECT_DATA" ]; then
5 | echo "You must define the 'PROJECT_DATA' environment variable in project.env. Set it to a directory where output files will be written and large data can be saved while running." 1>&2
6 | exit 1
7 | fi
8 | if [ -z "$PROJECT_ACCELERATOR_TYPE" ]; then
9 | if [ "$(uname)" != "Darwin" ] && (env -0 | cut -z -f1 -d= | tr '\0' '\n' | grep -q "TPU_"); then
10 | export PROJECT_ACCELERATOR_TYPE=tpu
11 | export PROJECT_VISIBLE_ACCELERATOR_DEVICES=all
12 | echo "Detected an environment with TPUs. For this reason, accelerator device dependencies will not be installed to a virtual environment." 1>&2
13 | export PROJECT_DISABLE_ACCELERATOR_REQUIREMENTS=1
14 | elif [ -x "$(command -v nvidia-smi)" ]; then
15 | export PROJECT_ACCELERATOR_TYPE=cuda
16 | echo "Detected an environment with CUDA GPUs." 1>&2
17 | export PROJECT_VISIBLE_ACCELERATOR_DEVICES=$(seq --separator="," 0 $(($(nvidia-smi --list-gpus | wc -l) - 1)))
18 | fi
19 | fi
20 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/cancel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Confirm
7 | read -r -p "Are you sure? [y/N] " response
8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm again
15 | read -r -p "Are you really sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | :
18 | else
19 | exit 1
20 | fi
21 |
22 | # Fetch the status
23 | if [ -z "$1" ]; then
24 | echo "Canceling the last job..." 1>&2
25 | if [ -f "../output/named/_latest/job/job_id" ]; then
26 | JOB_PID=$(ps -o sid= -p "$(cat ../output/named/_latest/job/job_id)")
27 | kill -9 "$JOB_PID"
28 | else
29 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
30 | exit 1
31 | fi
32 | else
33 | echo "Canceling the '$1' job..." 1>&2
34 | if [ -f "../output/named/$1/job/job_id" ]; then
35 | JOB_PID=$(ps -o sid= -p "$(cat "../output/named/$1/job/job_id")")
36 | kill -9 "$JOB_PID"
37 | else
38 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
39 | exit 1
40 | fi
41 | fi
42 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/cancel_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Confirm
7 | read -r -p "Are you sure? [y/N] " response
8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm again
15 | read -r -p "Are you really sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | echo "Canceling..."
18 | else
19 | exit 1
20 | fi
21 |
22 | # Cancel all jobs
23 | for JOB_FOLDER in ../output/dated/*/; do
24 | if [ -f "$JOB_FOLDER"/job/type ] && [ -f "$JOB_FOLDER"/job/job_id ]; then
25 | if grep -q "direct" "$JOB_FOLDER"/job/type; then
26 | JOB_PID=$(ps -o sid= -p "$(cat "$JOB_FOLDER"/job/job_id)")
27 | if [ -n "$JOB_PID" ]; then
28 | kill -9 "$JOB_PID"
29 | fi
30 | fi
31 | fi
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/interactive_submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Change directory to project root location
7 | cd ../../
8 |
9 | # Check if there is already a job running
10 | if ( (.cluster/direct/status_all.sh | grep -q "$USER") 2>/dev/null); then
11 | echo -e "WARNING: There is already a job running!\n"
12 | .cluster/direct/status_all.sh
13 |
14 | # Confirm
15 | echo ""
16 | read -r -p "Are you sure you want to submit another job? [y/N] " response
17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
18 | :
19 | else
20 | exit 1
21 | fi
22 | elif (
23 | # shellcheck disable=SC2009
24 | ps -aux | grep -v "grep" | grep -q .cluster/direct/_direct_config.sh
25 | ); then
26 | echo -e "WARNING: There is already a job running! It is still initializing...\n"
27 |
28 | # Confirm
29 | echo ""
30 | read -r -p "Are you sure you want to submit another job? [y/N] " response
31 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
32 | :
33 | else
34 | exit 1
35 | fi
36 | fi
37 |
38 | # Get the job name
39 | if [ -z "$1" ]; then
40 | echo "You must submit with a job name as the first argument to this script."
41 | exit 1
42 | else
43 | export JOB_NAME="$1"
44 | fi
45 |
46 | # Submit
47 | mkdir -p .cluster/direct/.last_job
48 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T)
49 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited")
50 | export PROJECT_JOB_NAME="$JOB_NAME"
51 | rm -rf .cluster/.lock
52 | export PROJECT_INTERACTIVE=1
53 | touch .cluster/direct/.last_job/.interactive
54 | echo -n "" >.cluster/direct/.last_job/submission.out
55 | shift
56 | .cluster/direct/_direct_config.sh "$@"
57 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/reset_venv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Reset the virtual env
7 | rm -rf ../../.venv/direct/*
8 | rm -rf ../../.venv/direct/
9 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/status.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the status
7 | if [ -z "$1" ]; then
8 | echo "Fetching the status of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/job_id" ]; then
10 | JOB_PID=$(ps -o sid= -p "$(cat ../output/named/_latest/job/job_id)")
11 | if [ -n "$JOB_PID" ]; then
12 | # shellcheck disable=SC2086
13 | ps --forest -wwo user,sid,pid,stat,start,time,%cpu,%mem,args -g $JOB_PID 2>/dev/null
14 | fi
15 | else
16 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
17 | exit 1
18 | fi
19 | else
20 | echo "Fetching the status of the '$1' job..." 1>&2
21 | if [ -f "../output/named/$1/job/job_id" ]; then
22 | JOB_PID=$(ps -o sid= -p "$(cat "../output/named/$1/job/job_id")")
23 | if [ -n "$JOB_PID" ]; then
24 | # shellcheck disable=SC2086
25 | ps --forest -wwo user,sid,pid,stat,start,time,%cpu,%mem,args -g $JOB_PID 2>/dev/null
26 | fi
27 | else
28 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
29 | exit 1
30 | fi
31 | fi
32 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/status_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the status of all jobs
7 | echo "Fetching the status of all jobs..." 1>&2
8 | for JOB_FOLDER in ../output/dated/*/; do
9 | if [ -f "$JOB_FOLDER"/job/type ] && [ -f "$JOB_FOLDER"/job/job_id ] && [ -f "$JOB_FOLDER"/job/job_name ]; then
10 | if grep -q "direct" "$JOB_FOLDER"/job/type; then
11 | JOB_PID=$(ps -o sid= -p "$(cat "$JOB_FOLDER"/job/job_id)")
12 | if [ -n "$JOB_PID" ]; then
13 | echo "------------------------------------------------------------------------"
14 | echo "Job: $(cat "$JOB_FOLDER"/job/job_name) ($(cat "$JOB_FOLDER"/job/date))"
15 | echo "------------------------------------------------------------------------"
16 | # shellcheck disable=SC2086
17 | ps --forest -wwo user,sid,pid,stat,start,time,%cpu,%mem,args -g $JOB_PID 2>/dev/null
18 | fi
19 | fi
20 | fi
21 | done
22 |
--------------------------------------------------------------------------------
/scripts/.cluster/direct/submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Change directory to project root location
7 | cd ../../
8 |
9 | # Check if there is already a job running
10 | if ( (.cluster/direct/status_all.sh | grep -q "$USER") 2>/dev/null); then
11 | echo -e "WARNING: There is already a job running!\n"
12 | .cluster/direct/status_all.sh
13 |
14 | # Confirm
15 | echo ""
16 | read -r -p "Are you sure you want to submit another job? [y/N] " response
17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
18 | :
19 | else
20 | exit 1
21 | fi
22 | elif (
23 | # shellcheck disable=SC2009
24 | ps -aux | grep -v "grep" | grep -q .cluster/direct/_direct_config.sh
25 | ); then
26 | echo -e "WARNING: There is already a job running! It is still initializing...\n"
27 |
28 | # Confirm
29 | echo ""
30 | read -r -p "Are you sure you want to submit another job? [y/N] " response
31 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
32 | :
33 | else
34 | exit 1
35 | fi
36 | fi
37 |
38 | # Get the job name
39 | if [ -z "$1" ]; then
40 | echo "You must submit with a job name as the first argument to this script."
41 | exit 1
42 | else
43 | export JOB_NAME="$1"
44 | fi
45 |
46 | # Submit
47 | rm -rf .cluster/direct/.last_job/*
48 | mkdir -p .cluster/direct/.last_job
49 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T)
50 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited")
51 | export PROJECT_JOB_NAME="$JOB_NAME"
52 | rm -rf .cluster/.lock
53 | shift
54 | setsid .cluster/direct/_direct_config.sh "$@" &
55 |
--------------------------------------------------------------------------------
/scripts/.cluster/full_reset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Warn
7 | read -r -p $'Are you sure? You will lose ALL data, logs, experiments, etc. associated with this project.\nType "delete all my data" to confirm: ' response
8 | if [[ "$response" =~ "delete all my data" ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm
15 | read -r -p "Are you sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | :
18 | else
19 | exit 1
20 | fi
21 |
22 | # Confirm again
23 | read -r -p "Are you really sure? [y/N] " response
24 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
25 | echo "Fully resetting..."
26 | else
27 | exit 1
28 | fi
29 |
30 | # Reset all data
31 | rm -rf ../.cluster/.lock
32 | rm -rf ../.cluster/output/*/* || true
33 | rm -rf ../.cluster/output/ || true
34 | rm ../.cluster/output 2>/dev/null || true
35 | rm -rf ../.cluster/*/.last_job || true
36 | rm -rf ../.venv/*/* || true
37 | rm -rf ../.venv/ || true
38 | rm -rf ../.venv_dev/*/* || true
39 | rm -rf ../.venv_dev/ || true
40 |
--------------------------------------------------------------------------------
/scripts/.cluster/install_log.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Get the job name and task ID
7 | if [[ "$1" =~ ':'$ ]]; then
8 | export FIRST_ARGUMENT="$1"
9 | else
10 | export FIRST_ARGUMENT="$1:"
11 | fi
12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:)
13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:)
14 |
15 | # Fetch the install log
16 | if [ -z "$JOB_NAME" ]; then
17 | echo "Fetching the install log of the last job..." 1>&2
18 | if [ -f "./output/named/_latest/job/.array" ]; then
19 | if [ -z "$TASK_ID" ]; then
20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
21 | exit 1
22 | else
23 | export LOG_PATH=./output/named/_latest/$TASK_ID/install.out
24 | fi
25 | else
26 | export LOG_PATH=./output/named/_latest/install.out
27 | fi
28 | if [ -f "$LOG_PATH" ]; then
29 | cat $LOG_PATH
30 | else
31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
32 | exit 1
33 | fi
34 | else
35 | echo "Fetching the install log of the '$JOB_NAME' job..." 1>&2
36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then
37 | if [ -z "$TASK_ID" ]; then
38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
39 | exit 1
40 | else
41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/install.out
42 | fi
43 | else
44 | export LOG_PATH=./output/named/$JOB_NAME/install.out
45 | fi
46 | if [ -f "$LOG_PATH" ]; then
47 | cat "$LOG_PATH"
48 | else
49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
50 | exit 1
51 | fi
52 | fi
53 |
--------------------------------------------------------------------------------
/scripts/.cluster/reset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Warn
7 | read -r -p $'Are you sure? You will lose all data, logs, etc. associated with this project that are not tagged as an experiment.\nType "i only want experiments" to confirm: ' response
8 | if [[ "$response" =~ "i only want experiments" ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm
15 | read -r -p "Are you sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | :
18 | else
19 | exit 1
20 | fi
21 |
22 | # Confirm again
23 | read -r -p "Are you really sure? [y/N] " response
24 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
25 | echo "Resetting..."
26 | else
27 | exit 1
28 | fi
29 |
30 | # Reset all data
31 | rm -rf ../.cluster/output/committed || true
32 | rm -rf ../.cluster/output/named || true
33 | (
34 | cd ../.cluster/output/dated || exit
35 | for x in *; do readlink -f ../experiments/* | grep -q "$x" || rm -rf "$x"; done
36 | )
37 | (
38 | cd ../.cluster/output/persistent_data || exit
39 | for x in *; do readlink -f ../dated/*/data/* | grep -q "$x" || rm -rf "$x"; done
40 | )
41 | rm -rf ../.cluster/*/.last_job || true
42 |
--------------------------------------------------------------------------------
/scripts/.cluster/reset_venv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Change directory to project root location
7 | cd ../../
8 |
9 | # Reset the virtual env
10 | rm -rf ../.venv/*/* || true
11 | rm -rf ../.venv/ || true
12 | rm -rf ../.venv_dev/*/* || true
13 | rm -rf ../.venv_dev/ || true
14 |
--------------------------------------------------------------------------------
/scripts/.cluster/series_log.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Get the job name, task ID, and series
7 | if [[ "$1" =~ '/'$ ]]; then
8 | export FIRST_ARGUMENT="$1"
9 | else
10 | export FIRST_ARGUMENT="$1/"
11 | fi
12 | export FIRST_ARGUMENT_=$(echo "$FIRST_ARGUMENT" | cut -f1 -d/)
13 | if [[ "$FIRST_ARGUMENT_" =~ ':'$ ]]; then
14 | export FIRST_ARGUMENT_="$FIRST_ARGUMENT_"
15 | else
16 | export FIRST_ARGUMENT_="$FIRST_ARGUMENT_:"
17 | fi
18 | export SERIES_NAME=$(echo "$FIRST_ARGUMENT" | cut -f2 -d/)
19 | export JOB_NAME=$(echo "$FIRST_ARGUMENT_" | cut -f1 -d:)
20 | export TASK_ID=$(echo "$FIRST_ARGUMENT_" | cut -f2 -d:)
21 |
22 | # Make sure a series name was provided
23 | if [ -z "$SERIES_NAME" ]; then
24 | echo "You must provide the series name with a slash (/)." 1>&2
25 | exit 1
26 | fi
27 |
28 | # Fetch the $SERIES_NAME series log
29 | if [ -z "$JOB_NAME" ]; then
30 | echo "Fetching the $SERIES_NAME series log of the last job..." 1>&2
31 | if [ -f "./output/named/_latest/job/.array" ]; then
32 | if [ -z "$TASK_ID" ]; then
33 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
34 | exit 1
35 | else
36 | export LOG_PATH=./output/named/_latest/$TASK_ID/series/$SERIES_NAME.csv
37 | fi
38 | else
39 | export LOG_PATH=./output/named/_latest/series/$SERIES_NAME.csv
40 | fi
41 | if [ -f "$LOG_PATH" ]; then
42 | cat "$LOG_PATH"
43 | else
44 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
45 | exit 1
46 | fi
47 | else
48 | echo "Fetching the $SERIES_NAME series log of the '$JOB_NAME' job..." 1>&2
49 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then
50 | if [ -z "$TASK_ID" ]; then
51 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
52 | exit 1
53 | else
54 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/series/$SERIES_NAME.csv
55 | fi
56 | else
57 | export LOG_PATH=./output/named/$JOB_NAME/series/$SERIES_NAME.csv
58 | fi
59 | if [ -f "$LOG_PATH" ]; then
60 | perl -pe 's/((?<=,)|(?<=^)),/ ,/g;' <"$LOG_PATH" | column -t -s, | less -S
61 | else
62 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
63 | exit 1
64 | fi
65 | fi
66 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/.gitignore:
--------------------------------------------------------------------------------
1 | .last_job/
--------------------------------------------------------------------------------
/scripts/.cluster/sge/_qsub_config.sh:
--------------------------------------------------------------------------------
1 | #$ -l h_rt=60:00:00
2 | #$ -o .cluster/sge/.last_job/submission.out
3 | #$ -e .cluster/sge/.last_job/submission.out
4 | #$ -cwd
5 | #$ -l mem=1G
6 | #$ -pe parallel-onenode 4
7 | #$ -l h=nlpgrid12
8 |
9 | # Source the user's bashrc
10 | # shellcheck disable=SC1090
11 | source ~/.bashrc
12 |
13 | # Mark that we are using sge
14 | export PROJECT_CLUSTER=1
15 | export PROJECT_CLUSTER_TYPE=sge
16 |
17 | # Set sge dependent environment variables
18 | export PROJECT_VENV=.venv/sge
19 | export PROJECT_DATA=/nlp/data/$USER
20 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache
21 | export PROJECT_JOB_NAME=$JOB_NAME
22 | if [ "$SGE_TASK_ID" != "undefined" ]; then
23 | export PROJECT_TASK_ID=$SGE_TASK_ID
24 | fi
25 |
26 | # Set up global cache directories
27 | if [ ! -e "$PROJECT_CACHE_DIR/huggingface_cache" ]; then
28 | ln -s /nlp/data/huggingface_cache "$PROJECT_CACHE_DIR/huggingface_cache"
29 | fi
30 | if [ ! -e "$PROJECT_CACHE_DIR/sentence_transformers_cache" ]; then
31 | ln -s /nlp/data/huggingface_cache/sentence_transformers "$PROJECT_CACHE_DIR/sentence_transformers_cache"
32 | fi
33 |
34 | # Change directory to submit location
35 | cd "$SGE_O_WORKDIR" || exit
36 |
37 | # Store the sge last job information
38 | cp .cluster/sge/_qsub_config.sh .cluster/sge/.last_job/resources
39 | echo $PROJECT_CLUSTER_TYPE >.cluster/sge/.last_job/type
40 | echo "$PROJECT_JOB_NAME" >.cluster/sge/.last_job/job_name
41 | if [ -z "$PROJECT_TASK_ID" ]; then
42 | echo "$JOB_ID" >.cluster/sge/.last_job/job_id
43 | else
44 | echo "$JOB_ID" >.cluster/sge/.last_job/job_id
45 | fi
46 | echo "$HOSTNAME" >.cluster/sge/.last_job/nodelist
47 | echo "$PROJECT_CURRENT_DATE" >.cluster/sge/.last_job/date
48 | echo "$PROJECT_CURRENT_COMMIT" >.cluster/sge/.last_job/commit
49 |
50 | # Run the boot script
51 | .cluster/_boot.sh "$@"
52 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/cancel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Confirm
7 | read -r -p "Are you sure? [y/N] " response
8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm again
15 | read -r -p "Are you really sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | :
18 | else
19 | exit 1
20 | fi
21 |
22 | # Fetch the status
23 | if [ -z "$1" ]; then
24 | echo "Canceling the last job..." 1>&2
25 | if [ -f "../output/named/_latest/job/job_id" ]; then
26 | qdel "$(cat ../output/named/_latest/job/job_id)"
27 | else
28 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
29 | exit 1
30 | fi
31 | else
32 | echo "Canceling the '$1' job..." 1>&2
33 | if [ -f "../output/named/$1/job/job_id" ]; then
34 | qdel "$(cat ../output/named/"$1"/job/job_id)"
35 | else
36 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
37 | exit 1
38 | fi
39 | fi
40 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/cancel_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Confirm
7 | read -r -p "Are you sure? [y/N] " response
8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm again
15 | read -r -p "Are you really sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | echo "Canceling..."
18 | else
19 | exit 1
20 | fi
21 |
22 | qdel -u "$USER"
23 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/htop.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the nodelist
7 | if [ -z "$1" ]; then
8 | echo "Fetching the compute environment of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/nodelist" ]; then
10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist)
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the compute environment of the '$1' job..." 1>&2
17 | if [ -f "../output/named/$1/job/nodelist" ]; then
18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist)
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 1 hour, this is a time limit to prevent hanging resources)" 1>&2
26 | sleep 2
27 | echo -e "\n" 1>&2
28 | qrsh -now no -cwd -pty y -V -N shell -l h_rt=1:00:00 -pe parallel-onenode 1 -l mem=100M -l h="$NODELIST" htop -u "$USER"
29 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/interactive_submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Change directory to project root location
7 | cd ../../
8 |
9 | # Check if there is already a job running
10 | if ( (.cluster/sge/status_all.sh | grep -q "$USER") 2>/dev/null); then
11 | echo -e "WARNING: There is already a job running!\n"
12 | .cluster/sge/status_all.sh
13 |
14 | # Confirm
15 | echo ""
16 | read -r -p "Are you sure you want to submit another job? [y/N] " response
17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
18 | :
19 | else
20 | exit 1
21 | fi
22 | fi
23 |
24 | # Get the job name
25 | if [ -z "$1" ]; then
26 | echo "You must submit with a job name as the first argument to this script."
27 | exit 1
28 | else
29 | export JOB_NAME="$1"
30 | fi
31 |
32 | # Submit
33 | mkdir -p .cluster/sge/.last_job
34 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T)
35 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited")
36 | rm -rf .cluster/.lock
37 | export QSUB_ARGS=$(grep -E '^#\$' .cluster/sge/_qsub_config.sh | sed -E 's/\s*#\$\s*//g' | grep -v '^-o' | grep -v '^-e' | tr '\n' ' ')
38 | export PROJECT_INTERACTIVE=1
39 | touch .cluster/sge/.last_job/.interactive
40 | echo -n "" >.cluster/sge/.last_job/submission.out
41 | shift
42 | # shellcheck disable=SC2086
43 | qrsh -now no -cwd -pty y -V -N "$JOB_NAME" $QSUB_ARGS .cluster/sge/_qsub_config.sh "$@"
44 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/reset_venv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Reset the virtual env
7 | rm -rf ../../.venv/sge/*
8 | rm -rf ../../.venv/sge/
9 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/shell.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the nodelist
7 | if [ -z "$1" ]; then
8 | echo "Fetching the compute environment of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/nodelist" ]; then
10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist)
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the compute environment of the '$1' job..." 1>&2
17 | if [ -f "../output/named/$1/job/nodelist" ]; then
18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist)
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 3 hours, this is a time limit to prevent hanging resources)" 1>&2
26 | sleep 2
27 | qrsh -now no -cwd -pty y -N shell -l h_rt=3:00:00 -pe parallel-onenode 1 -l mem=100M -l h="$NODELIST" /bin/bash -c "echo ""; cd ../../; /bin/bash -i"
28 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/status.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the status
7 | if [ -z "$1" ]; then
8 | echo "Fetching the status of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/job_id" ]; then
10 | qstat | grep --color=never -E "job-ID|----------------------------------|$(cat ../output/named/_latest/job/job_id)"
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the status of the '$1' job..." 1>&2
17 | if [ -f "../output/named/$1/job/job_id" ]; then
18 | qstat | grep --color=never -E "job-ID|----------------------------------|$(cat ../output/named/"$1"/job/job_id)"
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/status_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the status of all jobs
7 | echo "Fetching the status of all jobs..." 1>&2
8 | qstat -u "$USER"
9 |
--------------------------------------------------------------------------------
/scripts/.cluster/sge/submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Change directory to project root location
7 | cd ../../
8 |
9 | # Check if there is already a job running
10 | if ( (.cluster/sge/status_all.sh | grep -q "$USER") 2>/dev/null); then
11 | echo -e "WARNING: There is already a job running!\n"
12 | .cluster/sge/status_all.sh
13 |
14 | # Confirm
15 | echo ""
16 | read -r -p "Are you sure you want to submit another job? [y/N] " response
17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
18 | :
19 | else
20 | exit 1
21 | fi
22 | fi
23 |
24 | # Get the job name
25 | if [ -z "$1" ]; then
26 | echo "You must submit with a job name as the first argument to this script."
27 | exit 1
28 | else
29 | export JOB_NAME="$1"
30 | fi
31 |
32 | # Submit
33 | rm -rf .cluster/sge/.last_job/*
34 | mkdir -p .cluster/sge/.last_job
35 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T)
36 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited")
37 | rm -rf .cluster/.lock
38 | shift
39 | qsub -N "$JOB_NAME" -V .cluster/sge/_qsub_config.sh "$@"
40 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/.gitignore:
--------------------------------------------------------------------------------
1 | .last_job/
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/_sbatch_config.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | #SBATCH --time=2:00:00
4 | #SBATCH --partition=interactive
5 | #SBATCH --output=.cluster/slurm/.last_job/submission.out
6 | #SBATCH --ntasks 1
7 | #SBATCH --cpus-per-task 16
8 | #SBATCH --mem=30G
9 | #SBATCH --gpus=2
10 |
11 | # Source the user's bashrc
12 | # shellcheck disable=SC1090
13 | source ~/.bashrc
14 |
15 | # Mark that we are using slurm
16 | export PROJECT_CLUSTER=1
17 | export PROJECT_CLUSTER_TYPE=slurm
18 |
19 | # Set slurm dependent environment variables
20 | export PROJECT_VENV=.venv/slurm
21 | export PROJECT_DATA=/nlp/data/$USER
22 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache
23 | export PROJECT_JOB_NAME=$SLURM_JOB_NAME
24 | export PROJECT_TASK_ID=$SLURM_ARRAY_TASK_ID
25 |
26 | # Set up global cache directories
27 | if [ ! -e "$PROJECT_CACHE_DIR/huggingface_cache" ]; then
28 | ln -s /nlp/data/huggingface_cache "$PROJECT_CACHE_DIR/huggingface_cache"
29 | fi
30 | if [ ! -e "$PROJECT_CACHE_DIR/sentence_transformers_cache" ]; then
31 | ln -s /nlp/data/huggingface_cache/sentence_transformers "$PROJECT_CACHE_DIR/sentence_transformers_cache"
32 | fi
33 |
34 | # Change directory to submit location
35 | cd "$SLURM_SUBMIT_DIR" || exit
36 |
37 | # Store the slurm last job information
38 | cp .cluster/slurm/_sbatch_config.sh .cluster/slurm/.last_job/resources
39 | echo $PROJECT_CLUSTER_TYPE >.cluster/slurm/.last_job/type
40 | echo "$PROJECT_JOB_NAME" >.cluster/slurm/.last_job/job_name
41 | if [ -z "$PROJECT_TASK_ID" ]; then
42 | echo "$SLURM_JOBID" >.cluster/slurm/.last_job/job_id
43 | else
44 | echo "$SLURM_ARRAY_JOB_ID" >.cluster/slurm/.last_job/job_id
45 | fi
46 | echo "$SLURM_JOB_NODELIST" >.cluster/slurm/.last_job/nodelist
47 | echo "$PROJECT_CURRENT_DATE" >.cluster/slurm/.last_job/date
48 | echo "$PROJECT_CURRENT_COMMIT" >.cluster/slurm/.last_job/commit
49 |
50 | # Run the boot script
51 | .cluster/_boot.sh "$@"
52 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/cancel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Confirm
7 | read -r -p "Are you sure? [y/N] " response
8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm again
15 | read -r -p "Are you really sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | :
18 | else
19 | exit 1
20 | fi
21 |
22 | # Fetch the status
23 | if [ -z "$1" ]; then
24 | echo "Canceling the last job..." 1>&2
25 | if [ -f "../output/named/_latest/job/job_id" ]; then
26 | scancel "$(cat ../output/named/_latest/job/job_id)"
27 | else
28 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
29 | exit 1
30 | fi
31 | else
32 | echo "Canceling the '$1' job..." 1>&2
33 | if [ -f "../output/named/$1/job/job_id" ]; then
34 | scancel "$(cat ../output/named/"$1"/job/job_id)"
35 | else
36 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
37 | exit 1
38 | fi
39 | fi
40 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/cancel_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Confirm
7 | read -r -p "Are you sure? [y/N] " response
8 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
9 | :
10 | else
11 | exit 1
12 | fi
13 |
14 | # Confirm again
15 | read -r -p "Are you really sure? [y/N] " response
16 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
17 | echo "Canceling..."
18 | else
19 | exit 1
20 | fi
21 |
22 | scancel -u "$USER"
23 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/htop.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the nodelist
7 | if [ -z "$1" ]; then
8 | echo "Fetching the compute environment of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/nodelist" ]; then
10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist)
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the compute environment of the '$1' job..." 1>&2
17 | if [ -f "../output/named/$1/job/nodelist" ]; then
18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist)
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 1 hour, this is a time limit to prevent hanging resources)" 1>&2
26 | sleep 2
27 | echo -e "\n" 1>&2
28 | srun -u --job-name=shell --time=1:00:00 --ntasks 1 --cpus-per-task 1 --mem=100M --nodelist="$NODELIST" --pty htop -u "$USER"
29 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/interactive_submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Change directory to project root location
7 | cd ../../
8 |
9 | # Check if there is already a job running
10 | if ( (.cluster/slurm/status_all.sh | grep -q "$USER") 2>/dev/null); then
11 | echo -e "WARNING: There is already a job running!\n"
12 | .cluster/slurm/status_all.sh
13 |
14 | # Confirm
15 | echo ""
16 | read -r -p "Are you sure you want to submit another job? [y/N] " response
17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
18 | :
19 | else
20 | exit 1
21 | fi
22 | fi
23 |
24 | # Get the job name
25 | if [ -z "$1" ]; then
26 | echo "You must submit with a job name as the first argument to this script."
27 | exit 1
28 | else
29 | export JOB_NAME="$1"
30 | fi
31 |
32 | # Submit
33 | mkdir -p .cluster/slurm/.last_job
34 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T)
35 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited")
36 | rm -rf .cluster/.lock
37 | export SBATCH_ARGS=$(grep -E '^#SBATCH' .cluster/slurm/_sbatch_config.sh | sed -E 's/\s*#SBATCH\s*//g' | grep -v '^--output' | tr '\n' ' ')
38 | export PROJECT_INTERACTIVE=1
39 | touch .cluster/slurm/.last_job/.interactive
40 | echo -n "" >.cluster/slurm/.last_job/submission.out
41 | shift
42 | # shellcheck disable=SC2086
43 | srun -u --job-name="$JOB_NAME" $SBATCH_ARGS --pty .cluster/slurm/_sbatch_config.sh "$@"
44 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/nvidia-smi.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the nodelist
7 | if [ -z "$1" ]; then
8 | echo "Fetching the compute environment of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/nodelist" ]; then
10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist)
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the compute environment of the '$1' job..." 1>&2
17 | if [ -f "../output/named/$1/job/nodelist" ]; then
18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist)
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 1 hour, this is a time limit to prevent hanging resources)" 1>&2
26 | sleep 2
27 | echo -e "\n" 1>&2
28 | srun -u --job-name=shell --time=1:00:00 --ntasks 1 --cpus-per-task 1 --mem=100M --nodelist="$NODELIST" --pty watch -n0.5 nvidia-smi
29 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/reset_venv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Reset the virtual env
7 | rm -rf ../../.venv/slurm/*
8 | rm -rf ../../.venv/slurm/
9 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/shell.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the nodelist
7 | if [ -z "$1" ]; then
8 | echo "Fetching the compute environment of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/nodelist" ]; then
10 | export NODELIST=$(cat ../output/named/_latest/job/nodelist)
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the compute environment of the '$1' job..." 1>&2
17 | if [ -f "../output/named/$1/job/nodelist" ]; then
18 | export NODELIST=$(cat ../output/named/"$1"/job/nodelist)
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
25 | echo -e "Opening a shell to the $NODELIST compute environment...\n(note: the shell will only remain open for 3 hours, this is a time limit to prevent hanging resources)" 1>&2
26 | sleep 2
27 | echo -e "\n" 1>&2
28 | srun -u --job-name=shell --time=3:00:00 --ntasks 1 --cpus-per-task 1 --mem=100M --nodelist="$NODELIST" --pty /bin/bash -c "cd ../../; /bin/bash"
29 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/status.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the status
7 | if [ -z "$1" ]; then
8 | echo "Fetching the status of the last job..." 1>&2
9 | if [ -f "../output/named/_latest/job/job_id" ]; then
10 | squeue -j "$(cat ../output/named/_latest/job/job_id)"
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the status of the '$1' job..." 1>&2
17 | if [ -f "../output/named/$1/job/job_id" ]; then
18 | squeue -j "$(cat ../output/named/"$1"/job/job_id)"
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/status_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the status of all jobs
7 | echo "Fetching the status of all jobs..." 1>&2
8 | squeue -u "$USER"
9 |
--------------------------------------------------------------------------------
/scripts/.cluster/slurm/submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Change directory to project root location
7 | cd ../../
8 |
9 | # Check if there is already a job running
10 | if ( (.cluster/slurm/status_all.sh | grep -q "$USER") 2>/dev/null); then
11 | echo -e "WARNING: There is already a job running!\n"
12 | .cluster/slurm/status_all.sh
13 |
14 | # Confirm
15 | echo ""
16 | read -r -p "Are you sure you want to submit another job? [y/N] " response
17 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
18 | :
19 | else
20 | exit 1
21 | fi
22 | fi
23 |
24 | # Get the job name
25 | if [ -z "$1" ]; then
26 | echo "You must submit with a job name as the first argument to this script."
27 | exit 1
28 | else
29 | export JOB_NAME="$1"
30 | fi
31 |
32 | # Submit
33 | rm -rf .cluster/slurm/.last_job/*
34 | mkdir -p .cluster/slurm/.last_job
35 | export PROJECT_CURRENT_DATE=$(TZ=America/New_York date +%Y-%m-%d-%T)
36 | export PROJECT_CURRENT_COMMIT=$(git rev-parse --verify HEAD 2>/dev/null || echo "_uncommited")
37 | rm -rf .cluster/.lock
38 | shift
39 | sbatch --job-name="$JOB_NAME" .cluster/slurm/_sbatch_config.sh "$@"
40 |
--------------------------------------------------------------------------------
/scripts/.cluster/stderr_log.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Get the job name and task ID
7 | if [[ "$1" =~ ':'$ ]]; then
8 | export FIRST_ARGUMENT="$1"
9 | else
10 | export FIRST_ARGUMENT="$1:"
11 | fi
12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:)
13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:)
14 |
15 | # Fetch the stderr log
16 | if [ -z "$JOB_NAME" ]; then
17 | echo "Fetching the stderr log of the last job..." 1>&2
18 | if [ -f "./output/named/_latest/job/.array" ]; then
19 | if [ -z "$TASK_ID" ]; then
20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
21 | exit 1
22 | else
23 | export LOG_PATH=./output/named/_latest/$TASK_ID/stderr.out
24 | fi
25 | else
26 | export LOG_PATH=./output/named/_latest/stderr.out
27 | fi
28 | if [ -f "$LOG_PATH" ]; then
29 | cat "$LOG_PATH"
30 | else
31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
32 | exit 1
33 | fi
34 | else
35 | echo "Fetching the stderr log of the '$JOB_NAME' job..." 1>&2
36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then
37 | if [ -z "$TASK_ID" ]; then
38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
39 | exit 1
40 | else
41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/stderr.out
42 | fi
43 | else
44 | export LOG_PATH=./output/named/$JOB_NAME/stderr.out
45 | fi
46 | if [ -f "$LOG_PATH" ]; then
47 | cat "$LOG_PATH"
48 | else
49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
50 | exit 1
51 | fi
52 | fi
53 |
--------------------------------------------------------------------------------
/scripts/.cluster/stdout_log.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Get the job name and task ID
7 | if [[ "$1" =~ ':'$ ]]; then
8 | export FIRST_ARGUMENT="$1"
9 | else
10 | export FIRST_ARGUMENT="$1:"
11 | fi
12 | export JOB_NAME=$(echo "$FIRST_ARGUMENT" | cut -f1 -d:)
13 | export TASK_ID=$(echo "$FIRST_ARGUMENT" | cut -f2 -d:)
14 |
15 | # Fetch the stdout log
16 | if [ -z "$JOB_NAME" ]; then
17 | echo "Fetching the stdout log of the last job..." 1>&2
18 | if [ -f "./output/named/_latest/job/.array" ]; then
19 | if [ -z "$TASK_ID" ]; then
20 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
21 | exit 1
22 | else
23 | export LOG_PATH=./output/named/_latest/$TASK_ID/stdout.out
24 | fi
25 | else
26 | export LOG_PATH=./output/named/_latest/stdout.out
27 | fi
28 | if [ -f "$LOG_PATH" ]; then
29 | cat $LOG_PATH
30 | else
31 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
32 | exit 1
33 | fi
34 | else
35 | echo "Fetching the stdout log of the '$JOB_NAME' job..." 1>&2
36 | if [ -f "./output/named/$JOB_NAME/job/.array" ]; then
37 | if [ -z "$TASK_ID" ]; then
38 | echo "You must provide the task ID with a colon (:) since this job was run as a job array." 1>&2
39 | exit 1
40 | else
41 | export LOG_PATH=./output/named/$JOB_NAME/$TASK_ID/stdout.out
42 | fi
43 | else
44 | export LOG_PATH=./output/named/$JOB_NAME/stdout.out
45 | fi
46 | if [ -f "$LOG_PATH" ]; then
47 | cat "$LOG_PATH"
48 | else
49 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
50 | exit 1
51 | fi
52 | fi
53 |
--------------------------------------------------------------------------------
/scripts/.cluster/submission_log.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the submission log
7 | if [ -z "$1" ]; then
8 | echo "Fetching the submission log of the last job..." 1>&2
9 | if [ -f "./output/named/_latest/job/submission.out" ]; then
10 | cat ./output/named/_latest/job/submission.out
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the submission log of the '$1' job..." 1>&2
17 | if [ -f "./output/named/$1/job/submission.out" ]; then
18 | cat ./output/named/"$1"/job/submission.out
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
--------------------------------------------------------------------------------
/scripts/.cluster/tag.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")" || exit
5 |
6 | # Fetch the run dir
7 | if [ -z "$1" ]; then
8 | echo "Fetching the run output of the last job..." 1>&2
9 | if [ -e "./output/named/_latest" ]; then
10 | export RUN_DIR=$(readlink -f ./output/named/_latest)
11 | else
12 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
13 | exit 1
14 | fi
15 | else
16 | echo "Fetching the run output of the '$1' job..." 1>&2
17 | if [ -e "./output/named/$1" ]; then
18 | export RUN_DIR=$(readlink -f ./output/named/"$1")
19 | else
20 | echo "Job does not exist. If you just submitted the job, try again in a few seconds." 1>&2
21 | exit 1
22 | fi
23 | fi
24 |
25 | # Fetch the experiment dir
26 | export JOB_NAME=$(cat "$RUN_DIR"/job/job_name)
27 | read -r -p "What do you want to name the experiment (default: '$JOB_NAME')? " EXPERIMENT_NAME
28 | if [ -z "$EXPERIMENT_NAME" ]; then
29 | export EXPERIMENT_NAME=$JOB_NAME
30 | fi
31 | export EXPERIMENT_DIR=./output/experiments/$EXPERIMENT_NAME
32 |
33 | # Store the run as an experiment
34 | if [ -e "$EXPERIMENT_DIR" ]; then
35 | echo "This '$EXPERIMENT_NAME' experiment already exists. Do you want to replace it?"
36 |
37 | # Confirm
38 | read -r -p "Are you sure? [y/N] " response
39 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
40 | :
41 | else
42 | exit 1
43 | fi
44 |
45 | # Confirm again
46 | read -r -p "Are you really sure? [y/N] " response
47 | if [[ "$response" =~ ^([yY][eE][sS]|[yY])$ ]]; then
48 | :
49 | else
50 | exit 1
51 | fi
52 |
53 | rm "$EXPERIMENT_DIR"
54 | fi
55 | ln -s "$RUN_DIR" "$EXPERIMENT_DIR" && echo -e "\nDone. You can find the experiment at: $(
56 | cd "$EXPERIMENT_DIR" || exit
57 | pwd
58 | )"
59 |
--------------------------------------------------------------------------------
/scripts/.githooks/post-checkout:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Change directory to script location
5 | cd "$(dirname "$0")/../../" || exit
6 |
7 | # Symlink .cluster and .vscode at the top-level
8 | if [ ! -L ./.cluster ] || [ ! -e ./.cluster ]; then
9 | rm ./.cluster 1>/dev/null 2>/dev/null || true
10 | rm -rf ./.cluster 1>/dev/null 2>/dev/null || true
11 | ln -s "$(realpath ./scripts/.cluster)" ./.cluster
12 | fi
13 | if [ ! -L ./.vscode ] || [ ! -e ./.vscode ]; then
14 | rm ./.vscode 1>/dev/null 2>/dev/null || true
15 | rm -rf ./.vscode 1>/dev/null 2>/dev/null || true
16 | ln -s "$(realpath ./scripts/.vscode)" ./.vscode
17 | fi
18 |
19 | # Don't track local changes to project.env
20 | git update-index --skip-worktree ./scripts/project.env
--------------------------------------------------------------------------------
/scripts/.githooks/pre-commit:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")/../../" || exit
5 |
6 | # Format and lint
7 | ./scripts/format.sh
8 |
9 | # Check if lint failed
10 | retval=$?
11 | if [ $retval -ne 0 ]; then
12 | echo "Lint failed. Please fix lint errors. Commit aborted."
13 | exit $retval
14 | fi
15 |
--------------------------------------------------------------------------------
/scripts/.python-version:
--------------------------------------------------------------------------------
1 | 3.10.9
2 |
--------------------------------------------------------------------------------
/scripts/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "ms-vscode-remote.remote-ssh",
4 | "mechatroner.rainbow-csv",
5 | "ms-python.python",
6 | "ms-python.vscode-pylance",
7 | "charliermarsh.ruff",
8 | "ms-python.mypy-type-checker",
9 | "visualstudioexptteam.vscodeintellicode",
10 | "njpwerner.autodocstring",
11 | "iliazeus.vscode-ansi",
12 | "timonwong.shellcheck",
13 | "foxundermoon.shell-format",
14 | "mikestead.dotenv",
15 | "bungcip.better-toml",
16 | "github.copilot"
17 | ]
18 | }
--------------------------------------------------------------------------------
/scripts/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "terminal.integrated.scrollback": 100000,
3 | "[python]": {
4 | "editor.tabSize": 4,
5 | "editor.insertSpaces": true,
6 | "editor.rulers": [88],
7 | "editor.defaultFormatter": "charliermarsh.ruff",
8 | "editor.formatOnSave": true,
9 | "editor.codeActionsOnSave": {
10 | "source.organizeImports": "explicit",
11 | "source.fixAll": "explicit"
12 | },
13 | },
14 | "[shellscript]": {
15 | "editor.tabSize": 4,
16 | "editor.insertSpaces": true,
17 | "editor.formatOnSave": true,
18 | "editor.formatOnPaste": true,
19 | "editor.rulers": [88],
20 | },
21 | "python.terminal.activateEnvironment": false,
22 | "python.linting.flake8Enabled": false,
23 | "python.analysis.extraPaths": [
24 | "./src",
25 | "./.venv/slurm/lib64/python3.10/site-packages",
26 | "./.venv/sge/lib64/python3.10/site-packages",
27 | "./.venv/direct/lib64/python3.10/site-packages"
28 | ],
29 | "autoDocstring.docstringFormat": "google",
30 | "autoDocstring.guessTypes": false,
31 | "autoDocstring.quoteStyle": "\"\"\"",
32 | "shellcheck.exclude": ["2155", "1091", "2004"],
33 | "search.exclude": {
34 | "**/.git": true,
35 | "**/.cluster": true,
36 | "**/.venv": true,
37 | "**/.venv_dev": true,
38 | "**/wandb/**": true,
39 | "**/docs/build": true,
40 | "**/docs/source/!(index.rst)": true,
41 | },
42 | "files.watcherExclude": {
43 | "**/.git/**": true,
44 | "**/.cluster/**": true,
45 | "**/.venv/**": true,
46 | "**/.venv_dev/**": true,
47 | "**/wandb/**": true,
48 | "**/docs/build/**": true,
49 | "**/docs/source/!(index.rst)": true,
50 | }
51 | }
--------------------------------------------------------------------------------
/scripts/.vscode/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | // See https://go.microsoft.com/fwlink/?LinkId=733558
3 | // for the documentation about the tasks.json format
4 | "version": "2.0.0",
5 | "tasks": [
6 | {
7 | "label": "Format Project",
8 | "type": "shell",
9 | "command": "./scripts/format.sh",
10 | "problemMatcher": []
11 | }
12 | ]
13 | }
--------------------------------------------------------------------------------
/scripts/format.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")/../" || exit
5 |
6 | # Create and activate a virtual environment for dev requirements
7 | export PROJECT_VENV=.venv_dev/dev
8 | mkdir -p $PROJECT_VENV
9 | python3 -m venv $PROJECT_VENV
10 | source $PROJECT_VENV/bin/activate
11 | pip3 install -q -U pip
12 | pip3 install -q -r src/requirements-dev.txt
13 |
14 | # Ruff format: https://github.com/astral-sh/ruff/issues/8232
15 | ruff check --select I --fix .
16 | python3 -m ruff format src/
17 |
18 | # Lint
19 | python3 -m ruff check src/ --fix
20 |
--------------------------------------------------------------------------------
/scripts/lint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")/../" || exit
5 |
6 | # Create and activate a virtual environment for dev requirements
7 | export PROJECT_VENV=.venv_dev/dev
8 | mkdir -p $PROJECT_VENV
9 | python3 -m venv $PROJECT_VENV
10 | source $PROJECT_VENV/bin/activate
11 | pip3 install -q -U pip
12 | pip3 install -q -r src/requirements-dev.txt
13 |
14 | # Lint
15 | python3 -m ruff check src/ --fix
16 |
--------------------------------------------------------------------------------
/scripts/package.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Change directory to script location
5 | cd "$(dirname "$0")/../" || exit
6 |
7 | # Create and activate a virtual environment for the build
8 | export PROJECT_VENV=.venv_poetry
9 | mkdir -p $PROJECT_VENV
10 | python3 -m venv $PROJECT_VENV
11 | source $PROJECT_VENV/bin/activate
12 | pip3 install -q -U pip
13 |
14 | export PACKAGE_NAME=$(grep "include = " pyproject.toml | head -n1 | cut -d'"' -f 2 | awk '{print tolower($0)}')
15 | export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
16 | export POETRY_HOME=/nlp/data/ajayp/.cache/pypoetry
17 | mkdir -p $POETRY_HOME
18 | export POETRY_CACHE_DIR=/nlp/data/ajayp/.cache/pypoetry/cache
19 | mkdir -p $POETRY_CACHE_DIR
20 |
21 | echo "Setting up..."
22 | cp pyproject.toml pyproject.toml.bak
23 | poetry add $(cat src/requirements.txt | grep -v "pandas-stubs>=")
24 | poetry add $(cat src/requirements-cpu.txt | sed 's/+cpu//' | grep -v "find-links" | grep -v "torchvision=" | grep -v "torchaudio=" | grep -v "tensorflow=")
25 | poetry add --group dev $(cat src/requirements-dev.txt)
26 | mv src/ $PACKAGE_NAME/
27 |
28 | echo "Building 'dist' folder..."
29 | rm -rf dist || true
30 | cp .gitignore .gitignore.bak
31 | cat .gitignore.bak | grep -v '^\s*datadreamer\s*$' >.gitignore
32 | poetry build
33 | mv .gitignore.bak .gitignore
34 |
35 | echo "Cleaning up..."
36 | mv pyproject.toml.bak pyproject.toml
37 | mv $PACKAGE_NAME/ src/
38 | if [[ $* != *--keep-venv* ]]; then
39 | rm -rf ./.venv_poetry
40 | fi
41 | rm -rf poetry.lock
42 |
--------------------------------------------------------------------------------
/scripts/package_publish.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")/../" || exit
5 |
6 | poetry publish "$@"
7 |
--------------------------------------------------------------------------------
/scripts/project.env:
--------------------------------------------------------------------------------
1 | export PROJECT_JOB_NAME="test" # This makes run.sh run with `pytest` instead of `python3`
2 | export PROJECT_DATA=~/.datadreamer_dev/ # Where project dependencies will be installed and stored
3 | export PROJECT_DISABLE_TUNNEL=1 # Disables certain dependencies that are not required
4 |
5 | # API Keys and Tokens
6 | # export HUGGING_FACE_HUB_TOKEN="your huggingface_hub token" # (optional) Some tests require a Hugging Face Hub token
7 | # export OPENAI_API_KEY="your_openai_api_key" # (optional) Some tests OpenAI API key
8 |
9 | # You can un-comment the line below to make subsequent runs faster
10 | # after project dependencies have been installed.
11 | # export PROJECT_SKIP_INSTALL_REQS=1 # Skip installing reqs
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Change directory to script location
4 | cd "$(dirname "$0")/../" || exit
5 |
6 | # Set environment variables
7 | export PROJECT_VENV=.venv/prod
8 |
9 | # Load user-specified project environment variables
10 | source ./scripts/project.env
11 |
12 | # Ensure direct environment variables are provided
13 | source .cluster/direct/_direct_env.sh
14 |
15 | # Load user-specified project environment variables
16 | source ./scripts/project.env
17 |
18 | # Set environment variables
19 | export PROJECT_CACHE_DIR=$PROJECT_DATA/.cache
20 |
21 | # Run the boot script
22 | .cluster/_boot.sh "$@"
23 |
--------------------------------------------------------------------------------
/src/.env:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # If you have secret environment variables (API keys or passwords), do NOT
3 | # place them in this file.
4 | # Instead, create a file called ".secrets.env".
5 | # ------------------------------------------------------------------------
6 | # Define any environment variables you want set when running below.
7 | # Example:
8 | # export FOO="bar"
9 | # ------------------------------------------------------------------------
10 |
11 | # Output in unbuffered mode
12 | export PYTHONUNBUFFERED=1
13 |
14 | # Control log level
15 | export LOGURU_LEVEL="TRACE"
16 |
17 | # Disable TensorFlow from allocating all GPU memory on startup
18 | export TF_FORCE_GPU_ALLOW_GROWTH=true
19 |
20 | # Set cache directories
21 | export NLTK_DATA=$PROJECT_CACHE_DIR/nltk
22 | mkdir -p $NLTK_DATA
23 | export HF_HOME=$PROJECT_CACHE_DIR/huggingface_cache_datadreamer
24 | mkdir -p $HF_HOME
25 | export SENTENCE_TRANSFORMERS_HOME=$PROJECT_CACHE_DIR/sentence_transformers_cache_datadreamer
26 | mkdir -p $SENTENCE_TRANSFORMERS_HOME
--------------------------------------------------------------------------------
/src/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .secrets.env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
--------------------------------------------------------------------------------
/src/.secrets.template.env:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # If you have secret environment variables (API keys or passwords), place
3 | # them in this file and rename it to ".secrets.env" instead of
4 | # ".secrets.template.env".
5 | # ------------------------------------------------------------------------
6 | # Define any environment variables you want set when running below.
7 | # Example:
8 | # export FOO="bar"
9 | # ------------------------------------------------------------------------
10 |
11 | # ngrok
12 | export NGROK_AUTHTOKEN=your_api_key
13 |
14 | # Weights & Biases
15 | export WANDB_ENTITY="username"
16 | export WANDB_MODE="online" # or disabled
17 | export WANDB_API_KEY=your_api_key
18 |
19 | # Hugging Face Hub
20 | export HUGGINGFACEHUB_API_TOKEN=your_api_key
21 |
22 | # OpenAI
23 | export OPENAI_API_KEY=your_api_key
24 |
--------------------------------------------------------------------------------
/src/__cli__.py:
--------------------------------------------------------------------------------
1 | import click
2 |
3 | from . import project
4 |
5 |
6 | # Register main
7 | @click.group()
8 | @click.pass_context
9 | def _main(*args, **kwargs): # pragma: no cover
10 | # Run init
11 | project.init()
12 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | DataDreamer Sessions
4 | ====================
5 |
6 | You can run prompting, synthetic data generation, and training workflows within
7 | a DataDreamer session using a context manager like so:
8 |
9 | .. code-block:: python
10 |
11 | from datadreamer import DataDreamer
12 |
13 | with DataDreamer('./output/'):
14 | # ... run steps or trainers here ...
15 |
16 | Inside the ``with`` block, you can run any :py:class:`~datadreamer.steps.Step` or
17 | :py:class:`~datadreamer.trainers.Trainer` you want. DataDreamer will automatically
18 | organize, cache, and save the results of each step run within a session to the output
19 | folder.
20 |
21 | In-Memory Sessions
22 | ------------------------------------------------
23 |
24 | Optionally, you can run DataDreamer fully in-memory, without it saving anything to disk,
25 | by passing ``':memory:'`` as the ``output_folder_path`` argument like
26 | ``with DataDreamer(':memory:'):``.
27 |
28 | Sessions in Interactive Environments
29 | ------------------------------------------------
30 |
31 | As an alternative to using a Python context manager (``with`` block), you can also
32 | structure your code with :py:meth:`~DataDreamer.start` and :py:meth:`~DataDreamer.stop`
33 | to achieve the same result. Using the context manager, however, is recommended and
34 | preferred. Using :py:meth:`~DataDreamer.start` and :py:meth:`~DataDreamer.stop` may be
35 | useful if you want to run DataDreamer in a Jupyter or Google Colab notebook or
36 | other interactive environments.
37 |
38 | .. code-block:: python
39 |
40 | from datadreamer import DataDreamer
41 |
42 | dd = DataDreamer('./output/')
43 | dd.start()
44 | # ... run steps or trainers here ...
45 | dd.stop()
46 |
47 | Caching
48 | =======
49 |
50 | DataDreamer caches the results of each step or trainer run within a session to the
51 | output folder. If a session is interrupted and re-run, DataDreamer will automatically
52 | load the results of previously completed steps from disk and resume where it left off.
53 |
54 |
55 | Attributes:
56 | __version__ (str): The version of DataDreamer installed.
57 | """
58 |
59 | from .utils import import_utils # isort: skip # noqa: F401
60 |
61 | import importlib.metadata
62 | import os
63 |
64 | from .datadreamer import DataDreamer
65 |
66 | try:
67 | project_root_dir = os.path.dirname(os.path.dirname(__file__))
68 | with open(os.path.join(project_root_dir, "./pyproject.toml")) as pyproject_fp:
69 | version_line = [
70 | line.strip() for line in pyproject_fp if line.startswith("version")
71 | ][0]
72 | __version__ = version_line[version_line.find('"') + 1 : version_line.rfind('"')]
73 | except FileNotFoundError: # pragma: no cover
74 | __version__ = importlib.metadata.version(
75 | os.path.basename(os.path.dirname(__file__)) + "-dev"
76 | )
77 |
78 | __all__ = ["__version__", "DataDreamer"]
79 |
--------------------------------------------------------------------------------
/src/__main__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from . import project
4 |
5 | try:
6 | from . import __entry__ # type: ignore[attr-defined] # noqa: F401
7 | except ImportError:
8 | pass
9 | from .__cli__ import _main
10 |
11 | if __name__ == "__main__": # pragma: no cover
12 | # Set the initial cwd
13 | project.INITIAL_CWD = os.path.abspath(os.getcwd())
14 |
15 | _main()
16 |
--------------------------------------------------------------------------------
/src/_cachable/__init__.py:
--------------------------------------------------------------------------------
1 | from ._cachable import _Cachable, _default_batch_scheduler_buffer_size
2 | from ._parallel_cachable import _ParallelCachable
3 |
4 | __all__ = ["_Cachable", "_ParallelCachable", "_default_batch_scheduler_buffer_size"]
5 |
--------------------------------------------------------------------------------
/src/_patches/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/_patches/__init__.py
--------------------------------------------------------------------------------
/src/_patches/datasets_reset_state_hack.py:
--------------------------------------------------------------------------------
1 | # An update in datasets 2.20.0 adding state_dict to IterableDataset seems to have
2 | # broken IterableDataset. This patch is a temporary fix until the issue is resolved.
3 |
4 | import contextlib
5 | from unittest.mock import patch
6 |
7 | from datasets.iterable_dataset import (
8 | ArrowExamplesIterable,
9 | ExamplesIterable,
10 | TypedExamplesIterable,
11 | )
12 |
13 | __original_init_state_dict = TypedExamplesIterable._init_state_dict
14 | __original_examples__iter__ = ExamplesIterable.__iter__
15 | __original_arrowexamples__iter__ = ArrowExamplesIterable.__iter__
16 | _should_reset_state_dict = False
17 |
18 |
19 | def patched_examples__iter__(self):
20 | global _should_reset_state_dict
21 | if _should_reset_state_dict:
22 | self._init_state_dict()
23 | return __original_examples__iter__(self)
24 |
25 |
26 | def patched_arrowexamples__iter__(self):
27 | global _should_reset_state_dict
28 | if _should_reset_state_dict:
29 | self._init_state_dict()
30 | return __original_arrowexamples__iter__(self)
31 |
32 |
33 | ExamplesIterable.__iter__ = patched_examples__iter__
34 | ArrowExamplesIterable.__iter__ = patched_arrowexamples__iter__
35 |
36 |
37 | @contextlib.contextmanager
38 | def apply_datasets_reset_state_hack():
39 | def patched_init_state_dict(self):
40 | self._state_dict = None # Set to None to ensure it is reset
41 | return __original_init_state_dict(self)
42 |
43 | with patch(
44 | "datasets.iterable_dataset.TypedExamplesIterable._init_state_dict",
45 | patched_init_state_dict,
46 | ):
47 | yield None
48 |
49 |
50 | def start_datasets_reset_state_hack():
51 | global _should_reset_state_dict
52 | _should_reset_state_dict = True
53 |
54 |
55 | def stop_datasets_reset_state_hack():
56 | global _should_reset_state_dict
57 | _should_reset_state_dict = False
58 |
--------------------------------------------------------------------------------
/src/_patches/setfit_import_hack.py:
--------------------------------------------------------------------------------
1 | # SetFit is out-of-date with huggingface_hub and throws an error when trying to import
2 | # from it
3 | # like this: ImportError: cannot import name 'DatasetFilter' from 'huggingface_hub'
4 |
5 | # To fix this, we need to monkey patch huggingface_hub to prevent the import error
6 |
7 | from ..utils.import_utils import ignore_pydantic_warnings
8 |
9 |
10 | def apply_setfit_import_hack():
11 | with ignore_pydantic_warnings():
12 | import huggingface_hub
13 |
14 | huggingface_hub.DatasetFilter = None
15 |
--------------------------------------------------------------------------------
/src/_stubs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/_stubs/.gitkeep
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | :py:class:`OutputDataset` and :py:class:`OutputIterableDataset` dataset objects are
3 | returned as outputs from :py:class:`~datadreamer.steps.Step` objects under the
4 | :py:attr:`~datadreamer.steps.Step.output` attribute.
5 |
6 | .. tip::
7 |
8 | You never need to construct a dataset object yourself. They are returned as
9 | :py:attr:`~datadreamer.steps.Step.output` from
10 | :py:class:`~datadreamer.steps.Step` objects. If you need to convert in-memory Python
11 | data or data in files to a DataDreamer dataset object, see the
12 | `DataSource steps <./datadreamer.steps.html#types-of-steps>`_
13 | available in :py:mod:`datadreamer.steps`.
14 |
15 | Accessing Columns
16 | =================
17 |
18 | To access a column on the dataset objects you can use the ``__getitem__`` operator like
19 | so: ``step.output['column_name']``. This will return a :py:class:`OutputDatasetColumn`
20 | or :py:class:`OutputIterableDatasetColumn` column object that can be passed as an input
21 | to the ``inputs`` argument of a :py:class:`~datadreamer.steps.Step`.
22 | """
23 |
24 | from .datasets import (
25 | OutputDataset,
26 | OutputDatasetColumn,
27 | OutputIterableDataset,
28 | OutputIterableDatasetColumn,
29 | )
30 |
31 | __all__ = [
32 | "OutputDataset",
33 | "OutputDatasetColumn",
34 | "OutputIterableDataset",
35 | "OutputIterableDatasetColumn",
36 | ]
37 |
--------------------------------------------------------------------------------
/src/datasets/utils.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from functools import partial
3 | from itertools import chain
4 | from typing import Any, Iterator, cast
5 |
6 | from datasets import Dataset, IterableDataset
7 | from datasets.features.features import Features
8 |
9 | from .. import DataDreamer
10 |
11 |
12 | def get_column_names(dataset: Dataset | IterableDataset) -> list[str]:
13 | column_names = cast(None | list[str], dataset.column_names)
14 | if column_names:
15 | return column_names
16 | else:
17 | try:
18 | first_row = next(iter(dataset))
19 | except StopIteration:
20 | return []
21 | return list(first_row.keys())
22 |
23 |
24 | def drop_unsupported_features(dataset: Dataset | IterableDataset):
25 | if isinstance(dataset, Dataset):
26 | dataset.reset_format()
27 | for index_name in dataset.list_indexes():
28 | dataset.drop_index(index_name) # pragma: no cover
29 |
30 |
31 | def dataset_zip(
32 | *datasets: Dataset,
33 | writer_batch_size: None | int = 1000,
34 | num_proc: None | int = None,
35 | ) -> Dataset:
36 | if len(datasets) == 0:
37 | raise ValueError("You must provide at least one dataset to zip.")
38 | datasets = tuple([deepcopy(d) for d in datasets])
39 | for d in datasets:
40 | drop_unsupported_features(d)
41 | smallest_dataset = min(datasets, key=lambda d: len(d))
42 |
43 | def merge_rows(datasets, x, idx):
44 | result_row = {}
45 | for d in datasets:
46 | result_row.update(d[idx])
47 | return result_row
48 |
49 | DataDreamer._enable_hf_datasets_logging()
50 | zip_results = smallest_dataset.map(
51 | partial(merge_rows, datasets),
52 | with_indices=True,
53 | desc="Zipping datasets together",
54 | writer_batch_size=writer_batch_size,
55 | num_proc=num_proc,
56 | )
57 | DataDreamer._disable_hf_datasets_logging()
58 | return zip_results
59 |
60 |
61 | def iterable_dataset_zip(*datasets: Dataset | IterableDataset) -> IterableDataset:
62 | if len(datasets) == 0:
63 | raise ValueError("You must provide at least one dataset to zip.")
64 | datasets = tuple([deepcopy(d) for d in datasets])
65 | for d in datasets:
66 | drop_unsupported_features(d)
67 |
68 | def merged_generator(datasets):
69 | iters: list[Iterator[dict[str, Any]]] = [iter(d) for d in datasets]
70 | for row_dicts in zip(*iters):
71 | row = {}
72 | for d in row_dicts:
73 | for k, v in d.items():
74 | row[k] = v
75 | yield row
76 |
77 | column_names: list[str] = list(
78 | chain.from_iterable([get_column_names(d) for d in datasets])
79 | )
80 | features = Features([(n, None) for n in column_names])
81 | return IterableDataset.from_generator(
82 | partial(merged_generator, datasets), features=features
83 | )
84 |
--------------------------------------------------------------------------------
/src/embedders/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | :py:class:`Embedder` objects help convert texts to :wikipedia:`embeddings `.
3 | All embedders derive from the :py:class:`Embedder` base class.
4 |
5 | .. tip::
6 |
7 | Instead of using :py:meth:`~Embedder.run` directly, use a
8 | :py:class:`step ` that takes an :py:class:`Embedder` as an ``args``
9 | argument such as :py:class:`~datadreamer.steps.Embed` or construct an
10 | :py:class:`~datadreamer.retrievers.EmbeddingRetriever` with the embedder and then use
11 | a retrieval step such as :py:class:`~datadreamer.steps.Retrieve`.
12 |
13 | Caching
14 | =======
15 | Embedders internally perform caching to disk, so if you embed the same text multiple
16 | times, the embedder will only embed the text once and then cache the results for
17 | future runs.
18 | """
19 |
20 | from .embedder import Embedder
21 | from .openai_embedder import OpenAIEmbedder
22 | from .parallel_embedder import ParallelEmbedder
23 | from .sentence_transformers_embedder import SentenceTransformersEmbedder
24 | from .together_embedder import TogetherEmbedder
25 |
26 | __all__ = [
27 | "Embedder",
28 | "OpenAIEmbedder",
29 | "SentenceTransformersEmbedder",
30 | "TogetherEmbedder",
31 | "ParallelEmbedder",
32 | ]
33 |
--------------------------------------------------------------------------------
/src/embedders/parallel_embedder.py:
--------------------------------------------------------------------------------
1 | from typing import cast
2 |
3 | from ..task_models.parallel_task_model import ParallelTaskModel
4 | from .embedder import Embedder
5 |
6 |
7 | class ParallelEmbedder(ParallelTaskModel, Embedder):
8 | def __init__(self, *embedders: Embedder):
9 | """
10 | Creates an embedder that will run multiple embedders in parallel. See
11 | :doc:`running models in parallel
12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>`
13 | for more details.
14 |
15 | Args:
16 | *embedders: The embedders to run in parallel.
17 | """
18 | super().__init__(*embedders)
19 | self.embedders = cast(list[Embedder], self.cachables)
20 |
21 | @property
22 | def model_max_length(self) -> int:
23 | return self.embedders[0].model_max_length
24 |
25 | @property
26 | def dims(self) -> int:
27 | return self.embedders[0].dims
28 |
29 |
30 | __all__ = ["ParallelEmbedder"]
31 |
--------------------------------------------------------------------------------
/src/errors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Various exceptions that may be raised when using DataDreamer.
3 | """
4 |
5 | from .steps.step import StepOutputError, StepOutputTypeError
6 |
7 | __all__ = ["StepOutputError", "StepOutputTypeError"]
8 |
--------------------------------------------------------------------------------
/src/errors/steps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/errors/steps/__init__.py
--------------------------------------------------------------------------------
/src/errors/steps/step.py:
--------------------------------------------------------------------------------
1 | class StepOutputError(Exception):
2 | """Raised when a :py:class:`~datadreamer.steps.Step` is constructing its
3 | :py:attr:`~datadreamer.steps.Step.output` and an error occurs."""
4 |
5 | pass
6 |
7 |
8 | class StepOutputTypeError(TypeError, StepOutputError):
9 | """Raised when a :py:class:`~datadreamer.steps.Step` is constructing its
10 | :py:attr:`~datadreamer.steps.Step.output` and a type error occurs."""
11 |
12 | def __init__(self, message: str):
13 | if message:
14 | super().__init__(
15 | "Error processing dataset, make sure all values for each output of the"
16 | " dataset are of the same Python type/shape. If you need more"
17 | " flexibility you can pickle your data using the .pickle() method on"
18 | " a Step object. Data will automatically be un-pickled when read."
19 | f" Detailed error: {message.replace('struct', 'dict')}"
20 | )
21 |
--------------------------------------------------------------------------------
/src/llms/_tokenizers.py:
--------------------------------------------------------------------------------
1 | from functools import cache
2 |
3 | TOGETHER_TOKENIZERS = {
4 | "togethercomputer/Pythia-Chat-Base-7B-v0.16": "togethercomputer/Pythia-Chat-Base-7B",
5 | "togethercomputer/Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
6 | "togethercomputer/Koala-13B": "meta-llama/Llama-2-13b-hf",
7 | "togethercomputer/llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf",
8 | "togethercomputer/llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
9 | "togethercomputer/llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf",
10 | "togethercomputer/CodeLlama-7b-Instruct": "codellama/CodeLlama-7b-Instruct-hf",
11 | "togethercomputer/CodeLlama-13b-Instruct": "codellama/CodeLlama-13b-Instruct-hf",
12 | "togethercomputer/CodeLlama-34b-Instruct": "codellama/CodeLlama-34b-Instruct-hf",
13 | "togethercomputer/mpt-7b-chat": "mosaicml/mpt-7b-chat",
14 | "togethercomputer/mpt-30b-chat": "mosaicml/mpt-30b-chat",
15 | "togethercomputer/alpaca-7b": "huggyllama/llama-7b",
16 | "togethercomputer/falcon-7b-instruct": "tiiuae/falcon-7b-instruct",
17 | "togethercomputer/falcon-40b-instruct": "tiiuae/falcon-40b-instruct",
18 | "togethercomputer/guanaco-7b": "huggyllama/llama-7b",
19 | "togethercomputer/guanaco-13b": "huggyllama/llama-13b",
20 | "togethercomputer/guanaco-33b": "huggyllama/llama-30b",
21 | "togethercomputer/guanaco-65b": "huggyllama/llama-65b",
22 | "togethercomputer/Qwen-7B": "Qwen/Qwen-7B",
23 | "togethercomputer/llama-2-7b": "meta-llama/Llama-2-7b-hf",
24 | "togethercomputer/llama-2-13b": "meta-llama/Llama-2-13b-hf",
25 | "togethercomputer/llama-2-70b": "meta-llama/Llama-2-70b-hf",
26 | "togethercomputer/mpt-30b": "mosaicml/mpt-30b",
27 | "togethercomputer/mpt-30b-instruct": "mosaicml/mpt-30b-instruct",
28 | "togethercomputer/falcon-7b": "tiiuae/falcon-7b",
29 | "togethercomputer/falcon-40b": "tiiuae/falcon-40b",
30 | "togethercomputer/codegen2-7B": "Salesforce/codegen2-7B",
31 | "togethercomputer/codegen2-16B": "Salesforce/codegen2-16B",
32 | "togethercomputer/CodeLlama-7b": "codellama/CodeLlama-7b-hf",
33 | "togethercomputer/CodeLlama-13b": "codellama/CodeLlama-13b-hf",
34 | "togethercomputer/CodeLlama-34b": "codellama/CodeLlama-34b-hf",
35 | "togethercomputer/CodeLlama-7b-Python": "codellama/CodeLlama-7b-Python-hf",
36 | "togethercomputer/CodeLlama-13b-Python": "codellama/CodeLlama-13b-Python-hf",
37 | "togethercomputer/CodeLlama-34b-Python": "codellama/CodeLlama-34b-Python-hf",
38 | }
39 |
40 |
41 | @cache
42 | def _model_name_to_tokenizer_model_name(model_name: str) -> str: # pragma: no cover
43 | model_name_lower = model_name.lower()
44 | if all(fragment in model_name_lower for fragment in ["llama-", "-2-", "-chat"]):
45 | return "meta-llama/Llama-2-7b-chat-hf"
46 | return "gpt2"
47 |
48 |
49 | __all__ = ["TOGETHER_TOKENIZERS"]
50 |
--------------------------------------------------------------------------------
/src/llms/ai21.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 |
3 | from ._litellm import LiteLLM
4 |
5 |
6 | class AI21(LiteLLM):
7 | def __init__(
8 | self,
9 | model_name: str,
10 | api_key: None | str = None,
11 | retry_on_fail: bool = True,
12 | cache_folder_path: None | str = None,
13 | **kwargs,
14 | ):
15 | super().__init__(
16 | model_name=model_name,
17 | api_key=api_key,
18 | retry_on_fail=retry_on_fail,
19 | cache_folder_path=cache_folder_path,
20 | **kwargs,
21 | )
22 | self._model_name_prefix = ""
23 |
24 | @cached_property
25 | def model_card(self) -> None | str:
26 | return "https://www.ai21.com/blog/introducing-j2"
27 |
28 | @cached_property
29 | def license(self) -> None | str:
30 | return "https://www.ai21.com/terms-of-use"
31 |
32 |
33 | __all__ = ["AI21"]
34 |
--------------------------------------------------------------------------------
/src/llms/anthropic.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Callable
3 |
4 | from ._litellm import LiteLLM
5 | from .llm import DEFAULT_BATCH_SIZE
6 |
7 |
8 | class Anthropic(LiteLLM):
9 | def __init__(
10 | self,
11 | model_name: str,
12 | api_key: None | str = None,
13 | retry_on_fail: bool = True,
14 | cache_folder_path: None | str = None,
15 | **kwargs,
16 | ):
17 | super().__init__(
18 | model_name=model_name,
19 | api_key=api_key,
20 | retry_on_fail=retry_on_fail,
21 | cache_folder_path=cache_folder_path,
22 | **kwargs,
23 | )
24 | self._model_name_prefix = ""
25 |
26 | def _run_batch(
27 | self,
28 | max_length_func: Callable[[list[str]], int],
29 | inputs: list[str],
30 | max_new_tokens: None | int = None,
31 | temperature: float = 1.0,
32 | top_p: float = 0.0,
33 | n: int = 1,
34 | stop: None | str | list[str] = None,
35 | repetition_penalty: None | float = None,
36 | logit_bias: None | dict[int, float] = None,
37 | batch_size: int = DEFAULT_BATCH_SIZE,
38 | seed: None | int = None,
39 | **kwargs,
40 | ) -> list[str] | list[list[str]]:
41 | assert (
42 | repetition_penalty is None
43 | ), f"`repetition_penalty` is not supported for {type(self).__name__}"
44 | assert n == 1, f"Only `n` = 1 is supported for {type(self).__name__}"
45 | return super()._run_batch(
46 | max_length_func=max_length_func,
47 | inputs=inputs,
48 | max_new_tokens=max_new_tokens,
49 | temperature=temperature,
50 | top_p=top_p,
51 | n=n,
52 | stop=stop,
53 | repetition_penalty=repetition_penalty,
54 | logit_bias=logit_bias,
55 | batch_size=batch_size,
56 | seed=seed,
57 | **kwargs,
58 | )
59 |
60 | @cached_property
61 | def model_card(self) -> None | str:
62 | return "https://www.ai21.com/blog/introducing-j2"
63 |
64 | @cached_property
65 | def license(self) -> None | str:
66 | return "https://console.anthropic.com/legal/terms"
67 |
68 |
69 | __all__ = ["Anthropic"]
70 |
--------------------------------------------------------------------------------
/src/llms/bedrock.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Callable
3 |
4 | from ._litellm import LiteLLM
5 | from .llm import DEFAULT_BATCH_SIZE
6 |
7 |
8 | class Bedrock(LiteLLM):
9 | def __init__(
10 | self,
11 | model_name: str,
12 | aws_access_key_id: None | str = None,
13 | aws_secret_access_key: None | str = None,
14 | aws_region_name: None | str = None,
15 | retry_on_fail: bool = True,
16 | cache_folder_path: None | str = None,
17 | **kwargs,
18 | ):
19 | super().__init__(
20 | model_name=model_name,
21 | aws_access_key_id=aws_access_key_id,
22 | aws_secret_access_key=aws_secret_access_key,
23 | aws_region_name=aws_region_name,
24 | retry_on_fail=retry_on_fail,
25 | cache_folder_path=cache_folder_path,
26 | **kwargs,
27 | )
28 | self._model_name_prefix = ""
29 |
30 | def _run_batch(
31 | self,
32 | max_length_func: Callable[[list[str]], int],
33 | inputs: list[str],
34 | max_new_tokens: None | int = None,
35 | temperature: float = 1.0,
36 | top_p: float = 0.0,
37 | n: int = 1,
38 | stop: None | str | list[str] = None,
39 | repetition_penalty: None | float = None,
40 | logit_bias: None | dict[int, float] = None,
41 | batch_size: int = DEFAULT_BATCH_SIZE,
42 | seed: None | int = None,
43 | **kwargs,
44 | ) -> list[str] | list[list[str]]:
45 | assert (
46 | repetition_penalty is None
47 | ), f"`repetition_penalty` is not supported for {type(self).__name__}"
48 | assert n == 1, f"Only `n` = 1 is supported for {type(self).__name__}"
49 | return super()._run_batch(
50 | max_length_func=max_length_func,
51 | inputs=inputs,
52 | max_new_tokens=max_new_tokens,
53 | temperature=temperature,
54 | top_p=top_p,
55 | n=n,
56 | stop=stop,
57 | repetition_penalty=repetition_penalty,
58 | logit_bias=logit_bias,
59 | batch_size=batch_size,
60 | seed=seed,
61 | **kwargs,
62 | )
63 |
64 | @cached_property
65 | def model_card(self) -> None | str:
66 | return "https://aws.amazon.com/bedrock/"
67 |
68 | @cached_property
69 | def license(self) -> None | str:
70 | return "https://aws.amazon.com/terms/"
71 |
72 |
73 | __all__ = ["Bedrock"]
74 |
--------------------------------------------------------------------------------
/src/llms/cohere.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 |
3 | from ._litellm import LiteLLM
4 |
5 |
6 | class Cohere(LiteLLM):
7 | def __init__(
8 | self,
9 | model_name: str,
10 | api_key: None | str = None,
11 | retry_on_fail: bool = True,
12 | cache_folder_path: None | str = None,
13 | **kwargs,
14 | ):
15 | super().__init__(
16 | model_name=model_name,
17 | api_key=api_key,
18 | retry_on_fail=retry_on_fail,
19 | cache_folder_path=cache_folder_path,
20 | **kwargs,
21 | )
22 | self._model_name_prefix = ""
23 |
24 | @cached_property
25 | def model_card(self) -> None | str:
26 | return "https://docs.cohere.com/docs/models"
27 |
28 | @cached_property
29 | def license(self) -> None | str:
30 | return "https://cohere.com/saas-agreement"
31 |
32 |
33 | __all__ = ["Cohere"]
34 |
--------------------------------------------------------------------------------
/src/llms/google_ai_studio.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Callable
3 |
4 | from ._litellm import LiteLLM
5 | from .llm import DEFAULT_BATCH_SIZE
6 |
7 |
8 | class GoogleAIStudio(LiteLLM):
9 | def __init__(
10 | self,
11 | model_name: str,
12 | api_key: None | str = None,
13 | retry_on_fail: bool = True,
14 | cache_folder_path: None | str = None,
15 | **kwargs,
16 | ):
17 | super().__init__(
18 | model_name=model_name,
19 | api_key=api_key,
20 | retry_on_fail=retry_on_fail,
21 | cache_folder_path=cache_folder_path,
22 | **kwargs,
23 | )
24 | self._model_name_prefix = "gemini/"
25 |
26 | def _run_batch(
27 | self,
28 | max_length_func: Callable[[list[str]], int],
29 | inputs: list[str],
30 | max_new_tokens: None | int = None,
31 | temperature: float = 1.0,
32 | top_p: float = 0.0,
33 | n: int = 1,
34 | stop: None | str | list[str] = None,
35 | repetition_penalty: None | float = None,
36 | logit_bias: None | dict[int, float] = None,
37 | batch_size: int = DEFAULT_BATCH_SIZE,
38 | seed: None | int = None,
39 | **kwargs,
40 | ) -> list[str] | list[list[str]]: # pragma: no cover
41 | assert (
42 | repetition_penalty is None
43 | ), f"`repetition_penalty` is not supported for {type(self).__name__}"
44 | return super()._run_batch(
45 | max_length_func=max_length_func,
46 | inputs=inputs,
47 | max_new_tokens=max_new_tokens,
48 | temperature=temperature,
49 | top_p=top_p,
50 | n=n,
51 | stop=stop,
52 | repetition_penalty=repetition_penalty,
53 | logit_bias=logit_bias,
54 | batch_size=batch_size,
55 | seed=seed,
56 | **kwargs,
57 | )
58 |
59 | @cached_property
60 | def model_card(self) -> None | str:
61 | return "https://arxiv.org/abs/2312.11805"
62 |
63 | @cached_property
64 | def license(self) -> None | str:
65 | return "https://ai.google.dev/gemini-api/terms"
66 |
67 | @cached_property
68 | def citation(self) -> None | list[str]:
69 | citations = []
70 | citations.append(
71 | """@article{anil2023gemini,
72 | title={Gemini: A family of highly capable multimodal models},
73 | author={Anil, Rohan and Borgeaud, Sebastian and Wu, Yonghui and Alayrac, Jean-Baptiste and Yu, Jiahui and Soricut, Radu and Schalkwyk, Johan and Dai, Andrew M and Hauth, Anja and Millican, Katie and others},
74 | journal={arXiv preprint arXiv:2312.11805},
75 | volume={1},
76 | year={2023}
77 | }""".strip()
78 | )
79 | return citations
80 |
81 |
82 | __all__ = ["GoogleAIStudio"]
83 |
--------------------------------------------------------------------------------
/src/llms/parallel_llm.py:
--------------------------------------------------------------------------------
1 | from typing import Generator, Iterable, cast
2 |
3 | from .._cachable import _ParallelCachable
4 | from .llm import DEFAULT_BATCH_SIZE, LLM
5 |
6 |
7 | class ParallelLLM(_ParallelCachable, LLM):
8 | def __init__(self, *llms: LLM):
9 | """
10 | Creates a LLM that will run multiple LLMs in parallel. See
11 | :doc:`running models in parallel
12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>`
13 | for more details.
14 |
15 | Args:
16 | *llms: The LLMs to run in parallel.
17 | """
18 | super().__init__(*llms, cls=LLM)
19 | self.llms = cast(list[LLM], self.cachables)
20 |
21 | def count_tokens(self, value: str) -> int:
22 | """Counts the number of tokens in a string.
23 |
24 | Args:
25 | value: The string to count tokens for.
26 |
27 | Returns:
28 | The number of tokens in the string.
29 | """
30 | pass
31 | return self.llms[0].count_tokens(value=value)
32 |
33 | def get_max_context_length(self, max_new_tokens: int) -> int:
34 | """Gets the maximum context length for the model. When ``max_new_tokens`` is
35 | greater than 0, the maximum number of tokens that can be used for the prompt
36 | context is returned.
37 |
38 | Args:
39 | max_new_tokens: The maximum number of tokens that can be generated.
40 |
41 | Returns:
42 | The maximum context length.
43 | """
44 | return self.llms[0].get_max_context_length(max_new_tokens=max_new_tokens)
45 |
46 | def format_prompt( # noqa: C901
47 | self,
48 | max_new_tokens: None | int = None,
49 | beg_instruction: None | str = None,
50 | in_context_examples: None | list[str] = None,
51 | end_instruction: None | str = None,
52 | sep="\n",
53 | min_in_context_examples: None | int = None,
54 | max_in_context_examples: None | int = None,
55 | ) -> str:
56 | return self.llms[0].format_prompt(
57 | max_new_tokens=max_new_tokens,
58 | beg_instruction=beg_instruction,
59 | in_context_examples=in_context_examples,
60 | end_instruction=end_instruction,
61 | sep=sep,
62 | min_in_context_examples=min_in_context_examples,
63 | max_in_context_examples=max_in_context_examples,
64 | )
65 |
66 | def run(
67 | self, prompts: Iterable[str], *args, **kwargs
68 | ) -> Generator[str | list[str], None, None] | list[str | list[str]]:
69 | kwargs["batch_size"] = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE)
70 | results_generator = self._run_in_parallel(prompts, *args, **kwargs)
71 | if not kwargs.get("return_generator", False):
72 | return list(results_generator)
73 | else:
74 | return results_generator
75 |
76 | def unload_model(self):
77 | for llm in self.llms:
78 | llm.unload_model()
79 |
80 |
81 | __all__ = ["ParallelLLM"]
82 |
--------------------------------------------------------------------------------
/src/llms/vertex_ai.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Callable
3 |
4 | from ._litellm import LiteLLM
5 | from .llm import DEFAULT_BATCH_SIZE
6 |
7 |
8 | class VertexAI(LiteLLM):
9 | def __init__(
10 | self,
11 | model_name: str,
12 | vertex_project: None | str = None,
13 | vertex_location: None | str = None,
14 | retry_on_fail: bool = True,
15 | cache_folder_path: None | str = None,
16 | **kwargs,
17 | ):
18 | super().__init__(
19 | model_name=model_name,
20 | vertex_project=vertex_project,
21 | vertex_location=vertex_location,
22 | retry_on_fail=retry_on_fail,
23 | cache_folder_path=cache_folder_path,
24 | **kwargs,
25 | )
26 | self._model_name_prefix = ""
27 |
28 | def _run_batch(
29 | self,
30 | max_length_func: Callable[[list[str]], int],
31 | inputs: list[str],
32 | max_new_tokens: None | int = None,
33 | temperature: float = 1.0,
34 | top_p: float = 0.0,
35 | n: int = 1,
36 | stop: None | str | list[str] = None,
37 | repetition_penalty: None | float = None,
38 | logit_bias: None | dict[int, float] = None,
39 | batch_size: int = DEFAULT_BATCH_SIZE,
40 | seed: None | int = None,
41 | **kwargs,
42 | ) -> list[str] | list[list[str]]:
43 | assert stop is None, f"`stop` is not supported for {type(self).__name__}"
44 | assert (
45 | repetition_penalty is None
46 | ), f"`repetition_penalty` is not supported for {type(self).__name__}"
47 | assert n == 1, f"Only `n` = 1 is supported for {type(self).__name__}"
48 | return super()._run_batch(
49 | max_length_func=max_length_func,
50 | inputs=inputs,
51 | max_new_tokens=max_new_tokens,
52 | temperature=temperature,
53 | top_p=top_p,
54 | n=n,
55 | stop=stop,
56 | repetition_penalty=repetition_penalty,
57 | logit_bias=logit_bias,
58 | batch_size=batch_size,
59 | seed=seed,
60 | **kwargs,
61 | )
62 |
63 | @cached_property
64 | def model_card(self) -> None | str:
65 | return "https://arxiv.org/abs/2312.11805"
66 |
67 | @cached_property
68 | def license(self) -> None | str:
69 | return "https://ai.google.dev/gemini-api/terms"
70 |
71 | @cached_property
72 | def citation(self) -> None | list[str]:
73 | citations = []
74 | citations.append(
75 | """@article{anil2023gemini,
76 | title={Gemini: A family of highly capable multimodal models},
77 | author={Anil, Rohan and Borgeaud, Sebastian and Wu, Yonghui and Alayrac, Jean-Baptiste and Yu, Jiahui and Soricut, Radu and Schalkwyk, Johan and Dai, Andrew M and Hauth, Anja and Millican, Katie and others},
78 | journal={arXiv preprint arXiv:2312.11805},
79 | volume={1},
80 | year={2023}
81 | }""".strip()
82 | )
83 | return citations
84 |
85 |
86 | __all__ = ["VertexAI"]
87 |
--------------------------------------------------------------------------------
/src/logging/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import DATEFMT, DATETIME_FORMAT, STANDARD_FORMAT, logger
2 |
3 | __all__ = ["logger", "DATEFMT", "DATETIME_FORMAT", "STANDARD_FORMAT"]
4 |
--------------------------------------------------------------------------------
/src/logging/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from logging import Logger
3 |
4 | from ..project.environment import RUNNING_IN_PYTEST
5 |
6 | DATEFMT: str = "[%Y-%m-%d %H:%M:%S %z]"
7 | STANDARD_FORMAT: str = "[ \N{ESC}[35m🤖 Data\N{ESC}[33mDr\N{ESC}[31mea\N{ESC}[35mmer\u001b[0m 💤 ] %(message)s" # noqa: B950
8 | DATETIME_FORMAT: str = "%(asctime)s [ \N{ESC}[35m🤖 Data\N{ESC}[33mDr\N{ESC}[31mea\N{ESC}[35mmer\u001b[0m 💤 ] %(message)s" # noqa: B950
9 |
10 | # stderr Handler
11 | STDERR_HANDLER = logging.StreamHandler()
12 | STDERR_HANDLER.setLevel(logging.DEBUG)
13 |
14 | # Logger
15 | logger: Logger = logging.getLogger("datadreamer")
16 | if RUNNING_IN_PYTEST:
17 | logger.propagate = True
18 | else:
19 | logger.propagate = False # pragma: no cover
20 | formatter = logging.Formatter(
21 | STANDARD_FORMAT, datefmt="[%Y-%m-%d %H:%M:%S %z]", validate=False
22 | )
23 | STDERR_HANDLER.setFormatter(formatter)
24 | logger.addHandler(STDERR_HANDLER)
25 | logger.setLevel(logging.CRITICAL + 1)
26 |
--------------------------------------------------------------------------------
/src/pickling/__init__.py:
--------------------------------------------------------------------------------
1 | from .pickle import unpickle, unpickle_transform
2 |
3 | __all__ = ["unpickle", "unpickle_transform"]
4 |
--------------------------------------------------------------------------------
/src/pickling/pickle.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Any
3 |
4 | from datasets.features.features import Features, Value
5 | from dill import dumps, loads
6 |
7 | _INTERNAL_PICKLE_KEY = "__DataDreamer__pickle_internal__"
8 | _PICKLE_KEY = "__DataDreamer__pickle__"
9 | __FEATURES_DEFAULT = Features()
10 |
11 |
12 | def _pickle(value: Any, *args: Any, **kwargs: Any) -> bytes:
13 | if _INTERNAL_PICKLE_KEY not in kwargs:
14 | warnings.warn(
15 | "Do not call pickle() directly. You should instead use the .pickle()"
16 | " method on a Step object.",
17 | stacklevel=2,
18 | )
19 | else:
20 | del kwargs[_INTERNAL_PICKLE_KEY]
21 | return dumps({_PICKLE_KEY: value}, *args, **kwargs)
22 |
23 |
24 | def unpickle(value: bytes) -> Any:
25 | return loads(value)[_PICKLE_KEY]
26 |
27 |
28 | def _unpickle_transform_value(value):
29 | if (
30 | isinstance(value, bytes)
31 | and len(value) >= 2
32 | and value[0] == 128
33 | and value[-1] == 46
34 | and _PICKLE_KEY.encode("utf8") in value[:100]
35 | ):
36 | return unpickle(value)
37 | else:
38 | return value
39 |
40 |
41 | def unpickle_transform(batch, features=__FEATURES_DEFAULT, batched=False):
42 | for column in batch:
43 | feature = features.get(column, None)
44 | if not isinstance(feature, Value) or feature.dtype != "binary":
45 | continue
46 | if batched:
47 | for i in range(len(batch[column])):
48 | batch[column][i] = _unpickle_transform_value(batch[column][i])
49 | else:
50 | batch[column] = _unpickle_transform_value(batch[column])
51 | return batch
52 |
--------------------------------------------------------------------------------
/src/project/__init__.py:
--------------------------------------------------------------------------------
1 | """``project`` provides project-wide helpers and utilities useful in machine learning projects.
2 |
3 | Attributes:
4 | INITIAL_CWD (None | str): The initial current working directory path.
5 | context (dict): A dictionary to use to store global context.
6 | RUNNING_IN_PYTEST (bool): Whether or not the project is running in ``pytest``.
7 | RUNNING_IN_CLUSTER (bool): Whether or not the project is running on a cluster.
8 | """
9 |
10 | import json
11 | import os
12 | import sys
13 |
14 | from loguru import logger
15 |
16 | from .debug import bash, context, debugger
17 | from .devices import (
18 | get_jax_cpu_device,
19 | get_jax_device,
20 | get_jax_devices,
21 | get_tf_cpu_device,
22 | get_tf_device,
23 | get_tf_devices,
24 | get_torch_cpu_device,
25 | get_torch_device,
26 | get_torch_devices,
27 | )
28 | from .environment import RUNNING_IN_CLUSTER, RUNNING_IN_PYTEST
29 | from .persistent_storage import get_persistent_dir
30 | from .report import reporter # type:ignore[attr-defined]
31 | from .serve import run_ngrok
32 |
33 | # Initial cwd (defined in __main__.py)
34 | INITIAL_CWD: None | str = None
35 |
36 | # Make sure CUDA/NVIDIA_VISIBLE_DEVICES is set if it is needed
37 | if os.environ.get("PROJECT_ACCELERATOR_TYPE", None) == "cuda":
38 | if "PROJECT_VISIBLE_ACCELERATOR_DEVICES" in os.environ:
39 | os.environ["NVIDIA_VISIBLE_DEVICES"] = os.environ[
40 | "PROJECT_VISIBLE_ACCELERATOR_DEVICES"
41 | ]
42 | os.environ["CUDA_VISIBLE_DEVICES"] = os.environ[
43 | "PROJECT_VISIBLE_ACCELERATOR_DEVICES"
44 | ]
45 |
46 | # Make sure if CUDA/NVIDIA_VISIBLE_DEVICES is set, PROJECT_*_ACCELERATOR_* is set
47 | if (
48 | "CUDA_VISIBLE_DEVICES" in os.environ
49 | and os.environ.get("PROJECT_ACCELERATOR_TYPE", None) is None
50 | ):
51 | os.environ["PROJECT_ACCELERATOR_TYPE"] = "cuda"
52 | os.environ["PROJECT_VISIBLE_ACCELERATOR_DEVICES"] = os.environ[
53 | "CUDA_VISIBLE_DEVICES"
54 | ]
55 | elif (
56 | "NVIDIA_VISIBLE_DEVICES" in os.environ
57 | and os.environ.get("PROJECT_ACCELERATOR_TYPE", None) is None
58 | ):
59 | os.environ["PROJECT_ACCELERATOR_TYPE"] = "cuda"
60 | os.environ["PROJECT_VISIBLE_ACCELERATOR_DEVICES"] = os.environ[
61 | "NVIDIA_VISIBLE_DEVICES"
62 | ]
63 |
64 |
65 | def init():
66 | """Initializes the project. Adds logging and does any other project setup."""
67 | # Setup logger
68 | logger.remove()
69 | logger.add(
70 | sys.stderr,
71 | colorize=True,
72 | format="[{process}] | {time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", # noqa: B950
73 | )
74 |
75 | # Write args
76 | if RUNNING_IN_CLUSTER:
77 | with open(os.environ["PROJECT_ARGS_FILE"], "w+") as f:
78 | f.write(json.dumps(sys.argv, indent=2))
79 |
80 |
81 | __all__ = [
82 | "RUNNING_IN_CLUSTER",
83 | "RUNNING_IN_PYTEST",
84 | "bash",
85 | "context",
86 | "debugger",
87 | "get_jax_cpu_device",
88 | "get_jax_device",
89 | "get_jax_devices",
90 | "get_tf_cpu_device",
91 | "get_tf_device",
92 | "get_tf_devices",
93 | "get_torch_cpu_device",
94 | "get_torch_device",
95 | "get_torch_devices",
96 | "get_persistent_dir",
97 | "reporter",
98 | "run_ngrok",
99 | "init",
100 | ]
101 |
--------------------------------------------------------------------------------
/src/project/builtin_tasks.py:
--------------------------------------------------------------------------------
1 | import click
2 | from loguru import logger
3 |
4 | from .serve import (
5 | run_cloudflared,
6 | run_http_server,
7 | run_jupyter,
8 | run_ngrok,
9 | sleep_infinity,
10 | )
11 |
12 |
13 | @click.option(
14 | "--tunnel",
15 | "-t",
16 | default="cloudflare",
17 | type=click.Choice(["cloudflare", "ngrok"]),
18 | help="The tunneling service to use.",
19 | )
20 | @click.option(
21 | "--hostname", "-h", default=None, type=str, help="The hostname to serve at."
22 | )
23 | @click.option("--password", "-p", default=None, type=str, help="The password to use.")
24 | def jupyter(ctx, tunnel, hostname, password):
25 | """This command runs Jupyter Lab."""
26 | logger.info("Running Jupyter Lab...")
27 | port = run_jupyter(password=password)
28 | if tunnel == "cloudflare":
29 | url = run_cloudflared(port, hostname=hostname)
30 | else:
31 | url = run_ngrok(port, hostname=hostname)
32 | logger.info(f"Jupyter Lab is available at URL: {url}")
33 | sleep_infinity()
34 |
35 |
36 | @click.option(
37 | "--tunnel",
38 | "-t",
39 | default="cloudflare",
40 | type=click.Choice(["cloudflare", "ngrok"]),
41 | help="The tunneling service to use.",
42 | )
43 | @click.option(
44 | "--hostname", "-h", default=None, type=str, help="The hostname to serve at."
45 | )
46 | def http_server(ctx, tunnel, hostname):
47 | """This command runs a HTTP server."""
48 | logger.info("Running HTTP server...")
49 | port = run_http_server()
50 | if tunnel == "cloudflare":
51 | url = run_cloudflared(port, hostname=hostname)
52 | else:
53 | url = run_ngrok(port, hostname=hostname)
54 | logger.info(f"HTTP server is available at URL: {url}")
55 | sleep_infinity()
56 |
57 |
58 | def register_builtin_tasks(_main):
59 | _main.command(hidden=True)(click.pass_context(jupyter))
60 | _main.command(hidden=True)(click.pass_context(http_server))
61 |
--------------------------------------------------------------------------------
/src/project/debug.py:
--------------------------------------------------------------------------------
1 | import code
2 | import inspect
3 | import os
4 | import pty
5 | import tempfile
6 | from time import sleep
7 |
8 | from .report import _deep_defaultdict # type:ignore[attr-defined]
9 |
10 |
11 | def _get_callers_locals_and_globals():
12 | """Gets the local and global variables from the caller's frame.
13 |
14 | Returns:
15 | tuple[dict, dict]: A tuple of a dictionary of local variables and global
16 | variables.
17 | """
18 | frame = inspect.currentframe()
19 | if frame and frame.f_back and frame.f_back.f_back:
20 | try:
21 | return frame.f_back.f_back.f_locals, frame.f_back.f_back.f_globals
22 | finally:
23 | del frame
24 |
25 |
26 | def debugger(rank=None, launch_on_rank=0):
27 | """Pauses execution and opens an interactive REPL with access to local and global
28 | variables.
29 |
30 | Args:
31 | rank (Any, optional): The current rank. Defaults to None (will always
32 | launch the debugger).
33 | launch_on_rank (Any, optional): What rank the debugger should be launched on.
34 | Defaults to 0.
35 | """
36 | if "PROJECT_INTERACTIVE" in os.environ:
37 | from filelock import FileLock
38 |
39 | lock = FileLock(
40 | os.path.join(
41 | os.path.dirname(tempfile.mkdtemp()),
42 | f"{os.environ['PROJECT_NAME']}-debugger.lock",
43 | )
44 | )
45 | if rank is None or rank == launch_on_rank:
46 | with lock.acquire():
47 | ls, gs = _get_callers_locals_and_globals()
48 | all_items = list(ls.items()) + list(gs.items())
49 | code.interact(
50 | banner="Opening Python REPL (press 'Ctrl-D' to exit the shell)...",
51 | local=dict(all_items),
52 | )
53 | else:
54 | sleep(5)
55 | while True:
56 | with lock.acquire():
57 | break
58 |
59 |
60 | def bash(rank=None, launch_on_rank=0):
61 | """Pauses execution and opens a bash shell.
62 |
63 | Args:
64 | rank (Any, optional): The current rank. Defaults to None (will always
65 | launch bash).
66 | launch_on_rank (Any, optional): What rank bash should be launched on.
67 | Defaults to 0.
68 | """
69 | if "PROJECT_INTERACTIVE" in os.environ:
70 | from filelock import FileLock
71 |
72 | lock = FileLock(
73 | os.path.join(
74 | os.path.dirname(tempfile.mkdtemp()),
75 | f"{os.environ['PROJECT_NAME']}-bash.lock",
76 | )
77 | )
78 | if rank is None or rank == launch_on_rank:
79 | with lock.acquire():
80 | print("Opening bash shell (type 'exit' to exit the shell)...")
81 | pty.spawn("/bin/bash")
82 | else:
83 | sleep(5)
84 | while True:
85 | with lock.acquire():
86 | break
87 |
88 |
89 | # Create a context to help store global context when debugging
90 | context = _deep_defaultdict()
91 |
--------------------------------------------------------------------------------
/src/project/environment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # Detect environments
5 | RUNNING_IN_PYTEST = os.path.basename(sys.argv[0]) == "pytest"
6 | RUNNING_IN_CLUSTER = "PROJECT_CLUSTER" in os.environ
7 |
--------------------------------------------------------------------------------
/src/project/pennnlp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import sys
4 |
5 | from loguru import logger
6 |
7 | """This file contains PennNLP cluster specific utilities."""
8 |
9 | NLPDATA_PATH = os.path.join("/nlp/data/", os.environ["USER"])
10 |
11 | SCRATCH_PATH = os.path.join("/scratch/", os.environ["USER"], os.environ["PROJECT_NAME"])
12 |
13 |
14 | def detect_pennnlp():
15 | """Detect if running on PennNLP's cluster.
16 |
17 | Returns:
18 | bool: Whether or not we are running on PennNLP's cluster.
19 | """
20 | return os.path.exists(NLPDATA_PATH)
21 |
22 |
23 | def copy_file(src, dest):
24 | """Copies a file from the source path to the destination path, but skips the copying
25 | if the file already exists (determined by last modified time or file size).
26 |
27 | Args:
28 | src (str): The source path.
29 | dest (str): The destination path.
30 | """
31 | if (
32 | (not os.path.exists(dest))
33 | or (os.stat(src).st_mtime - os.stat(dest).st_mtime > 1)
34 | or (os.stat(src).st_size != os.stat(dest).st_size)
35 | ):
36 | shutil.copy2(src, dest)
37 |
38 |
39 | def copy_files_to_ssd(*paths, subfolder=None):
40 | if detect_pennnlp():
41 | # Create scratch dir for SSD disk speed on PennNLP cluster
42 | scratch_path = SCRATCH_PATH
43 | if subfolder is True:
44 | scratch_path = os.path.join(SCRATCH_PATH, sys.argv[1])
45 | elif subfolder:
46 | scratch_path = os.path.join(SCRATCH_PATH, subfolder)
47 | os.makedirs(scratch_path, exist_ok=True)
48 |
49 | # Copy files to scratch dir for SSD disk speed on PennNLP cluster
50 | new_paths = []
51 | for path in paths:
52 | path = os.path.normpath(os.path.abspath(path))
53 | basename = os.path.basename(path)
54 | new_path = os.path.join(scratch_path, basename)
55 | new_paths.append(new_path)
56 | logger.debug(f"Copying file {path} to PennNLP scratch path: {new_path}...")
57 | copy_file(path, new_path)
58 | logger.debug("Done copying file.")
59 | return new_paths
60 | else:
61 | return paths
62 |
--------------------------------------------------------------------------------
/src/project/persistent_storage.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import json
3 | import os
4 |
5 | from .report import reporter # type:ignore[attr-defined]
6 |
7 |
8 | def _dict_hash(dictionary):
9 | """Returns the MD5 hash of a dictionary.
10 |
11 | Args:
12 | dictionary (dict[str, Any]): The dictionary to hash.
13 |
14 | Returns:
15 | str: The MD5 hash.
16 | """
17 | dhash = hashlib.md5()
18 | # We need to sort arguments so {'a': 1, 'b': 2} is
19 | # the same as {'b': 2, 'a': 1}
20 | encoded = json.dumps(dictionary, sort_keys=True).encode()
21 | dhash.update(encoded)
22 | return dhash.hexdigest()
23 |
24 |
25 | def get_persistent_dir(name, config_path):
26 | """Returns the path to a persistent directory that will be usable across jobs. Any
27 | future jobs
28 |
29 | Args:
30 | name (str): [description]
31 | config_path (str): [description]
32 | """
33 | config_hash = _dict_hash(reporter.get(config_path))
34 | persistent_dir = os.path.join(
35 | os.environ["PROJECT_DATA_OUTPUT_PERSISTENT_DATA"], config_hash
36 | )
37 | local_persistent_dir = os.path.join(os.environ["PROJECT_WRITE_DIR"], name)
38 | os.makedirs(persistent_dir, exist_ok=True)
39 | try:
40 | os.symlink(persistent_dir, local_persistent_dir, target_is_directory=True)
41 | except FileExistsError:
42 | pass
43 | return local_persistent_dir
44 |
--------------------------------------------------------------------------------
/src/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/py.typed
--------------------------------------------------------------------------------
/src/requirements-accelerator-device.txt:
--------------------------------------------------------------------------------
1 | torch>=2.5.1,<3.0.0
--------------------------------------------------------------------------------
/src/requirements-cpu.txt:
--------------------------------------------------------------------------------
1 | torch>=2.5.1,<3.0.0
--------------------------------------------------------------------------------
/src/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | ruff==0.1.9
2 | Sphinx==7.0.1
3 | furo==2023.5.20
4 | sphinx-inline-tabs==2023.4.21
5 | sphinx-autobuild==2021.3.14
6 | sphinx-copybutton==0.5.2
7 | sphinx-click==4.4.0
8 | sphinx-sitemap==2.5.0
9 | sphinx-autodoc-typehints==1.23.0
10 | sphinx-toolbox==3.4.0
11 | sphinx_design==0.5.0
12 | sphinx-reredirects==0.1.3
--------------------------------------------------------------------------------
/src/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest==7.3.1
2 | pytest-cov==4.1.0
3 | flaky==3.7.0
4 | mypy==1.4.1
5 | pytest-mock==3.11.1
6 | pytest-timeout==2.2.0
7 | pytest-order==1.2.0
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | click>=8.1.3
2 | loguru>=0.7.0,<1.0.0
3 | filelock>=3.13.1,<4.0.0
4 | jsonlines>=4.0.0,<7.0.0
5 | numpy>=1.26.4,<2.0.0
6 | sortedcontainers>=2.4.0,<3.0.0
7 | sqlitedict>=2.1.0,<3.0.0
8 | pandas>=1.5.3,<2.0.0
9 | pandas-stubs>=1.5.3.230321,<2.0.0
10 | tenacity>=9.0.0
11 | dill>=0.3.8,<1.0.0
12 | ring>=0.10.1,<1.0.0
13 | psutil>=6.1.1
14 | faiss-cpu>=1.9.0.post1,<2.0.0
15 | evaluate>=0.4.3,<1.0.0
16 | tiktoken>=0.7.0,<1.0.0
17 | sentence-transformers>=3.4.0,<4.0.0
18 | setfit>=1.1.1,<2.0.0
19 | openai>=1.59.6,<2.0.0
20 | datasets>=3.2.0,<4.0.0
21 | peft>=0.14.0,<1.0.0
22 | bitsandbytes>=0.45.0,<1.0.0
23 | huggingface-hub>=0.27.1,<1.0.0
24 | optimum>=1.21.2,<2.0.0
25 | accelerate>=1.3.0,<2.0.0
26 | transformers>=4.48.1,<4.50.0
27 | ctransformers>=0.2.27,<1.0.0
28 | outlines-core>=0.1.26
29 | Pyro5>=5.15
30 | litellm==1.57.8
31 | trl==0.9.6
32 |
--------------------------------------------------------------------------------
/src/retrievers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | :py:class:`Retriever` objects help retrieve texts based on a set of queries.
3 | All retrievers derive from the :py:class:`Retriever` base class.
4 |
5 | .. tip::
6 |
7 | Instead of using :py:meth:`~Retriever.run` directly, use a
8 | :py:class:`step ` that takes a :py:class:`Retriever` as an ``args``
9 | argument such as :py:class:`~datadreamer.steps.Retrieve` and
10 | :py:class:`~datadreamer.steps.RAGPrompt`. Some other steps like and
11 | :py:class:`~datadreamer.steps.FewShotPromptWithRetrieval`
12 | use retrievers internally.
13 |
14 | Caching
15 | =======
16 | Retrievers typically initially build an index once and cache the index to disk.
17 | Retrievers additionally internally perform caching to disk, so if you retrieve results
18 | for the same query multiple times, the retriever will only retrieve results for the
19 | query once and then cache the results for future runs.
20 | """
21 |
22 | from .embedding_retriever import EmbeddingRetriever
23 | from .parallel_retriever import ParallelRetriever
24 | from .retriever import Retriever
25 |
26 | __all__ = ["Retriever", "EmbeddingRetriever", "ParallelRetriever"]
27 |
--------------------------------------------------------------------------------
/src/retrievers/parallel_retriever.py:
--------------------------------------------------------------------------------
1 | from typing import Generator, Iterable, cast
2 |
3 | from .._cachable import _ParallelCachable
4 | from .retriever import DEFAULT_BATCH_SIZE, Retriever
5 |
6 |
7 | class ParallelRetriever(_ParallelCachable, Retriever):
8 | def __init__(self, *retrievers: Retriever):
9 | """
10 | Creates a retriever that will run multiple retrievers in parallel. See
11 | :doc:`running models in parallel
12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>`
13 | for more details.
14 |
15 | Args:
16 | *retrievers: The retrievers to run in parallel.
17 | """
18 | super().__init__(*retrievers, cls=Retriever)
19 | self.retrievers = cast(list[Retriever], self.cachables)
20 |
21 | @property
22 | def index(self): # pragma: no cover
23 | return self.retrievers[0].index
24 |
25 | def run(
26 | self, queries: Iterable[str], *args, **kwargs
27 | ) -> Generator[str | list[str], None, None] | list[str | list[str]]:
28 | kwargs["batch_size"] = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE)
29 | results_generator = self._run_in_parallel(queries, *args, **kwargs)
30 | if not kwargs.get("return_generator", False):
31 | return list(results_generator)
32 | else:
33 | return results_generator
34 |
35 | def unload_model(self):
36 | for llm in self.retrievers:
37 | llm.unload_model()
38 |
39 |
40 | __all__ = ["ParallelRetriever"]
41 |
--------------------------------------------------------------------------------
/src/steps/data_card.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from ..utils.collection_utils import sort_keys
4 |
5 |
6 | class DataCardType:
7 | """The types of data card entries."""
8 |
9 | DATETIME = "Date & Time"
10 | MODEL_NAME = "Model Name"
11 | DATASET_NAME = "Dataset Name"
12 | LICENSE = "License Information"
13 | CITATION = "Citation Information"
14 | DATASET_CARD = "Dataset Card"
15 | MODEL_CARD = "Model Card"
16 | URL = "URL"
17 |
18 |
19 | def sort_data_card(data_card: dict[str, list[Any]]) -> dict[str, list[Any]]:
20 | return sort_keys(
21 | data_card,
22 | key_order=[
23 | DataCardType.DATETIME,
24 | DataCardType.DATASET_NAME,
25 | DataCardType.MODEL_NAME,
26 | DataCardType.URL,
27 | DataCardType.DATASET_CARD,
28 | DataCardType.MODEL_CARD,
29 | DataCardType.LICENSE,
30 | DataCardType.CITATION,
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/src/steps/data_sources/csv_data_source.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Sequence
3 |
4 | from datasets import DatasetDict, load_dataset
5 | from datasets.fingerprint import Hasher
6 |
7 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY
8 | from .data_source import DataSource
9 |
10 |
11 | class CSVDataSource(DataSource):
12 | """Loads a CSV dataset from a local path. See :py:func:`datasets.load_dataset` for
13 | more details.
14 |
15 | Args:
16 | name: The name of the step.
17 | data_folder: The path to the dataset folder.
18 | data_files: The name of files from the folder to load.
19 | progress_interval: How often to log progress in seconds.
20 | force: Whether to force run the step (ignore saved results).
21 | verbose: Whether or not to print verbose logs.
22 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.).
23 | save_num_proc: The number of processes to use if saving to disk.
24 | save_num_shards: The number of shards on disk to save the dataset into.
25 | background: Whether to run the operation in the background.
26 | **config_kwargs: Additional keyword arguments to pass to
27 | :py:func:`datasets.load_dataset`.
28 | """
29 |
30 | def __init__(
31 | self,
32 | name: str,
33 | data_folder: None | str = None,
34 | data_files: None | str | Sequence[str] = None,
35 | sep: str = ",",
36 | progress_interval: None | int = 60,
37 | force: bool = False,
38 | verbose: None | bool = None,
39 | log_level: None | int = None,
40 | save_num_proc: None | int = None,
41 | save_num_shards: None | int = None,
42 | background: bool = False,
43 | **config_kwargs,
44 | ):
45 | self.data_folder = data_folder
46 | self.data_files = data_files
47 | self.sep = sep
48 | self.config_kwargs = config_kwargs
49 | super().__init__(
50 | name,
51 | data=None, # type: ignore[arg-type]
52 | progress_interval=progress_interval,
53 | force=force,
54 | verbose=verbose,
55 | log_level=log_level,
56 | save_num_proc=save_num_proc,
57 | save_num_shards=save_num_shards,
58 | background=background,
59 | )
60 |
61 | def setup(self):
62 | pass
63 |
64 | def run(self):
65 | if isinstance(self.data_files, dict):
66 | raise ValueError(
67 | "You supplied a dict to data_files, multiple splits are not supported."
68 | )
69 | result = load_dataset(
70 | "csv",
71 | data_dir=self.data_folder,
72 | data_files=self.data_files,
73 | num_proc=self.save_num_proc,
74 | sep=self.sep,
75 | **self.config_kwargs,
76 | )
77 | if isinstance(result, DatasetDict):
78 | result = result["train"]
79 | return result
80 |
81 | @cached_property
82 | def fingerprint(self) -> str:
83 | return Hasher.hash(
84 | [super().fingerprint, self.data_folder, self.data_files, self.config_kwargs]
85 | )
86 |
87 |
88 | setattr(CSVDataSource, _INTERNAL_STEP_OPERATION_KEY, True)
89 |
90 | __all__ = ["CSVDataSource"]
91 |
--------------------------------------------------------------------------------
/src/steps/data_sources/hf_dataset_data_source.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 |
3 | from datasets import Dataset
4 | from datasets.fingerprint import Hasher
5 |
6 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY
7 | from .data_source import DataSource
8 |
9 |
10 | class HFDatasetDataSource(DataSource):
11 | """Loads a Hugging Face :py:class:`~datasets.Dataset` from a local path. See
12 | :py:func:`datasets.load_from_disk` for more details.
13 |
14 | Args:
15 | name: The name of the step.
16 | dataset_path: The path to the :py:class:`datasets.Dataset` folder.
17 | progress_interval: How often to log progress in seconds.
18 | force: Whether to force run the step (ignore saved results).
19 | verbose: Whether or not to print verbose logs.
20 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.).
21 | save_num_proc: The number of processes to use if saving to disk.
22 | save_num_shards: The number of shards on disk to save the dataset into.
23 | background: Whether to run the operation in the background.
24 | """
25 |
26 | def __init__(
27 | self,
28 | name: str,
29 | dataset_path: str,
30 | progress_interval: None | int = 60,
31 | force: bool = False,
32 | verbose: None | bool = None,
33 | log_level: None | int = None,
34 | save_num_proc: None | int = None,
35 | save_num_shards: None | int = None,
36 | background: bool = False,
37 | ):
38 | self.path_to_dataset = dataset_path
39 | super().__init__(
40 | name,
41 | data=None, # type: ignore[arg-type]
42 | progress_interval=progress_interval,
43 | force=force,
44 | verbose=verbose,
45 | log_level=log_level,
46 | save_num_proc=save_num_proc,
47 | save_num_shards=save_num_shards,
48 | background=background,
49 | )
50 |
51 | def setup(self):
52 | pass
53 |
54 | def run(self):
55 | return Dataset.load_from_disk(self.path_to_dataset)
56 |
57 | @cached_property
58 | def fingerprint(self) -> str:
59 | return Hasher.hash([super().fingerprint, self.path_to_dataset])
60 |
61 |
62 | setattr(HFDatasetDataSource, _INTERNAL_STEP_OPERATION_KEY, True)
63 |
64 | __all__ = ["HFDatasetDataSource"]
65 |
--------------------------------------------------------------------------------
/src/steps/data_sources/json_data_source.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Sequence
3 |
4 | from datasets import DatasetDict, load_dataset
5 | from datasets.fingerprint import Hasher
6 |
7 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY
8 | from .data_source import DataSource
9 |
10 |
11 | class JSONDataSource(DataSource):
12 | """Loads a JSON dataset from a local path. See :py:func:`datasets.load_dataset` for
13 | more details.
14 |
15 | Args:
16 | name: The name of the step.
17 | data_folder: The path to the dataset folder.
18 | data_files: The name of files from the folder to load.
19 | progress_interval: How often to log progress in seconds.
20 | force: Whether to force run the step (ignore saved results).
21 | verbose: Whether or not to print verbose logs.
22 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.).
23 | save_num_proc: The number of processes to use if saving to disk.
24 | save_num_shards: The number of shards on disk to save the dataset into.
25 | background: Whether to run the operation in the background.
26 | **config_kwargs: Additional keyword arguments to pass to
27 | :py:func:`datasets.load_dataset`.
28 | """
29 |
30 | def __init__(
31 | self,
32 | name: str,
33 | data_folder: None | str = None,
34 | data_files: None | str | Sequence[str] = None,
35 | progress_interval: None | int = 60,
36 | force: bool = False,
37 | verbose: None | bool = None,
38 | log_level: None | int = None,
39 | save_num_proc: None | int = None,
40 | save_num_shards: None | int = None,
41 | background: bool = False,
42 | **config_kwargs,
43 | ):
44 | self.data_folder = data_folder
45 | self.data_files = data_files
46 | self.config_kwargs = config_kwargs
47 | super().__init__(
48 | name,
49 | data=None, # type: ignore[arg-type]
50 | progress_interval=progress_interval,
51 | force=force,
52 | verbose=verbose,
53 | log_level=log_level,
54 | save_num_proc=save_num_proc,
55 | save_num_shards=save_num_shards,
56 | background=background,
57 | )
58 |
59 | def setup(self):
60 | pass
61 |
62 | def run(self):
63 | if isinstance(self.data_files, dict):
64 | raise ValueError(
65 | "You supplied a dict to data_files, multiple splits are not supported."
66 | )
67 | result = load_dataset(
68 | "json",
69 | data_dir=self.data_folder,
70 | data_files=self.data_files,
71 | num_proc=self.save_num_proc,
72 | **self.config_kwargs,
73 | )
74 | if isinstance(result, DatasetDict):
75 | result = result["train"]
76 | return result
77 |
78 | @cached_property
79 | def fingerprint(self) -> str:
80 | return Hasher.hash(
81 | [super().fingerprint, self.data_folder, self.data_files, self.config_kwargs]
82 | )
83 |
84 |
85 | setattr(JSONDataSource, _INTERNAL_STEP_OPERATION_KEY, True)
86 |
87 | __all__ = ["JSONDataSource"]
88 |
--------------------------------------------------------------------------------
/src/steps/data_sources/text_data_source.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Sequence
3 |
4 | from datasets import DatasetDict, load_dataset
5 | from datasets.fingerprint import Hasher
6 |
7 | from ..step_operations import _INTERNAL_STEP_OPERATION_KEY
8 | from .data_source import DataSource
9 |
10 |
11 | class TextDataSource(DataSource):
12 | """Loads a text dataset from a local path. See :py:func:`datasets.load_dataset` for
13 | more details.
14 |
15 | Args:
16 | name: The name of the step.
17 | data_folder: The path to the dataset folder.
18 | data_files: The name of files from the folder to load.
19 | progress_interval: How often to log progress in seconds.
20 | force: Whether to force run the step (ignore saved results).
21 | verbose: Whether or not to print verbose logs.
22 | log_level: The logging level to use (:py:data:`~logging.DEBUG`, :py:data:`~logging.INFO`, etc.).
23 | save_num_proc: The number of processes to use if saving to disk.
24 | save_num_shards: The number of shards on disk to save the dataset into.
25 | background: Whether to run the operation in the background.
26 | **config_kwargs: Additional keyword arguments to pass to
27 | :py:func:`datasets.load_dataset`.
28 | """
29 |
30 | def __init__(
31 | self,
32 | name: str,
33 | data_folder: None | str = None,
34 | data_files: None | str | Sequence[str] = None,
35 | progress_interval: None | int = 60,
36 | force: bool = False,
37 | verbose: None | bool = None,
38 | log_level: None | int = None,
39 | save_num_proc: None | int = None,
40 | save_num_shards: None | int = None,
41 | background: bool = False,
42 | **config_kwargs,
43 | ):
44 | self.data_folder = data_folder
45 | self.data_files = data_files
46 | self.config_kwargs = config_kwargs
47 | super().__init__(
48 | name,
49 | data=None, # type: ignore[arg-type]
50 | progress_interval=progress_interval,
51 | force=force,
52 | verbose=verbose,
53 | log_level=log_level,
54 | save_num_proc=save_num_proc,
55 | save_num_shards=save_num_shards,
56 | background=background,
57 | )
58 |
59 | def setup(self):
60 | pass
61 |
62 | def run(self):
63 | if isinstance(self.data_files, dict):
64 | raise ValueError(
65 | "You supplied a dict to data_files, multiple splits are not supported."
66 | )
67 | result = load_dataset(
68 | "text",
69 | data_dir=self.data_folder,
70 | data_files=self.data_files,
71 | num_proc=self.save_num_proc,
72 | **self.config_kwargs,
73 | )
74 | if isinstance(result, DatasetDict):
75 | result = result["train"]
76 | return result
77 |
78 | @cached_property
79 | def fingerprint(self) -> str:
80 | return Hasher.hash(
81 | [super().fingerprint, self.data_folder, self.data_files, self.config_kwargs]
82 | )
83 |
84 |
85 | setattr(TextDataSource, _INTERNAL_STEP_OPERATION_KEY, True)
86 |
87 | __all__ = ["TextDataSource"]
88 |
--------------------------------------------------------------------------------
/src/steps/prompt/data_from_prompt.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from ..._cachable._cachable import _StrWithSeed
4 | from ._prompt_base import _PromptBase
5 |
6 |
7 | class DataFromPrompt(_PromptBase):
8 | """Generates ``n`` rows of data using an instruction with a
9 | :py:class:`~datadreamer.llms.LLM`."""
10 |
11 | def setup(self):
12 | self._prompt_input_type = "none"
13 | self._register_prompt_args()
14 | self.register_arg(
15 | "instruction",
16 | required=True,
17 | help="The instruction to use to generate data.",
18 | )
19 | self.register_arg(
20 | "n", required=True, help="The number of rows to generate from the prompt."
21 | )
22 | self.register_arg(
23 | "temperature",
24 | required=False,
25 | default=1.0,
26 | help="The temperature to use when generating data.",
27 | )
28 | self.register_arg(
29 | "top_p",
30 | required=False,
31 | default=1.0,
32 | help="The top_p to use when generating data.",
33 | )
34 | self._register_prompt_optional_args()
35 | self._register_prompt_outputs()
36 |
37 | def run(self):
38 | # Get inputs and arguments
39 | args = self.args
40 | instruction = args.pop("instruction")
41 | n = args.pop("n")
42 | _seed = args.pop("_seed", None)
43 |
44 | def create_prompts(instruction, n, seed):
45 | for prompt_idx in range(n):
46 | yield _StrWithSeed(
47 | instruction,
48 | seed=((_seed, prompt_idx) if _seed is not None else prompt_idx),
49 | )
50 |
51 | return self._run_prompts(
52 | args=args,
53 | prompts=partial(create_prompts, instruction, n, _seed),
54 | total_num_prompts=n,
55 | )
56 |
57 |
58 | __all__ = ["DataFromPrompt"]
59 |
--------------------------------------------------------------------------------
/src/steps/prompt/process_with_prompt.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from ._prompt_base import _PromptBase
4 |
5 |
6 | class ProcessWithPrompt(_PromptBase):
7 | """Processes a set of inputs using an instruction with a
8 | :py:class:`~datadreamer.llms.LLM`."""
9 |
10 | def setup(self):
11 | self._register_prompt_inputs(prompt_input_type="input")
12 | self._register_prompt_args()
13 | self.register_arg(
14 | "instruction",
15 | required=True,
16 | help="An instruction that describes how to process the input.",
17 | )
18 | self.register_arg(
19 | "input_label",
20 | required=False,
21 | default="Input:",
22 | help="The label to use for inputs.",
23 | )
24 | self.register_arg(
25 | "instruction_label",
26 | required=False,
27 | default="Instruction:",
28 | help="The label to use for the instruction.",
29 | )
30 | self.register_arg(
31 | "max_new_tokens",
32 | required=False,
33 | help="The maximum number of tokens to generate.",
34 | )
35 | self.register_arg(
36 | "sep",
37 | required=False,
38 | default="\n\n",
39 | help="The separator to use between instructions and the input.",
40 | )
41 | self._register_prompt_optional_args()
42 | self._register_prompt_outputs()
43 |
44 | def run(self):
45 | # Get inputs and arguments
46 | args = self.args
47 | llm = args["llm"]
48 | inputs = self.inputs["inputs"]
49 | input_label = args.pop("input_label")
50 | instruction_label = args.pop("instruction_label")
51 | max_new_tokens = args["max_new_tokens"]
52 | format_prompt_args = dict(
53 | max_new_tokens=max_new_tokens,
54 | end_instruction=(
55 | f"{instruction_label} {args.pop('instruction')}"
56 | if instruction_label
57 | else args.pop("instruction")
58 | ),
59 | sep=args.pop("sep"),
60 | )
61 |
62 | def create_process_input_with_instruction_prompts(
63 | llm, inputs, input_label, format_prompt_args
64 | ):
65 | for input in inputs:
66 | beg_instruction = f"{input_label} {input}" if input_label else input
67 | yield llm.format_prompt(
68 | beg_instruction=beg_instruction, **format_prompt_args
69 | )
70 |
71 | # Generate
72 | return self._run_prompts(
73 | args=args,
74 | prompts=partial(
75 | create_process_input_with_instruction_prompts,
76 | llm,
77 | inputs,
78 | input_label,
79 | format_prompt_args,
80 | ),
81 | )
82 |
83 |
84 | __all__ = ["ProcessWithPrompt"]
85 |
--------------------------------------------------------------------------------
/src/steps/prompt/prompt.py:
--------------------------------------------------------------------------------
1 | from ._prompt_base import _PromptBase
2 |
3 |
4 | class Prompt(_PromptBase):
5 | "Runs a set of prompts against a :py:class:`~datadreamer.llms.LLM`."
6 |
7 | def setup(self):
8 | self._register_prompt_inputs()
9 | self._register_prompt_args()
10 | self._register_prompt_optional_args()
11 | self._register_prompt_outputs()
12 |
13 | def run(self):
14 | return self._run_prompts(args=self.args)
15 |
16 |
17 | __all__ = ["Prompt"]
18 |
--------------------------------------------------------------------------------
/src/steps/step_background.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from multiprocessing.dummy import Pool as ThreadPool
3 | from time import sleep
4 | from typing import TYPE_CHECKING, Callable
5 |
6 | from .. import DataDreamer
7 | from ..errors import StepOutputError
8 | from ..utils.background_utils import get_thread_id, run_in_background_thread
9 |
10 | if TYPE_CHECKING: # pragma: no cover
11 | from .step import Step
12 |
13 |
14 | def _check_step_output(step: "Step") -> bool:
15 | try:
16 | step.output # noqa: B018
17 | return True
18 | except StepOutputError:
19 | return False
20 |
21 |
22 | def _waiter(steps, poll_interval=1.0):
23 | while len(steps) > 0:
24 | step = steps[-1]
25 | if _check_step_output(step):
26 | steps.pop()
27 | else:
28 | sleep(poll_interval)
29 |
30 |
31 | def wait(*steps: "Step", poll_interval=1.0):
32 | """Wait for all steps to complete if they are running in the background.
33 |
34 | Args:
35 | poll_interval: How often to poll in seconds.
36 | """
37 | from ..steps import Step
38 |
39 | if not all([isinstance(s, Step) for s in steps]):
40 | raise TypeError("All arguments to wait() must be of type Step.")
41 | if all([_check_step_output(step) for step in steps]):
42 | return
43 | steps_list = list(steps)
44 | wait_thread = run_in_background_thread(
45 | _waiter, steps_list, poll_interval=poll_interval
46 | )
47 | wait_thread.join()
48 |
49 |
50 | def concurrent(*funcs: Callable):
51 | """Run a set of functions (which run steps) concurrently.
52 |
53 | Args:
54 | *funcs: The functions to run concurrently.
55 | """
56 | parent_thread_id = get_thread_id()
57 |
58 | def wrapper_func(parent_thread_id, func):
59 | if DataDreamer.initialized():
60 | DataDreamer._register_child_thread(parent_thread_id)
61 | return func()
62 |
63 | if not all([callable(f) for f in funcs]):
64 | raise TypeError("All arguments to concurrent() must be functions.")
65 | thread_pool = ThreadPool(len(funcs))
66 | results = thread_pool.map(partial(wrapper_func, parent_thread_id), funcs)
67 | thread_pool.close()
68 | thread_pool.join()
69 | return results
70 |
--------------------------------------------------------------------------------
/src/steps/step_export.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import partial
3 |
4 | from datasets import DatasetDict
5 |
6 | from ..datasets import OutputDataset
7 | from ..pickling import unpickle_transform
8 |
9 |
10 | def _path_to_split_paths(path: str, dataset_dict: DatasetDict) -> dict[str, str]:
11 | os.makedirs(os.path.dirname(path), exist_ok=True)
12 | base, extension = os.path.splitext(path)
13 | paths: dict[str, str] = {}
14 | for split_name in dataset_dict:
15 | if split_name == "validation":
16 | path_split_name = "val"
17 | else:
18 | path_split_name = split_name
19 | split_path = f"{base}.{path_split_name}{extension}"
20 | paths[split_name] = split_path
21 | return paths
22 |
23 |
24 | def _unpickle_export(export: DatasetDict | list | dict, output_dataset: OutputDataset):
25 | if output_dataset._pickled:
26 | if isinstance(export, DatasetDict):
27 | export.set_transform(
28 | partial(
29 | unpickle_transform, features=output_dataset._features, batched=True
30 | )
31 | )
32 | return export
33 | elif isinstance(export, list):
34 | return [
35 | unpickle_transform(
36 | row, features=output_dataset._features, batched=False
37 | )
38 | for row in export
39 | ]
40 | else:
41 | return unpickle_transform(
42 | export, features=output_dataset._features, batched=True
43 | )
44 | else:
45 | return export
46 |
47 |
48 | __all__ = ["_path_to_split_paths", "_unpickle_export"]
49 |
--------------------------------------------------------------------------------
/src/steps/tasks/embed.py:
--------------------------------------------------------------------------------
1 | from itertools import tee
2 |
3 | from ..data_card import DataCardType
4 | from ..step import Step
5 | from ..step_output import LazyRows
6 |
7 |
8 | class Embed(Step):
9 | "Embeds a set of texts with an :py:class:`~datadreamer.embedders.Embedder`."
10 |
11 | def setup(self):
12 | self.register_input("texts", help="The texts to embed.")
13 | self.register_arg("embedder", help="The Embedder to use.")
14 | self.register_arg(
15 | "truncate",
16 | required=False,
17 | default=False,
18 | help="Whether or not to truncate inputs.",
19 | )
20 | self.register_arg(
21 | "instruction",
22 | required=False,
23 | help="The instruction to prefix inputs to the embedding model with.",
24 | )
25 | self.register_arg(
26 | "lazy", required=False, default=False, help="Whether to run lazily or not."
27 | )
28 | self.register_arg(
29 | "**kwargs",
30 | required=False,
31 | help="Any other arguments you want to pass to the .run() method of the Embedder.",
32 | )
33 | self.register_output("texts", help="The texts that were embedded.")
34 | self.register_output("embeddings", help="The embeddings by the Embedder.")
35 |
36 | def run(self):
37 | args = self.args
38 |
39 | # Get inputs and arguments
40 | embedder = args.pop("embedder")
41 | lazy = args.pop("lazy")
42 |
43 | # Register trace info from the Embedder model
44 | if hasattr(embedder, "model_name"):
45 | self.register_data_card(DataCardType.MODEL_NAME, embedder.model_name)
46 | self.register_data_card(DataCardType.MODEL_CARD, embedder.model_card)
47 | self.register_data_card(DataCardType.LICENSE, embedder.license)
48 | for citation in embedder.citation or []:
49 | self.register_data_card(DataCardType.CITATION, citation)
50 |
51 | # Get the total number of texts
52 | texts = self.inputs["texts"]
53 | total_num_texts = texts.num_rows
54 |
55 | # Define a function that yields embeddings
56 | def get_embeddings():
57 | # Get an iterator over texts
58 | texts_iter_1, texts_iter_2 = tee(iter(texts), 2)
59 |
60 | # Generate
61 | embeddings_iter = iter(
62 | embedder.run(
63 | texts=texts_iter_1,
64 | progress_interval=self.progress_interval,
65 | total_num_texts=total_num_texts,
66 | return_generator=True,
67 | _step=self,
68 | **args,
69 | )
70 | )
71 |
72 | yield from zip(texts_iter_2, embeddings_iter)
73 |
74 | # Return embeddings
75 | return LazyRows(
76 | get_embeddings,
77 | total_num_rows=total_num_texts,
78 | auto_progress=False,
79 | save=(not lazy),
80 | )
81 |
82 |
83 | __all__ = ["Embed"]
84 |
--------------------------------------------------------------------------------
/src/steps/tasks/run_task_model.py:
--------------------------------------------------------------------------------
1 | from itertools import tee
2 |
3 | from ..data_card import DataCardType
4 | from ..step import Step
5 | from ..step_output import LazyRows
6 |
7 |
8 | class RunTaskModel(Step):
9 | "Runs a set of texts against a :py:class:`~datadreamer.task_models.TaskModel`."
10 |
11 | def setup(self):
12 | self.register_input("texts", help="The texts to process with the TaskModel.")
13 | self.register_arg("model", help="The TaskModel to use.")
14 | self.register_arg(
15 | "truncate",
16 | required=False,
17 | default=False,
18 | help="Whether or not to truncate inputs.",
19 | )
20 | self.register_arg(
21 | "lazy", required=False, default=False, help="Whether to run lazily or not."
22 | )
23 | self.register_arg(
24 | "**kwargs",
25 | required=False,
26 | help="Any other arguments you want to pass to the .run() method of the TaskModel.",
27 | )
28 | self.register_output("texts", help="The texts processed with the TaskModel.")
29 | self.register_output("results", help="The results from the TaskModel.")
30 |
31 | def run(self):
32 | args = self.args
33 |
34 | # Get inputs and arguments
35 | model = args.pop("model")
36 | lazy = args.pop("lazy")
37 |
38 | # Register trace info from the TaskModel model
39 | if hasattr(model, "model_name"):
40 | self.register_data_card(DataCardType.MODEL_NAME, model.model_name)
41 | self.register_data_card(DataCardType.MODEL_CARD, model.model_card)
42 | self.register_data_card(DataCardType.LICENSE, model.license)
43 | for citation in model.citation or []:
44 | self.register_data_card(DataCardType.CITATION, citation)
45 |
46 | # Get the total number of texts
47 | texts = self.inputs["texts"]
48 | total_num_texts = texts.num_rows
49 |
50 | # Define a function that yields results
51 | def get_results():
52 | # Get an iterator over texts
53 | texts_iter_1, texts_iter_2 = tee(iter(texts), 2)
54 |
55 | # Generate
56 | results_iter = iter(
57 | model.run(
58 | texts=texts_iter_1,
59 | progress_interval=self.progress_interval,
60 | total_num_texts=total_num_texts,
61 | return_generator=True,
62 | _step=self,
63 | **args,
64 | )
65 | )
66 |
67 | yield from zip(texts_iter_2, results_iter)
68 |
69 | # Return results
70 | return LazyRows(
71 | get_results,
72 | total_num_rows=total_num_texts,
73 | auto_progress=False,
74 | save=(not lazy),
75 | )
76 |
77 |
78 | __all__ = ["RunTaskModel"]
79 |
--------------------------------------------------------------------------------
/src/task_models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | :py:class:`TaskModel` objects help perform some sort of arbitrary NLP task
3 | (classification, etc.).
4 | All task models derive from the :py:class:`TaskModel` base class.
5 |
6 | .. tip::
7 |
8 | Instead of using :py:meth:`~TaskModel.run` directly, use a
9 | :py:class:`step ` that takes a :py:class:`TaskModel` as an
10 | ``args`` argument such as :py:class:`~datadreamer.steps.RunTaskModel`.
11 |
12 | Caching
13 | =======
14 | Task models internally perform caching to disk, so if you run the same text multiple
15 | times, the task model will only run once and then cache the results for future runs.
16 | """
17 |
18 | from .hf_classification_task_model import HFClassificationTaskModel
19 | from .parallel_task_model import ParallelTaskModel
20 | from .task_model import TaskModel
21 |
22 | __all__ = ["TaskModel", "HFClassificationTaskModel", "ParallelTaskModel"]
23 |
--------------------------------------------------------------------------------
/src/task_models/parallel_task_model.py:
--------------------------------------------------------------------------------
1 | from typing import Generator, Iterable, cast
2 |
3 | from .._cachable import _ParallelCachable
4 | from .task_model import DEFAULT_BATCH_SIZE, TaskModel
5 |
6 |
7 | class ParallelTaskModel(_ParallelCachable, TaskModel):
8 | def __init__(self, *task_models: TaskModel):
9 | """
10 | Creates a task model that will run multiple task models in parallel. See
11 | :doc:`running models in parallel
12 | <./pages/advanced_usage/parallelization/running_models_on_multiple_gpus>`
13 | for more details.
14 |
15 | Args:
16 | *task_models: The task models to run in parallel.
17 | """
18 | super().__init__(*task_models, cls=TaskModel)
19 | self.task_models = cast(list[TaskModel], self.cachables)
20 |
21 | def count_tokens(self, value: str) -> int:
22 | """Counts the number of tokens in a string.
23 |
24 | Args:
25 | value: The string to count tokens for.
26 |
27 | Returns:
28 | The number of tokens in the string.
29 | """
30 | pass
31 | return self.task_models[0].count_tokens(value=value)
32 |
33 | @property
34 | def model_max_length(self) -> int: # pragma: no cover
35 | return self.task_models[0].model_max_length
36 |
37 | def run( # type:ignore[override]
38 | self, texts: Iterable[str], *args, **kwargs
39 | ) -> Generator[str | list[str], None, None] | list[str | list[str]]:
40 | kwargs["batch_size"] = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE)
41 | results_generator = self._run_in_parallel(texts, *args, **kwargs)
42 | if not kwargs.get("return_generator", False):
43 | return list(results_generator)
44 | else:
45 | return results_generator
46 |
47 | def unload_model(self):
48 | for llm in self.task_models:
49 | llm.unload_model()
50 |
51 |
52 | __all__ = ["ParallelTaskModel"]
53 |
--------------------------------------------------------------------------------
/src/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/__init__.py
--------------------------------------------------------------------------------
/src/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | from .. import project
5 | from ..utils.fs_utils import clear_dir
6 |
7 |
8 | # Register pytest fixtures
9 | def refactor(string: str) -> str:
10 | return string.replace("/", ".").replace("\\", ".").replace(".py", "")
11 |
12 |
13 | pytest_plugins = [
14 | refactor(fixture)
15 | for fixture in glob("src/tests/test_utils/fixtures/*.py")
16 | if "__" not in fixture
17 | ]
18 |
19 | # Set the initial cwd
20 | project.INITIAL_CWD = os.path.abspath(os.getcwd())
21 |
22 | # Clear the tests data directory
23 | try:
24 | clear_dir("./.tests_data")
25 | except FileNotFoundError:
26 | pass
27 |
--------------------------------------------------------------------------------
/src/tests/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/datasets/__init__.py
--------------------------------------------------------------------------------
/src/tests/embedders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/embedders/__init__.py
--------------------------------------------------------------------------------
/src/tests/llms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/llms/__init__.py
--------------------------------------------------------------------------------
/src/tests/retrievers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/retrievers/__init__.py
--------------------------------------------------------------------------------
/src/tests/steps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/steps/__init__.py
--------------------------------------------------------------------------------
/src/tests/steps/prompt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/steps/prompt/__init__.py
--------------------------------------------------------------------------------
/src/tests/steps/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/steps/tasks/__init__.py
--------------------------------------------------------------------------------
/src/tests/task_models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/task_models/__init__.py
--------------------------------------------------------------------------------
/src/tests/test_cli.py:
--------------------------------------------------------------------------------
1 | from ..__cli__ import _main
2 |
3 |
4 | class TestCli:
5 | def test_help(self, cli_runner):
6 | result = cli_runner.invoke(_main, ["--help"])
7 | assert result.output.count("Show this message and exit.") == 1
8 |
--------------------------------------------------------------------------------
/src/tests/test_package.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 |
4 | import pytest
5 | from mypy import api
6 |
7 | from .. import __version__
8 | from ..project import RUNNING_IN_PYTEST
9 |
10 |
11 | class TestPackage:
12 | def test_version(self):
13 | assert len(__version__.split(".")) == 3
14 |
15 | def test_running_in_pytest(self):
16 | assert RUNNING_IN_PYTEST
17 |
18 | @pytest.mark.skipif(
19 | "GITHUB_ACTIONS" not in os.environ and "PROJECT_CLUSTER" not in os.environ,
20 | reason="only run on CI",
21 | )
22 | def test_python_version(self):
23 | with open("./scripts/.python-version", "r") as f:
24 | python_version = f.read().strip()
25 | assert python_version == platform.python_version()
26 |
27 | def test_mypy(self):
28 | result = api.run(["src/", "--sqlite-cache", "--explicit-package-bases"])
29 | if result[0]:
30 | print("\nType checking report:\n")
31 | print(result[0])
32 | if result[1]:
33 | print("\nError report:\n")
34 | print(result[1])
35 | assert result[2] == 0
36 |
--------------------------------------------------------------------------------
/src/tests/test_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/test_utils/__init__.py
--------------------------------------------------------------------------------
/src/tests/test_utils/config.py:
--------------------------------------------------------------------------------
1 | TEST_DIR = "./.tests_data"
2 |
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/test_utils/fixtures/__init__.py
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/bitsandbytes_fixture.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import pytest
4 |
5 | imported_banb = False
6 |
7 |
8 | @pytest.fixture(autouse=True)
9 | def reset_bandb_import():
10 | """We need to import banb as it throws a warning on first import"""
11 | global imported_banb
12 |
13 | # Code that will run before your test
14 | with warnings.catch_warnings():
15 | warnings.filterwarnings(
16 | "ignore",
17 | category=UserWarning,
18 | message="The installed version of bitsandbytes was compiled without GPU.*",
19 | module="bitsandbytes.cextension",
20 | )
21 | if not imported_banb:
22 | print(
23 | "\nDataDreamer test suite is importing bitsandbyes,"
24 | " ignore any warnings below this...\n"
25 | )
26 | import bitsandbytes # noqa: F401
27 |
28 | if not imported_banb:
29 | print("\nDataDreamer test suite is done importing bitsandbyes.\n")
30 |
31 | imported_banb = True
32 |
33 | yield
34 | # Code that will run after your test
35 |
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/clear_space.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | from ....utils.fs_utils import clear_dir
6 |
7 |
8 | @pytest.fixture(autouse=True)
9 | def clear_github_space():
10 | yield
11 | if "GITHUB_ACTIONS" in os.environ:
12 | # Clear the tests data directory to make more disk space available
13 | try:
14 | clear_dir("./.tests_data")
15 | os.system("rm -rf ~/.cache/huggingface/")
16 | except FileNotFoundError:
17 | pass
18 |
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/cli_runner.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from click.testing import CliRunner
3 |
4 |
5 | @pytest.fixture
6 | def cli_runner():
7 | return CliRunner()
8 |
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/create_datadreamer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | from typing import Callable
4 |
5 | import pytest
6 |
7 | from .... import DataDreamer
8 | from ..config import TEST_DIR
9 |
10 |
11 | @pytest.fixture
12 | def create_datadreamer() -> Callable[..., DataDreamer]:
13 | def _create_datadreamer(path: None | str = None, **kwargs) -> DataDreamer:
14 | if path is None:
15 | path = uuid.uuid4().hex[0:10]
16 | if path == ":memory:":
17 | return DataDreamer(path, **kwargs)
18 | else:
19 | return DataDreamer(os.path.join(TEST_DIR, path), **kwargs)
20 |
21 | return _create_datadreamer
22 |
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/create_test_step.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import pytest
4 |
5 | from ....steps import Step
6 | from ....steps.step import _INTERNAL_TEST_KEY
7 |
8 |
9 | @pytest.fixture
10 | def create_test_step() -> Callable[..., Step]:
11 | def _create_test_step(
12 | name="my-step",
13 | inputs=None,
14 | args=None,
15 | outputs=None,
16 | output_names=None,
17 | setup=None,
18 | **kwargs,
19 | ) -> Step:
20 | if output_names is None:
21 | output_names = []
22 |
23 | class TestStep(Step):
24 | def setup(self):
25 | if isinstance(output_names, str):
26 | self.register_output(output_names)
27 | else:
28 | for o in output_names:
29 | self.register_output(o)
30 | if setup is not None:
31 | setup(self)
32 |
33 | setattr(TestStep, _INTERNAL_TEST_KEY, True)
34 |
35 | return TestStep(name, inputs=inputs, args=args, outputs=outputs, **kwargs)
36 |
37 | return _create_test_step
38 |
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/mock_llm.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import pytest
4 |
5 | from ....llms import LLM
6 |
7 |
8 | @pytest.fixture
9 | def mock_llm(
10 | allowed_kwargs=frozenset(
11 | {
12 | "inputs",
13 | "batch_size",
14 | "max_new_tokens",
15 | "temperature",
16 | "top_p",
17 | "n",
18 | "stop",
19 | "repetition_penalty",
20 | "logit_bias",
21 | "seed",
22 | "max_length_func",
23 | "cached_tokenizer",
24 | }
25 | ),
26 | ) -> Callable[..., LLM]:
27 | def _mock_llm(llm: LLM, responses: dict[str, str]) -> LLM:
28 | def _run_batch_mocked(**kwargs):
29 | for kwarg in kwargs:
30 | assert kwarg in allowed_kwargs, f"LLM got unexpected keyword: {kwarg}"
31 | return [responses[prompt] for prompt in kwargs["inputs"]]
32 |
33 | llm._run_batch = _run_batch_mocked # type: ignore[attr-defined]
34 |
35 | return llm
36 |
37 | return _mock_llm
38 |
--------------------------------------------------------------------------------
/src/tests/test_utils/fixtures/restore_os_environ.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 |
6 | @pytest.fixture(autouse=True)
7 | def restore_os_environ():
8 | orig_environ = os.environ.copy()
9 | yield
10 | os.environ.clear()
11 | os.environ.update(orig_environ)
12 |
--------------------------------------------------------------------------------
/src/tests/trainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/trainers/__init__.py
--------------------------------------------------------------------------------
/src/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/tests/utils/__init__.py
--------------------------------------------------------------------------------
/src/trainers/_vendored/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/trainers/_vendored/__init__.py
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datadreamer-dev/DataDreamer/4d232497a17b4e5f7392c1b06e2f8a7ad289ecd4/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/arg_utils.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 |
4 | class Default:
5 | def __init__(self, name: str):
6 | self.name = name
7 |
8 | def __repr__(self): # pragma: no cover
9 | return self.name
10 |
11 |
12 | DEFAULT = Default("DEFAULT")
13 | AUTO = Default("AUTO")
14 |
15 | T = TypeVar("T")
16 |
17 |
18 | def default_to(val: T | Default, default_val: T) -> T:
19 | if isinstance(val, Default):
20 | return default_val
21 | else:
22 | return val
23 |
--------------------------------------------------------------------------------
/src/utils/collection_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Any, Iterable
3 |
4 |
5 | def uniq_str(collection: Iterable) -> str:
6 | seen = set()
7 | uniqed = tuple([x for x in collection if not (x in seen or seen.add(x))]) # type: ignore[func-returns-value]
8 | return re.sub(r",}$", "}", f"{{{str(uniqed)[1:-1]}}}") if uniqed else "{}"
9 |
10 |
11 | def sort_keys(
12 | d: dict[Any, Any], key_order: list[Any]
13 | ) -> dict[Any, Any]: # pragma: no cover
14 | d_copy = dict(d)
15 | all_keys = set(d_copy.keys())
16 | other_keys = all_keys.difference(set(key_order))
17 | d.clear()
18 | for key in key_order:
19 | if key in d_copy:
20 | d[key] = d_copy[key]
21 | for key in other_keys:
22 | d[key] = d_copy[key]
23 | return d
24 |
--------------------------------------------------------------------------------
/src/utils/fs_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from typing import Any
4 |
5 |
6 | def mkdir(path: str):
7 | try:
8 | os.makedirs(path, exist_ok=True)
9 | except FileExistsError: # pragma: no cover
10 | pass
11 |
12 |
13 | def safe_fn(value: str, allow_slashes=False, to_lower=False) -> str:
14 | if allow_slashes:
15 | value = value.replace(" / ", "/")
16 | value = value.replace(" ", "-")
17 | if not allow_slashes:
18 | value = value.replace("/", "-")
19 | safe_chars: Any = ("-", "_")
20 | if allow_slashes:
21 | safe_chars = ("-", "_", "/")
22 | strip_chars = "".join(c for c in value if c.isalnum() or c in safe_chars).strip()
23 | if to_lower:
24 | return strip_chars.lower()
25 | else:
26 | return strip_chars
27 |
28 |
29 | def rm_dir(path: str):
30 | if os.path.isdir(path):
31 | shutil.rmtree(path, ignore_errors=True)
32 |
33 |
34 | def clear_dir(path: str):
35 | if os.path.isdir(path):
36 | shutil.rmtree(path, ignore_errors=True)
37 | mkdir(path)
38 |
39 |
40 | def move_dir(src_path: str, dst_path: str):
41 | mkdir(src_path)
42 | clear_dir(dst_path)
43 | shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
44 | clear_dir(src_path)
45 |
46 |
47 | def dir_size(path: str) -> int: # pragma: no cover
48 | total_size = 0
49 | for dirpath, _, filenames in os.walk(path):
50 | for f in filenames:
51 | fp = os.path.join(dirpath, f)
52 | if not os.path.islink(fp):
53 | total_size += os.path.getsize(fp)
54 | return total_size
55 |
--------------------------------------------------------------------------------
/src/utils/ring_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import ring
4 |
5 |
6 | def lru(*args, **kwargs):
7 | if "SPHINX_BUILD" in os.environ: # pragma: no cover
8 |
9 | def noop_decorator(func):
10 | return func
11 |
12 | return noop_decorator
13 | else:
14 | return ring.lru(*args, **kwargs)
15 |
--------------------------------------------------------------------------------
/src/utils/str_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from itertools import chain
3 |
4 |
5 | def replace_many(text, substitions):
6 | pattern = re.compile("|".join(map(re.escape, substitions.keys())))
7 | return pattern.sub(lambda match: substitions[match.group(0)], text)
8 |
9 |
10 | def get_templated_var_names(templated_str):
11 | escaped_components = re.split(r"{{|}}", templated_str) # Ignore {{ and }}
12 | template_var_pattern = r"{([a-zA-Z0-9:_\.]+)}"
13 | return list(
14 | chain.from_iterable(
15 | [
16 | re.findall(template_var_pattern, component)
17 | for component in escaped_components
18 | ]
19 | )
20 | )
21 |
22 |
23 | def replace_templated_vars(templated_str, var_name_to_values):
24 | return replace_many(
25 | templated_str, {"{" + k + "}": str(v) for k, v in var_name_to_values.items()}
26 | )
27 |
--------------------------------------------------------------------------------
/src/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | from time import time
2 |
3 | TIME_DURATION_UNITS = (
4 | ("week", 60 * 60 * 24 * 7),
5 | ("day", 60 * 60 * 24),
6 | ("hour", 60 * 60),
7 | ("min", 60),
8 | ("sec", 1),
9 | )
10 |
11 |
12 | def human_time_duration(seconds: float) -> str: # pragma: no cover
13 | parts = []
14 | for unit, div in TIME_DURATION_UNITS:
15 | amount, seconds = divmod(int(seconds), div)
16 | if amount > 0:
17 | parts.append("{} {}{}".format(amount, unit, "" if amount == 1 else "s"))
18 | if len(parts) == 0:
19 | return "0 secs"
20 | return ", ".join(parts)
21 |
22 |
23 | def progress_eta(progress: float, start_time: float) -> str:
24 | elapsed_time = time() - start_time
25 | eta = (
26 | human_time_duration((elapsed_time / progress) - elapsed_time)
27 | if progress > 0
28 | else "calculating..."
29 | )
30 | return f"(Estimated time left: {eta})"
31 |
--------------------------------------------------------------------------------