├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ └── feature-request.yml └── workflows │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── code_quality.yml │ ├── delete_doc_comment.yml │ ├── docker.yml │ ├── ngc.yml │ ├── stale.yml │ ├── tests.yml │ └── upload_pr_documentation.yml ├── .gitignore ├── Dockerfile ├── Dockerfile.api ├── Dockerfile.app ├── LICENSE ├── Makefile ├── Manifest.in ├── README.md ├── colabs ├── AutoTrain.ipynb ├── AutoTrain_LLM.ipynb ├── AutoTrain_ngrok.ipynb └── image_classification.ipynb ├── configs ├── extractive_question_answering │ ├── hub_dataset.yml │ └── local_dataset.yml ├── image_classification │ ├── hub_dataset.yml │ └── local.yml ├── image_scoring │ ├── hub_dataset.yml │ ├── image_quality.yml │ └── local.yml ├── llm_finetuning │ ├── gpt2_sft.yml │ ├── llama3-70b-orpo-v1.yml │ ├── llama3-70b-sft.yml │ ├── llama3-8b-dpo-qlora.yml │ ├── llama3-8b-orpo-space.yml │ ├── llama3-8b-orpo.yml │ ├── llama3-8b-sft-unsloth.yml │ ├── llama32-1b-sft.yml │ ├── qwen.yml │ ├── smollm2.yml │ ├── smollm2_guanaco.yml │ └── smollm2_orpo.yml ├── object_detection │ ├── hub_dataset.yml │ └── local.yml ├── sentence_transformers │ ├── local_dataset.yml │ ├── pair.yml │ ├── pair_class.yml │ ├── pair_score.yml │ ├── qa.yml │ └── triplet.yml ├── seq2seq │ ├── hub_dataset.yml │ └── local.yml ├── text_classification │ ├── hub_dataset.yml │ └── local_dataset.yml ├── text_regression │ ├── hub_dataset.yml │ └── local_dataset.yml ├── token_classification │ ├── hub_dataset.yml │ └── local_dataset.yml └── vlm │ └── paligemma_vqa.yml ├── docs ├── README.md └── source │ ├── _toctree.yml │ ├── autotrain_api.mdx │ ├── col_map.mdx │ ├── config.mdx │ ├── cost.mdx │ ├── faq.mdx │ ├── getting_started.bck │ ├── index.mdx │ ├── quickstart.mdx │ ├── quickstart_py.mdx │ ├── quickstart_spaces.mdx │ ├── starting_ui.bck │ ├── support.mdx │ └── tasks │ ├── extractive_qa.mdx │ ├── image_classification_regression.mdx │ ├── llm_finetuning.mdx │ ├── object_detection.mdx │ ├── sentence_transformer.mdx │ ├── seq2seq.mdx │ ├── tabular.mdx │ ├── text_classification_regression.mdx │ └── token_classification.mdx ├── notebooks ├── llm_finetuning.ipynb ├── text_classification.ipynb └── text_regression.ipynb ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── autotrain │ ├── __init__.py │ ├── app │ ├── __init__.py │ ├── api_routes.py │ ├── app.py │ ├── colab.py │ ├── db.py │ ├── models.py │ ├── oauth.py │ ├── params.py │ ├── static │ │ ├── logo.png │ │ └── scripts │ │ │ ├── fetch_data_and_update_models.js │ │ │ ├── listeners.js │ │ │ ├── logs.js │ │ │ ├── poll.js │ │ │ └── utils.js │ ├── templates │ │ ├── duplicate.html │ │ ├── error.html │ │ ├── index.html │ │ └── login.html │ ├── training_api.py │ ├── ui_routes.py │ └── utils.py │ ├── backends │ ├── __init__.py │ ├── base.py │ ├── endpoints.py │ ├── local.py │ ├── ngc.py │ ├── nvcf.py │ └── spaces.py │ ├── cli │ ├── __init__.py │ ├── autotrain.py │ ├── run_api.py │ ├── run_app.py │ ├── run_extractive_qa.py │ ├── run_image_classification.py │ ├── run_image_regression.py │ ├── run_llm.py │ ├── run_object_detection.py │ ├── run_sent_tranformers.py │ ├── run_seq2seq.py │ ├── run_setup.py │ ├── run_spacerunner.py │ ├── run_tabular.py │ ├── run_text_classification.py │ ├── run_text_regression.py │ ├── run_token_classification.py │ ├── run_tools.py │ ├── run_vlm.py │ └── utils.py │ ├── client.py │ ├── commands.py │ ├── config.py │ ├── dataset.py │ ├── help.py │ ├── logging.py │ ├── params.py │ ├── parser.py │ ├── preprocessor │ ├── __init__.py │ ├── tabular.py │ ├── text.py │ ├── vision.py │ └── vlm.py │ ├── project.py │ ├── tasks.py │ ├── tests │ ├── test_cli.py │ └── test_dummy.py │ ├── tools │ ├── __init__.py │ ├── convert_to_kohya.py │ └── merge_adapter.py │ ├── trainers │ ├── __init__.py │ ├── clm │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── callbacks.py │ │ ├── params.py │ │ ├── train_clm_default.py │ │ ├── train_clm_dpo.py │ │ ├── train_clm_orpo.py │ │ ├── train_clm_reward.py │ │ ├── train_clm_sft.py │ │ └── utils.py │ ├── common.py │ ├── extractive_question_answering │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ ├── generic │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── params.py │ │ └── utils.py │ ├── image_classification │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ ├── image_regression │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ ├── object_detection │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ ├── sent_transformers │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── params.py │ │ └── utils.py │ ├── seq2seq │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ ├── tabular │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── params.py │ │ └── utils.py │ ├── text_classification │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ ├── text_regression │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ ├── token_classification │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ └── utils.py │ └── vlm │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── dataset.py │ │ ├── params.py │ │ ├── train_vlm_generic.py │ │ └── utils.py │ └── utils.py └── static ├── autotrain_homepage.png ├── autotrain_model_choice.png ├── autotrain_space.png ├── autotrain_text_classification.png ├── cost.png ├── dreambooth1.jpeg ├── dreambooth2.png ├── duplicate_space.png ├── ext_qa.png ├── hub_model_choice.png ├── image_classification_1.png ├── img_reg_ui.png ├── llm_1.png ├── llm_2.png ├── llm_3.png ├── llm_orpo_example.png ├── logo.png ├── model_choice_1.png ├── param_choice_1.png ├── param_choice_2.png ├── space_template_1.png ├── space_template_2.png ├── space_template_3.png ├── space_template_4.png ├── space_template_5.png ├── text_classification_1.png └── ui.png /.dockerignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | logs/ 4 | output/ 5 | output2/ 6 | test/ 7 | test.py 8 | .DS_Store 9 | .vscode/ 10 | op* 11 | op_* 12 | .git 13 | *.db 14 | autotrain-data* 15 | autotrain-* -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: 🐛 Submit a bug report to help us improve AutoTrain 3 | title: "[BUG]" 4 | labels: ["bug"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to fill out this bug report! Before you proceed, please make sure you've checked the documentation and previous issues for similar problems. 10 | Always remember to hide any sensitive information such as API keys, passwords and tokens. 11 | 12 | - type: checkboxes 13 | id: read-docs 14 | attributes: 15 | label: Prerequisites 16 | description: Confirm before submitting the bug report. 17 | options: 18 | - label: I have read the [documentation](https://hf.co/docs/autotrain). 19 | required: true 20 | - label: I have checked other issues for similar problems. 21 | required: true 22 | 23 | - type: dropdown 24 | id: backend 25 | attributes: 26 | label: Backend 27 | description: Which backend are you using? 28 | options: 29 | - Local 30 | - Colab 31 | - Hugging Face Space/Endpoints 32 | - Other cloud providers 33 | validations: 34 | required: true 35 | 36 | - type: dropdown 37 | id: interface 38 | attributes: 39 | label: Interface Used 40 | description: Are you using the CLI or the UI? 41 | options: 42 | - CLI 43 | - UI 44 | validations: 45 | required: true 46 | 47 | - type: textarea 48 | id: cli-command 49 | attributes: 50 | label: CLI Command 51 | description: If you're using the CLI, please provide the full command you ran. 52 | validations: 53 | required: false 54 | 55 | - type: textarea 56 | id: ui-params 57 | attributes: 58 | label: UI Screenshots & Parameters 59 | description: If using the UI, add a screenshot of the UI & copy-paste the parameters you used. 60 | validations: 61 | required: false 62 | 63 | - type: textarea 64 | id: stacktrace 65 | attributes: 66 | label: Error Logs 67 | 68 | description: Please provide a stack trace or detailed error message, if available. 69 | placeholder: Paste the stack trace here 70 | validations: 71 | required: true 72 | 73 | - type: textarea 74 | id: additional-info 75 | attributes: 76 | label: Additional Information 77 | description: Any additional information or context about the problem. 78 | validations: 79 | required: false 80 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | version: 2.1 3 | contact_links: 4 | - name: AutoTrain Documentation 5 | url: https://huggingface.co/docs/autotrain 6 | about: Getting started and FAQs 7 | - name: AutoTrain Discussions 8 | url: https://github.com/huggingface/autotrain-advanced/discussions 9 | about: General usage questions and community discussions -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: 🚀 Submit a proposal/request for a new AutoTrain feature 3 | title: "[FEATURE REQUEST]" 4 | labels: ["feature request"] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thanks for taking the time to submit a feature request! Please provide as much detail as possible. 10 | 11 | - type: textarea 12 | id: feature-request 13 | attributes: 14 | label: Feature Request 15 | description: A clear and concise description of the feature proposal. Please provide a link to the paper and code if they exist. 16 | validations: 17 | required: true 18 | 19 | - type: textarea 20 | id: motivation 21 | attributes: 22 | label: Motivation 23 | description: A clear and concise description of what you want to happen. 24 | validations: 25 | required: true 26 | 27 | - type: textarea 28 | id: additional-context 29 | attributes: 30 | label: Additional Context 31 | description: Add any other context, details, or screenshots about the feature request here. 32 | validations: 33 | required: false 34 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: autotrain-advanced 16 | package_name: autotrain 17 | secrets: 18 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 19 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.event.pull_request.head.sha }} 15 | pr_number: ${{ github.event.number }} 16 | package: autotrain-advanced 17 | package_name: autotrain 18 | -------------------------------------------------------------------------------- /.github/workflows/code_quality.yml: -------------------------------------------------------------------------------- 1 | name: Code quality 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | release: 11 | types: 12 | - created 13 | 14 | jobs: 15 | check_code_quality: 16 | name: Check code quality 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python 3.9 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: 3.9 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | python -m pip install flake8 black isort 28 | - name: Make quality 29 | run: | 30 | make quality 31 | -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Delete doc comment trigger"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | delete: 11 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main 12 | secrets: 13 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Docker Hub 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | dockerhub: 10 | name: Docker Hub 11 | runs-on: 12 | group: 'aws-general-8-plus' 13 | steps: 14 | - name: Check out the repo 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Docker Buildx 18 | uses: docker/setup-buildx-action@v1 19 | 20 | - name: Log in to Docker Hub 21 | uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 22 | with: 23 | username: ${{ secrets.DOCKERHUB_USERNAME }} 24 | password: ${{ secrets.DOCKERHUB_PASSWORD }} 25 | 26 | - name: Set short git commit SHA 27 | id: vars 28 | run: | 29 | sha=$(git rev-parse --short ${{ github.sha }}) 30 | echo "SHA=$sha" >> $GITHUB_ENV 31 | 32 | - name: SHA 33 | run: echo ${{ env.SHA }} 34 | 35 | - name: Build and Push Docker Image 36 | run: | 37 | docker build -t autotrain-advanced:latest . 38 | docker tag autotrain-advanced:latest huggingface/autotrain-advanced:latest 39 | docker tag autotrain-advanced:latest huggingface/autotrain-advanced:${{ env.SHA }} 40 | docker push huggingface/autotrain-advanced:latest 41 | docker push huggingface/autotrain-advanced:${{ env.SHA }} 42 | -------------------------------------------------------------------------------- /.github/workflows/ngc.yml: -------------------------------------------------------------------------------- 1 | name: NGC Build & Push 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | dockerhub: 10 | name: NGC 11 | runs-on: 12 | group: 'aws-general-8-plus' 13 | steps: 14 | - name: Check out the repo 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Docker Buildx 18 | uses: docker/setup-buildx-action@v1 19 | 20 | - name: Log in to NGC 21 | uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 22 | with: 23 | registry: nvcr.io 24 | username: ${{ secrets.NVCR_USERNAME }} 25 | password: ${{ secrets.NVCR_PASSWORD }} 26 | 27 | - name: Set short git commit SHA 28 | id: vars 29 | run: | 30 | sha=$(git rev-parse --short ${{ github.sha }}) 31 | echo "SHA=$sha" >> $GITHUB_ENV 32 | 33 | - name: SHA 34 | run: echo ${{ env.SHA }} 35 | 36 | 37 | - name: Build and Push NGC Image 38 | run: | 39 | docker build -t autotrain-advanced:latest . 40 | docker tag autotrain-advanced:latest nvcr.io/ycymhzotssoi/autotrain-advanced:latest 41 | docker tag autotrain-advanced:latest nvcr.io/ycymhzotssoi/autotrain-advanced:${{ env.SHA }} 42 | docker push nvcr.io/ycymhzotssoi/autotrain-advanced:latest 43 | docker push nvcr.io/ycymhzotssoi/autotrain-advanced:${{ env.SHA }} 44 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "0 15 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v5 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 20 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 20 days since being marked as stale." 20 | days-before-pr-stale: 30 21 | days-before-pr-close: 20 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | release: 11 | types: 12 | - created 13 | 14 | jobs: 15 | tests: 16 | name: Run unit tests 17 | runs-on: 18 | group: aws-g6-4xlarge-plus 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python 3.9 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.9 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install .[dev] 29 | - name: Make test 30 | run: | 31 | make test 32 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: autotrain 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Local stuff 2 | .DS_Store 3 | .vscode/ 4 | test/ 5 | test.py 6 | output/ 7 | output2/ 8 | logs/ 9 | op_*/ 10 | autotrain.db 11 | autotrain.log 12 | *.db 13 | data/ 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | autotrain-data* 19 | autotrain-* 20 | op-* 21 | # C extensions 22 | *.so 23 | test.yml 24 | test.ipynb 25 | output.png 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive \ 4 | TZ=UTC \ 5 | HF_HUB_ENABLE_HF_TRANSFER=1 6 | 7 | ENV PATH="${HOME}/miniconda3/bin:${PATH}" 8 | ARG PATH="${HOME}/miniconda3/bin:${PATH}" 9 | ENV PATH="/app/ngc-cli:${PATH}" 10 | ARG PATH="/app/ngc-cli:${PATH}" 11 | 12 | RUN mkdir -p /tmp/model && \ 13 | chown -R 1000:1000 /tmp/model && \ 14 | mkdir -p /tmp/data && \ 15 | chown -R 1000:1000 /tmp/data 16 | 17 | RUN apt-get update && \ 18 | apt-get upgrade -y && \ 19 | apt-get install -y \ 20 | build-essential \ 21 | cmake \ 22 | curl \ 23 | ca-certificates \ 24 | gcc \ 25 | git \ 26 | locales \ 27 | net-tools \ 28 | wget \ 29 | libpq-dev \ 30 | libsndfile1-dev \ 31 | git \ 32 | git-lfs \ 33 | libgl1 \ 34 | unzip \ 35 | libjpeg-dev \ 36 | libpng-dev \ 37 | libgomp1 \ 38 | && rm -rf /var/lib/apt/lists/* && \ 39 | apt-get clean 40 | 41 | 42 | RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \ 43 | git lfs install 44 | 45 | WORKDIR /app 46 | RUN mkdir -p /app/.cache 47 | ENV HF_HOME="/app/.cache" 48 | RUN useradd -m -u 1000 user 49 | RUN chown -R user:user /app 50 | USER user 51 | ENV HOME=/app 52 | 53 | ENV PYTHONPATH=$HOME/app \ 54 | PYTHONUNBUFFERED=1 \ 55 | GRADIO_ALLOW_FLAGGING=never \ 56 | GRADIO_NUM_PORTS=1 \ 57 | GRADIO_SERVER_NAME=0.0.0.0 \ 58 | SYSTEM=spaces 59 | 60 | 61 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 62 | && sh Miniconda3-latest-Linux-x86_64.sh -b -p /app/miniconda \ 63 | && rm -f Miniconda3-latest-Linux-x86_64.sh 64 | ENV PATH /app/miniconda/bin:$PATH 65 | 66 | RUN conda create -p /app/env -y python=3.10 67 | 68 | SHELL ["conda", "run","--no-capture-output", "-p","/app/env", "/bin/bash", "-c"] 69 | 70 | RUN conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia && \ 71 | conda clean -ya && \ 72 | conda install -c "nvidia/label/cuda-12.1.1" cuda-nvcc && conda clean -ya && \ 73 | conda install xformers -c xformers && conda clean -ya 74 | 75 | COPY --chown=1000:1000 . /app/ 76 | 77 | RUN pip install -e . && \ 78 | python -m nltk.downloader punkt && \ 79 | pip install -U ninja && \ 80 | pip install -U flash-attn --no-build-isolation && \ 81 | pip install -U deepspeed && \ 82 | pip install --upgrade --force-reinstall --no-cache-dir "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git" --no-deps && \ 83 | pip cache purge 84 | -------------------------------------------------------------------------------- /Dockerfile.api: -------------------------------------------------------------------------------- 1 | FROM huggingface/autotrain-advanced:latest 2 | 3 | CMD autotrain api --port 7860 --host 0.0.0.0 -------------------------------------------------------------------------------- /Dockerfile.app: -------------------------------------------------------------------------------- 1 | FROM huggingface/autotrain-advanced:latest 2 | CMD uvicorn autotrain.app:app --host 0.0.0.0 --port 7860 --reload --workers 4 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style test 2 | 3 | # Check that source code meets quality standards 4 | 5 | quality: 6 | black --check --line-length 119 --target-version py38 . 7 | isort --check-only . 8 | flake8 --max-line-length 119 9 | 10 | # Format source code automatically 11 | 12 | style: 13 | black --line-length 119 --target-version py38 . 14 | isort . 15 | 16 | test: 17 | pytest -sv ./src/ 18 | 19 | docker: 20 | docker build -t autotrain-advanced:latest . 21 | docker tag autotrain-advanced:latest huggingface/autotrain-advanced:latest 22 | docker push huggingface/autotrain-advanced:latest 23 | 24 | api: 25 | docker build -t autotrain-advanced-api:latest -f Dockerfile.api . 26 | docker tag autotrain-advanced-api:latest public.ecr.aws/z4c3o6n6/autotrain-api:latest 27 | docker push public.ecr.aws/z4c3o6n6/autotrain-api:latest 28 | 29 | ngc: 30 | docker build -t autotrain-advanced:latest . 31 | docker tag autotrain-advanced:latest nvcr.io/ycymhzotssoi/autotrain-advanced:latest 32 | docker push nvcr.io/ycymhzotssoi/autotrain-advanced:latest 33 | 34 | pip: 35 | rm -rf build/ 36 | rm -rf dist/ 37 | make style && make quality 38 | python setup.py sdist bdist_wheel 39 | twine upload dist/* --verbose --repository autotrain-advanced -------------------------------------------------------------------------------- /Manifest.in: -------------------------------------------------------------------------------- 1 | recursive-include src/autotrain/static * 2 | recursive-include src/autotrain/templates * -------------------------------------------------------------------------------- /colabs/AutoTrain.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"AutoTrain\"\n", 8 | "\n", 9 | "- Attach proper hardware\n", 10 | "- Click Runtime > Run all\n", 11 | "- Read the [docs](https://hf.co/docs/autotrain) for data format, parameters and other questions\n", 12 | "- GitHub Repo: https://github.com/huggingface/autotrain-advanced" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "!pip install -U autotrain-advanced > install_logs.txt 2>&1\n", 22 | "from IPython.display import display\n", 23 | "from autotrain.app.colab import colab_app\n", 24 | "elements = colab_app()\n", 25 | "display(elements)" 26 | ] 27 | } 28 | ], 29 | "metadata": { 30 | "kernelspec": { 31 | "display_name": "autotrain", 32 | "language": "python", 33 | "name": "python3" 34 | }, 35 | "language_info": { 36 | "codemirror_mode": { 37 | "name": "ipython", 38 | "version": 3 39 | }, 40 | "file_extension": ".py", 41 | "mimetype": "text/x-python", 42 | "name": "python", 43 | "nbconvert_exporter": "python", 44 | "pygments_lexer": "ipython3", 45 | "version": "3.1.-1" 46 | } 47 | }, 48 | "nbformat": 4, 49 | "nbformat_minor": 2 50 | } 51 | -------------------------------------------------------------------------------- /colabs/AutoTrain_ngrok.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "cellView": "form", 8 | "id": "II6F7ThkI10I" 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "#@title 🤗 AutoTrain\n", 13 | "#@markdown In order to use this colab\n", 14 | "#@markdown - Enter your [Hugging Face Write Token](https://huggingface.co/settings/tokens)\n", 15 | "#@markdown - Enter your [ngrok auth token](https://dashboard.ngrok.com/get-started/your-authtoken)\n", 16 | "huggingface_token = '' # @param {type:\"string\"}\n", 17 | "ngrok_token = \"\" # @param {type:\"string\"}\n", 18 | "\n", 19 | "#@markdown\n", 20 | "#@markdown - Attach appropriate accelerator `Runtime > Change runtime type > Hardware accelerator`\n", 21 | "#@markdown - click `Runtime > Run all`\n", 22 | "#@markdown - Follow the link to access the UI\n", 23 | "#@markdown - Training happens inside this Google Colab\n", 24 | "#@markdown - report issues / feature requests [here](https://github.com/huggingface/autotrain-advanced/issues)\n", 25 | "\n", 26 | "import os\n", 27 | "os.environ[\"HF_TOKEN\"] = str(huggingface_token)\n", 28 | "os.environ[\"NGROK_AUTH_TOKEN\"] = str(ngrok_token)\n", 29 | "os.environ[\"AUTOTRAIN_LOCAL\"] = \"1\"\n", 30 | "\n", 31 | "!pip install -U autotrain-advanced > install_logs.txt 2>&1\n", 32 | "!autotrain app --share" 33 | ] 34 | } 35 | ], 36 | "metadata": { 37 | "accelerator": "GPU", 38 | "colab": { 39 | "gpuType": "T4", 40 | "provenance": [] 41 | }, 42 | "kernelspec": { 43 | "display_name": "Python 3", 44 | "name": "python3" 45 | }, 46 | "language_info": { 47 | "name": "python" 48 | } 49 | }, 50 | "nbformat": 4, 51 | "nbformat_minor": 0 52 | } 53 | -------------------------------------------------------------------------------- /colabs/image_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%%writefile config.yml\n", 10 | "task: image_classification # do not change\n", 11 | "base_model: google/vit-base-patch16-224 # the model to be used from hugging face hub\n", 12 | "project_name: autotrain-image-classification-model # the name of the project, must be unique\n", 13 | "log: tensorboard # do not change\n", 14 | "backend: local # do not change\n", 15 | "\n", 16 | "data:\n", 17 | " path: data/ # the path to the data folder, this folder consists of `train` and `valid` (if any) folders\n", 18 | " train_split: train # this folder inside data/ will be used for training, it contains the images in subfolders.\n", 19 | " valid_split: null # this folder inside data/ will be used for validation, it contains the images in subfolders. If not available, set it to null\n", 20 | " column_mapping: # do not change\n", 21 | " image_column: image\n", 22 | " target_column: labels\n", 23 | "\n", 24 | "params:\n", 25 | " epochs: 2\n", 26 | " batch_size: 4\n", 27 | " lr: 2e-5\n", 28 | " optimizer: adamw_torch\n", 29 | " scheduler: linear\n", 30 | " gradient_accumulation: 1\n", 31 | " mixed_precision: fp16\n", 32 | "\n", 33 | "hub:\n", 34 | " username: ${HF_USERNAME} # please set HF_USERNAME in colab secrets\n", 35 | " token: ${HF_TOKEN} # please set HF_TOKEN in colab secrets, must be valid hugging face write token\n", 36 | " push_to_hub: true # set to true if you want to push the model to the hub" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "import os\n", 46 | "from google.colab import userdata\n", 47 | "HF_USERNAME = userdata.get('HF_USERNAME')\n", 48 | "HF_TOKEN = userdata.get('HF_TOKEN')\n", 49 | "os.environ['HF_USERNAME'] = HF_USERNAME\n", 50 | "\n", 51 | "os.environ['HF_TOKEN'] = HF_TOKEN\n", 52 | "!autotrain --config config.yml" 53 | ] 54 | } 55 | ], 56 | "metadata": { 57 | "language_info": { 58 | "name": "python" 59 | } 60 | }, 61 | "nbformat": 4, 62 | "nbformat_minor": 2 63 | } 64 | -------------------------------------------------------------------------------- /configs/extractive_question_answering/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: extractive-qa 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-ex-qa1 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: lhoestq/squad 9 | train_split: train 10 | valid_split: validation 11 | column_mapping: 12 | text_column: context 13 | question_column: question 14 | answer_column: answers 15 | 16 | params: 17 | max_seq_length: 512 18 | max_doc_stride: 128 19 | epochs: 3 20 | batch_size: 4 21 | lr: 2e-5 22 | optimizer: adamw_torch 23 | scheduler: linear 24 | gradient_accumulation: 1 25 | mixed_precision: fp16 26 | 27 | hub: 28 | username: ${HF_USERNAME} 29 | token: ${HF_TOKEN} 30 | push_to_hub: true -------------------------------------------------------------------------------- /configs/extractive_question_answering/local_dataset.yml: -------------------------------------------------------------------------------- 1 | task: extractive-qa 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-ex-qa2 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: data/ # this must be the path to the directory containing the train and valid files 9 | train_split: train # this must be either train.csv or train.json 10 | valid_split: valid # this must be either valid.csv or valid.json 11 | column_mapping: 12 | text_column: context 13 | question_column: question 14 | answer_column: answers 15 | 16 | params: 17 | max_seq_length: 512 18 | max_doc_stride: 128 19 | epochs: 3 20 | batch_size: 4 21 | lr: 2e-5 22 | optimizer: adamw_torch 23 | scheduler: linear 24 | gradient_accumulation: 1 25 | mixed_precision: fp16 26 | 27 | hub: 28 | username: ${HF_USERNAME} 29 | token: ${HF_TOKEN} 30 | push_to_hub: true -------------------------------------------------------------------------------- /configs/image_classification/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: image_classification 2 | base_model: google/vit-base-patch16-224 3 | project_name: autotrain-cats-vs-dogs-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: cats_vs_dogs 9 | train_split: train 10 | valid_split: null 11 | column_mapping: 12 | image_column: image 13 | target_column: labels 14 | 15 | params: 16 | epochs: 2 17 | batch_size: 4 18 | lr: 2e-5 19 | optimizer: adamw_torch 20 | scheduler: linear 21 | gradient_accumulation: 1 22 | mixed_precision: fp16 23 | 24 | hub: 25 | username: ${HF_USERNAME} 26 | token: ${HF_TOKEN} 27 | push_to_hub: true -------------------------------------------------------------------------------- /configs/image_classification/local.yml: -------------------------------------------------------------------------------- 1 | task: image_classification 2 | base_model: google/vit-base-patch16-224 3 | project_name: autotrain-image-classification-model 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: data/ 9 | train_split: train # this folder inside data/ will be used for training, it contains the images in subfolders. 10 | valid_split: null 11 | column_mapping: 12 | image_column: image 13 | target_column: label 14 | 15 | params: 16 | epochs: 2 17 | batch_size: 4 18 | lr: 2e-5 19 | optimizer: adamw_torch 20 | scheduler: linear 21 | gradient_accumulation: 1 22 | mixed_precision: fp16 23 | 24 | hub: 25 | username: ${HF_USERNAME} 26 | token: ${HF_TOKEN} 27 | push_to_hub: true -------------------------------------------------------------------------------- /configs/image_scoring/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: image_regression 2 | base_model: google/vit-base-patch16-224 3 | project_name: autotrain-cats-vs-dogs-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: cats_vs_dogs 9 | train_split: train 10 | valid_split: null 11 | column_mapping: 12 | image_column: image 13 | target_column: labels 14 | 15 | params: 16 | epochs: 2 17 | batch_size: 4 18 | lr: 2e-5 19 | optimizer: adamw_torch 20 | scheduler: linear 21 | gradient_accumulation: 1 22 | mixed_precision: fp16 23 | 24 | hub: 25 | username: ${HF_USERNAME} 26 | token: ${HF_TOKEN} 27 | push_to_hub: true -------------------------------------------------------------------------------- /configs/image_scoring/image_quality.yml: -------------------------------------------------------------------------------- 1 | task: image_regression 2 | base_model: microsoft/resnet-50 3 | project_name: autotrain-img-quality-resnet50 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: abhishek/img-quality-full 9 | train_split: train 10 | valid_split: null 11 | column_mapping: 12 | image_column: image 13 | target_column: target 14 | 15 | params: 16 | epochs: 10 17 | batch_size: 8 18 | lr: 2e-3 19 | optimizer: adamw_torch 20 | scheduler: cosine 21 | gradient_accumulation: 1 22 | mixed_precision: fp16 23 | 24 | hub: 25 | username: ${HF_USERNAME} 26 | token: ${HF_TOKEN} 27 | push_to_hub: true -------------------------------------------------------------------------------- /configs/image_scoring/local.yml: -------------------------------------------------------------------------------- 1 | task: image_regression 2 | base_model: google/vit-base-patch16-224 3 | project_name: autotrain-image-regression-model 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: data/ 9 | train_split: train # this folder inside data/ will be used for training, it contains the images and metadata.jsonl 10 | valid_split: valid # this folder inside data/ will be used for validation, it contains the images and metadata.jsonl. can be set to null 11 | # column mapping should not be changed for local datasets 12 | column_mapping: 13 | image_column: image 14 | target_column: target 15 | 16 | params: 17 | epochs: 2 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/llm_finetuning/gpt2_sft.yml: -------------------------------------------------------------------------------- 1 | task: llm-sft 2 | base_model: openai-community/gpt2 3 | project_name: autotrain-gpt2-finetuned-guanaco 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: timdettmers/openassistant-guanaco 9 | train_split: train 10 | valid_split: null 11 | chat_template: null 12 | column_mapping: 13 | text_column: text 14 | 15 | params: 16 | block_size: 1024 17 | model_max_length: 2048 18 | max_prompt_length: 512 19 | epochs: 3 20 | batch_size: 2 21 | lr: 3e-5 22 | padding: right 23 | optimizer: adamw_torch 24 | scheduler: linear 25 | gradient_accumulation: 4 26 | mixed_precision: fp16 27 | merge_adapter: true 28 | 29 | hub: 30 | username: ${HF_USERNAME} 31 | token: ${HF_TOKEN} 32 | push_to_hub: false -------------------------------------------------------------------------------- /configs/llm_finetuning/llama3-70b-orpo-v1.yml: -------------------------------------------------------------------------------- 1 | task: llm-orpo 2 | base_model: meta-llama/Meta-Llama-3-70B-Instruct 3 | project_name: autotrain-llama3-70b-orpo-v1 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: argilla/distilabel-capybara-dpo-7k-binarized 9 | train_split: train 10 | valid_split: valid 11 | chat_template: chatml 12 | column_mapping: 13 | text_column: chosen 14 | rejected_text_column: rejected 15 | prompt_text_column: prompt 16 | 17 | params: 18 | block_size: 2048 19 | model_max_length: 8192 20 | max_prompt_length: 1024 21 | epochs: 3 22 | batch_size: 1 23 | lr: 1e-5 24 | peft: true 25 | quantization: null 26 | target_modules: all-linear 27 | padding: right 28 | optimizer: paged_adamw_8bit 29 | scheduler: linear 30 | gradient_accumulation: 4 31 | mixed_precision: bf16 32 | 33 | hub: 34 | username: ${HF_USERNAME} 35 | token: ${HF_TOKEN} 36 | push_to_hub: true -------------------------------------------------------------------------------- /configs/llm_finetuning/llama3-70b-sft.yml: -------------------------------------------------------------------------------- 1 | task: llm-sft 2 | base_model: meta-llama/Meta-Llama-3-70B-Instruct 3 | project_name: autotrain-llama3-70b-math-v1 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: rishiraj/guanaco-style-metamath-40k 9 | train_split: train 10 | valid_split: null 11 | chat_template: null 12 | column_mapping: 13 | text_column: text 14 | 15 | params: 16 | block_size: 2048 17 | model_max_length: 8192 18 | epochs: 2 19 | batch_size: 1 20 | lr: 1e-5 21 | peft: true 22 | quantization: null 23 | target_modules: all-linear 24 | padding: right 25 | optimizer: paged_adamw_8bit 26 | scheduler: linear 27 | gradient_accumulation: 8 28 | mixed_precision: bf16 29 | 30 | hub: 31 | username: ${HF_USERNAME} 32 | token: ${HF_TOKEN} 33 | push_to_hub: true -------------------------------------------------------------------------------- /configs/llm_finetuning/llama3-8b-dpo-qlora.yml: -------------------------------------------------------------------------------- 1 | task: llm-dpo 2 | base_model: meta-llama/Meta-Llama-3-8B-Instruct 3 | project_name: autotrain-llama3-8b-dpo-qlora 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: mlabonne/orpo-dpo-mix-40k 9 | train_split: train 10 | valid_split: null 11 | chat_template: chatml 12 | column_mapping: 13 | text_column: chosen 14 | rejected_text_column: rejected 15 | prompt_text_column: prompt 16 | 17 | params: 18 | block_size: 1024 19 | model_max_length: 2048 20 | max_prompt_length: 512 21 | epochs: 3 22 | batch_size: 2 23 | lr: 3e-5 24 | peft: true 25 | quantization: int4 26 | target_modules: all-linear 27 | padding: right 28 | optimizer: adamw_torch 29 | scheduler: linear 30 | gradient_accumulation: 4 31 | mixed_precision: fp16 32 | 33 | hub: 34 | username: ${HF_USERNAME} 35 | token: ${HF_TOKEN} 36 | push_to_hub: false -------------------------------------------------------------------------------- /configs/llm_finetuning/llama3-8b-orpo-space.yml: -------------------------------------------------------------------------------- 1 | task: llm-orpo 2 | base_model: meta-llama/Meta-Llama-3-8B-Instruct 3 | project_name: autotrain-llama3-8b-orpo-t1 4 | log: tensorboard 5 | backend: spaces-a10g-largex4 6 | 7 | data: 8 | path: argilla/distilabel-capybara-dpo-7k-binarized 9 | train_split: train 10 | valid_split: null 11 | chat_template: chatml 12 | column_mapping: 13 | text_column: chosen 14 | rejected_text_column: rejected 15 | prompt_text_column: prompt 16 | 17 | params: 18 | block_size: 1024 19 | model_max_length: 8192 20 | max_prompt_length: 512 21 | epochs: 3 22 | batch_size: 2 23 | lr: 3e-5 24 | peft: true 25 | quantization: int4 26 | target_modules: all-linear 27 | padding: right 28 | optimizer: adamw_torch 29 | scheduler: linear 30 | gradient_accumulation: 4 31 | mixed_precision: fp16 32 | 33 | hub: 34 | username: ${HF_USERNAME} 35 | token: ${HF_TOKEN} 36 | push_to_hub: true -------------------------------------------------------------------------------- /configs/llm_finetuning/llama3-8b-orpo.yml: -------------------------------------------------------------------------------- 1 | task: llm-orpo 2 | base_model: meta-llama/Meta-Llama-3-8B-Instruct 3 | project_name: autotrain-llama3-8b-orpo 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: argilla/distilabel-capybara-dpo-7k-binarized 9 | train_split: train 10 | valid_split: null 11 | chat_template: chatml 12 | column_mapping: 13 | text_column: chosen 14 | rejected_text_column: rejected 15 | prompt_text_column: prompt 16 | 17 | params: 18 | block_size: 1024 19 | model_max_length: 8192 20 | max_prompt_length: 512 21 | epochs: 3 22 | batch_size: 2 23 | lr: 3e-5 24 | peft: true 25 | quantization: int4 26 | target_modules: all-linear 27 | padding: right 28 | optimizer: adamw_torch 29 | scheduler: linear 30 | gradient_accumulation: 4 31 | mixed_precision: fp16 32 | 33 | hub: 34 | username: ${HF_USERNAME} 35 | token: ${HF_TOKEN} 36 | push_to_hub: true -------------------------------------------------------------------------------- /configs/llm_finetuning/llama3-8b-sft-unsloth.yml: -------------------------------------------------------------------------------- 1 | task: llm-sft 2 | base_model: meta-llama/Meta-Llama-3-8B-Instruct 3 | project_name: autotrain-llama3-8b-sft-unsloth 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: rishiraj/guanaco-style-metamath-40k 9 | train_split: train 10 | valid_split: null 11 | chat_template: null 12 | column_mapping: 13 | text_column: text 14 | 15 | params: 16 | block_size: 1024 17 | model_max_length: 8192 18 | max_prompt_length: 512 19 | epochs: 3 20 | batch_size: 2 21 | lr: 3e-5 22 | peft: true 23 | quantization: int4 24 | target_modules: all-linear 25 | padding: right 26 | optimizer: adamw_torch 27 | scheduler: linear 28 | gradient_accumulation: 4 29 | mixed_precision: fp16 30 | unsloth: true 31 | lora_dropout: 0 32 | 33 | hub: 34 | username: ${HF_USERNAME} 35 | token: ${HF_TOKEN} 36 | push_to_hub: true -------------------------------------------------------------------------------- /configs/llm_finetuning/llama32-1b-sft.yml: -------------------------------------------------------------------------------- 1 | task: llm-sft 2 | base_model: meta-llama/Llama-3.2-1B 3 | project_name: autotrain-llama32-1b-finetune 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: HuggingFaceH4/no_robots 9 | train_split: train 10 | valid_split: null 11 | chat_template: tokenizer 12 | column_mapping: 13 | text_column: messages 14 | 15 | params: 16 | block_size: 2048 17 | model_max_length: 4096 18 | epochs: 2 19 | batch_size: 1 20 | lr: 1e-5 21 | peft: true 22 | quantization: int4 23 | target_modules: all-linear 24 | padding: right 25 | optimizer: paged_adamw_8bit 26 | scheduler: linear 27 | gradient_accumulation: 8 28 | mixed_precision: bf16 29 | merge_adapter: true 30 | 31 | hub: 32 | username: ${HF_USERNAME} 33 | token: ${HF_TOKEN} 34 | push_to_hub: true 35 | -------------------------------------------------------------------------------- /configs/llm_finetuning/qwen.yml: -------------------------------------------------------------------------------- 1 | task: llm-sft 2 | base_model: Qwen/Qwen2.5-Coder-7B-Instruct 3 | project_name: autotrain-qwen-finetune 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: HuggingFaceH4/no_robots 9 | train_split: test 10 | valid_split: null 11 | chat_template: tokenizer 12 | column_mapping: 13 | text_column: messages 14 | 15 | params: 16 | block_size: 2048 17 | model_max_length: 4096 18 | epochs: 1 19 | batch_size: 1 20 | lr: 1e-5 21 | peft: true 22 | quantization: int4 23 | target_modules: all-linear 24 | padding: right 25 | optimizer: adamw_torch 26 | scheduler: linear 27 | gradient_accumulation: 1 28 | mixed_precision: fp16 29 | merge_adapter: true 30 | 31 | hub: 32 | username: ${HF_USERNAME} 33 | token: ${HF_TOKEN} 34 | push_to_hub: true 35 | -------------------------------------------------------------------------------- /configs/llm_finetuning/smollm2.yml: -------------------------------------------------------------------------------- 1 | task: llm-sft 2 | base_model: HuggingFaceTB/SmolLM2-1.7B-Instruct 3 | project_name: autotrain-smollm2-finetune 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: HuggingFaceH4/no_robots 9 | train_split: train 10 | valid_split: null 11 | chat_template: tokenizer 12 | column_mapping: 13 | text_column: messages 14 | 15 | params: 16 | block_size: 2048 17 | model_max_length: 4096 18 | epochs: 2 19 | batch_size: 1 20 | lr: 1e-5 21 | peft: true 22 | quantization: int4 23 | target_modules: all-linear 24 | padding: right 25 | optimizer: paged_adamw_8bit 26 | scheduler: linear 27 | gradient_accumulation: 8 28 | mixed_precision: bf16 29 | merge_adapter: true 30 | 31 | hub: 32 | username: ${HF_USERNAME} 33 | token: ${HF_TOKEN} 34 | push_to_hub: true 35 | -------------------------------------------------------------------------------- /configs/llm_finetuning/smollm2_guanaco.yml: -------------------------------------------------------------------------------- 1 | task: llm-sft 2 | base_model: HuggingFaceTB/SmolLM2-135M-Instruct 3 | project_name: autotrain-smollm2-135m-finetune-guanaco 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: timdettmers/openassistant-guanaco 9 | train_split: train 10 | valid_split: null 11 | chat_template: null 12 | column_mapping: 13 | text_column: text 14 | 15 | params: 16 | block_size: 1024 17 | model_max_length: 2048 18 | epochs: 1 19 | batch_size: 1 20 | lr: 1e-5 21 | peft: true 22 | quantization: int4 23 | target_modules: all-linear 24 | padding: right 25 | optimizer: paged_adamw_8bit 26 | scheduler: linear 27 | gradient_accumulation: 8 28 | mixed_precision: bf16 29 | merge_adapter: true 30 | 31 | hub: 32 | username: ${HF_USERNAME} 33 | token: ${HF_TOKEN} 34 | push_to_hub: true -------------------------------------------------------------------------------- /configs/llm_finetuning/smollm2_orpo.yml: -------------------------------------------------------------------------------- 1 | task: llm-orpo 2 | base_model: HuggingFaceTB/SmolLM2-1.7B-Instruct 3 | project_name: autotrain-smallm2-orpo 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: argilla/distilabel-capybara-dpo-7k-binarized 9 | train_split: train 10 | valid_split: null 11 | chat_template: chatml 12 | column_mapping: 13 | text_column: chosen 14 | rejected_text_column: rejected 15 | prompt_text_column: prompt 16 | 17 | params: 18 | block_size: 1024 19 | model_max_length: 2048 20 | max_prompt_length: 512 21 | epochs: 3 22 | batch_size: 2 23 | lr: 3e-5 24 | peft: true 25 | quantization: int4 26 | target_modules: all-linear 27 | padding: right 28 | optimizer: adamw_torch 29 | scheduler: linear 30 | gradient_accumulation: 4 31 | mixed_precision: fp16 32 | 33 | hub: 34 | username: ${HF_USERNAME} 35 | token: ${HF_TOKEN} 36 | push_to_hub: false -------------------------------------------------------------------------------- /configs/object_detection/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: object_detection 2 | base_model: facebook/detr-resnet-50 3 | project_name: autotrain-obj-det-cppe5-2 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: cppe-5 9 | train_split: train 10 | valid_split: test 11 | column_mapping: 12 | image_column: image 13 | objects_column: objects 14 | 15 | params: 16 | image_square_size: 600 17 | epochs: 100 18 | batch_size: 8 19 | lr: 5e-5 20 | weight_decay: 1e-4 21 | optimizer: adamw_torch 22 | scheduler: linear 23 | gradient_accumulation: 1 24 | mixed_precision: fp16 25 | early_stopping_patience: 50 26 | early_stopping_threshold: 0.001 27 | 28 | hub: 29 | username: ${HF_USERNAME} 30 | token: ${HF_TOKEN} 31 | push_to_hub: true -------------------------------------------------------------------------------- /configs/object_detection/local.yml: -------------------------------------------------------------------------------- 1 | task: object_detection 2 | base_model: facebook/detr-resnet-50 3 | project_name: autotrain-obj-det-local-dataset 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: data/ # this contains the train and validation folders 9 | train_split: train # this is the folder name inside the data path, contains images and metadata.jsonl 10 | valid_split: validation # this is the folder name inside the data path, contains images and metadata.jsonl, optional 11 | column_mapping: 12 | image_column: image 13 | objects_column: objects 14 | 15 | params: 16 | image_square_size: 600 17 | epochs: 100 18 | batch_size: 8 19 | lr: 5e-5 20 | weight_decay: 1e-4 21 | optimizer: adamw_torch 22 | scheduler: linear 23 | gradient_accumulation: 1 24 | mixed_precision: fp16 25 | early_stopping_patience: 50 26 | early_stopping_threshold: 0.001 27 | 28 | hub: 29 | username: ${HF_USERNAME} 30 | token: ${HF_TOKEN} 31 | push_to_hub: true -------------------------------------------------------------------------------- /configs/sentence_transformers/local_dataset.yml: -------------------------------------------------------------------------------- 1 | task: sentence-transformers:pair_score 2 | base_model: microsoft/mpnet-base 3 | project_name: autotrain-st-pair-score-local-dataset 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: /path/to/your/dataset # this must be the path to the directory containing the train and valid files 9 | train_split: train # this is the name of the train file (csv or jsonl) 10 | valid_split: null # this is the name of the valid file (csv or jsonl), optional 11 | column_mapping: 12 | sentence1_column: input_sentence 13 | sentence2_column: target_sentence 14 | target_column: score 15 | 16 | params: 17 | max_seq_length: 512 18 | epochs: 5 19 | batch_size: 8 20 | lr: 2e-5 21 | optimizer: adamw_torch 22 | scheduler: linear 23 | gradient_accumulation: 1 24 | mixed_precision: fp16 25 | 26 | hub: 27 | username: ${HF_USERNAME} 28 | token: ${HF_TOKEN} 29 | push_to_hub: true -------------------------------------------------------------------------------- /configs/sentence_transformers/pair.yml: -------------------------------------------------------------------------------- 1 | task: sentence-transformers:pair 2 | base_model: microsoft/mpnet-base 3 | project_name: autotrain-st-pair 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: sentence-transformers/all-nli 9 | train_split: pair:train 10 | valid_split: pair:dev 11 | column_mapping: 12 | sentence1_column: anchor 13 | sentence2_column: positive 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 5 18 | batch_size: 8 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/sentence_transformers/pair_class.yml: -------------------------------------------------------------------------------- 1 | task: sentence-transformers:pair_class 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-st-pair-class 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: sentence-transformers/all-nli 9 | train_split: pair-class:train 10 | valid_split: pair-class:test 11 | column_mapping: 12 | sentence1_column: premise 13 | sentence2_column: hypothesis 14 | target_column: label 15 | 16 | params: 17 | max_seq_length: 512 18 | epochs: 5 19 | batch_size: 8 20 | lr: 2e-5 21 | optimizer: adamw_torch 22 | scheduler: linear 23 | gradient_accumulation: 1 24 | mixed_precision: fp16 25 | 26 | hub: 27 | username: ${HF_USERNAME} 28 | token: ${HF_TOKEN} 29 | push_to_hub: true -------------------------------------------------------------------------------- /configs/sentence_transformers/pair_score.yml: -------------------------------------------------------------------------------- 1 | task: sentence-transformers:pair_score 2 | base_model: microsoft/mpnet-base 3 | project_name: autotrain-st-pair-score 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: sentence-transformers/all-nli 9 | train_split: pair-score:train 10 | valid_split: pair-score:dev 11 | column_mapping: 12 | sentence1_column: sentence1 13 | sentence2_column: sentence2 14 | target_column: score 15 | 16 | params: 17 | max_seq_length: 512 18 | epochs: 5 19 | batch_size: 8 20 | lr: 2e-5 21 | optimizer: adamw_torch 22 | scheduler: linear 23 | gradient_accumulation: 1 24 | mixed_precision: fp16 25 | 26 | hub: 27 | username: ${HF_USERNAME} 28 | token: ${HF_TOKEN} 29 | push_to_hub: true -------------------------------------------------------------------------------- /configs/sentence_transformers/qa.yml: -------------------------------------------------------------------------------- 1 | task: sentence-transformers:qa 2 | base_model: microsoft/mpnet-base 3 | project_name: autotrain-st-qa 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: sentence-transformers/natural-questions 9 | train_split: train 10 | valid_split: null 11 | column_mapping: 12 | sentence1_column: query 13 | sentence2_column: answer 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 5 18 | batch_size: 8 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/sentence_transformers/triplet.yml: -------------------------------------------------------------------------------- 1 | task: sentence-transformers:triplet 2 | base_model: microsoft/mpnet-base 3 | project_name: autotrain-st-triplet 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: sentence-transformers/all-nli 9 | train_split: triplet:train 10 | valid_split: triplet:dev 11 | column_mapping: 12 | sentence1_column: anchor 13 | sentence2_column: positive 14 | sentence3_column: negative 15 | 16 | params: 17 | max_seq_length: 512 18 | epochs: 5 19 | batch_size: 8 20 | lr: 2e-5 21 | optimizer: adamw_torch 22 | scheduler: linear 23 | gradient_accumulation: 1 24 | mixed_precision: fp16 25 | 26 | hub: 27 | username: ${HF_USERNAME} 28 | token: ${HF_TOKEN} 29 | push_to_hub: true -------------------------------------------------------------------------------- /configs/seq2seq/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: seq2seq 2 | base_model: google/flan-t5-base 3 | project_name: autotrain-seq2seq-hub-dataset 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: samsum 9 | train_split: train 10 | valid_split: test 11 | column_mapping: 12 | text_column: dialogue 13 | target_column: summary 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 3 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: none 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/seq2seq/local.yml: -------------------------------------------------------------------------------- 1 | task: seq2seq 2 | base_model: google/flan-t5-base 3 | project_name: autotrain-seq2seq-local 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: path/to/your/dataset csv/jsonl files 9 | train_split: train 10 | valid_split: test 11 | column_mapping: 12 | text_column: text 13 | target_column: target 14 | 15 | 16 | params: 17 | max_seq_length: 512 18 | epochs: 3 19 | batch_size: 4 20 | lr: 2e-5 21 | optimizer: adamw_torch 22 | scheduler: linear 23 | gradient_accumulation: 1 24 | mixed_precision: none 25 | 26 | hub: 27 | username: ${HF_USERNAME} 28 | token: ${HF_TOKEN} 29 | push_to_hub: true -------------------------------------------------------------------------------- /configs/text_classification/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: text_classification 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-imdb-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: stanfordnlp/imdb 9 | train_split: train 10 | valid_split: test 11 | column_mapping: 12 | text_column: text 13 | target_column: label 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 3 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/text_classification/local_dataset.yml: -------------------------------------------------------------------------------- 1 | task: text_classification 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-imdb-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: data/ # this must be the path to the directory containing the train and valid files 9 | train_split: train # this must be either train.csv or train.json 10 | valid_split: valid # this must be either valid.csv or valid.json 11 | column_mapping: 12 | text_column: text # this must be the name of the column containing the text 13 | target_column: label # this must be the name of the column containing the target 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 3 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/text_regression/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: text_regression 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-sms-spam-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: sms_spam 9 | train_split: train 10 | valid_split: null 11 | column_mapping: 12 | text_column: sms 13 | target_column: label 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 3 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/text_regression/local_dataset.yml: -------------------------------------------------------------------------------- 1 | task: text_regression 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-custom-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: data/ # this must be the path to the directory containing the train and valid files 9 | train_split: train # this must be either train.csv or train.json 10 | valid_split: valid # this must be either valid.csv or valid.json 11 | column_mapping: 12 | text_column: text # this must be the name of the column containing the text 13 | target_column: label # this must be the name of the column containing the target 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 3 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/token_classification/hub_dataset.yml: -------------------------------------------------------------------------------- 1 | task: token_classification 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-conll2003-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: conll2003 9 | train_split: train 10 | valid_split: validation 11 | column_mapping: 12 | tokens_column: tokens 13 | tags_column: ner_tags 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 3 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/token_classification/local_dataset.yml: -------------------------------------------------------------------------------- 1 | task: token_classification 2 | base_model: google-bert/bert-base-uncased 3 | project_name: autotrain-bert-custom-finetuned 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: data/ # this must be the path to the directory containing the train and valid files 9 | train_split: train # this must be either train.json 10 | valid_split: test # this must be either valid.json, can also be set to null 11 | column_mapping: 12 | tokens_column: tokens # this must be the name of the column containing the text 13 | tags_column: tags # this must be the name of the column containing the target 14 | 15 | params: 16 | max_seq_length: 512 17 | epochs: 3 18 | batch_size: 4 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 1 23 | mixed_precision: fp16 24 | 25 | hub: 26 | username: ${HF_USERNAME} 27 | token: ${HF_TOKEN} 28 | push_to_hub: true -------------------------------------------------------------------------------- /configs/vlm/paligemma_vqa.yml: -------------------------------------------------------------------------------- 1 | task: vlm:vqa 2 | base_model: google/paligemma-3b-pt-224 3 | project_name: autotrain-paligemma-finetuned-vqa 4 | log: tensorboard 5 | backend: local 6 | 7 | data: 8 | path: abhishek/vqa_small 9 | train_split: train 10 | valid_split: validation 11 | column_mapping: 12 | image_column: image 13 | text_column: multiple_choice_answer 14 | prompt_text_column: question 15 | 16 | params: 17 | epochs: 3 18 | batch_size: 2 19 | lr: 2e-5 20 | optimizer: adamw_torch 21 | scheduler: linear 22 | gradient_accumulation: 4 23 | mixed_precision: fp16 24 | peft: true 25 | quantization: int4 26 | 27 | hub: 28 | username: ${HF_USERNAME} 29 | token: ${HF_TOKEN} 30 | push_to_hub: true -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Generating the documentation 2 | 3 | To generate the documentation, you have to build it. Several packages are necessary to build the doc. 4 | 5 | First, you need to install the project itself by running the following command at the root of the code repository: 6 | 7 | ```bash 8 | pip install -e . 9 | ``` 10 | 11 | You also need to install 2 extra packages: 12 | 13 | ```bash 14 | # `hf-doc-builder` to build the docs 15 | pip install git+https://github.com/huggingface/doc-builder@main 16 | # `watchdog` for live reloads 17 | pip install watchdog 18 | ``` 19 | 20 | --- 21 | **NOTE** 22 | 23 | You only need to generate the documentation to inspect it locally (if you're planning changes and want to 24 | check how they look before committing for instance). You don't have to commit the built documentation. 25 | 26 | --- 27 | 28 | ## Building the documentation 29 | 30 | Once you have setup the `doc-builder` and additional packages with the pip install command above, 31 | you can generate the documentation by typing the following command: 32 | 33 | ```bash 34 | doc-builder build autotrain docs/source/ --build_dir ~/tmp/test-build 35 | ``` 36 | 37 | You can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate 38 | the MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite 39 | Markdown editor. 40 | 41 | ## Previewing the documentation 42 | 43 | To preview the docs, run the following command: 44 | 45 | ```bash 46 | doc-builder preview autotrain docs/source/ 47 | ``` 48 | 49 | The docs will be viewable at [http://localhost:5173](http://localhost:5173). You can also preview the docs once you 50 | have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives. 51 | 52 | --- 53 | **NOTE** 54 | 55 | The `preview` command only works with existing doc files. When you add a completely new file, you need to update 56 | `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again). 57 | 58 | --- -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: 🤗 AutoTrain 4 | - local: cost 5 | title: How much does it cost? 6 | - local: support 7 | title: Get help and support 8 | - local: faq 9 | title: Frequently Asked Questions 10 | title: Getting Started 11 | - sections: 12 | - local: quickstart_spaces 13 | title: Train on Spaces 14 | - local: quickstart_py 15 | title: Python SDK 16 | - local: quickstart 17 | title: Train Locally 18 | - local: config 19 | title: Config File 20 | title: Quickstart 21 | - sections: 22 | - local: tasks/llm_finetuning 23 | title: LLM Finetuning 24 | - local: tasks/text_classification_regression 25 | title: Text Classification/Regression 26 | - local: tasks/extractive_qa 27 | title: Extractive QA 28 | - local: tasks/sentence_transformer 29 | title: Sentence Transformer 30 | - local: tasks/image_classification_regression 31 | title: Image Classification / Regression 32 | - local: tasks/object_detection 33 | title: Object Detection 34 | - local: tasks/seq2seq 35 | title: Seq2Seq 36 | - local: tasks/token_classification 37 | title: Token Classification 38 | - local: tasks/tabular 39 | title: Tabular 40 | title: Tasks 41 | - sections: 42 | - local: col_map 43 | title: Understanding Column Mapping 44 | - local: autotrain_api 45 | title: AutoTrain API 46 | title: Miscellaneous -------------------------------------------------------------------------------- /docs/source/autotrain_api.mdx: -------------------------------------------------------------------------------- 1 | # AutoTrain API 2 | 3 | With AutoTrain API, you can run your own instance of AutoTrain and use it to 4 | train models on Hugging Face Spaces infrastructure (local training coming soon). 5 | This API is designed to be used with autotrain compatible models and datasets, and it provides a simple interface to 6 | train models with minimal configuration. 7 | 8 | ## Getting Started 9 | 10 | To get started with AutoTrain API, all you need to do is install `autotrain-advanced` 11 | as discussed in running locally section and run the autotrain app command: 12 | 13 | ```bash 14 | $ autotrain app --port 8000 --host 127.0.0.1 15 | ``` 16 | 17 | You can then access the API reference at `http://127.0.0.1:8000/docs`. 18 | 19 | ## Example Usage 20 | 21 | ```bash 22 | curl -X POST "http://127.0.0.1:8000/api/create_project" \ 23 | -H "Content-Type: application/json" \ 24 | -H "Authorization: Bearer hf_XXXXX" \ 25 | -d '{ 26 | "username": "abhishek", 27 | "project_name": "my-autotrain-api-model", 28 | "task": "llm:orpo", 29 | "base_model": "meta-llama/Meta-Llama-3-8B-Instruct", 30 | "hub_dataset": "argilla/distilabel-capybara-dpo-7k-binarized", 31 | "train_split": "train", 32 | "hardware": "spaces-a10g-large", 33 | "column_mapping": { 34 | "text_column": "chosen", 35 | "rejected_text_column": "rejected", 36 | "prompt_text_column": "prompt" 37 | }, 38 | "params": { 39 | "block_size": 1024, 40 | "model_max_length": 4096, 41 | "max_prompt_length": 512, 42 | "epochs": 1, 43 | "batch_size": 2, 44 | "lr": 0.00003, 45 | "peft": true, 46 | "quantization": "int4", 47 | "target_modules": "all-linear", 48 | "padding": "right", 49 | "optimizer": "adamw_torch", 50 | "scheduler": "linear", 51 | "gradient_accumulation": 4, 52 | "mixed_precision": "fp16", 53 | "chat_template": "chatml" 54 | } 55 | }' 56 | ``` 57 | 58 | -------------------------------------------------------------------------------- /docs/source/config.mdx: -------------------------------------------------------------------------------- 1 | # AutoTrain Configs 2 | 3 | AutoTrain Configs are the way to use and train models using AutoTrain locally. 4 | 5 | Once you have installed AutoTrain Advanced, you can use the following command to train models using AutoTrain config files: 6 | 7 | ```bash 8 | $ export HF_USERNAME=your_hugging_face_username 9 | $ export HF_TOKEN=your_hugging_face_write_token 10 | 11 | $ autotrain --config path/to/config.yaml 12 | ``` 13 | 14 | Example configurations for all tasks can be found in the `configs` directory of 15 | the [AutoTrain Advanced GitHub repository](https://github.com/huggingface/autotrain-advanced). 16 | 17 | Here is an example of an AutoTrain config file: 18 | 19 | ```yaml 20 | task: llm 21 | base_model: meta-llama/Meta-Llama-3-8B-Instruct 22 | project_name: autotrain-llama3-8b-orpo 23 | log: tensorboard 24 | backend: local 25 | 26 | data: 27 | path: argilla/distilabel-capybara-dpo-7k-binarized 28 | train_split: train 29 | valid_split: null 30 | chat_template: chatml 31 | column_mapping: 32 | text_column: chosen 33 | rejected_text_column: rejected 34 | 35 | params: 36 | trainer: orpo 37 | block_size: 1024 38 | model_max_length: 2048 39 | max_prompt_length: 512 40 | epochs: 3 41 | batch_size: 2 42 | lr: 3e-5 43 | peft: true 44 | quantization: int4 45 | target_modules: all-linear 46 | padding: right 47 | optimizer: adamw_torch 48 | scheduler: linear 49 | gradient_accumulation: 4 50 | mixed_precision: bf16 51 | 52 | hub: 53 | username: ${HF_USERNAME} 54 | token: ${HF_TOKEN} 55 | push_to_hub: true 56 | ``` 57 | 58 | In this config, we are finetuning the `meta-llama/Meta-Llama-3-8B-Instruct` model 59 | on the `argilla/distilabel-capybara-dpo-7k-binarized` dataset using the `orpo` 60 | trainer for 3 epochs with a batch size of 2 and a learning rate of `3e-5`. 61 | More information on the available parameters can be found in the *Data Formats and Parameters* section. 62 | 63 | In case you dont want to push the model to hub, you can set `push_to_hub` to `false` in the config file. 64 | If not pushing the model to hub username and token are not required. Note: they may still be needed 65 | if you are trying to access gated models or datasets. -------------------------------------------------------------------------------- /docs/source/cost.mdx: -------------------------------------------------------------------------------- 1 | # How much does it cost? 2 | 3 | AutoTrain offers an accessible approach to model training, providing deployable models 4 | with just a few clicks. Understanding the cost involved is essential to planning and 5 | executing your projects efficiently. 6 | 7 | 8 | ## Local Usage 9 | 10 | When you choose to use AutoTrain locally on your own hardware, there is no cost. 11 | This option is ideal for those who prefer to manage their own infrastructure and 12 | do not require the scalability that cloud resources offer. 13 | 14 | ## Using AutoTrain on Hugging Face Spaces 15 | 16 | **Pay-As-You-Go**: Costs for using AutoTrain in Hugging Face Spaces are based on the 17 | computing resources you consume. This flexible pricing structure ensures you only pay 18 | for what you use, making it cost-effective and scalable for projects of any size. 19 | 20 | 21 | **Ownership and Portability**: Unlike some other platforms, AutoTrain does not retain 22 | ownership of your models. Once training is complete, you are free to download and 23 | deploy your models wherever you choose, providing flexibility and control over your all your assets. 24 | 25 | ### Pricing Details 26 | 27 | **Resource-Based Billing**: Charges are accrued per minute according to the type of hardware 28 | utilized during training. This means you can scale your resource usage based on the 29 | complexity and needs of your projects. 30 | 31 | For a detailed breakdown of the costs associated with using Hugging Face Spaces, 32 | please refer to the [pricing](https://huggingface.co/pricing#spaces) section on our website. 33 | 34 | To access the paid features of AutoTrain, you must have a valid payment method on file. 35 | You can manage your payment options and view your billing information in 36 | the [billing section of your Hugging Face account settings.](https://huggingface.co/settings/billing) 37 | 38 | By offering both free and flexible paid options, AutoTrain ensures that users can choose 39 | the most suitable model training solution for their needs, whether they are experimenting 40 | on a local machine or scaling up operations on Hugging Face Spaces. 41 | -------------------------------------------------------------------------------- /docs/source/getting_started.bck: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | There is no installation required! AutoTrain Advanced runs on Hugging Face Spaces. All you need to do is create a new space with the AutoTrain Advanced template: https://huggingface.co/new-space?template=autotrain-projects/autotrain-advanced. Please make sure you keep the space private. 4 | 5 | ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_1.png) 6 | 7 | Once you have selected Docker > AutoTrain template and an appropriate hardware, you can click on "Create Space" and you will be redirected to your new space. 8 | 9 | ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_2.png) 10 | 11 | Make sure to use a write token and keep the space private for any unauthorized access. 12 | 13 | # Updating AutoTrain Advanced to Latest Version 14 | 15 | We are constantly adding new features and tasks to AutoTrain Advanced. Its always a good idea to update your space to the latest version before starting a new project. An up-to-date version of AutoTrain Advanced will have the latest tasks, features and bug fixes! Updating is as easy as clicking on the "Factory reboot" button in the setting page of your space. 16 | 17 | ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_5.png) 18 | 19 | Please note that "restarting" a space will not update it to the latest version. You need to "Factory reboot" the space to update it to the latest version. 20 | 21 | And now we are all set and we can start with our first project! 22 | 23 | # Understanding the UI 24 | 25 | ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/ui.png) 26 | -------------------------------------------------------------------------------- /docs/source/index.mdx: -------------------------------------------------------------------------------- 1 | # AutoTrain 2 | 3 | ![autotrain-homepage](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/autotrain_homepage.png) 4 | 5 | 🤗 AutoTrain Advanced (or simply AutoTrain), developed by Hugging Face, is a robust no-code 6 | platform designed to simplify the process of training state-of-the-art models across 7 | multiple domains: Natural Language Processing (NLP), Computer Vision (CV), 8 | and even Tabular Data analysis. This tool leverages the powerful frameworks created by 9 | various teams at Hugging Face, making advanced machine learning and artificial intelligence accessible to a broader 10 | audience without requiring deep technical expertise. 11 | 12 | ## Who should use AutoTrain? 13 | 14 | AutoTrain is the perfect tool for anyone eager to dive into the world of machine learning 15 | without getting bogged down by the complexities of model training. 16 | Whether you're a business professional, researcher, educator, or hobbyist, 17 | AutoTrain offers the simplicity of a no-code interface while still providing the 18 | capabilities necessary to develop sophisticated models tailored to your unique datasets. 19 | 20 | AutoTrain is for anyone who wants to train a state-of-the-art model for a NLP, CV, Speech or even Tabular task, 21 | but doesn't want to spend time on the technical details of training a model. 22 | 23 | Our mission is to democratize machine learning technology, ensuring it is not only 24 | accessible to data scientists and ML engineers but also to those without a technical 25 | background. If you're looking to harness the power of AI for your projects, 26 | AutoTrain is your answer. 27 | 28 | 29 | ## How to use AutoTrain? 30 | 31 | We offer several ways to use AutoTrain: 32 | 33 | - No code users can use `AutoTrain Advanced` by creating a new space with AutoTrain Docker image: 34 | [Click here](https://huggingface.co/login?next=/spaces/autotrain-projects/autotrain-advanced?duplicate=true) to create AutoTrain Space. 35 | Remember to keep your space private and ensure it is equipped with the necessary hardware resources (GPU) for optimal performance. 36 | 37 | - If you prefer a more hands-on approach, AutoTrain Advanced can also be run locally 38 | through its intuitive UI or accessed via the Python API provided in the autotrain-advanced 39 | package. This flexibility allows developers to integrate AutoTrain capabilities directly 40 | into their projects, customize workflows, and enhance their toolsets with advanced machine 41 | learning functionalities. 42 | 43 | 44 | By bridging the gap between cutting-edge technology and practical usability, 45 | AutoTrain Advanced empowers users to achieve remarkable results in AI without the need 46 | for extensive programming knowledge. Start your journey with AutoTrain today and unlock 47 | the potential of machine learning for your projects! 48 | 49 | 50 | ## Walkthroughs 51 | 52 | To get started with AutoTrain, check out our walkthroughs and tutorials: 53 | 54 | - [Extractive Question Answering with AutoTrain](https://huggingface.co/blog/abhishek/extractive-qa-autotrain) 55 | - [Finetuning PaliGemma with AutoTrain](https://huggingface.co/blog/abhishek/paligemma-finetuning-autotrain) 56 | - [Training an Object Detection Model with AutoTrain](https://huggingface.co/blog/abhishek/object-detection-autotrain) 57 | - [How to Fine-Tune Custom Embedding Models Using AutoTrain](https://huggingface.co/blog/abhishek/finetune-custom-embeddings-autotrain) 58 | - [Train Custom Models on Hugging Face Spaces with AutoTrain SpaceRunner](https://huggingface.co/blog/abhishek/autotrain-spacerunner) 59 | - [How to Finetune phi-3 on MacBook Pro](https://huggingface.co/blog/abhishek/phi3-finetune-macbook) 60 | - [Finetune Mixtral 8x7B with AutoTrain](https://huggingface.co/blog/abhishek/autotrain-mixtral-dgx-cloud-local) 61 | - [Easily Train Models with H100 GPUs on NVIDIA DGX Cloud](https://huggingface.co/blog/train-dgx-cloud) 62 | -------------------------------------------------------------------------------- /docs/source/quickstart.mdx: -------------------------------------------------------------------------------- 1 | # Quickstart Guide for Local Training 2 | 3 | This quickstart is for local installation and usage. 4 | If you want to use AutoTrain on Hugging Face Spaces, please refer to the *AutoTrain on Hugging Face Spaces* section. 5 | 6 | You can install AutoTrain Advanced using pip: 7 | 8 | ```bash 9 | $ pip install autotrain-advanced 10 | ``` 11 | 12 | It is advised to install autotrain-advanced in a virtual environment to avoid any conflicts with other packages. 13 | Note: AutoTrain doesn't install pytorch, torchaudio, torchvision, or any other large dependencies. You will need to install them separately. 14 | 15 | ```bash 16 | $ conda create -n autotrain python=3.10 17 | $ conda activate autotrain 18 | $ pip install autotrain-advanced 19 | $ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia 20 | $ conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc 21 | $ conda install xformers -c xformers 22 | $ python -m nltk.downloader punkt 23 | $ pip install flash-attn --no-build-isolation # if you want to use flash-attn 24 | $ pip install deepspeed # if you want to use deepspeed 25 | ```` 26 | 27 | # Running AutoTrain User Interface (UI) 28 | 29 | To run the autotrain app locally, you can use the following command: 30 | 31 | ```bash 32 | $ export HF_TOKEN=your_hugging_face_write_token 33 | $ autotrain app --host 127.0.0.1 --port 8000 34 | ``` 35 | 36 | This will start the app on `http://127.0.0.1:8000`. 37 | 38 | 39 | # Using AutoTrain Command Line Interface (CLI) 40 | 41 | It is also possible to use the CLI: 42 | 43 | ```bash 44 | $ export HF_TOKEN=your_hugging_face_write_token 45 | $ autotrain --help 46 | ``` 47 | 48 | This will show the CLI commands that can be used: 49 | 50 | ```bash 51 | usage: autotrain [] 52 | 53 | positional arguments: 54 | { 55 | app, 56 | llm, 57 | setup, 58 | api, 59 | text-classification, 60 | text-regression, 61 | image-classification, 62 | tabular, 63 | spacerunner, 64 | seq2seq, 65 | token-classification 66 | } 67 | 68 | commands 69 | 70 | options: 71 | -h, --help show this help message and exit 72 | --version, -v Display AutoTrain version 73 | --config CONFIG Optional configuration file 74 | 75 | For more information about a command, run: `autotrain --help` 76 | ``` 77 | 78 | It is advised to use only the `autotrain --config CONFIG_FILE` command for training when using the CLI. 79 | 80 | The autotrain commands that end users will be interested in are: 81 | 82 | - `app`: Start the AutoTrain UI 83 | - `llm`: Train a language model 84 | - `text-classification`: Train a text classification model 85 | - `text-regression`: Train a text regression model 86 | - `image-classification`: Train an image classification model 87 | - `tabular`: Train a tabular model 88 | - `spacerunner`: Train any custom model using SpaceRunner 89 | - `seq2seq`: Train a sequence-to-sequence model 90 | - `token-classification`: Train a token classification model 91 | 92 | Note: above commands are not required if you use preferred `autotrain --config CONFIG_FILE` command to train the models. -------------------------------------------------------------------------------- /docs/source/quickstart_py.mdx: -------------------------------------------------------------------------------- 1 | # Quickstart with Python 2 | 3 | AutoTrain is a library that allows you to train state of the art models on Hugging Face Spaces, or locally. 4 | It provides a simple and easy-to-use interface to train models for various tasks like llm finetuning, text classification, 5 | image classification, object detection, and more. 6 | 7 | In this quickstart guide, we will show you how to train a model using AutoTrain in Python. 8 | 9 | ## Getting Started 10 | 11 | AutoTrain can be installed using pip: 12 | 13 | ```bash 14 | $ pip install autotrain-advanced 15 | ``` 16 | 17 | The example code below shows how to finetune an LLM model using AutoTrain in Python: 18 | 19 | ```python 20 | import os 21 | 22 | from autotrain.params import LLMTrainingParams 23 | from autotrain.project import AutoTrainProject 24 | 25 | 26 | params = LLMTrainingParams( 27 | model="meta-llama/Llama-3.2-1B-Instruct", 28 | data_path="HuggingFaceH4/no_robots", 29 | chat_template="tokenizer", 30 | text_column="messages", 31 | train_split="train", 32 | trainer="sft", 33 | epochs=3, 34 | batch_size=1, 35 | lr=1e-5, 36 | peft=True, 37 | quantization="int4", 38 | target_modules="all-linear", 39 | padding="right", 40 | optimizer="paged_adamw_8bit", 41 | scheduler="cosine", 42 | gradient_accumulation=8, 43 | mixed_precision="bf16", 44 | merge_adapter=True, 45 | project_name="autotrain-llama32-1b-finetune", 46 | log="tensorboard", 47 | push_to_hub=True, 48 | username=os.environ.get("HF_USERNAME"), 49 | token=os.environ.get("HF_TOKEN"), 50 | ) 51 | 52 | 53 | backend = "local" 54 | project = AutoTrainProject(params=params, backend=backend, process=True) 55 | project.create() 56 | ``` 57 | 58 | In this example, we are finetuning the `meta-llama/Llama-3.2-1B-Instruct` model on the `HuggingFaceH4/no_robots` dataset. 59 | We are training the model for 3 epochs with a batch size of 1 and a learning rate of `1e-5`. 60 | We are using the `paged_adamw_8bit` optimizer and the `cosine` scheduler. 61 | We are also using mixed precision training with a gradient accumulation of 8. 62 | The final model will be pushed to the Hugging Face Hub after training. 63 | 64 | To train the model, run the following command: 65 | 66 | ```bash 67 | $ export HF_USERNAME= 68 | $ export HF_TOKEN= 69 | $ python train.py 70 | ``` 71 | 72 | This will create a new project directory with the name `autotrain-llama32-1b-finetune` and start the training process. 73 | Once the training is complete, the model will be pushed to the Hugging Face Hub. 74 | 75 | Your HF_TOKEN and HF_USERNAME are only required if you want to push the model or if you are accessing a gated model or dataset. 76 | 77 | ## AutoTrainProject Class 78 | 79 | [[autodoc]] project.AutoTrainProject 80 | 81 | ## Parameters 82 | 83 | ### Text Tasks 84 | 85 | [[autodoc]] trainers.clm.params.LLMTrainingParams 86 | 87 | [[autodoc]] trainers.sent_transformers.params.SentenceTransformersParams 88 | 89 | [[autodoc]] trainers.seq2seq.params.Seq2SeqParams 90 | 91 | [[autodoc]] trainers.token_classification.params.TokenClassificationParams 92 | 93 | [[autodoc]] trainers.extractive_question_answering.params.ExtractiveQuestionAnsweringParams 94 | 95 | [[autodoc]] trainers.text_classification.params.TextClassificationParams 96 | 97 | [[autodoc]] trainers.text_regression.params.TextRegressionParams 98 | 99 | ### Image Tasks 100 | 101 | [[autodoc]] trainers.image_classification.params.ImageClassificationParams 102 | 103 | [[autodoc]] trainers.image_regression.params.ImageRegressionParams 104 | 105 | [[autodoc]] trainers.object_detection.params.ObjectDetectionParams 106 | 107 | 108 | ### Tabular Tasks 109 | 110 | [[autodoc]] trainers.tabular.params.TabularParams -------------------------------------------------------------------------------- /docs/source/quickstart_spaces.mdx: -------------------------------------------------------------------------------- 1 | # Quickstart Guide to AutoTrain on Hugging Face Spaces 2 | 3 | AutoTrain on Hugging Face Spaces is the preferred choice for a streamlined experience in 4 | model training. This platform is optimized for ease of use, with pre-installed dependencies 5 | and managed hardware resources. AutoTrain on Hugging Face Spaces can be used both by 6 | no-code users and developers, making it versatile for various levels of expertise. 7 | 8 | 9 | ## Creating a New AutoTrain Space 10 | 11 | Getting started with AutoTrain is straightforward. Here’s how you can create your new space: 12 | 13 | 1. **Visit the AutoTrain Page**: To create a new space with AutoTrain Docker image, all you need to do is go 14 | to [AutoTrain Homepage](https://hf.co/autotrain) and click on "Create new project". 15 | 16 | 2. **Log In or View the Setup Screen**: If not logged in, you'll be prompted to do so. Then, you’ll see a screen similar to this: 17 | 18 | ![autotrain-duplicate-space](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/duplicate_space.png) 19 | 20 | 3. **Set Up Your Space**: 21 | 22 | - **Choose a Space Name**: Name your space something relevant to your project. 23 | 24 | - **Allocate Hardware Resources**: Select the necessary computational resources based on your project needs. 25 | 26 | - **Duplicate Space**: Click on "Duplicate Space" to initiate your AutoTrain space with the Docker image. 27 | 28 | 4. **Configuration Options**: 29 | 30 | - PAUSE_ON_FAILURE: Set this to 0 if you prefer the space not to pause on training failures, useful for running continuous experiments. This option can also be used if you continuously want to perfom many experiments in the same space. 31 | 32 | 5. **Launch and Train**: 33 | 34 | - Once done, in a few seconds, the AutoTrain Space will be up and running and you will be presented with the following screen: 35 | 36 | ![autotrain-space](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/autotrain_space.png) 37 | 38 | - From here, you can select tasks, upload datasets, choose models, adjust hyperparameters (if needed), 39 | and start the training process directly within the space. 40 | 41 | - The space will manage its own activity, shutting down post-training unless configured 42 | otherwise based on the `PAUSE_ON_FAILURE` setting. 43 | 44 | 6. **Monitoring Progress**: 45 | 46 | - All training logs and progress can be monitored via TensorBoard, accessible under 47 | `username/project_name` on the Hugging Face Hub. 48 | 49 | - Once training concludes successfully, you’ll find the model files in the same repository. 50 | 51 | 7. **Navigating the UI**: 52 | 53 | - If you need help understanding any UI elements, click on the small (i) information icons for detailed descriptions. 54 | 55 | If you are confused about the UI elements, click on the small (i) information icon to get more information about the UI element. 56 | 57 | For data formats and detailed parameter information, please see the Data Formats and Parameters section where we provide 58 | example datasets and detailed information about the parameters for each task supported by AutoTrain. 59 | 60 | ## Ensuring Your AutoTrain is Up-to-Date 61 | 62 | We are constantly adding new features and tasks to AutoTrain Advanced. To benefit from the latest features, tasks, and bug fixes, update your AutoTrain space regularly: 63 | 64 | - *Factory Reboot*: Navigate to the settings page of your space and click on "Factory reboot" to upgrade to the latest version of AutoTrain Advanced. 65 | 66 | ![autotrain-space-template](https://raw.githubusercontent.com/huggingface/autotrain-advanced/main/static/space_template_5.png) 67 | 68 | - *Note*: Simply "restarting" the space does not update it; a factory reboot is necessary for a complete update. 69 | 70 | 71 | For additional details on data formats and specific parameters, refer to the 72 | 'Data Formats and Parameters' section where we provide example datasets and extensive 73 | parameter information for each supported task by AutoTrain. 74 | 75 | 76 | With these steps, you can effortlessly initiate and manage your AutoTrain projects on 77 | Hugging Face Spaces, leveraging the platform's robust capabilities for your machine learning and AI 78 | needs. 79 | -------------------------------------------------------------------------------- /docs/source/starting_ui.bck: -------------------------------------------------------------------------------- 1 | # Starting the UI 2 | 3 | The AutoTrain UI can be started in multiple ways depending on your needs. 4 | We offer UI on Hugging Face Spaces, Colab and locally! 5 | 6 | ## Hugging Face Spaces 7 | 8 | To start the UI on Hugging Face Spaces, you can simply click on the following link: 9 | 10 | [![Deploy on Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/deploy-on-spaces-md.svg)](https://huggingface.co/login?next=/spaces/autotrain-projects/autotrain-advanced?duplicate=true) 11 | 12 | Please make sure you keep the space private and attach appropriate hardware to the space. 13 | You can also read more about AutoTrain on the homepage and follow the link there to start your own training instance on 14 | Hugging Face Spaces. [Click here](https://huggingface.co/autotrain) to visit the homepage. 15 | 16 | ## Colab 17 | 18 | To start the UI on Colab, you can simply click on the following link: 19 | 20 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/autotrain-advanced/blob/main/colabs/AutoTrain.ipynb) 21 | 22 | Please note, to run the app on Colab, you will need an ngrok token. You can get one by signing up for free on [ngrok](https://ngrok.com/). 23 | This is because Colab does not allow exposing ports to the internet directly. 24 | 25 | 26 | ## Locally 27 | 28 | To run the autotrain app locally, install autotrain-advanced python package: 29 | 30 | ```bash 31 | $ pip install autotrain-advanced 32 | ``` 33 | 34 | and then run the following command: 35 | 36 | ```bash 37 | $ export HF_TOKEN=your_hugging_face_write_token 38 | $ autotrain app --host 127.0.0.1 --port 8000 39 | ``` 40 | 41 | This will start the app on `http://127.0.0.1:8000`. 42 | 43 | AutoTrain doesn't install pytorch, torchaudio, torchvision, or any other dependencies. You will need to install them separately. 44 | It is thus recommended to use conda environment: 45 | 46 | 47 | ```bash 48 | $ conda create -n autotrain python=3.10 49 | $ conda activate autotrain 50 | 51 | $ pip install autotrain-advanced 52 | 53 | $ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia 54 | $ conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc 55 | $ conda install xformers -c xformers 56 | 57 | $ python -m nltk.downloader punkt 58 | $ pip install flash-attn --no-build-isolation 59 | $ pip install deepspeed 60 | 61 | $ export HF_TOKEN=your_hugging_face_write_token 62 | $ autotrain app --host 127.0.0.1 --port 8000 63 | ``` 64 | 65 | In case of any issues, please report on the [GitHub issues](https://github.com/huggingface/autotrain-advanced/). 66 | -------------------------------------------------------------------------------- /docs/source/support.mdx: -------------------------------------------------------------------------------- 1 | # Help and Support 2 | 3 | If you need assistance with AutoTrain Advanced or have questions about your projects, 4 | you can reach out through several dedicated support channels. We're here to help you 5 | navigate any issues you encounter, from technical queries to billing concerns. 6 | Below are the best ways to get support: 7 | 8 | 9 | - For technical support or to report a bug, you can [create an issue](https://github.com/huggingface/autotrain-advanced/issues/new) 10 | directly in the AutoTrain Advanced GitHub repository. GitHub repo is ideal for tracking bugs, 11 | requesting features, or getting help with troubleshooting problems. When submitting an 12 | issue, please include all the details in question to help us provide the most 13 | relevant support quickly. 14 | 15 | - [Ask in the Hugging Face Forum](https://discuss.huggingface.co/c/autotrain/16). This space is perfect for asking questions, 16 | sharing your experiences, or discussing AutoTrain with other users and the Hugging Face 17 | team. The forum is a great resource for getting advice, learning best practices, and 18 | connecting with other machine learning practitioners. 19 | 20 | - For enterprise users or specific inquiries related to billing, please [email us](mailto:autotrain@hf.co) directly. 21 | This channel ensures that your more sensitive or account-specific issues are handled 22 | appropriately and confidentially. When emailing, please provide your username and 23 | project name so we can assist you efficiently. 24 | 25 | Please note: e-mail support is only available for pro/enterprise users or those with specific queries about billing. 26 | 27 | 28 | By utilizing these support channels, you can ensure that any hurdles you face while using 29 | AutoTrain Advanced are addressed promptly, allowing you to focus on achieving your project 30 | goals. Whether you're a beginner or an experienced user, we are here to support your 31 | journey in AI model training. 32 | -------------------------------------------------------------------------------- /docs/source/tasks/object_detection.mdx: -------------------------------------------------------------------------------- 1 | # Object Detection 2 | 3 | Object detection is a form of supervised learning where a model is trained to identify 4 | and categorize objects within images. AutoTrain simplifies the process, enabling you to 5 | train a state-of-the-art object detection model by simply uploading labeled example images. 6 | 7 | 8 | ## Preparing your data 9 | 10 | To ensure your object detection model trains effectively, follow these guidelines for preparing your data: 11 | 12 | 13 | ### Organizing Images 14 | 15 | 16 | Prepare a zip file containing your images and metadata.jsonl. 17 | 18 | 19 | ``` 20 | Archive.zip 21 | ├── 0001.png 22 | ├── 0002.png 23 | ├── 0003.png 24 | ├── . 25 | ├── . 26 | ├── . 27 | └── metadata.jsonl 28 | ``` 29 | 30 | Example for `metadata.jsonl`: 31 | 32 | ``` 33 | {"file_name": "0001.png", "objects": {"bbox": [[302.0, 109.0, 73.0, 52.0]], "category": [0]}} 34 | {"file_name": "0002.png", "objects": {"bbox": [[810.0, 100.0, 57.0, 28.0]], "category": [1]}} 35 | {"file_name": "0003.png", "objects": {"bbox": [[160.0, 31.0, 248.0, 616.0], [741.0, 68.0, 202.0, 401.0]], "category": [2, 2]}} 36 | ``` 37 | 38 | Please note that bboxes need to be in COCO format `[x, y, width, height]`. 39 | 40 | 41 | ### Image Requirements 42 | 43 | - Format: Ensure all images are in JPEG, JPG, or PNG format. 44 | 45 | - Quantity: Include at least 5 images to provide the model with sufficient examples for learning. 46 | 47 | - Exclusivity: The zip file should exclusively contain images and metadata.jsonl. 48 | No additional files or nested folders should be included. 49 | 50 | 51 | Some points to keep in mind: 52 | 53 | - The images must be jpeg, jpg or png. 54 | - There should be at least 5 images per split. 55 | - There must not be any other files in the zip file. 56 | - There must not be any other folders inside the zip folder. 57 | 58 | When train.zip is decompressed, it creates no folders: only images and metadata.jsonl. 59 | 60 | ## Parameters 61 | 62 | [[autodoc]] trainers.object_detection.params.ObjectDetectionParams 63 | -------------------------------------------------------------------------------- /docs/source/tasks/sentence_transformer.mdx: -------------------------------------------------------------------------------- 1 | # Sentence Transformers 2 | 3 | This task lets you easily train or fine-tune a Sentence Transformer model on your own dataset. 4 | 5 | AutoTrain supports the following types of sentence transformer finetuning: 6 | 7 | - `pair`: dataset with two sentences: anchor and positive 8 | - `pair_class`: dataset with two sentences: premise and hypothesis and a target label 9 | - `pair_score`: dataset with two sentences: sentence1 and sentence2 and a target score 10 | - `triplet`: dataset with three sentences: anchor, positive and negative 11 | - `qa`: dataset with two sentences: query and answer 12 | 13 | ## Data Format 14 | 15 | Sentence Transformers finetuning accepts data in CSV/JSONL format. You can also use a dataset from Hugging Face Hub. 16 | 17 | ### `pair` 18 | 19 | For `pair` training, the data should be in the following format: 20 | 21 | | anchor | positive | 22 | |--------|----------| 23 | | hello | hi | 24 | | how are you | I am fine | 25 | | What is your name? | My name is Abhishek | 26 | | Which is the best programming language? | Python | 27 | 28 | ### `pair_class` 29 | 30 | For `pair_class` training, the data should be in the following format: 31 | 32 | | premise | hypothesis | label | 33 | |---------|------------|-------| 34 | | hello | hi | 1 | 35 | | how are you | I am fine | 0 | 36 | | What is your name? | My name is Abhishek | 1 | 37 | | Which is the best programming language? | Python | 1 | 38 | 39 | ### `pair_score` 40 | 41 | For `pair_score` training, the data should be in the following format: 42 | 43 | | sentence1 | sentence2 | score | 44 | |-----------|-----------|-------| 45 | | hello | hi | 0.8 | 46 | | how are you | I am fine | 0.2 | 47 | | What is your name? | My name is Abhishek | 0.9 | 48 | | Which is the best programming language? | Python | 0.7 | 49 | 50 | ### `triplet` 51 | 52 | For `triplet` training, the data should be in the following format: 53 | 54 | | anchor | positive | negative | 55 | |--------|----------|----------| 56 | | hello | hi | bye | 57 | | how are you | I am fine | I am not fine | 58 | | What is your name? | My name is Abhishek | Whats it to you? | 59 | | Which is the best programming language? | Python | Javascript | 60 | 61 | ### `qa` 62 | 63 | For `qa` training, the data should be in the following format: 64 | 65 | | query | answer | 66 | |-------|--------| 67 | | hello | hi | 68 | | how are you | I am fine | 69 | | What is your name? | My name is Abhishek | 70 | | Which is the best programming language? | Python | 71 | 72 | 73 | ## Parameters 74 | 75 | [[autodoc]] trainers.sent_transformers.params.SentenceTransformersParams 76 | -------------------------------------------------------------------------------- /docs/source/tasks/seq2seq.mdx: -------------------------------------------------------------------------------- 1 | # Seq2Seq 2 | 3 | Seq2Seq is a task that involves converting a sequence of words into another sequence of words. 4 | It is used in machine translation, text summarization, and question answering. 5 | 6 | ## Data Format 7 | 8 | You can have the dataset as a CSV file: 9 | 10 | ```csv 11 | text,target 12 | "this movie is great","dieser Film ist großartig" 13 | "this movie is bad","dieser Film ist schlecht" 14 | . 15 | . 16 | . 17 | ``` 18 | 19 | Or as a JSONL file: 20 | 21 | ```json 22 | {"text": "this movie is great", "target": "dieser Film ist großartig"} 23 | {"text": "this movie is bad", "target": "dieser Film ist schlecht"} 24 | . 25 | . 26 | . 27 | ``` 28 | 29 | 30 | ## Columns 31 | 32 | Your CSV/JSONL dataset must have two columns: `text` and `target`. 33 | 34 | 35 | ## Parameters 36 | 37 | [[autodoc]] trainers.seq2seq.params.Seq2SeqParams 38 | -------------------------------------------------------------------------------- /docs/source/tasks/tabular.mdx: -------------------------------------------------------------------------------- 1 | # Tabular Classification / Regression 2 | 3 | Using AutoTrain, you can train a model to classify or regress tabular data easily. 4 | All you need to do is select from a list of models and upload your dataset. 5 | Parameter tuning is done automatically. 6 | 7 | ## Models 8 | 9 | The following models are available for tabular classification / regression. 10 | 11 | - xgboost 12 | - random_forest 13 | - ridge 14 | - logistic_regression 15 | - svm 16 | - extra_trees 17 | - gradient_boosting 18 | - adaboost 19 | - decision_tree 20 | - knn 21 | 22 | 23 | ## Data Format 24 | 25 | ```csv 26 | id,category1,category2,feature1,target 27 | 1,A,X,0.3373961604172684,1 28 | 2,B,Z,0.6481718720511972,0 29 | 3,A,Y,0.36824153984054797,1 30 | 4,B,Z,0.9571551589530464,1 31 | 5,B,Z,0.14035078041264515,1 32 | 6,C,X,0.8700872583584364,1 33 | 7,A,Y,0.4736080452737105,0 34 | 8,C,Y,0.8009107519796442,1 35 | 9,A,Y,0.5204774795512048,0 36 | 10,A,Y,0.6788795301189603,0 37 | . 38 | . 39 | . 40 | ``` 41 | 42 | ## Columns 43 | 44 | Your CSV dataset must have two columns: `id` and `target`. 45 | 46 | 47 | ## Parameters 48 | 49 | [[autodoc]] trainers.tabular.params.TabularParams 50 | -------------------------------------------------------------------------------- /docs/source/tasks/token_classification.mdx: -------------------------------------------------------------------------------- 1 | # Token Classification 2 | 3 | Token classification is the task of classifying each token in a sequence. This can be used 4 | for Named Entity Recognition (NER), Part-of-Speech (POS) tagging, and more. Get your data ready in 5 | proper format and then with just a few clicks, your state-of-the-art model will be ready to 6 | be used in production. 7 | 8 | ## Data Format 9 | 10 | The data should be in the following CSV format: 11 | 12 | ```csv 13 | tokens,tags 14 | "['I', 'love', 'Paris']","['O', 'O', 'B-LOC']" 15 | "['I', 'live', 'in', 'New', 'York']","['O', 'O', 'O', 'B-LOC', 'I-LOC']" 16 | . 17 | . 18 | . 19 | ``` 20 | 21 | or you can also use JSONL format: 22 | 23 | ```json 24 | {"tokens": ["I", "love", "Paris"],"tags": ["O", "O", "B-LOC"]} 25 | {"tokens": ["I", "live", "in", "New", "York"],"tags": ["O", "O", "O", "B-LOC", "I-LOC"]} 26 | . 27 | . 28 | . 29 | ``` 30 | 31 | As you can see, we have two columns in the CSV file. One column is the tokens and the other 32 | is the tags. Both the columns are stringified lists! The tokens column contains the tokens 33 | of the sentence and the tags column contains the tags for each token. 34 | 35 | If your CSV is huge, you can divide it into multiple CSV files and upload them separately. 36 | Please make sure that the column names are the same in all CSV files. 37 | 38 | One way to divide the CSV file using pandas is as follows: 39 | 40 | ```python 41 | import pandas as pd 42 | 43 | # Set the chunk size 44 | chunk_size = 1000 45 | i = 1 46 | 47 | # Open the CSV file and read it in chunks 48 | for chunk in pd.read_csv('example.csv', chunksize=chunk_size): 49 | # Save each chunk to a new file 50 | chunk.to_csv(f'chunk_{i}.csv', index=False) 51 | i += 1 52 | ``` 53 | 54 | 55 | Sample dataset from HuggingFace Hub: [conll2003](https://huggingface.co/datasets/eriktks/conll2003) 56 | 57 | 58 | ## Columns 59 | 60 | Your CSV/JSONL dataset must have two columns: `tokens` and `tags`. 61 | 62 | 63 | ## Parameters 64 | 65 | [[autodoc]] trainers.token_classification.params.TokenClassificationParams 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.4.23 2 | datasets[vision]~=3.2.0 3 | evaluate==0.4.3 4 | ipadic==1.0.0 5 | jiwer==3.0.5 6 | joblib==1.4.2 7 | loguru==0.7.3 8 | pandas==2.2.3 9 | nltk==3.9.1 10 | optuna==4.1.0 11 | Pillow==11.0.0 12 | sacremoses==0.1.1 13 | scikit-learn==1.6.0 14 | sentencepiece==0.2.0 15 | tqdm==4.67.1 16 | werkzeug==3.1.3 17 | xgboost==2.1.3 18 | huggingface_hub==0.27.0 19 | requests==2.32.3 20 | einops==0.8.0 21 | packaging==24.2 22 | cryptography==44.0.0 23 | nvitop==1.3.2 24 | # latest versions 25 | tensorboard==2.18.0 26 | peft==0.14.0 27 | trl==0.13.0 28 | tiktoken==0.8.0 29 | transformers==4.48.0 30 | accelerate==1.2.1 31 | bitsandbytes==0.45.0 32 | # extras 33 | rouge_score==0.1.2 34 | py7zr==0.22.0 35 | fastapi==0.115.6 36 | uvicorn==0.34.0 37 | python-multipart==0.0.20 38 | pydantic==2.10.4 39 | hf-transfer 40 | pyngrok==7.2.1 41 | authlib==1.4.0 42 | itsdangerous==2.2.0 43 | seqeval==1.2.2 44 | httpx==0.28.1 45 | pyyaml==6.0.2 46 | timm==1.0.12 47 | torchmetrics==1.6.0 48 | pycocotools==2.0.8 49 | sentence-transformers==3.3.1 50 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE 3 | version = attr: autotrain.__version__ 4 | 5 | [isort] 6 | ensure_newline_before_comments = True 7 | force_grid_wrap = 0 8 | include_trailing_comma = True 9 | line_length = 119 10 | lines_after_imports = 2 11 | multi_line_output = 3 12 | use_parentheses = True 13 | 14 | [flake8] 15 | ignore = E203, E501, W503 16 | max-line-length = 119 17 | per-file-ignores = 18 | # imported but unused 19 | __init__.py: F401, E402 20 | src/autotrain/params.py: F401 21 | exclude = 22 | .git, 23 | .venv, 24 | __pycache__, 25 | dist 26 | build -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | """ 3 | HuggingFace / AutoTrain Advanced 4 | """ 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | DOCLINES = __doc__.split("\n") 11 | 12 | this_directory = os.path.abspath(os.path.dirname(__file__)) 13 | with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: 14 | LONG_DESCRIPTION = f.read() 15 | 16 | # get INSTALL_REQUIRES from requirements.txt 17 | INSTALL_REQUIRES = [] 18 | requirements_path = os.path.join(this_directory, "requirements.txt") 19 | with open(requirements_path, encoding="utf-8") as f: 20 | for line in f: 21 | # Exclude 'bitsandbytes' if installing on macOS 22 | if "bitsandbytes" in line: 23 | line = line.strip() + " ; sys_platform == 'linux'" 24 | INSTALL_REQUIRES.append(line.strip()) 25 | else: 26 | INSTALL_REQUIRES.append(line.strip()) 27 | 28 | QUALITY_REQUIRE = [ 29 | "black", 30 | "isort", 31 | "flake8==3.7.9", 32 | ] 33 | 34 | TESTS_REQUIRE = ["pytest"] 35 | 36 | CLIENT_REQUIRES = ["requests", "loguru"] 37 | 38 | 39 | EXTRAS_REQUIRE = { 40 | "base": INSTALL_REQUIRES, 41 | "dev": INSTALL_REQUIRES + QUALITY_REQUIRE + TESTS_REQUIRE, 42 | "quality": INSTALL_REQUIRES + QUALITY_REQUIRE, 43 | "docs": INSTALL_REQUIRES 44 | + [ 45 | "recommonmark", 46 | "sphinx==3.1.2", 47 | "sphinx-markdown-tables", 48 | "sphinx-rtd-theme==0.4.3", 49 | "sphinx-copybutton", 50 | ], 51 | "client": CLIENT_REQUIRES, 52 | } 53 | 54 | setup( 55 | name="autotrain-advanced", 56 | description=DOCLINES[0], 57 | long_description=LONG_DESCRIPTION, 58 | long_description_content_type="text/markdown", 59 | author="HuggingFace Inc.", 60 | author_email="autotrain@huggingface.co", 61 | url="https://github.com/huggingface/autotrain-advanced", 62 | download_url="https://github.com/huggingface/autotrain-advanced/tags", 63 | license="Apache 2.0", 64 | package_dir={"": "src"}, 65 | packages=find_packages("src"), 66 | extras_require=EXTRAS_REQUIRE, 67 | install_requires=INSTALL_REQUIRES, 68 | entry_points={"console_scripts": ["autotrain=autotrain.cli.autotrain:main"]}, 69 | classifiers=[ 70 | "Development Status :: 5 - Production/Stable", 71 | "Intended Audience :: Developers", 72 | "Intended Audience :: Education", 73 | "Intended Audience :: Science/Research", 74 | "License :: OSI Approved :: Apache Software License", 75 | "Operating System :: OS Independent", 76 | "Programming Language :: Python :: 3.8", 77 | "Programming Language :: Python :: 3.9", 78 | "Programming Language :: Python :: 3.10", 79 | "Programming Language :: Python :: 3.11", 80 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 81 | ], 82 | keywords="automl autonlp autotrain huggingface", 83 | data_files=[ 84 | ( 85 | "static", 86 | [ 87 | "src/autotrain/app/static/logo.png", 88 | "src/autotrain/app/static/scripts/fetch_data_and_update_models.js", 89 | "src/autotrain/app/static/scripts/listeners.js", 90 | "src/autotrain/app/static/scripts/utils.js", 91 | "src/autotrain/app/static/scripts/poll.js", 92 | "src/autotrain/app/static/scripts/logs.js", 93 | ], 94 | ), 95 | ( 96 | "templates", 97 | [ 98 | "src/autotrain/app/templates/index.html", 99 | "src/autotrain/app/templates/error.html", 100 | "src/autotrain/app/templates/duplicate.html", 101 | "src/autotrain/app/templates/login.html", 102 | ], 103 | ), 104 | ], 105 | include_package_data=True, 106 | ) 107 | -------------------------------------------------------------------------------- /src/autotrain/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020-2023 The HuggingFace AutoTrain Authors 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | # pylint: enable=line-too-long 18 | import os 19 | 20 | 21 | os.environ["BITSANDBYTES_NOWELCOME"] = "1" 22 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 23 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 24 | 25 | 26 | import warnings 27 | 28 | 29 | try: 30 | import torch._dynamo 31 | 32 | torch._dynamo.config.suppress_errors = True 33 | except ImportError: 34 | pass 35 | 36 | from autotrain.logging import Logger 37 | 38 | 39 | warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow") 40 | warnings.filterwarnings("ignore", category=UserWarning, module="transformers") 41 | warnings.filterwarnings("ignore", category=UserWarning, module="peft") 42 | warnings.filterwarnings("ignore", category=UserWarning, module="accelerate") 43 | warnings.filterwarnings("ignore", category=UserWarning, module="datasets") 44 | warnings.filterwarnings("ignore", category=FutureWarning, module="accelerate") 45 | warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub") 46 | 47 | logger = Logger().get_logger() 48 | __version__ = "0.8.37.dev0" 49 | 50 | 51 | def is_colab(): 52 | try: 53 | import google.colab 54 | 55 | return True 56 | except ImportError: 57 | return False 58 | 59 | 60 | def is_unsloth_available(): 61 | try: 62 | from unsloth import FastLanguageModel 63 | 64 | return True 65 | except Exception as e: 66 | logger.warning("Unsloth not available, continuing without it") 67 | logger.warning(e) 68 | return False 69 | -------------------------------------------------------------------------------- /src/autotrain/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/app/__init__.py -------------------------------------------------------------------------------- /src/autotrain/app/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fastapi import FastAPI, Request 4 | from fastapi.responses import RedirectResponse 5 | from fastapi.staticfiles import StaticFiles 6 | 7 | from autotrain import __version__, logger 8 | from autotrain.app.api_routes import api_router 9 | from autotrain.app.oauth import attach_oauth 10 | from autotrain.app.ui_routes import ui_router 11 | 12 | 13 | logger.info("Starting AutoTrain...") 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | app = FastAPI() 16 | if "SPACE_ID" in os.environ: 17 | attach_oauth(app) 18 | 19 | app.include_router(ui_router, prefix="/ui", include_in_schema=False) 20 | app.include_router(api_router, prefix="/api") 21 | static_path = os.path.join(BASE_DIR, "static") 22 | app.mount("/static", StaticFiles(directory=static_path), name="static") 23 | logger.info(f"AutoTrain version: {__version__}") 24 | logger.info("AutoTrain started successfully") 25 | 26 | 27 | @app.get("/") 28 | async def forward_to_ui(request: Request): 29 | """ 30 | Forwards the incoming request to the UI endpoint. 31 | 32 | Args: 33 | request (Request): The incoming HTTP request. 34 | 35 | Returns: 36 | RedirectResponse: A response object that redirects to the UI endpoint, 37 | including any query parameters from the original request. 38 | """ 39 | query_params = request.query_params 40 | url = "/ui/" 41 | if query_params: 42 | url += f"?{query_params}" 43 | return RedirectResponse(url=url) 44 | -------------------------------------------------------------------------------- /src/autotrain/app/db.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | 3 | 4 | class AutoTrainDB: 5 | """ 6 | A class to manage job records in a SQLite database. 7 | 8 | Attributes: 9 | ----------- 10 | db_path : str 11 | The path to the SQLite database file. 12 | conn : sqlite3.Connection 13 | The SQLite database connection object. 14 | c : sqlite3.Cursor 15 | The SQLite database cursor object. 16 | 17 | Methods: 18 | -------- 19 | __init__(db_path): 20 | Initializes the database connection and creates the jobs table if it does not exist. 21 | 22 | create_jobs_table(): 23 | Creates the jobs table in the database if it does not exist. 24 | 25 | add_job(pid): 26 | Adds a new job with the given process ID (pid) to the jobs table. 27 | 28 | get_running_jobs(): 29 | Retrieves a list of all running job process IDs (pids) from the jobs table. 30 | 31 | delete_job(pid): 32 | Deletes the job with the given process ID (pid) from the jobs table. 33 | """ 34 | 35 | def __init__(self, db_path): 36 | self.db_path = db_path 37 | self.conn = sqlite3.connect(db_path) 38 | self.c = self.conn.cursor() 39 | self.create_jobs_table() 40 | 41 | def create_jobs_table(self): 42 | self.c.execute( 43 | """CREATE TABLE IF NOT EXISTS jobs 44 | (id INTEGER PRIMARY KEY, pid INTEGER)""" 45 | ) 46 | self.conn.commit() 47 | 48 | def add_job(self, pid): 49 | sql = f"INSERT INTO jobs (pid) VALUES ({pid})" 50 | self.c.execute(sql) 51 | self.conn.commit() 52 | 53 | def get_running_jobs(self): 54 | self.c.execute("""SELECT pid FROM jobs""") 55 | running_pids = self.c.fetchall() 56 | running_pids = [pid[0] for pid in running_pids] 57 | return running_pids 58 | 59 | def delete_job(self, pid): 60 | sql = f"DELETE FROM jobs WHERE pid={pid}" 61 | self.c.execute(sql) 62 | self.conn.commit() 63 | -------------------------------------------------------------------------------- /src/autotrain/app/static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/app/static/logo.png -------------------------------------------------------------------------------- /src/autotrain/app/static/scripts/fetch_data_and_update_models.js: -------------------------------------------------------------------------------- 1 | document.addEventListener('DOMContentLoaded', function () { 2 | function fetchDataAndUpdateModels() { 3 | const taskValue = document.getElementById('task').value; 4 | const baseModelSelect = document.getElementById('base_model'); 5 | const queryParams = new URLSearchParams(window.location.search); 6 | const customModelsValue = queryParams.get('custom_models'); 7 | const baseModelInput = document.getElementById('base_model_input'); 8 | const baseModelCheckbox = document.getElementById('base_model_checkbox'); 9 | 10 | let fetchURL = `/ui/model_choices/${taskValue}`; 11 | if (customModelsValue) { 12 | fetchURL += `?custom_models=${customModelsValue}`; 13 | } 14 | baseModelSelect.innerHTML = 'Fetching models...'; 15 | fetch(fetchURL) 16 | .then(response => response.json()) 17 | .then(data => { 18 | const baseModelSelect = document.getElementById('base_model'); 19 | baseModelCheckbox.checked = false; 20 | baseModelSelect.classList.remove('hidden'); 21 | baseModelInput.classList.add('hidden'); 22 | baseModelSelect.innerHTML = ''; // Clear existing options 23 | data.forEach(model => { 24 | let option = document.createElement('option'); 25 | option.value = model.id; // Assuming each model has an 'id' 26 | option.textContent = model.name; // Assuming each model has a 'name' 27 | baseModelSelect.appendChild(option); 28 | }); 29 | }) 30 | .catch(error => console.error('Error:', error)); 31 | } 32 | document.getElementById('task').addEventListener('change', fetchDataAndUpdateModels); 33 | fetchDataAndUpdateModels(); 34 | }); -------------------------------------------------------------------------------- /src/autotrain/app/static/scripts/logs.js: -------------------------------------------------------------------------------- 1 | document.addEventListener('DOMContentLoaded', function () { 2 | var fetchLogsInterval; 3 | 4 | // Function to check the modal's display property and fetch logs if visible 5 | function fetchAndDisplayLogs() { 6 | var modal = document.getElementById('logs-modal'); 7 | var displayStyle = window.getComputedStyle(modal).display; 8 | 9 | // Check if the modal display property is 'flex' 10 | if (displayStyle === 'flex') { 11 | fetchLogs(); // Initial fetch when the modal is opened 12 | 13 | // Clear any existing interval to avoid duplicates 14 | clearInterval(fetchLogsInterval); 15 | 16 | // Set up the interval to fetch logs every 5 seconds 17 | fetchLogsInterval = setInterval(fetchLogs, 5000); 18 | } else { 19 | // Clear the interval when the modal is not displayed as 'flex' 20 | clearInterval(fetchLogsInterval); 21 | } 22 | } 23 | 24 | // Function to fetch logs from the server 25 | function fetchLogs() { 26 | fetch('/ui/logs') 27 | .then(response => response.json()) 28 | .then(data => { 29 | var logContainer = document.getElementById('logContent'); 30 | logContainer.innerHTML = ''; // Clear previous logs 31 | 32 | // Handling the case when logs are only available in local mode or no logs available 33 | if (typeof data.logs === 'string') { 34 | logContainer.textContent = data.logs; 35 | } else { 36 | // Assuming data.logs is an array of log entries 37 | data.logs.forEach(log => { 38 | if (log.trim().length > 0) { 39 | var p = document.createElement('p'); 40 | p.textContent = log; 41 | logContainer.appendChild(p); // Appends logs in order received 42 | } 43 | }); 44 | } 45 | }) 46 | .catch(error => console.error('Error fetching logs:', error)); 47 | } 48 | 49 | // Set up an observer to detect when the modal becomes visible or hidden 50 | var observer = new MutationObserver(function (mutations) { 51 | mutations.forEach(function (mutation) { 52 | if (mutation.attributeName === 'class') { 53 | fetchAndDisplayLogs(); 54 | } 55 | }); 56 | }); 57 | 58 | var modal = document.getElementById('logs-modal'); 59 | observer.observe(modal, { 60 | attributes: true //configure it to listen to attribute changes 61 | }); 62 | }); -------------------------------------------------------------------------------- /src/autotrain/app/static/scripts/poll.js: -------------------------------------------------------------------------------- 1 | document.addEventListener('DOMContentLoaded', (event) => { 2 | function pollAccelerators() { 3 | const numAcceleratorsElement = document.getElementById('num_accelerators'); 4 | if (autotrain_local_value === 0) { 5 | numAcceleratorsElement.innerText = 'Accelerators: Only available in local mode.'; 6 | numAcceleratorsElement.style.display = 'block'; // Ensure the element is visible 7 | return; 8 | } 9 | 10 | // Send a request to the /accelerators endpoint 11 | fetch('/ui/accelerators') 12 | .then(response => response.json()) // Assuming the response is in JSON format 13 | .then(data => { 14 | // Update the paragraph with the number of accelerators 15 | document.getElementById('num_accelerators').innerText = `Accelerators: ${data.accelerators}`; 16 | }) 17 | .catch(error => { 18 | console.error('Error:', error); 19 | // Update the paragraph to show an error message 20 | document.getElementById('num_accelerators').innerText = 'Accelerators: Error fetching data'; 21 | }); 22 | } 23 | function pollModelTrainingStatus() { 24 | // Send a request to the /is_model_training endpoint 25 | 26 | if (autotrain_local_value === 0) { 27 | const statusParagraph = document.getElementById('is_model_training'); 28 | statusParagraph.innerText = 'Running jobs: Only available in local mode.'; 29 | statusParagraph.style.display = 'block'; 30 | return; 31 | } 32 | fetch('/ui/is_model_training') 33 | .then(response => response.json()) // Assuming the response is in JSON format 34 | .then(data => { 35 | // Construct the message to display 36 | let message = data.model_training ? 'Running job PID(s): ' + data.pids.join(', ') : 'No running jobs'; 37 | 38 | // Update the paragraph with the status of model training 39 | let statusParagraph = document.getElementById('is_model_training'); 40 | statusParagraph.innerText = message; 41 | let stopTrainingButton = document.getElementById('stop-training-button'); 42 | let startTrainingButton = document.getElementById('start-training-button'); 43 | 44 | // Change the text color based on the model training status 45 | if (data.model_training) { 46 | // Set text color to red if jobs are running 47 | statusParagraph.style.color = 'red'; 48 | stopTrainingButton.style.display = 'block'; 49 | startTrainingButton.style.display = 'none'; 50 | } else { 51 | // Set text color to green if no jobs are running 52 | statusParagraph.style.color = 'green'; 53 | stopTrainingButton.style.display = 'none'; 54 | startTrainingButton.style.display = 'block'; 55 | } 56 | }) 57 | .catch(error => { 58 | console.error('Error:', error); 59 | // Update the paragraph to show an error message 60 | let statusParagraph = document.getElementById('is_model_training'); 61 | statusParagraph.innerText = 'Error fetching training status'; 62 | statusParagraph.style.color = 'red'; // Set error message color to red 63 | }); 64 | } 65 | 66 | setInterval(pollAccelerators, 10000); 67 | setInterval(pollModelTrainingStatus, 5000); 68 | pollAccelerators(); 69 | pollModelTrainingStatus(); 70 | }); -------------------------------------------------------------------------------- /src/autotrain/app/templates/duplicate.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 16 | 17 | 18 | 19 |
20 |
21 | AutoTrain 22 |
23 |
24 | 25 |
26 |

Error

27 |

Please DUPLICATE 30 | this space in order to use it

31 |
32 | 33 | 34 | -------------------------------------------------------------------------------- /src/autotrain/app/templates/error.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 16 | 17 | 18 | 19 |
20 |
21 | AutoTrain 22 |
23 |
24 | 25 |
26 |

Error

27 |

HF_TOKEN environment variable is not set.

28 | Go back to Home 29 |
30 | 31 | 32 | -------------------------------------------------------------------------------- /src/autotrain/app/templates/login.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 16 | 17 | 18 | 24 | 25 | 26 |
27 |
28 | AutoTrain 29 |
30 |
31 | 32 |
33 |

Please login to use 35 | AutoTrain

36 |
37 | 38 | Login using Hugging Face 40 | 41 |
42 |

Alternatively, if you face login issues, 43 | you can add your 44 | Hugging Face Write Token to this space as a secret in space settings. Note: The name of secret must be 45 | HF_TOKEN and the value must be your Hugging Face WRITE token! You can find your tokens in user settings.

46 |
47 | Docs | 49 | GitHub 51 |
52 |
53 | 54 | 55 | -------------------------------------------------------------------------------- /src/autotrain/app/training_api.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import signal 4 | import sys 5 | from contextlib import asynccontextmanager 6 | 7 | from fastapi import FastAPI 8 | 9 | from autotrain import logger 10 | from autotrain.app.db import AutoTrainDB 11 | from autotrain.app.utils import get_running_jobs, kill_process_by_pid 12 | from autotrain.utils import run_training 13 | 14 | 15 | HF_TOKEN = os.environ.get("HF_TOKEN") 16 | AUTOTRAIN_USERNAME = os.environ.get("AUTOTRAIN_USERNAME") 17 | PROJECT_NAME = os.environ.get("PROJECT_NAME") 18 | TASK_ID = int(os.environ.get("TASK_ID")) 19 | PARAMS = os.environ.get("PARAMS") 20 | DATA_PATH = os.environ.get("DATA_PATH") 21 | MODEL = os.environ.get("MODEL") 22 | DB = AutoTrainDB("autotrain.db") 23 | 24 | 25 | def graceful_exit(signum, frame): 26 | """ 27 | Handles the SIGTERM signal to perform cleanup and exit the program gracefully. 28 | 29 | Args: 30 | signum (int): The signal number. 31 | frame (FrameType): The current stack frame (or None). 32 | 33 | Logs a message indicating that SIGTERM was received and then exits the program with status code 0. 34 | """ 35 | logger.info("SIGTERM received. Performing cleanup...") 36 | sys.exit(0) 37 | 38 | 39 | signal.signal(signal.SIGTERM, graceful_exit) 40 | 41 | 42 | class BackgroundRunner: 43 | """ 44 | A class to handle background running tasks. 45 | 46 | Methods 47 | ------- 48 | run_main(): 49 | Continuously checks for running jobs and shuts down the server if no jobs are found. 50 | """ 51 | 52 | async def run_main(self): 53 | while True: 54 | running_jobs = get_running_jobs(DB) 55 | if not running_jobs: 56 | logger.info("No running jobs found. Shutting down the server.") 57 | kill_process_by_pid(os.getpid()) 58 | await asyncio.sleep(30) 59 | 60 | 61 | runner = BackgroundRunner() 62 | 63 | 64 | @asynccontextmanager 65 | async def lifespan(app: FastAPI): 66 | """ 67 | Manages the lifespan of the FastAPI application. 68 | 69 | This function is responsible for starting the training process and 70 | managing a background task runner. It logs the process ID of the 71 | training job, adds the job to the database, and ensures the background 72 | task is properly cancelled when the application shuts down. 73 | 74 | Args: 75 | app (FastAPI): The FastAPI application instance. 76 | 77 | Yields: 78 | None: This function is a generator that yields control back to the 79 | FastAPI application lifecycle. 80 | """ 81 | process_pid = run_training(params=PARAMS, task_id=TASK_ID) 82 | logger.info(f"Started training with PID {process_pid}") 83 | DB.add_job(process_pid) 84 | task = asyncio.create_task(runner.run_main()) 85 | yield 86 | 87 | task.cancel() 88 | try: 89 | await task 90 | except asyncio.CancelledError: 91 | logger.info("Background runner task cancelled.") 92 | 93 | 94 | api = FastAPI(lifespan=lifespan) 95 | logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}") 96 | logger.info(f"PROJECT_NAME: {PROJECT_NAME}") 97 | logger.info(f"TASK_ID: {TASK_ID}") 98 | logger.info(f"DATA_PATH: {DATA_PATH}") 99 | logger.info(f"MODEL: {MODEL}") 100 | 101 | 102 | @api.get("/") 103 | async def root(): 104 | return "Your model is being trained..." 105 | 106 | 107 | @api.get("/health") 108 | async def health(): 109 | return "OK" 110 | -------------------------------------------------------------------------------- /src/autotrain/backends/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/backends/__init__.py -------------------------------------------------------------------------------- /src/autotrain/backends/endpoints.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | from autotrain.backends.base import BaseBackend 4 | 5 | 6 | ENDPOINTS_URL = "https://api.endpoints.huggingface.cloud/v2/endpoint/" 7 | 8 | 9 | class EndpointsRunner(BaseBackend): 10 | """ 11 | EndpointsRunner is responsible for creating and managing endpoint instances. 12 | 13 | Methods 14 | ------- 15 | create(): 16 | Creates an endpoint instance with the specified hardware and model parameters. 17 | 18 | create() Method 19 | --------------- 20 | Creates an endpoint instance with the specified hardware and model parameters. 21 | 22 | Parameters 23 | ---------- 24 | None 25 | 26 | Returns 27 | ------- 28 | str 29 | The name of the created endpoint instance. 30 | 31 | Raises 32 | ------ 33 | requests.exceptions.RequestException 34 | If there is an issue with the HTTP request. 35 | """ 36 | 37 | def create(self): 38 | hardware = self.available_hardware[self.backend] 39 | accelerator = hardware.split("_")[2] 40 | instance_size = hardware.split("_")[3] 41 | region = hardware.split("_")[1] 42 | vendor = hardware.split("_")[0] 43 | instance_type = hardware.split("_")[4] 44 | payload = { 45 | "accountId": self.username, 46 | "compute": { 47 | "accelerator": accelerator, 48 | "instanceSize": instance_size, 49 | "instanceType": instance_type, 50 | "scaling": {"maxReplica": 1, "minReplica": 1}, 51 | }, 52 | "model": { 53 | "framework": "custom", 54 | "image": { 55 | "custom": { 56 | "env": { 57 | "HF_TOKEN": self.params.token, 58 | "AUTOTRAIN_USERNAME": self.username, 59 | "PROJECT_NAME": self.params.project_name, 60 | "PARAMS": self.params.model_dump_json(), 61 | "DATA_PATH": self.params.data_path, 62 | "TASK_ID": str(self.task_id), 63 | "MODEL": self.params.model, 64 | "ENDPOINT_ID": f"{self.username}/{self.params.project_name}", 65 | }, 66 | "health_route": "/", 67 | "port": 7860, 68 | "url": "public.ecr.aws/z4c3o6n6/autotrain-api:latest", 69 | } 70 | }, 71 | "repository": "autotrain-projects/autotrain-advanced", 72 | "revision": "main", 73 | "task": "custom", 74 | }, 75 | "name": self.params.project_name, 76 | "provider": {"region": region, "vendor": vendor}, 77 | "type": "protected", 78 | } 79 | headers = {"Authorization": f"Bearer {self.params.token}"} 80 | r = requests.post( 81 | ENDPOINTS_URL + self.username, 82 | json=payload, 83 | headers=headers, 84 | timeout=120, 85 | ) 86 | return r.json()["name"] 87 | -------------------------------------------------------------------------------- /src/autotrain/backends/local.py: -------------------------------------------------------------------------------- 1 | from autotrain import logger 2 | from autotrain.backends.base import BaseBackend 3 | from autotrain.utils import run_training 4 | 5 | 6 | class LocalRunner(BaseBackend): 7 | """ 8 | LocalRunner is a class that inherits from BaseBackend and is responsible for managing local training tasks. 9 | 10 | Methods: 11 | create(): 12 | Starts the local training process by retrieving parameters and task ID from environment variables. 13 | Logs the start of the training process. 14 | Runs the training with the specified parameters and task ID. 15 | If the `wait` attribute is False, logs the training process ID (PID). 16 | Returns the training process ID (PID). 17 | """ 18 | 19 | def create(self): 20 | logger.info("Starting local training...") 21 | params = self.env_vars["PARAMS"] 22 | task_id = int(self.env_vars["TASK_ID"]) 23 | training_pid = run_training(params, task_id, local=True, wait=self.wait) 24 | if not self.wait: 25 | logger.info(f"Training PID: {training_pid}") 26 | return training_pid 27 | -------------------------------------------------------------------------------- /src/autotrain/backends/spaces.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | from huggingface_hub import HfApi 4 | 5 | from autotrain.backends.base import BaseBackend 6 | from autotrain.trainers.generic.params import GenericParams 7 | 8 | 9 | _DOCKERFILE = """ 10 | FROM huggingface/autotrain-advanced:latest 11 | 12 | CMD pip uninstall -y autotrain-advanced && pip install -U autotrain-advanced && autotrain api --port 7860 --host 0.0.0.0 13 | """ 14 | 15 | # format _DOCKERFILE 16 | _DOCKERFILE = _DOCKERFILE.replace("\n", " ").replace(" ", "\n").strip() 17 | 18 | 19 | class SpaceRunner(BaseBackend): 20 | """ 21 | SpaceRunner is a backend class responsible for creating and managing training jobs on Hugging Face Spaces. 22 | 23 | Methods 24 | ------- 25 | _create_readme(): 26 | Creates a README.md file content for the space. 27 | 28 | _add_secrets(api, space_id): 29 | Adds necessary secrets to the space repository. 30 | 31 | create(): 32 | Creates a new space repository, adds secrets, and uploads necessary files. 33 | """ 34 | 35 | def _create_readme(self): 36 | _readme = "---\n" 37 | _readme += f"title: {self.params.project_name}\n" 38 | _readme += "emoji: 🚀\n" 39 | _readme += "colorFrom: green\n" 40 | _readme += "colorTo: indigo\n" 41 | _readme += "sdk: docker\n" 42 | _readme += "pinned: false\n" 43 | _readme += "tags:\n" 44 | _readme += "- autotrain\n" 45 | _readme += "duplicated_from: autotrain-projects/autotrain-advanced\n" 46 | _readme += "---\n" 47 | _readme = io.BytesIO(_readme.encode()) 48 | return _readme 49 | 50 | def _add_secrets(self, api, space_id): 51 | if isinstance(self.params, GenericParams): 52 | for k, v in self.params.env.items(): 53 | api.add_space_secret(repo_id=space_id, key=k, value=v) 54 | self.params.env = {} 55 | 56 | api.add_space_secret(repo_id=space_id, key="HF_TOKEN", value=self.params.token) 57 | api.add_space_secret(repo_id=space_id, key="AUTOTRAIN_USERNAME", value=self.username) 58 | api.add_space_secret(repo_id=space_id, key="PROJECT_NAME", value=self.params.project_name) 59 | api.add_space_secret(repo_id=space_id, key="TASK_ID", value=str(self.task_id)) 60 | api.add_space_secret(repo_id=space_id, key="PARAMS", value=self.params.model_dump_json()) 61 | api.add_space_secret(repo_id=space_id, key="DATA_PATH", value=self.params.data_path) 62 | 63 | if not isinstance(self.params, GenericParams): 64 | api.add_space_secret(repo_id=space_id, key="MODEL", value=self.params.model) 65 | 66 | def create(self): 67 | api = HfApi(token=self.params.token) 68 | space_id = f"{self.username}/autotrain-{self.params.project_name}" 69 | api.create_repo( 70 | repo_id=space_id, 71 | repo_type="space", 72 | space_sdk="docker", 73 | space_hardware=self.available_hardware[self.backend], 74 | private=True, 75 | ) 76 | self._add_secrets(api, space_id) 77 | api.set_space_sleep_time(repo_id=space_id, sleep_time=604800) 78 | readme = self._create_readme() 79 | api.upload_file( 80 | path_or_fileobj=readme, 81 | path_in_repo="README.md", 82 | repo_id=space_id, 83 | repo_type="space", 84 | ) 85 | 86 | _dockerfile = io.BytesIO(_DOCKERFILE.encode()) 87 | api.upload_file( 88 | path_or_fileobj=_dockerfile, 89 | path_in_repo="Dockerfile", 90 | repo_id=space_id, 91 | repo_type="space", 92 | ) 93 | return space_id 94 | -------------------------------------------------------------------------------- /src/autotrain/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from argparse import ArgumentParser 3 | 4 | 5 | class BaseAutoTrainCommand(ABC): 6 | @staticmethod 7 | @abstractmethod 8 | def register_subcommand(parser: ArgumentParser): 9 | raise NotImplementedError() 10 | 11 | @abstractmethod 12 | def run(self): 13 | raise NotImplementedError() 14 | -------------------------------------------------------------------------------- /src/autotrain/cli/autotrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from autotrain import __version__, logger 4 | from autotrain.cli.run_api import RunAutoTrainAPICommand 5 | from autotrain.cli.run_app import RunAutoTrainAppCommand 6 | from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand 7 | from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand 8 | from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand 9 | from autotrain.cli.run_llm import RunAutoTrainLLMCommand 10 | from autotrain.cli.run_object_detection import RunAutoTrainObjectDetectionCommand 11 | from autotrain.cli.run_sent_tranformers import RunAutoTrainSentenceTransformersCommand 12 | from autotrain.cli.run_seq2seq import RunAutoTrainSeq2SeqCommand 13 | from autotrain.cli.run_setup import RunSetupCommand 14 | from autotrain.cli.run_spacerunner import RunAutoTrainSpaceRunnerCommand 15 | from autotrain.cli.run_tabular import RunAutoTrainTabularCommand 16 | from autotrain.cli.run_text_classification import RunAutoTrainTextClassificationCommand 17 | from autotrain.cli.run_text_regression import RunAutoTrainTextRegressionCommand 18 | from autotrain.cli.run_token_classification import RunAutoTrainTokenClassificationCommand 19 | from autotrain.cli.run_tools import RunAutoTrainToolsCommand 20 | from autotrain.parser import AutoTrainConfigParser 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser( 25 | "AutoTrain advanced CLI", 26 | usage="autotrain []", 27 | epilog="For more information about a command, run: `autotrain --help`", 28 | ) 29 | parser.add_argument("--version", "-v", help="Display AutoTrain version", action="store_true") 30 | parser.add_argument("--config", help="Optional configuration file", type=str) 31 | commands_parser = parser.add_subparsers(help="commands") 32 | 33 | # Register commands 34 | RunAutoTrainAppCommand.register_subcommand(commands_parser) 35 | RunAutoTrainLLMCommand.register_subcommand(commands_parser) 36 | RunSetupCommand.register_subcommand(commands_parser) 37 | RunAutoTrainAPICommand.register_subcommand(commands_parser) 38 | RunAutoTrainTextClassificationCommand.register_subcommand(commands_parser) 39 | RunAutoTrainImageClassificationCommand.register_subcommand(commands_parser) 40 | RunAutoTrainTabularCommand.register_subcommand(commands_parser) 41 | RunAutoTrainSpaceRunnerCommand.register_subcommand(commands_parser) 42 | RunAutoTrainSeq2SeqCommand.register_subcommand(commands_parser) 43 | RunAutoTrainTokenClassificationCommand.register_subcommand(commands_parser) 44 | RunAutoTrainToolsCommand.register_subcommand(commands_parser) 45 | RunAutoTrainTextRegressionCommand.register_subcommand(commands_parser) 46 | RunAutoTrainObjectDetectionCommand.register_subcommand(commands_parser) 47 | RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser) 48 | RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser) 49 | RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser) 50 | 51 | args = parser.parse_args() 52 | 53 | if args.version: 54 | print(__version__) 55 | exit(0) 56 | 57 | if args.config: 58 | logger.info(f"Using AutoTrain configuration: {args.config}") 59 | cp = AutoTrainConfigParser(args.config) 60 | cp.run() 61 | exit(0) 62 | 63 | if not hasattr(args, "func"): 64 | parser.print_help() 65 | exit(1) 66 | 67 | command = args.func(args) 68 | command.run() 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /src/autotrain/cli/run_api.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from . import BaseAutoTrainCommand 4 | 5 | 6 | def run_api_command_factory(args): 7 | return RunAutoTrainAPICommand( 8 | args.port, 9 | args.host, 10 | args.task, 11 | ) 12 | 13 | 14 | class RunAutoTrainAPICommand(BaseAutoTrainCommand): 15 | """ 16 | Command to run the AutoTrain API. 17 | 18 | This command sets up and runs the AutoTrain API using the specified host and port. 19 | 20 | Methods 21 | ------- 22 | register_subcommand(parser: ArgumentParser) 23 | Registers the 'api' subcommand and its arguments to the provided parser. 24 | 25 | __init__(port: int, host: str, task: str) 26 | Initializes the command with the specified port, host, and task. 27 | 28 | run() 29 | Runs the AutoTrain API using the uvicorn server. 30 | """ 31 | 32 | @staticmethod 33 | def register_subcommand(parser: ArgumentParser): 34 | run_api_parser = parser.add_parser( 35 | "api", 36 | description="✨ Run AutoTrain API", 37 | ) 38 | run_api_parser.add_argument( 39 | "--port", 40 | type=int, 41 | default=7860, 42 | help="Port to run the api on", 43 | required=False, 44 | ) 45 | run_api_parser.add_argument( 46 | "--host", 47 | type=str, 48 | default="127.0.0.1", 49 | help="Host to run the api on", 50 | required=False, 51 | ) 52 | run_api_parser.add_argument( 53 | "--task", 54 | type=str, 55 | required=False, 56 | help="Task to run", 57 | ) 58 | run_api_parser.set_defaults(func=run_api_command_factory) 59 | 60 | def __init__(self, port, host, task): 61 | self.port = port 62 | self.host = host 63 | self.task = task 64 | 65 | def run(self): 66 | import uvicorn 67 | 68 | from autotrain.app.training_api import api 69 | 70 | uvicorn.run(api, host=self.host, port=self.port) 71 | -------------------------------------------------------------------------------- /src/autotrain/cli/run_seq2seq.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from autotrain import logger 4 | from autotrain.cli.utils import get_field_info 5 | from autotrain.project import AutoTrainProject 6 | from autotrain.trainers.seq2seq.params import Seq2SeqParams 7 | 8 | from . import BaseAutoTrainCommand 9 | 10 | 11 | def run_seq2seq_command_factory(args): 12 | return RunAutoTrainSeq2SeqCommand(args) 13 | 14 | 15 | class RunAutoTrainSeq2SeqCommand(BaseAutoTrainCommand): 16 | @staticmethod 17 | def register_subcommand(parser: ArgumentParser): 18 | arg_list = get_field_info(Seq2SeqParams) 19 | arg_list = [ 20 | { 21 | "arg": "--train", 22 | "help": "Command to train the model", 23 | "required": False, 24 | "action": "store_true", 25 | }, 26 | { 27 | "arg": "--deploy", 28 | "help": "Command to deploy the model (limited availability)", 29 | "required": False, 30 | "action": "store_true", 31 | }, 32 | { 33 | "arg": "--inference", 34 | "help": "Command to run inference (limited availability)", 35 | "required": False, 36 | "action": "store_true", 37 | }, 38 | { 39 | "arg": "--backend", 40 | "help": "Backend", 41 | "required": False, 42 | "type": str, 43 | "default": "local", 44 | }, 45 | ] + arg_list 46 | run_seq2seq_parser = parser.add_parser("seq2seq", description="✨ Run AutoTrain Seq2Seq") 47 | for arg in arg_list: 48 | names = [arg["arg"]] + arg.get("alias", []) 49 | if "action" in arg: 50 | run_seq2seq_parser.add_argument( 51 | *names, 52 | dest=arg["arg"].replace("--", "").replace("-", "_"), 53 | help=arg["help"], 54 | required=arg.get("required", False), 55 | action=arg.get("action"), 56 | default=arg.get("default"), 57 | ) 58 | else: 59 | run_seq2seq_parser.add_argument( 60 | *names, 61 | dest=arg["arg"].replace("--", "").replace("-", "_"), 62 | help=arg["help"], 63 | required=arg.get("required", False), 64 | type=arg.get("type"), 65 | default=arg.get("default"), 66 | choices=arg.get("choices"), 67 | ) 68 | run_seq2seq_parser.set_defaults(func=run_seq2seq_command_factory) 69 | 70 | def __init__(self, args): 71 | self.args = args 72 | 73 | store_true_arg_names = ["train", "deploy", "inference", "auto_find_batch_size", "push_to_hub", "peft"] 74 | for arg_name in store_true_arg_names: 75 | if getattr(self.args, arg_name) is None: 76 | setattr(self.args, arg_name, False) 77 | 78 | if self.args.train: 79 | if self.args.project_name is None: 80 | raise ValueError("Project name must be specified") 81 | if self.args.data_path is None: 82 | raise ValueError("Data path must be specified") 83 | if self.args.model is None: 84 | raise ValueError("Model must be specified") 85 | if self.args.push_to_hub: 86 | if self.args.username is None: 87 | raise ValueError("Username must be specified for push to hub") 88 | else: 89 | raise ValueError("Must specify --train, --deploy or --inference") 90 | 91 | def run(self): 92 | logger.info("Running Seq2Seq Classification") 93 | if self.args.train: 94 | params = Seq2SeqParams(**vars(self.args)) 95 | project = AutoTrainProject(params=params, backend=self.args.backend, process=True) 96 | job_id = project.create() 97 | logger.info(f"Job ID: {job_id}") 98 | -------------------------------------------------------------------------------- /src/autotrain/cli/run_setup.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from argparse import ArgumentParser 3 | 4 | from autotrain import logger 5 | 6 | from . import BaseAutoTrainCommand 7 | 8 | 9 | def run_app_command_factory(args): 10 | return RunSetupCommand(args.update_torch, args.colab) 11 | 12 | 13 | class RunSetupCommand(BaseAutoTrainCommand): 14 | @staticmethod 15 | def register_subcommand(parser: ArgumentParser): 16 | run_setup_parser = parser.add_parser( 17 | "setup", 18 | description="✨ Run AutoTrain setup", 19 | ) 20 | run_setup_parser.add_argument( 21 | "--update-torch", 22 | action="store_true", 23 | help="Update PyTorch to latest version", 24 | ) 25 | run_setup_parser.add_argument( 26 | "--colab", 27 | action="store_true", 28 | help="Run setup for Google Colab", 29 | ) 30 | run_setup_parser.set_defaults(func=run_app_command_factory) 31 | 32 | def __init__(self, update_torch: bool, colab: bool = False): 33 | self.update_torch = update_torch 34 | self.colab = colab 35 | 36 | def run(self): 37 | if self.colab: 38 | cmd = "pip install -U xformers==0.0.24" 39 | else: 40 | cmd = "pip uninstall -y xformers" 41 | cmd = cmd.split() 42 | pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 43 | logger.info("Installing latest xformers") 44 | _, _ = pipe.communicate() 45 | logger.info("Successfully installed latest xformers") 46 | 47 | if self.update_torch: 48 | cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121" 49 | cmd = cmd.split() 50 | pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 51 | logger.info("Installing latest PyTorch") 52 | _, _ = pipe.communicate() 53 | logger.info("Successfully installed latest PyTorch") 54 | -------------------------------------------------------------------------------- /src/autotrain/cli/run_tools.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from . import BaseAutoTrainCommand 4 | 5 | 6 | def run_tools_command_factory(args): 7 | return RunAutoTrainToolsCommand(args) 8 | 9 | 10 | class RunAutoTrainToolsCommand(BaseAutoTrainCommand): 11 | @staticmethod 12 | def register_subcommand(parser: ArgumentParser): 13 | run_app_parser = parser.add_parser("tools", help="Run AutoTrain tools") 14 | subparsers = run_app_parser.add_subparsers(title="tools", dest="tool_name") 15 | 16 | merge_llm_parser = subparsers.add_parser( 17 | "merge-llm-adapter", 18 | help="Merge LLM Adapter tool", 19 | ) 20 | merge_llm_parser.add_argument( 21 | "--base-model-path", 22 | type=str, 23 | help="Base model path", 24 | ) 25 | merge_llm_parser.add_argument( 26 | "--adapter-path", 27 | type=str, 28 | help="Adapter path", 29 | ) 30 | merge_llm_parser.add_argument( 31 | "--token", 32 | type=str, 33 | help="Token", 34 | default=None, 35 | required=False, 36 | ) 37 | merge_llm_parser.add_argument( 38 | "--pad-to-multiple-of", 39 | type=int, 40 | help="Pad to multiple of", 41 | default=None, 42 | required=False, 43 | ) 44 | merge_llm_parser.add_argument( 45 | "--output-folder", 46 | type=str, 47 | help="Output folder", 48 | required=False, 49 | default=None, 50 | ) 51 | merge_llm_parser.add_argument( 52 | "--push-to-hub", 53 | action="store_true", 54 | help="Push to Hugging Face Hub", 55 | required=False, 56 | ) 57 | merge_llm_parser.set_defaults(func=run_tools_command_factory, merge_llm_adapter=True) 58 | 59 | convert_to_kohya_parser = subparsers.add_parser("convert_to_kohya", help="Convert to Kohya tool") 60 | convert_to_kohya_parser.add_argument( 61 | "--input-path", 62 | type=str, 63 | help="Input path", 64 | ) 65 | convert_to_kohya_parser.add_argument( 66 | "--output-path", 67 | type=str, 68 | help="Output path", 69 | ) 70 | convert_to_kohya_parser.set_defaults(func=run_tools_command_factory, convert_to_kohya=True) 71 | 72 | def __init__(self, args): 73 | self.args = args 74 | 75 | def run(self): 76 | if getattr(self.args, "merge_llm_adapter", False): 77 | self.run_merge_llm_adapter() 78 | if getattr(self.args, "convert_to_kohya", False): 79 | self.run_convert_to_kohya() 80 | 81 | def run_merge_llm_adapter(self): 82 | from autotrain.tools.merge_adapter import merge_llm_adapter 83 | 84 | merge_llm_adapter( 85 | base_model_path=self.args.base_model_path, 86 | adapter_path=self.args.adapter_path, 87 | token=self.args.token, 88 | output_folder=self.args.output_folder, 89 | pad_to_multiple_of=self.args.pad_to_multiple_of, 90 | push_to_hub=self.args.push_to_hub, 91 | ) 92 | 93 | def run_convert_to_kohya(self): 94 | from autotrain.tools.convert_to_kohya import convert_to_kohya 95 | 96 | convert_to_kohya( 97 | input_path=self.args.input_path, 98 | output_path=self.args.output_path, 99 | ) 100 | -------------------------------------------------------------------------------- /src/autotrain/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | HF_API = os.getenv("HF_API", "https://huggingface.co") 5 | -------------------------------------------------------------------------------- /src/autotrain/logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | 4 | from loguru import logger 5 | 6 | 7 | IS_ACCELERATE_AVAILABLE = False 8 | 9 | try: 10 | from accelerate.state import PartialState 11 | 12 | IS_ACCELERATE_AVAILABLE = True 13 | except ImportError: 14 | pass 15 | 16 | 17 | @dataclass 18 | class Logger: 19 | """ 20 | A custom logger class that sets up and manages logging configuration. 21 | 22 | Methods 23 | ------- 24 | __post_init__(): 25 | Initializes the logger with a specific format and sets up the logger. 26 | 27 | _should_log(record): 28 | Determines if a log record should be logged based on the process state. 29 | 30 | setup_logger(): 31 | Configures the logger to output to stdout with the specified format and filter. 32 | 33 | get_logger(): 34 | Returns the configured logger instance. 35 | """ 36 | 37 | def __post_init__(self): 38 | self.log_format = ( 39 | "{level: <8} | " 40 | "{time:YYYY-MM-DD HH:mm:ss} | " 41 | "{name}:{function}:{line} - " 42 | "{message}" 43 | ) 44 | self.logger = logger 45 | self.setup_logger() 46 | 47 | def _should_log(self, record): 48 | if not IS_ACCELERATE_AVAILABLE: 49 | return None 50 | return PartialState().is_main_process 51 | 52 | def setup_logger(self): 53 | self.logger.remove() 54 | self.logger.add( 55 | sys.stdout, 56 | format=self.log_format, 57 | filter=lambda x: self._should_log(x) if IS_ACCELERATE_AVAILABLE else None, 58 | ) 59 | 60 | def get_logger(self): 61 | return self.logger 62 | -------------------------------------------------------------------------------- /src/autotrain/params.py: -------------------------------------------------------------------------------- 1 | from autotrain.trainers.clm.params import LLMTrainingParams 2 | from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams 3 | from autotrain.trainers.image_classification.params import ImageClassificationParams 4 | from autotrain.trainers.image_regression.params import ImageRegressionParams 5 | from autotrain.trainers.object_detection.params import ObjectDetectionParams 6 | from autotrain.trainers.sent_transformers.params import SentenceTransformersParams 7 | from autotrain.trainers.seq2seq.params import Seq2SeqParams 8 | from autotrain.trainers.tabular.params import TabularParams 9 | from autotrain.trainers.text_classification.params import TextClassificationParams 10 | from autotrain.trainers.text_regression.params import TextRegressionParams 11 | from autotrain.trainers.token_classification.params import TokenClassificationParams 12 | from autotrain.trainers.vlm.params import VLMTrainingParams 13 | -------------------------------------------------------------------------------- /src/autotrain/preprocessor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/preprocessor/__init__.py -------------------------------------------------------------------------------- /src/autotrain/tasks.py: -------------------------------------------------------------------------------- 1 | NLP_TASKS = { 2 | "text_binary_classification": 1, 3 | "text_multi_class_classification": 2, 4 | "text_token_classification": 4, 5 | "text_extractive_question_answering": 5, 6 | "text_summarization": 8, 7 | "text_single_column_regression": 10, 8 | "speech_recognition": 11, 9 | "natural_language_inference": 22, 10 | "lm_training": 9, 11 | "seq2seq": 28, # 27 is reserved for generic training 12 | "sentence_transformers": 30, 13 | "vlm": 31, 14 | } 15 | 16 | VISION_TASKS = { 17 | "image_binary_classification": 17, 18 | "image_multi_class_classification": 18, 19 | "image_single_column_regression": 24, 20 | "image_object_detection": 29, 21 | } 22 | 23 | TABULAR_TASKS = { 24 | "tabular_binary_classification": 13, 25 | "tabular_multi_class_classification": 14, 26 | "tabular_multi_label_classification": 15, 27 | "tabular_single_column_regression": 16, 28 | "tabular": 26, 29 | } 30 | 31 | 32 | TASKS = { 33 | **NLP_TASKS, 34 | **VISION_TASKS, 35 | **TABULAR_TASKS, 36 | } 37 | -------------------------------------------------------------------------------- /src/autotrain/tests/test_cli.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/tests/test_cli.py -------------------------------------------------------------------------------- /src/autotrain/tests/test_dummy.py: -------------------------------------------------------------------------------- 1 | def test_dummy(): 2 | assert 1 + 1 == 2 3 | -------------------------------------------------------------------------------- /src/autotrain/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/tools/__init__.py -------------------------------------------------------------------------------- /src/autotrain/tools/convert_to_kohya.py: -------------------------------------------------------------------------------- 1 | from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya 2 | from safetensors.torch import load_file, save_file 3 | 4 | from autotrain import logger 5 | 6 | 7 | def convert_to_kohya(input_path, output_path): 8 | """ 9 | Converts a Lora state dictionary to a Kohya state dictionary and saves it to the specified output path. 10 | 11 | Args: 12 | input_path (str): The file path to the input Lora state dictionary. 13 | output_path (str): The file path where the converted Kohya state dictionary will be saved. 14 | 15 | Returns: 16 | None 17 | """ 18 | logger.info(f"Converting Lora state dict from {input_path} to Kohya state dict at {output_path}") 19 | lora_state_dict = load_file(input_path) 20 | peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) 21 | kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) 22 | save_file(kohya_state_dict, output_path) 23 | logger.info(f"Kohya state dict saved at {output_path}") 24 | -------------------------------------------------------------------------------- /src/autotrain/tools/merge_adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import PeftModel 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | 5 | from autotrain import logger 6 | from autotrain.trainers.common import ALLOW_REMOTE_CODE 7 | 8 | 9 | def merge_llm_adapter( 10 | base_model_path, adapter_path, token, output_folder=None, pad_to_multiple_of=None, push_to_hub=False 11 | ): 12 | """ 13 | Merges a language model adapter into a base model and optionally saves or pushes the merged model. 14 | 15 | Args: 16 | base_model_path (str): Path to the base model. 17 | adapter_path (str): Path to the adapter model. 18 | token (str): Authentication token for accessing the models. 19 | output_folder (str, optional): Directory to save the merged model. Defaults to None. 20 | pad_to_multiple_of (int, optional): If specified, pad the token embeddings to a multiple of this value. Defaults to None. 21 | push_to_hub (bool, optional): If True, push the merged model to the Hugging Face Hub. Defaults to False. 22 | 23 | Raises: 24 | ValueError: If neither `output_folder` nor `push_to_hub` is specified. 25 | 26 | Returns: 27 | None 28 | """ 29 | if output_folder is None and push_to_hub is False: 30 | raise ValueError("You must specify either --output_folder or --push_to_hub") 31 | 32 | logger.info("Loading adapter...") 33 | base_model = AutoModelForCausalLM.from_pretrained( 34 | base_model_path, 35 | torch_dtype=torch.float16, 36 | low_cpu_mem_usage=True, 37 | trust_remote_code=ALLOW_REMOTE_CODE, 38 | token=token, 39 | ) 40 | 41 | tokenizer = AutoTokenizer.from_pretrained( 42 | adapter_path, 43 | trust_remote_code=ALLOW_REMOTE_CODE, 44 | token=token, 45 | ) 46 | if pad_to_multiple_of: 47 | base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=pad_to_multiple_of) 48 | else: 49 | base_model.resize_token_embeddings(len(tokenizer)) 50 | 51 | model = PeftModel.from_pretrained( 52 | base_model, 53 | adapter_path, 54 | token=token, 55 | ) 56 | model = model.merge_and_unload() 57 | 58 | if output_folder is not None: 59 | logger.info("Saving target model...") 60 | model.save_pretrained(output_folder) 61 | tokenizer.save_pretrained(output_folder) 62 | logger.info(f"Model saved to {output_folder}") 63 | 64 | if push_to_hub: 65 | logger.info("Pushing model to Hugging Face Hub...") 66 | model.push_to_hub(adapter_path) 67 | tokenizer.push_to_hub(adapter_path) 68 | logger.info(f"Model pushed to Hugging Face Hub as {adapter_path}") 69 | -------------------------------------------------------------------------------- /src/autotrain/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/clm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/clm/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/clm/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from autotrain.trainers.clm.params import LLMTrainingParams 5 | from autotrain.trainers.common import monitor 6 | 7 | 8 | def parse_args(): 9 | # get training_config.json from the end user 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--training_config", type=str, required=True) 12 | return parser.parse_args() 13 | 14 | 15 | @monitor 16 | def train(config): 17 | if isinstance(config, dict): 18 | config = LLMTrainingParams(**config) 19 | 20 | if config.trainer == "default": 21 | from autotrain.trainers.clm.train_clm_default import train as train_default 22 | 23 | train_default(config) 24 | 25 | elif config.trainer == "sft": 26 | from autotrain.trainers.clm.train_clm_sft import train as train_sft 27 | 28 | train_sft(config) 29 | 30 | elif config.trainer == "reward": 31 | from autotrain.trainers.clm.train_clm_reward import train as train_reward 32 | 33 | train_reward(config) 34 | 35 | elif config.trainer == "dpo": 36 | from autotrain.trainers.clm.train_clm_dpo import train as train_dpo 37 | 38 | train_dpo(config) 39 | 40 | elif config.trainer == "orpo": 41 | from autotrain.trainers.clm.train_clm_orpo import train as train_orpo 42 | 43 | train_orpo(config) 44 | 45 | else: 46 | raise ValueError(f"trainer `{config.trainer}` not supported") 47 | 48 | 49 | if __name__ == "__main__": 50 | _args = parse_args() 51 | training_config = json.load(open(_args.training_config)) 52 | _config = LLMTrainingParams(**training_config) 53 | train(_config) 54 | -------------------------------------------------------------------------------- /src/autotrain/trainers/clm/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from peft import set_peft_model_state_dict 5 | from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments 6 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 7 | 8 | 9 | class SavePeftModelCallback(TrainerCallback): 10 | def on_save( 11 | self, 12 | args: TrainingArguments, 13 | state: TrainerState, 14 | control: TrainerControl, 15 | **kwargs, 16 | ): 17 | checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") 18 | 19 | kwargs["model"].save_pretrained(checkpoint_folder) 20 | 21 | pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") 22 | torch.save({}, pytorch_model_path) 23 | return control 24 | 25 | 26 | class LoadBestPeftModelCallback(TrainerCallback): 27 | def on_train_end( 28 | self, 29 | args: TrainingArguments, 30 | state: TrainerState, 31 | control: TrainerControl, 32 | **kwargs, 33 | ): 34 | print(f"Loading best peft model from {state.best_model_checkpoint} (score: {state.best_metric}).") 35 | best_model_path = os.path.join(state.best_model_checkpoint, "adapter_model.bin") 36 | adapters_weights = torch.load(best_model_path) 37 | model = kwargs["model"] 38 | set_peft_model_state_dict(model, adapters_weights) 39 | return control 40 | 41 | 42 | class SaveDeepSpeedPeftModelCallback(TrainerCallback): 43 | def __init__(self, trainer, save_steps=500): 44 | self.trainer = trainer 45 | self.save_steps = save_steps 46 | 47 | def on_step_end( 48 | self, 49 | args: TrainingArguments, 50 | state: TrainerState, 51 | control: TrainerControl, 52 | **kwargs, 53 | ): 54 | if (state.global_step + 1) % self.save_steps == 0: 55 | self.trainer.accelerator.wait_for_everyone() 56 | state_dict = self.trainer.accelerator.get_state_dict(self.trainer.deepspeed) 57 | unwrapped_model = self.trainer.accelerator.unwrap_model(self.trainer.deepspeed) 58 | if self.trainer.accelerator.is_main_process: 59 | unwrapped_model.save_pretrained(args.output_dir, state_dict=state_dict) 60 | self.trainer.accelerator.wait_for_everyone() 61 | return control 62 | -------------------------------------------------------------------------------- /src/autotrain/trainers/clm/train_clm_orpo.py: -------------------------------------------------------------------------------- 1 | from peft import LoraConfig 2 | from transformers.trainer_callback import PrinterCallback 3 | from trl import ORPOConfig, ORPOTrainer 4 | 5 | from autotrain import logger 6 | from autotrain.trainers.clm import utils 7 | from autotrain.trainers.clm.params import LLMTrainingParams 8 | 9 | 10 | def train(config): 11 | logger.info("Starting ORPO training...") 12 | if isinstance(config, dict): 13 | config = LLMTrainingParams(**config) 14 | train_data, valid_data = utils.process_input_data(config) 15 | tokenizer = utils.get_tokenizer(config) 16 | train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data) 17 | 18 | logging_steps = utils.configure_logging_steps(config, train_data, valid_data) 19 | training_args = utils.configure_training_args(config, logging_steps) 20 | config = utils.configure_block_size(config, tokenizer) 21 | 22 | training_args["max_length"] = config.block_size 23 | training_args["max_prompt_length"] = config.max_prompt_length 24 | training_args["max_completion_length"] = config.max_completion_length 25 | args = ORPOConfig(**training_args) 26 | 27 | model = utils.get_model(config, tokenizer) 28 | 29 | if config.peft: 30 | peft_config = LoraConfig( 31 | r=config.lora_r, 32 | lora_alpha=config.lora_alpha, 33 | lora_dropout=config.lora_dropout, 34 | bias="none", 35 | task_type="CAUSAL_LM", 36 | target_modules=utils.get_target_modules(config), 37 | ) 38 | 39 | logger.info("creating trainer") 40 | callbacks = utils.get_callbacks(config) 41 | trainer_args = dict( 42 | args=args, 43 | model=model, 44 | callbacks=callbacks, 45 | ) 46 | 47 | trainer = ORPOTrainer( 48 | **trainer_args, 49 | train_dataset=train_data, 50 | eval_dataset=valid_data if config.valid_split is not None else None, 51 | processing_class=tokenizer, 52 | peft_config=peft_config if config.peft else None, 53 | ) 54 | 55 | trainer.remove_callback(PrinterCallback) 56 | trainer.train() 57 | utils.post_training_steps(config, trainer) 58 | -------------------------------------------------------------------------------- /src/autotrain/trainers/clm/train_clm_sft.py: -------------------------------------------------------------------------------- 1 | from peft import LoraConfig 2 | from transformers.trainer_callback import PrinterCallback 3 | from trl import SFTConfig, SFTTrainer 4 | 5 | from autotrain import logger 6 | from autotrain.trainers.clm import utils 7 | from autotrain.trainers.clm.params import LLMTrainingParams 8 | 9 | 10 | def train(config): 11 | logger.info("Starting SFT training...") 12 | if isinstance(config, dict): 13 | config = LLMTrainingParams(**config) 14 | train_data, valid_data = utils.process_input_data(config) 15 | tokenizer = utils.get_tokenizer(config) 16 | train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data) 17 | 18 | logging_steps = utils.configure_logging_steps(config, train_data, valid_data) 19 | training_args = utils.configure_training_args(config, logging_steps) 20 | config = utils.configure_block_size(config, tokenizer) 21 | 22 | training_args["dataset_text_field"] = config.text_column 23 | training_args["max_seq_length"] = config.block_size 24 | training_args["packing"] = True 25 | args = SFTConfig(**training_args) 26 | 27 | model = utils.get_model(config, tokenizer) 28 | 29 | if config.peft: 30 | peft_config = LoraConfig( 31 | r=config.lora_r, 32 | lora_alpha=config.lora_alpha, 33 | lora_dropout=config.lora_dropout, 34 | bias="none", 35 | task_type="CAUSAL_LM", 36 | target_modules=utils.get_target_modules(config), 37 | ) 38 | 39 | logger.info("creating trainer") 40 | callbacks = utils.get_callbacks(config) 41 | trainer_args = dict( 42 | args=args, 43 | model=model, 44 | callbacks=callbacks, 45 | ) 46 | trainer = SFTTrainer( 47 | **trainer_args, 48 | train_dataset=train_data, 49 | eval_dataset=valid_data if config.valid_split is not None else None, 50 | peft_config=peft_config if config.peft else None, 51 | processing_class=tokenizer, 52 | ) 53 | 54 | trainer.remove_callback(PrinterCallback) 55 | trainer.train() 56 | utils.post_training_steps(config, trainer) 57 | -------------------------------------------------------------------------------- /src/autotrain/trainers/extractive_question_answering/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/extractive_question_answering/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/generic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/generic/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/generic/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from autotrain import logger 5 | from autotrain.trainers.common import monitor, pause_space 6 | from autotrain.trainers.generic import utils 7 | from autotrain.trainers.generic.params import GenericParams 8 | 9 | 10 | def parse_args(): 11 | # get training_config.json from the end user 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--config", type=str, required=True) 14 | return parser.parse_args() 15 | 16 | 17 | @monitor 18 | def run(config): 19 | """ 20 | Executes a series of operations based on the provided configuration. 21 | 22 | This function performs the following steps: 23 | 1. Converts the configuration dictionary to a GenericParams object if necessary. 24 | 2. Downloads the data repository specified in the configuration. 25 | 3. Uninstalls any existing requirements specified in the configuration. 26 | 4. Installs the necessary requirements specified in the configuration. 27 | 5. Runs a command specified in the configuration. 28 | 6. Pauses the space as specified in the configuration. 29 | 30 | Args: 31 | config (dict or GenericParams): The configuration for the operations to be performed. 32 | """ 33 | if isinstance(config, dict): 34 | config = GenericParams(**config) 35 | 36 | # download the data repo 37 | logger.info("Downloading data repo...") 38 | utils.pull_dataset_repo(config) 39 | 40 | logger.info("Unintalling requirements...") 41 | utils.uninstall_requirements(config) 42 | 43 | # install the requirements 44 | logger.info("Installing requirements...") 45 | utils.install_requirements(config) 46 | 47 | # run the command 48 | logger.info("Running command...") 49 | utils.run_command(config) 50 | 51 | pause_space(config) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = parse_args() 56 | _config = json.load(open(args.config)) 57 | _config = GenericParams(**_config) 58 | run(_config) 59 | -------------------------------------------------------------------------------- /src/autotrain/trainers/generic/params.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from pydantic import Field 4 | 5 | from autotrain.trainers.common import AutoTrainParams 6 | 7 | 8 | class GenericParams(AutoTrainParams): 9 | """ 10 | GenericParams is a class that holds configuration parameters for an AutoTrain SpaceRunner project. 11 | 12 | Attributes: 13 | username (str): The username for your Hugging Face account. 14 | project_name (str): The name of the project. 15 | data_path (str): The file path to the dataset. 16 | token (str): The authentication token for accessing Hugging Face Hub. 17 | script_path (str): The file path to the script to be executed. Path to script.py. 18 | env (Optional[Dict[str, str]]): A dictionary of environment variables to be set. 19 | args (Optional[Dict[str, str]]): A dictionary of arguments to be passed to the script. 20 | """ 21 | 22 | username: str = Field( 23 | None, title="Hugging Face Username", description="The username for your Hugging Face account." 24 | ) 25 | project_name: str = Field("project-name", title="Project Name", description="The name of the project.") 26 | data_path: str = Field(None, title="Data Path", description="The file path to the dataset.") 27 | token: str = Field(None, title="Hub Token", description="The authentication token for accessing Hugging Face Hub.") 28 | script_path: str = Field( 29 | None, title="Script Path", description="The file path to the script to be executed. Path to script.py" 30 | ) 31 | env: Optional[Dict[str, str]] = Field( 32 | None, title="Environment Variables", description="A dictionary of environment variables to be set." 33 | ) 34 | args: Optional[Dict[str, str]] = Field( 35 | None, title="Arguments", description="A dictionary of arguments to be passed to the script." 36 | ) 37 | -------------------------------------------------------------------------------- /src/autotrain/trainers/image_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/image_classification/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/image_classification/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ImageClassificationDataset: 6 | """ 7 | A custom dataset class for image classification tasks. 8 | 9 | Args: 10 | data (list): A list of data samples, where each sample is a dictionary containing image and target information. 11 | transforms (callable): A function/transform that takes in an image and returns a transformed version. 12 | config (object): A configuration object containing the column names for images and targets. 13 | 14 | Attributes: 15 | data (list): The dataset containing image and target information. 16 | transforms (callable): The transformation function to be applied to the images. 17 | config (object): The configuration object with image and target column names. 18 | 19 | Methods: 20 | __len__(): Returns the number of samples in the dataset. 21 | __getitem__(item): Retrieves the image and target at the specified index, applies transformations, and returns them as tensors. 22 | 23 | Example: 24 | dataset = ImageClassificationDataset(data, transforms, config) 25 | image, target = dataset[0] 26 | """ 27 | 28 | def __init__(self, data, transforms, config): 29 | self.data = data 30 | self.transforms = transforms 31 | self.config = config 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | def __getitem__(self, item): 37 | image = self.data[item][self.config.image_column] 38 | target = int(self.data[item][self.config.target_column]) 39 | 40 | image = self.transforms(image=np.array(image.convert("RGB")))["image"] 41 | image = np.transpose(image, (2, 0, 1)).astype(np.float32) 42 | 43 | return { 44 | "pixel_values": torch.tensor(image, dtype=torch.float), 45 | "labels": torch.tensor(target, dtype=torch.long), 46 | } 47 | -------------------------------------------------------------------------------- /src/autotrain/trainers/image_regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/image_regression/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/image_regression/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ImageRegressionDataset: 6 | """ 7 | A dataset class for image regression tasks. 8 | 9 | Args: 10 | data (list): A list of data points where each data point is a dictionary containing image and target information. 11 | transforms (callable): A function/transform that takes in an image and returns a transformed version. 12 | config (object): A configuration object that contains the column names for images and targets. 13 | 14 | Attributes: 15 | data (list): The input data. 16 | transforms (callable): The transformation function. 17 | config (object): The configuration object. 18 | 19 | Methods: 20 | __len__(): Returns the number of data points in the dataset. 21 | __getitem__(item): Returns a dictionary containing the transformed image and the target value for the given index. 22 | """ 23 | 24 | def __init__(self, data, transforms, config): 25 | self.data = data 26 | self.transforms = transforms 27 | self.config = config 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, item): 33 | image = self.data[item][self.config.image_column] 34 | target = self.data[item][self.config.target_column] 35 | 36 | image = self.transforms(image=np.array(image.convert("RGB")))["image"] 37 | image = np.transpose(image, (2, 0, 1)).astype(np.float32) 38 | 39 | return { 40 | "pixel_values": torch.tensor(image, dtype=torch.float), 41 | "labels": torch.tensor(target, dtype=torch.float), 42 | } 43 | -------------------------------------------------------------------------------- /src/autotrain/trainers/object_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/object_detection/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/object_detection/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ObjectDetectionDataset: 5 | """ 6 | A dataset class for object detection tasks. 7 | 8 | Args: 9 | data (list): A list of data entries where each entry is a dictionary containing image and object information. 10 | transforms (callable): A function or transform to apply to the images and bounding boxes. 11 | image_processor (callable): A function or processor to convert images and annotations into the desired format. 12 | config (object): A configuration object containing column names for images and objects. 13 | 14 | Attributes: 15 | data (list): The dataset containing image and object information. 16 | transforms (callable): The transform function to apply to the images and bounding boxes. 17 | image_processor (callable): The processor to convert images and annotations into the desired format. 18 | config (object): The configuration object with column names for images and objects. 19 | 20 | Methods: 21 | __len__(): Returns the number of items in the dataset. 22 | __getitem__(item): Retrieves and processes the image and annotations for the given index. 23 | 24 | Example: 25 | dataset = ObjectDetectionDataset(data, transforms, image_processor, config) 26 | image_data = dataset[0] 27 | """ 28 | 29 | def __init__(self, data, transforms, image_processor, config): 30 | self.data = data 31 | self.transforms = transforms 32 | self.image_processor = image_processor 33 | self.config = config 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, item): 39 | image = self.data[item][self.config.image_column] 40 | objects = self.data[item][self.config.objects_column] 41 | output = self.transforms( 42 | image=np.array(image.convert("RGB")), bboxes=objects["bbox"], category=objects["category"] 43 | ) 44 | image = output["image"] 45 | annotations = [] 46 | for j in range(len(output["bboxes"])): 47 | annotations.append( 48 | { 49 | "image_id": str(item), 50 | "category_id": output["category"][j], 51 | "iscrowd": 0, 52 | "area": objects["bbox"][j][2] * objects["bbox"][j][3], # [x, y, w, h 53 | "bbox": output["bboxes"][j], 54 | } 55 | ) 56 | annotations = {"annotations": annotations, "image_id": str(item)} 57 | result = self.image_processor(images=image, annotations=annotations, return_tensors="pt") 58 | result["pixel_values"] = result["pixel_values"][0] 59 | result["labels"] = result["labels"][0] 60 | return result 61 | -------------------------------------------------------------------------------- /src/autotrain/trainers/sent_transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/sent_transformers/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/seq2seq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/seq2seq/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/seq2seq/dataset.py: -------------------------------------------------------------------------------- 1 | class Seq2SeqDataset: 2 | """ 3 | A dataset class for sequence-to-sequence tasks. 4 | 5 | Args: 6 | data (list): The dataset containing input and target sequences. 7 | tokenizer (PreTrainedTokenizer): The tokenizer to process the text data. 8 | config (object): Configuration object containing dataset parameters. 9 | 10 | Attributes: 11 | data (list): The dataset containing input and target sequences. 12 | tokenizer (PreTrainedTokenizer): The tokenizer to process the text data. 13 | config (object): Configuration object containing dataset parameters. 14 | max_len_input (int): Maximum length for input sequences. 15 | max_len_target (int): Maximum length for target sequences. 16 | 17 | Methods: 18 | __len__(): Returns the number of samples in the dataset. 19 | __getitem__(item): Returns the tokenized input and target sequences for a given index. 20 | """ 21 | 22 | def __init__(self, data, tokenizer, config): 23 | self.data = data 24 | self.tokenizer = tokenizer 25 | self.config = config 26 | self.max_len_input = self.config.max_seq_length 27 | self.max_len_target = self.config.max_target_length 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, item): 33 | text = str(self.data[item][self.config.text_column]) 34 | target = str(self.data[item][self.config.target_column]) 35 | 36 | model_inputs = self.tokenizer(text, max_length=self.max_len_input, truncation=True) 37 | 38 | labels = self.tokenizer(text_target=target, max_length=self.max_len_target, truncation=True) 39 | 40 | model_inputs["labels"] = labels["input_ids"] 41 | return model_inputs 42 | -------------------------------------------------------------------------------- /src/autotrain/trainers/seq2seq/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import evaluate 4 | import nltk 5 | import numpy as np 6 | 7 | 8 | ROUGE_METRIC = evaluate.load("rouge") 9 | 10 | MODEL_CARD = """ 11 | --- 12 | library_name: transformers 13 | tags: 14 | - autotrain 15 | - text2text-generation{base_model} 16 | widget: 17 | - text: "I love AutoTrain"{dataset_tag} 18 | --- 19 | 20 | # Model Trained Using AutoTrain 21 | 22 | - Problem type: Seq2Seq 23 | 24 | ## Validation Metrics 25 | {validation_metrics} 26 | """ 27 | 28 | 29 | def _seq2seq_metrics(pred, tokenizer): 30 | """ 31 | Compute sequence-to-sequence metrics for predictions and labels. 32 | 33 | Args: 34 | pred (tuple): A tuple containing predictions and labels. 35 | Predictions and labels are expected to be token IDs. 36 | tokenizer (PreTrainedTokenizer): The tokenizer used for decoding the predictions and labels. 37 | 38 | Returns: 39 | dict: A dictionary containing the computed ROUGE metrics and the average length of the generated sequences. 40 | The keys are the metric names and the values are the corresponding scores rounded to four decimal places. 41 | """ 42 | predictions, labels = pred 43 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 44 | 45 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 46 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 47 | 48 | decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds] 49 | decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels] 50 | 51 | result = ROUGE_METRIC.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 52 | result = {key: value * 100 for key, value in result.items()} 53 | 54 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions] 55 | result["gen_len"] = np.mean(prediction_lens) 56 | 57 | return {k: round(v, 4) for k, v in result.items()} 58 | 59 | 60 | def create_model_card(config, trainer): 61 | """ 62 | Generates a model card string based on the provided configuration and trainer. 63 | 64 | Args: 65 | config (object): Configuration object containing the following attributes: 66 | - valid_split (optional): If not None, the function will include evaluation scores. 67 | - data_path (str): Path to the dataset. 68 | - project_name (str): Name of the project. 69 | - model (str): Path or identifier of the model. 70 | trainer (object): Trainer object with an `evaluate` method that returns evaluation metrics. 71 | 72 | Returns: 73 | str: A formatted model card string containing dataset information, validation metrics, and base model details. 74 | """ 75 | if config.valid_split is not None: 76 | eval_scores = trainer.evaluate() 77 | eval_scores = [f"{k[len('eval_'):]}: {v}" for k, v in eval_scores.items()] 78 | eval_scores = "\n\n".join(eval_scores) 79 | 80 | else: 81 | eval_scores = "No validation metrics available" 82 | 83 | if config.data_path == f"{config.project_name}/autotrain-data" or os.path.isdir(config.data_path): 84 | dataset_tag = "" 85 | else: 86 | dataset_tag = f"\ndatasets:\n- {config.data_path}" 87 | 88 | if os.path.isdir(config.model): 89 | base_model = "" 90 | else: 91 | base_model = f"\nbase_model: {config.model}" 92 | 93 | model_card = MODEL_CARD.format( 94 | dataset_tag=dataset_tag, 95 | validation_metrics=eval_scores, 96 | base_model=base_model, 97 | ) 98 | return model_card 99 | -------------------------------------------------------------------------------- /src/autotrain/trainers/tabular/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/tabular/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/tabular/params.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from pydantic import Field 4 | 5 | from autotrain.trainers.common import AutoTrainParams 6 | 7 | 8 | class TabularParams(AutoTrainParams): 9 | """ 10 | TabularParams is a configuration class for tabular data training parameters. 11 | 12 | Attributes: 13 | data_path (str): Path to the dataset. 14 | model (str): Name of the model to use. Default is "xgboost". 15 | username (Optional[str]): Hugging Face Username. 16 | seed (int): Random seed for reproducibility. Default is 42. 17 | train_split (str): Name of the training data split. Default is "train". 18 | valid_split (Optional[str]): Name of the validation data split. 19 | project_name (str): Name of the output directory. Default is "project-name". 20 | token (Optional[str]): Hub Token for authentication. 21 | push_to_hub (bool): Whether to push the model to the hub. Default is False. 22 | id_column (str): Name of the ID column. Default is "id". 23 | target_columns (Union[List[str], str]): Target column(s) in the dataset. Default is ["target"]. 24 | categorical_columns (Optional[List[str]]): List of categorical columns. 25 | numerical_columns (Optional[List[str]]): List of numerical columns. 26 | task (str): Type of task (e.g., "classification"). Default is "classification". 27 | num_trials (int): Number of trials for hyperparameter optimization. Default is 10. 28 | time_limit (int): Time limit for training in seconds. Default is 600. 29 | categorical_imputer (Optional[str]): Imputer strategy for categorical columns. 30 | numerical_imputer (Optional[str]): Imputer strategy for numerical columns. 31 | numeric_scaler (Optional[str]): Scaler strategy for numerical columns. 32 | """ 33 | 34 | data_path: str = Field(None, title="Data path") 35 | model: str = Field("xgboost", title="Model name") 36 | username: Optional[str] = Field(None, title="Hugging Face Username") 37 | seed: int = Field(42, title="Seed") 38 | train_split: str = Field("train", title="Train split") 39 | valid_split: Optional[str] = Field(None, title="Validation split") 40 | project_name: str = Field("project-name", title="Output directory") 41 | token: Optional[str] = Field(None, title="Hub Token") 42 | push_to_hub: bool = Field(False, title="Push to hub") 43 | id_column: str = Field("id", title="ID column") 44 | target_columns: Union[List[str], str] = Field(["target"], title="Target column(s)") 45 | categorical_columns: Optional[List[str]] = Field(None, title="Categorical columns") 46 | numerical_columns: Optional[List[str]] = Field(None, title="Numerical columns") 47 | task: str = Field("classification", title="Task") 48 | num_trials: int = Field(10, title="Number of trials") 49 | time_limit: int = Field(600, title="Time limit") 50 | categorical_imputer: Optional[str] = Field(None, title="Categorical imputer") 51 | numerical_imputer: Optional[str] = Field(None, title="Numerical imputer") 52 | numeric_scaler: Optional[str] = Field(None, title="Numeric scaler") 53 | -------------------------------------------------------------------------------- /src/autotrain/trainers/text_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/text_classification/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/text_classification/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TextClassificationDataset: 5 | """ 6 | A dataset class for text classification tasks. 7 | 8 | Args: 9 | data (list): The dataset containing text and target columns. 10 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data. 11 | config (object): Configuration object containing dataset parameters. 12 | 13 | Attributes: 14 | data (list): The dataset containing text and target columns. 15 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data. 16 | config (object): Configuration object containing dataset parameters. 17 | text_column (str): The name of the column containing text data. 18 | target_column (str): The name of the column containing target labels. 19 | 20 | Methods: 21 | __len__(): Returns the number of samples in the dataset. 22 | __getitem__(item): Returns a dictionary containing tokenized input ids, attention mask, token type ids (if available), and target labels for the given item index. 23 | """ 24 | 25 | def __init__(self, data, tokenizer, config): 26 | self.data = data 27 | self.tokenizer = tokenizer 28 | self.config = config 29 | self.text_column = self.config.text_column 30 | self.target_column = self.config.target_column 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | def __getitem__(self, item): 36 | text = str(self.data[item][self.text_column]) 37 | target = self.data[item][self.target_column] 38 | target = int(target) 39 | inputs = self.tokenizer( 40 | text, 41 | max_length=self.config.max_seq_length, 42 | padding="max_length", 43 | truncation=True, 44 | ) 45 | 46 | ids = inputs["input_ids"] 47 | mask = inputs["attention_mask"] 48 | 49 | if "token_type_ids" in inputs: 50 | token_type_ids = inputs["token_type_ids"] 51 | else: 52 | token_type_ids = None 53 | 54 | if token_type_ids is not None: 55 | return { 56 | "input_ids": torch.tensor(ids, dtype=torch.long), 57 | "attention_mask": torch.tensor(mask, dtype=torch.long), 58 | "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), 59 | "labels": torch.tensor(target, dtype=torch.long), 60 | } 61 | return { 62 | "input_ids": torch.tensor(ids, dtype=torch.long), 63 | "attention_mask": torch.tensor(mask, dtype=torch.long), 64 | "labels": torch.tensor(target, dtype=torch.long), 65 | } 66 | -------------------------------------------------------------------------------- /src/autotrain/trainers/text_regression/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/text_regression/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/text_regression/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TextRegressionDataset: 5 | """ 6 | A custom dataset class for text regression tasks for AutoTrain. 7 | 8 | Args: 9 | data (list of dict): The dataset containing text and target values. 10 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data. 11 | config (object): Configuration object containing dataset parameters. 12 | 13 | Attributes: 14 | data (list of dict): The dataset containing text and target values. 15 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data. 16 | config (object): Configuration object containing dataset parameters. 17 | text_column (str): The column name for text data in the dataset. 18 | target_column (str): The column name for target values in the dataset. 19 | max_len (int): The maximum sequence length for tokenized inputs. 20 | 21 | Methods: 22 | __len__(): Returns the number of samples in the dataset. 23 | __getitem__(item): Returns a dictionary containing tokenized inputs and target value for a given index. 24 | """ 25 | 26 | def __init__(self, data, tokenizer, config): 27 | self.data = data 28 | self.tokenizer = tokenizer 29 | self.config = config 30 | self.text_column = self.config.text_column 31 | self.target_column = self.config.target_column 32 | self.max_len = self.config.max_seq_length 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def __getitem__(self, item): 38 | text = str(self.data[item][self.text_column]) 39 | target = float(self.data[item][self.target_column]) 40 | inputs = self.tokenizer( 41 | text, 42 | max_length=self.max_len, 43 | padding="max_length", 44 | truncation=True, 45 | ) 46 | 47 | ids = inputs["input_ids"] 48 | mask = inputs["attention_mask"] 49 | 50 | if "token_type_ids" in inputs: 51 | token_type_ids = inputs["token_type_ids"] 52 | else: 53 | token_type_ids = None 54 | 55 | if token_type_ids is not None: 56 | return { 57 | "input_ids": torch.tensor(ids, dtype=torch.long), 58 | "attention_mask": torch.tensor(mask, dtype=torch.long), 59 | "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), 60 | "labels": torch.tensor(target, dtype=torch.float), 61 | } 62 | return { 63 | "input_ids": torch.tensor(ids, dtype=torch.long), 64 | "attention_mask": torch.tensor(mask, dtype=torch.long), 65 | "labels": torch.tensor(target, dtype=torch.float), 66 | } 67 | -------------------------------------------------------------------------------- /src/autotrain/trainers/token_classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/token_classification/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/token_classification/dataset.py: -------------------------------------------------------------------------------- 1 | class TokenClassificationDataset: 2 | """ 3 | A dataset class for token classification tasks. 4 | 5 | Args: 6 | data (Dataset): The dataset containing the text and tags. 7 | tokenizer (PreTrainedTokenizer): The tokenizer to be used for tokenizing the text. 8 | config (Config): Configuration object containing necessary parameters. 9 | 10 | Attributes: 11 | data (Dataset): The dataset containing the text and tags. 12 | tokenizer (PreTrainedTokenizer): The tokenizer to be used for tokenizing the text. 13 | config (Config): Configuration object containing necessary parameters. 14 | 15 | Methods: 16 | __len__(): 17 | Returns the number of samples in the dataset. 18 | 19 | __getitem__(item): 20 | Retrieves a tokenized sample and its corresponding labels. 21 | 22 | Args: 23 | item (int): The index of the sample to retrieve. 24 | 25 | Returns: 26 | dict: A dictionary containing tokenized text and corresponding labels. 27 | """ 28 | 29 | def __init__(self, data, tokenizer, config): 30 | self.data = data 31 | self.tokenizer = tokenizer 32 | self.config = config 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def __getitem__(self, item): 38 | text = self.data[item][self.config.tokens_column] 39 | tags = self.data[item][self.config.tags_column] 40 | 41 | label_list = self.data.features[self.config.tags_column].feature.names 42 | label_to_id = {i: i for i in range(len(label_list))} 43 | 44 | tokenized_text = self.tokenizer( 45 | text, 46 | max_length=self.config.max_seq_length, 47 | padding="max_length", 48 | truncation=True, 49 | is_split_into_words=True, 50 | ) 51 | 52 | word_ids = tokenized_text.word_ids(batch_index=0) 53 | previous_word_idx = None 54 | label_ids = [] 55 | for word_idx in word_ids: 56 | if word_idx is None: 57 | label_ids.append(-100) 58 | elif word_idx != previous_word_idx: 59 | label_ids.append(label_to_id[tags[word_idx]]) 60 | else: 61 | label_ids.append(label_to_id[tags[word_idx]]) 62 | previous_word_idx = word_idx 63 | 64 | tokenized_text["labels"] = label_ids 65 | return tokenized_text 66 | -------------------------------------------------------------------------------- /src/autotrain/trainers/token_classification/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from seqeval import metrics 5 | 6 | 7 | MODEL_CARD = """ 8 | --- 9 | library_name: transformers 10 | tags: 11 | - autotrain 12 | - token-classification{base_model} 13 | widget: 14 | - text: "I love AutoTrain"{dataset_tag} 15 | --- 16 | 17 | # Model Trained Using AutoTrain 18 | 19 | - Problem type: Token Classification 20 | 21 | ## Validation Metrics 22 | {validation_metrics} 23 | """ 24 | 25 | 26 | def token_classification_metrics(pred, label_list): 27 | """ 28 | Compute token classification metrics including precision, recall, F1 score, and accuracy. 29 | 30 | Args: 31 | pred (tuple): A tuple containing predictions and labels. 32 | Predictions should be a 3D array (batch_size, sequence_length, num_labels). 33 | Labels should be a 2D array (batch_size, sequence_length). 34 | label_list (list): A list of label names corresponding to the indices used in predictions and labels. 35 | 36 | Returns: 37 | dict: A dictionary containing the following metrics: 38 | - "precision": Precision score of the token classification. 39 | - "recall": Recall score of the token classification. 40 | - "f1": F1 score of the token classification. 41 | - "accuracy": Accuracy score of the token classification. 42 | """ 43 | predictions, labels = pred 44 | predictions = np.argmax(predictions, axis=2) 45 | 46 | true_predictions = [ 47 | [label_list[predi] for (predi, lbl) in zip(prediction, label) if lbl != -100] 48 | for prediction, label in zip(predictions, labels) 49 | ] 50 | true_labels = [ 51 | [label_list[lbl] for (predi, lbl) in zip(prediction, label) if lbl != -100] 52 | for prediction, label in zip(predictions, labels) 53 | ] 54 | 55 | results = { 56 | "precision": metrics.precision_score(true_labels, true_predictions), 57 | "recall": metrics.recall_score(true_labels, true_predictions), 58 | "f1": metrics.f1_score(true_labels, true_predictions), 59 | "accuracy": metrics.accuracy_score(true_labels, true_predictions), 60 | } 61 | return results 62 | 63 | 64 | def create_model_card(config, trainer): 65 | """ 66 | Generates a model card string based on the provided configuration and trainer. 67 | 68 | Args: 69 | config (object): Configuration object containing model and dataset information. 70 | trainer (object): Trainer object used to evaluate the model. 71 | 72 | Returns: 73 | str: A formatted model card string with dataset tags, validation metrics, and base model information. 74 | """ 75 | if config.valid_split is not None: 76 | eval_scores = trainer.evaluate() 77 | valid_metrics = ["eval_loss", "eval_precision", "eval_recall", "eval_f1", "eval_accuracy"] 78 | eval_scores = [f"{k[len('eval_'):]}: {v}" for k, v in eval_scores.items() if k in valid_metrics] 79 | eval_scores = "\n\n".join(eval_scores) 80 | else: 81 | eval_scores = "No validation metrics available" 82 | 83 | if config.data_path == f"{config.project_name}/autotrain-data" or os.path.isdir(config.data_path): 84 | dataset_tag = "" 85 | else: 86 | dataset_tag = f"\ndatasets:\n- {config.data_path}" 87 | 88 | if os.path.isdir(config.model): 89 | base_model = "" 90 | else: 91 | base_model = f"\nbase_model: {config.model}" 92 | 93 | model_card = MODEL_CARD.format( 94 | dataset_tag=dataset_tag, 95 | validation_metrics=eval_scores, 96 | base_model=base_model, 97 | ) 98 | return model_card 99 | -------------------------------------------------------------------------------- /src/autotrain/trainers/vlm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/vlm/__init__.py -------------------------------------------------------------------------------- /src/autotrain/trainers/vlm/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | from autotrain.trainers.common import monitor 5 | from autotrain.trainers.vlm import utils 6 | from autotrain.trainers.vlm.params import VLMTrainingParams 7 | 8 | 9 | def parse_args(): 10 | # get training_config.json from the end user 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--training_config", type=str, required=True) 13 | return parser.parse_args() 14 | 15 | 16 | @monitor 17 | def train(config): 18 | if isinstance(config, dict): 19 | config = VLMTrainingParams(**config) 20 | 21 | if not utils.check_model_support(config): 22 | raise ValueError(f"model `{config.model}` not supported") 23 | 24 | if config.trainer in ("vqa", "captioning"): 25 | from autotrain.trainers.vlm.train_vlm_generic import train as train_generic 26 | 27 | train_generic(config) 28 | 29 | else: 30 | raise ValueError(f"trainer `{config.trainer}` not supported") 31 | 32 | 33 | if __name__ == "__main__": 34 | _args = parse_args() 35 | training_config = json.load(open(_args.training_config)) 36 | _config = VLMTrainingParams(**training_config) 37 | train(_config) 38 | -------------------------------------------------------------------------------- /src/autotrain/trainers/vlm/dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/vlm/dataset.py -------------------------------------------------------------------------------- /src/autotrain/trainers/vlm/train_vlm_generic.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from datasets import load_dataset, load_from_disk 4 | from transformers import AutoProcessor, Trainer, TrainingArguments 5 | from transformers.trainer_callback import PrinterCallback 6 | 7 | from autotrain import logger 8 | from autotrain.trainers.common import ALLOW_REMOTE_CODE 9 | from autotrain.trainers.vlm import utils 10 | 11 | 12 | def collate_fn(examples, config, processor): 13 | prompts = ["answer " + example[config.prompt_text_column] for example in examples] 14 | labels = [example[config.text_column] for example in examples] 15 | images = [example[config.image_column].convert("RGB") for example in examples] 16 | tokens = processor( 17 | text=prompts, 18 | images=images, 19 | suffix=labels, 20 | return_tensors="pt", 21 | padding="longest", 22 | tokenize_newline_separately=False, 23 | ) 24 | return tokens 25 | 26 | 27 | def train(config): 28 | valid_data = None 29 | if config.data_path == f"{config.project_name}/autotrain-data": 30 | train_data = load_from_disk(config.data_path)[config.train_split] 31 | else: 32 | if ":" in config.train_split: 33 | dataset_config_name, split = config.train_split.split(":") 34 | train_data = load_dataset( 35 | config.data_path, 36 | name=dataset_config_name, 37 | split=split, 38 | token=config.token, 39 | ) 40 | else: 41 | train_data = load_dataset( 42 | config.data_path, 43 | split=config.train_split, 44 | token=config.token, 45 | ) 46 | 47 | if config.valid_split is not None: 48 | if config.data_path == f"{config.project_name}/autotrain-data": 49 | valid_data = load_from_disk(config.data_path)[config.valid_split] 50 | else: 51 | if ":" in config.valid_split: 52 | dataset_config_name, split = config.valid_split.split(":") 53 | valid_data = load_dataset( 54 | config.data_path, 55 | name=dataset_config_name, 56 | split=split, 57 | token=config.token, 58 | ) 59 | else: 60 | valid_data = load_dataset( 61 | config.data_path, 62 | split=config.valid_split, 63 | token=config.token, 64 | ) 65 | 66 | logger.info(f"Train data: {train_data}") 67 | logger.info(f"Valid data: {valid_data}") 68 | 69 | if config.trainer == "captioning": 70 | config.prompt_text_column = "caption" 71 | 72 | processor = AutoProcessor.from_pretrained(config.model, token=config.token, trust_remote_code=ALLOW_REMOTE_CODE) 73 | 74 | logging_steps = utils.configure_logging_steps(config, train_data, valid_data) 75 | training_args = utils.configure_training_args(config, logging_steps) 76 | 77 | args = TrainingArguments(**training_args) 78 | model = utils.get_model(config) 79 | 80 | logger.info("creating trainer") 81 | callbacks = utils.get_callbacks(config) 82 | trainer_args = dict( 83 | args=args, 84 | model=model, 85 | callbacks=callbacks, 86 | ) 87 | 88 | col_fn = partial(collate_fn, config=config, processor=processor) 89 | 90 | trainer = Trainer( 91 | **trainer_args, 92 | train_dataset=train_data, 93 | eval_dataset=valid_data if valid_data is not None else None, 94 | data_collator=col_fn, 95 | ) 96 | trainer.remove_callback(PrinterCallback) 97 | trainer.train() 98 | utils.post_training_steps(config, trainer) 99 | -------------------------------------------------------------------------------- /src/autotrain/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | 5 | from autotrain.commands import launch_command 6 | from autotrain.trainers.clm.params import LLMTrainingParams 7 | from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams 8 | from autotrain.trainers.generic.params import GenericParams 9 | from autotrain.trainers.image_classification.params import ImageClassificationParams 10 | from autotrain.trainers.image_regression.params import ImageRegressionParams 11 | from autotrain.trainers.object_detection.params import ObjectDetectionParams 12 | from autotrain.trainers.sent_transformers.params import SentenceTransformersParams 13 | from autotrain.trainers.seq2seq.params import Seq2SeqParams 14 | from autotrain.trainers.tabular.params import TabularParams 15 | from autotrain.trainers.text_classification.params import TextClassificationParams 16 | from autotrain.trainers.text_regression.params import TextRegressionParams 17 | from autotrain.trainers.token_classification.params import TokenClassificationParams 18 | from autotrain.trainers.vlm.params import VLMTrainingParams 19 | 20 | 21 | ALLOW_REMOTE_CODE = os.environ.get("ALLOW_REMOTE_CODE", "true").lower() == "true" 22 | 23 | 24 | def run_training(params, task_id, local=False, wait=False): 25 | """ 26 | Run the training process based on the provided parameters and task ID. 27 | 28 | Args: 29 | params (str): JSON string of the parameters required for training. 30 | task_id (int): Identifier for the type of task to be performed. 31 | local (bool, optional): Flag to indicate if the training should be run locally. Defaults to False. 32 | wait (bool, optional): Flag to indicate if the function should wait for the process to complete. Defaults to False. 33 | 34 | Returns: 35 | int: Process ID of the launched training process. 36 | 37 | Raises: 38 | NotImplementedError: If the task_id does not match any of the predefined tasks. 39 | """ 40 | params = json.loads(params) 41 | if isinstance(params, str): 42 | params = json.loads(params) 43 | if task_id == 9: 44 | params = LLMTrainingParams(**params) 45 | elif task_id == 28: 46 | params = Seq2SeqParams(**params) 47 | elif task_id in (1, 2): 48 | params = TextClassificationParams(**params) 49 | elif task_id in (13, 14, 15, 16, 26): 50 | params = TabularParams(**params) 51 | elif task_id == 27: 52 | params = GenericParams(**params) 53 | elif task_id == 18: 54 | params = ImageClassificationParams(**params) 55 | elif task_id == 4: 56 | params = TokenClassificationParams(**params) 57 | elif task_id == 10: 58 | params = TextRegressionParams(**params) 59 | elif task_id == 29: 60 | params = ObjectDetectionParams(**params) 61 | elif task_id == 30: 62 | params = SentenceTransformersParams(**params) 63 | elif task_id == 24: 64 | params = ImageRegressionParams(**params) 65 | elif task_id == 31: 66 | params = VLMTrainingParams(**params) 67 | elif task_id == 5: 68 | params = ExtractiveQuestionAnsweringParams(**params) 69 | else: 70 | raise NotImplementedError 71 | 72 | params.save(output_dir=params.project_name) 73 | cmd = launch_command(params=params) 74 | cmd = [str(c) for c in cmd] 75 | env = os.environ.copy() 76 | process = subprocess.Popen(cmd, env=env) 77 | if wait: 78 | process.wait() 79 | return process.pid 80 | -------------------------------------------------------------------------------- /static/autotrain_homepage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_homepage.png -------------------------------------------------------------------------------- /static/autotrain_model_choice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_model_choice.png -------------------------------------------------------------------------------- /static/autotrain_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_space.png -------------------------------------------------------------------------------- /static/autotrain_text_classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_text_classification.png -------------------------------------------------------------------------------- /static/cost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/cost.png -------------------------------------------------------------------------------- /static/dreambooth1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/dreambooth1.jpeg -------------------------------------------------------------------------------- /static/dreambooth2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/dreambooth2.png -------------------------------------------------------------------------------- /static/duplicate_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/duplicate_space.png -------------------------------------------------------------------------------- /static/ext_qa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/ext_qa.png -------------------------------------------------------------------------------- /static/hub_model_choice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/hub_model_choice.png -------------------------------------------------------------------------------- /static/image_classification_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/image_classification_1.png -------------------------------------------------------------------------------- /static/img_reg_ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/img_reg_ui.png -------------------------------------------------------------------------------- /static/llm_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_1.png -------------------------------------------------------------------------------- /static/llm_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_2.png -------------------------------------------------------------------------------- /static/llm_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_3.png -------------------------------------------------------------------------------- /static/llm_orpo_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_orpo_example.png -------------------------------------------------------------------------------- /static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/logo.png -------------------------------------------------------------------------------- /static/model_choice_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/model_choice_1.png -------------------------------------------------------------------------------- /static/param_choice_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/param_choice_1.png -------------------------------------------------------------------------------- /static/param_choice_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/param_choice_2.png -------------------------------------------------------------------------------- /static/space_template_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_1.png -------------------------------------------------------------------------------- /static/space_template_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_2.png -------------------------------------------------------------------------------- /static/space_template_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_3.png -------------------------------------------------------------------------------- /static/space_template_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_4.png -------------------------------------------------------------------------------- /static/space_template_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_5.png -------------------------------------------------------------------------------- /static/text_classification_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/text_classification_1.png -------------------------------------------------------------------------------- /static/ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/ui.png --------------------------------------------------------------------------------