69 | $ python train.py
70 | ```
71 |
72 | This will create a new project directory with the name `autotrain-llama32-1b-finetune` and start the training process.
73 | Once the training is complete, the model will be pushed to the Hugging Face Hub.
74 |
75 | Your HF_TOKEN and HF_USERNAME are only required if you want to push the model or if you are accessing a gated model or dataset.
76 |
77 | ## AutoTrainProject Class
78 |
79 | [[autodoc]] project.AutoTrainProject
80 |
81 | ## Parameters
82 |
83 | ### Text Tasks
84 |
85 | [[autodoc]] trainers.clm.params.LLMTrainingParams
86 |
87 | [[autodoc]] trainers.sent_transformers.params.SentenceTransformersParams
88 |
89 | [[autodoc]] trainers.seq2seq.params.Seq2SeqParams
90 |
91 | [[autodoc]] trainers.token_classification.params.TokenClassificationParams
92 |
93 | [[autodoc]] trainers.extractive_question_answering.params.ExtractiveQuestionAnsweringParams
94 |
95 | [[autodoc]] trainers.text_classification.params.TextClassificationParams
96 |
97 | [[autodoc]] trainers.text_regression.params.TextRegressionParams
98 |
99 | ### Image Tasks
100 |
101 | [[autodoc]] trainers.image_classification.params.ImageClassificationParams
102 |
103 | [[autodoc]] trainers.image_regression.params.ImageRegressionParams
104 |
105 | [[autodoc]] trainers.object_detection.params.ObjectDetectionParams
106 |
107 |
108 | ### Tabular Tasks
109 |
110 | [[autodoc]] trainers.tabular.params.TabularParams
--------------------------------------------------------------------------------
/docs/source/quickstart_spaces.mdx:
--------------------------------------------------------------------------------
1 | # Quickstart Guide to AutoTrain on Hugging Face Spaces
2 |
3 | AutoTrain on Hugging Face Spaces is the preferred choice for a streamlined experience in
4 | model training. This platform is optimized for ease of use, with pre-installed dependencies
5 | and managed hardware resources. AutoTrain on Hugging Face Spaces can be used both by
6 | no-code users and developers, making it versatile for various levels of expertise.
7 |
8 |
9 | ## Creating a New AutoTrain Space
10 |
11 | Getting started with AutoTrain is straightforward. Here’s how you can create your new space:
12 |
13 | 1. **Visit the AutoTrain Page**: To create a new space with AutoTrain Docker image, all you need to do is go
14 | to [AutoTrain Homepage](https://hf.co/autotrain) and click on "Create new project".
15 |
16 | 2. **Log In or View the Setup Screen**: If not logged in, you'll be prompted to do so. Then, you’ll see a screen similar to this:
17 |
18 | 
19 |
20 | 3. **Set Up Your Space**:
21 |
22 | - **Choose a Space Name**: Name your space something relevant to your project.
23 |
24 | - **Allocate Hardware Resources**: Select the necessary computational resources based on your project needs.
25 |
26 | - **Duplicate Space**: Click on "Duplicate Space" to initiate your AutoTrain space with the Docker image.
27 |
28 | 4. **Configuration Options**:
29 |
30 | - PAUSE_ON_FAILURE: Set this to 0 if you prefer the space not to pause on training failures, useful for running continuous experiments. This option can also be used if you continuously want to perfom many experiments in the same space.
31 |
32 | 5. **Launch and Train**:
33 |
34 | - Once done, in a few seconds, the AutoTrain Space will be up and running and you will be presented with the following screen:
35 |
36 | 
37 |
38 | - From here, you can select tasks, upload datasets, choose models, adjust hyperparameters (if needed),
39 | and start the training process directly within the space.
40 |
41 | - The space will manage its own activity, shutting down post-training unless configured
42 | otherwise based on the `PAUSE_ON_FAILURE` setting.
43 |
44 | 6. **Monitoring Progress**:
45 |
46 | - All training logs and progress can be monitored via TensorBoard, accessible under
47 | `username/project_name` on the Hugging Face Hub.
48 |
49 | - Once training concludes successfully, you’ll find the model files in the same repository.
50 |
51 | 7. **Navigating the UI**:
52 |
53 | - If you need help understanding any UI elements, click on the small (i) information icons for detailed descriptions.
54 |
55 | If you are confused about the UI elements, click on the small (i) information icon to get more information about the UI element.
56 |
57 | For data formats and detailed parameter information, please see the Data Formats and Parameters section where we provide
58 | example datasets and detailed information about the parameters for each task supported by AutoTrain.
59 |
60 | ## Ensuring Your AutoTrain is Up-to-Date
61 |
62 | We are constantly adding new features and tasks to AutoTrain Advanced. To benefit from the latest features, tasks, and bug fixes, update your AutoTrain space regularly:
63 |
64 | - *Factory Reboot*: Navigate to the settings page of your space and click on "Factory reboot" to upgrade to the latest version of AutoTrain Advanced.
65 |
66 | 
67 |
68 | - *Note*: Simply "restarting" the space does not update it; a factory reboot is necessary for a complete update.
69 |
70 |
71 | For additional details on data formats and specific parameters, refer to the
72 | 'Data Formats and Parameters' section where we provide example datasets and extensive
73 | parameter information for each supported task by AutoTrain.
74 |
75 |
76 | With these steps, you can effortlessly initiate and manage your AutoTrain projects on
77 | Hugging Face Spaces, leveraging the platform's robust capabilities for your machine learning and AI
78 | needs.
79 |
--------------------------------------------------------------------------------
/docs/source/starting_ui.bck:
--------------------------------------------------------------------------------
1 | # Starting the UI
2 |
3 | The AutoTrain UI can be started in multiple ways depending on your needs.
4 | We offer UI on Hugging Face Spaces, Colab and locally!
5 |
6 | ## Hugging Face Spaces
7 |
8 | To start the UI on Hugging Face Spaces, you can simply click on the following link:
9 |
10 | [](https://huggingface.co/login?next=/spaces/autotrain-projects/autotrain-advanced?duplicate=true)
11 |
12 | Please make sure you keep the space private and attach appropriate hardware to the space.
13 | You can also read more about AutoTrain on the homepage and follow the link there to start your own training instance on
14 | Hugging Face Spaces. [Click here](https://huggingface.co/autotrain) to visit the homepage.
15 |
16 | ## Colab
17 |
18 | To start the UI on Colab, you can simply click on the following link:
19 |
20 | [](https://colab.research.google.com/github/huggingface/autotrain-advanced/blob/main/colabs/AutoTrain.ipynb)
21 |
22 | Please note, to run the app on Colab, you will need an ngrok token. You can get one by signing up for free on [ngrok](https://ngrok.com/).
23 | This is because Colab does not allow exposing ports to the internet directly.
24 |
25 |
26 | ## Locally
27 |
28 | To run the autotrain app locally, install autotrain-advanced python package:
29 |
30 | ```bash
31 | $ pip install autotrain-advanced
32 | ```
33 |
34 | and then run the following command:
35 |
36 | ```bash
37 | $ export HF_TOKEN=your_hugging_face_write_token
38 | $ autotrain app --host 127.0.0.1 --port 8000
39 | ```
40 |
41 | This will start the app on `http://127.0.0.1:8000`.
42 |
43 | AutoTrain doesn't install pytorch, torchaudio, torchvision, or any other dependencies. You will need to install them separately.
44 | It is thus recommended to use conda environment:
45 |
46 |
47 | ```bash
48 | $ conda create -n autotrain python=3.10
49 | $ conda activate autotrain
50 |
51 | $ pip install autotrain-advanced
52 |
53 | $ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
54 | $ conda install -c "nvidia/label/cuda-12.1.0" cuda-nvcc
55 | $ conda install xformers -c xformers
56 |
57 | $ python -m nltk.downloader punkt
58 | $ pip install flash-attn --no-build-isolation
59 | $ pip install deepspeed
60 |
61 | $ export HF_TOKEN=your_hugging_face_write_token
62 | $ autotrain app --host 127.0.0.1 --port 8000
63 | ```
64 |
65 | In case of any issues, please report on the [GitHub issues](https://github.com/huggingface/autotrain-advanced/).
66 |
--------------------------------------------------------------------------------
/docs/source/support.mdx:
--------------------------------------------------------------------------------
1 | # Help and Support
2 |
3 | If you need assistance with AutoTrain Advanced or have questions about your projects,
4 | you can reach out through several dedicated support channels. We're here to help you
5 | navigate any issues you encounter, from technical queries to billing concerns.
6 | Below are the best ways to get support:
7 |
8 |
9 | - For technical support or to report a bug, you can [create an issue](https://github.com/huggingface/autotrain-advanced/issues/new)
10 | directly in the AutoTrain Advanced GitHub repository. GitHub repo is ideal for tracking bugs,
11 | requesting features, or getting help with troubleshooting problems. When submitting an
12 | issue, please include all the details in question to help us provide the most
13 | relevant support quickly.
14 |
15 | - [Ask in the Hugging Face Forum](https://discuss.huggingface.co/c/autotrain/16). This space is perfect for asking questions,
16 | sharing your experiences, or discussing AutoTrain with other users and the Hugging Face
17 | team. The forum is a great resource for getting advice, learning best practices, and
18 | connecting with other machine learning practitioners.
19 |
20 | - For enterprise users or specific inquiries related to billing, please [email us](mailto:autotrain@hf.co) directly.
21 | This channel ensures that your more sensitive or account-specific issues are handled
22 | appropriately and confidentially. When emailing, please provide your username and
23 | project name so we can assist you efficiently.
24 |
25 | Please note: e-mail support is only available for pro/enterprise users or those with specific queries about billing.
26 |
27 |
28 | By utilizing these support channels, you can ensure that any hurdles you face while using
29 | AutoTrain Advanced are addressed promptly, allowing you to focus on achieving your project
30 | goals. Whether you're a beginner or an experienced user, we are here to support your
31 | journey in AI model training.
32 |
--------------------------------------------------------------------------------
/docs/source/tasks/object_detection.mdx:
--------------------------------------------------------------------------------
1 | # Object Detection
2 |
3 | Object detection is a form of supervised learning where a model is trained to identify
4 | and categorize objects within images. AutoTrain simplifies the process, enabling you to
5 | train a state-of-the-art object detection model by simply uploading labeled example images.
6 |
7 |
8 | ## Preparing your data
9 |
10 | To ensure your object detection model trains effectively, follow these guidelines for preparing your data:
11 |
12 |
13 | ### Organizing Images
14 |
15 |
16 | Prepare a zip file containing your images and metadata.jsonl.
17 |
18 |
19 | ```
20 | Archive.zip
21 | ├── 0001.png
22 | ├── 0002.png
23 | ├── 0003.png
24 | ├── .
25 | ├── .
26 | ├── .
27 | └── metadata.jsonl
28 | ```
29 |
30 | Example for `metadata.jsonl`:
31 |
32 | ```
33 | {"file_name": "0001.png", "objects": {"bbox": [[302.0, 109.0, 73.0, 52.0]], "category": [0]}}
34 | {"file_name": "0002.png", "objects": {"bbox": [[810.0, 100.0, 57.0, 28.0]], "category": [1]}}
35 | {"file_name": "0003.png", "objects": {"bbox": [[160.0, 31.0, 248.0, 616.0], [741.0, 68.0, 202.0, 401.0]], "category": [2, 2]}}
36 | ```
37 |
38 | Please note that bboxes need to be in COCO format `[x, y, width, height]`.
39 |
40 |
41 | ### Image Requirements
42 |
43 | - Format: Ensure all images are in JPEG, JPG, or PNG format.
44 |
45 | - Quantity: Include at least 5 images to provide the model with sufficient examples for learning.
46 |
47 | - Exclusivity: The zip file should exclusively contain images and metadata.jsonl.
48 | No additional files or nested folders should be included.
49 |
50 |
51 | Some points to keep in mind:
52 |
53 | - The images must be jpeg, jpg or png.
54 | - There should be at least 5 images per split.
55 | - There must not be any other files in the zip file.
56 | - There must not be any other folders inside the zip folder.
57 |
58 | When train.zip is decompressed, it creates no folders: only images and metadata.jsonl.
59 |
60 | ## Parameters
61 |
62 | [[autodoc]] trainers.object_detection.params.ObjectDetectionParams
63 |
--------------------------------------------------------------------------------
/docs/source/tasks/sentence_transformer.mdx:
--------------------------------------------------------------------------------
1 | # Sentence Transformers
2 |
3 | This task lets you easily train or fine-tune a Sentence Transformer model on your own dataset.
4 |
5 | AutoTrain supports the following types of sentence transformer finetuning:
6 |
7 | - `pair`: dataset with two sentences: anchor and positive
8 | - `pair_class`: dataset with two sentences: premise and hypothesis and a target label
9 | - `pair_score`: dataset with two sentences: sentence1 and sentence2 and a target score
10 | - `triplet`: dataset with three sentences: anchor, positive and negative
11 | - `qa`: dataset with two sentences: query and answer
12 |
13 | ## Data Format
14 |
15 | Sentence Transformers finetuning accepts data in CSV/JSONL format. You can also use a dataset from Hugging Face Hub.
16 |
17 | ### `pair`
18 |
19 | For `pair` training, the data should be in the following format:
20 |
21 | | anchor | positive |
22 | |--------|----------|
23 | | hello | hi |
24 | | how are you | I am fine |
25 | | What is your name? | My name is Abhishek |
26 | | Which is the best programming language? | Python |
27 |
28 | ### `pair_class`
29 |
30 | For `pair_class` training, the data should be in the following format:
31 |
32 | | premise | hypothesis | label |
33 | |---------|------------|-------|
34 | | hello | hi | 1 |
35 | | how are you | I am fine | 0 |
36 | | What is your name? | My name is Abhishek | 1 |
37 | | Which is the best programming language? | Python | 1 |
38 |
39 | ### `pair_score`
40 |
41 | For `pair_score` training, the data should be in the following format:
42 |
43 | | sentence1 | sentence2 | score |
44 | |-----------|-----------|-------|
45 | | hello | hi | 0.8 |
46 | | how are you | I am fine | 0.2 |
47 | | What is your name? | My name is Abhishek | 0.9 |
48 | | Which is the best programming language? | Python | 0.7 |
49 |
50 | ### `triplet`
51 |
52 | For `triplet` training, the data should be in the following format:
53 |
54 | | anchor | positive | negative |
55 | |--------|----------|----------|
56 | | hello | hi | bye |
57 | | how are you | I am fine | I am not fine |
58 | | What is your name? | My name is Abhishek | Whats it to you? |
59 | | Which is the best programming language? | Python | Javascript |
60 |
61 | ### `qa`
62 |
63 | For `qa` training, the data should be in the following format:
64 |
65 | | query | answer |
66 | |-------|--------|
67 | | hello | hi |
68 | | how are you | I am fine |
69 | | What is your name? | My name is Abhishek |
70 | | Which is the best programming language? | Python |
71 |
72 |
73 | ## Parameters
74 |
75 | [[autodoc]] trainers.sent_transformers.params.SentenceTransformersParams
76 |
--------------------------------------------------------------------------------
/docs/source/tasks/seq2seq.mdx:
--------------------------------------------------------------------------------
1 | # Seq2Seq
2 |
3 | Seq2Seq is a task that involves converting a sequence of words into another sequence of words.
4 | It is used in machine translation, text summarization, and question answering.
5 |
6 | ## Data Format
7 |
8 | You can have the dataset as a CSV file:
9 |
10 | ```csv
11 | text,target
12 | "this movie is great","dieser Film ist großartig"
13 | "this movie is bad","dieser Film ist schlecht"
14 | .
15 | .
16 | .
17 | ```
18 |
19 | Or as a JSONL file:
20 |
21 | ```json
22 | {"text": "this movie is great", "target": "dieser Film ist großartig"}
23 | {"text": "this movie is bad", "target": "dieser Film ist schlecht"}
24 | .
25 | .
26 | .
27 | ```
28 |
29 |
30 | ## Columns
31 |
32 | Your CSV/JSONL dataset must have two columns: `text` and `target`.
33 |
34 |
35 | ## Parameters
36 |
37 | [[autodoc]] trainers.seq2seq.params.Seq2SeqParams
38 |
--------------------------------------------------------------------------------
/docs/source/tasks/tabular.mdx:
--------------------------------------------------------------------------------
1 | # Tabular Classification / Regression
2 |
3 | Using AutoTrain, you can train a model to classify or regress tabular data easily.
4 | All you need to do is select from a list of models and upload your dataset.
5 | Parameter tuning is done automatically.
6 |
7 | ## Models
8 |
9 | The following models are available for tabular classification / regression.
10 |
11 | - xgboost
12 | - random_forest
13 | - ridge
14 | - logistic_regression
15 | - svm
16 | - extra_trees
17 | - gradient_boosting
18 | - adaboost
19 | - decision_tree
20 | - knn
21 |
22 |
23 | ## Data Format
24 |
25 | ```csv
26 | id,category1,category2,feature1,target
27 | 1,A,X,0.3373961604172684,1
28 | 2,B,Z,0.6481718720511972,0
29 | 3,A,Y,0.36824153984054797,1
30 | 4,B,Z,0.9571551589530464,1
31 | 5,B,Z,0.14035078041264515,1
32 | 6,C,X,0.8700872583584364,1
33 | 7,A,Y,0.4736080452737105,0
34 | 8,C,Y,0.8009107519796442,1
35 | 9,A,Y,0.5204774795512048,0
36 | 10,A,Y,0.6788795301189603,0
37 | .
38 | .
39 | .
40 | ```
41 |
42 | ## Columns
43 |
44 | Your CSV dataset must have two columns: `id` and `target`.
45 |
46 |
47 | ## Parameters
48 |
49 | [[autodoc]] trainers.tabular.params.TabularParams
50 |
--------------------------------------------------------------------------------
/docs/source/tasks/token_classification.mdx:
--------------------------------------------------------------------------------
1 | # Token Classification
2 |
3 | Token classification is the task of classifying each token in a sequence. This can be used
4 | for Named Entity Recognition (NER), Part-of-Speech (POS) tagging, and more. Get your data ready in
5 | proper format and then with just a few clicks, your state-of-the-art model will be ready to
6 | be used in production.
7 |
8 | ## Data Format
9 |
10 | The data should be in the following CSV format:
11 |
12 | ```csv
13 | tokens,tags
14 | "['I', 'love', 'Paris']","['O', 'O', 'B-LOC']"
15 | "['I', 'live', 'in', 'New', 'York']","['O', 'O', 'O', 'B-LOC', 'I-LOC']"
16 | .
17 | .
18 | .
19 | ```
20 |
21 | or you can also use JSONL format:
22 |
23 | ```json
24 | {"tokens": ["I", "love", "Paris"],"tags": ["O", "O", "B-LOC"]}
25 | {"tokens": ["I", "live", "in", "New", "York"],"tags": ["O", "O", "O", "B-LOC", "I-LOC"]}
26 | .
27 | .
28 | .
29 | ```
30 |
31 | As you can see, we have two columns in the CSV file. One column is the tokens and the other
32 | is the tags. Both the columns are stringified lists! The tokens column contains the tokens
33 | of the sentence and the tags column contains the tags for each token.
34 |
35 | If your CSV is huge, you can divide it into multiple CSV files and upload them separately.
36 | Please make sure that the column names are the same in all CSV files.
37 |
38 | One way to divide the CSV file using pandas is as follows:
39 |
40 | ```python
41 | import pandas as pd
42 |
43 | # Set the chunk size
44 | chunk_size = 1000
45 | i = 1
46 |
47 | # Open the CSV file and read it in chunks
48 | for chunk in pd.read_csv('example.csv', chunksize=chunk_size):
49 | # Save each chunk to a new file
50 | chunk.to_csv(f'chunk_{i}.csv', index=False)
51 | i += 1
52 | ```
53 |
54 |
55 | Sample dataset from HuggingFace Hub: [conll2003](https://huggingface.co/datasets/eriktks/conll2003)
56 |
57 |
58 | ## Columns
59 |
60 | Your CSV/JSONL dataset must have two columns: `tokens` and `tags`.
61 |
62 |
63 | ## Parameters
64 |
65 | [[autodoc]] trainers.token_classification.params.TokenClassificationParams
66 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.4.23
2 | datasets[vision]~=3.2.0
3 | evaluate==0.4.3
4 | ipadic==1.0.0
5 | jiwer==3.0.5
6 | joblib==1.4.2
7 | loguru==0.7.3
8 | pandas==2.2.3
9 | nltk==3.9.1
10 | optuna==4.1.0
11 | Pillow==11.0.0
12 | sacremoses==0.1.1
13 | scikit-learn==1.6.0
14 | sentencepiece==0.2.0
15 | tqdm==4.67.1
16 | werkzeug==3.1.3
17 | xgboost==2.1.3
18 | huggingface_hub==0.27.0
19 | requests==2.32.3
20 | einops==0.8.0
21 | packaging==24.2
22 | cryptography==44.0.0
23 | nvitop==1.3.2
24 | # latest versions
25 | tensorboard==2.18.0
26 | peft==0.14.0
27 | trl==0.13.0
28 | tiktoken==0.8.0
29 | transformers==4.48.0
30 | accelerate==1.2.1
31 | bitsandbytes==0.45.0
32 | # extras
33 | rouge_score==0.1.2
34 | py7zr==0.22.0
35 | fastapi==0.115.6
36 | uvicorn==0.34.0
37 | python-multipart==0.0.20
38 | pydantic==2.10.4
39 | hf-transfer
40 | pyngrok==7.2.1
41 | authlib==1.4.0
42 | itsdangerous==2.2.0
43 | seqeval==1.2.2
44 | httpx==0.28.1
45 | pyyaml==6.0.2
46 | timm==1.0.12
47 | torchmetrics==1.6.0
48 | pycocotools==2.0.8
49 | sentence-transformers==3.3.1
50 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | license_files = LICENSE
3 | version = attr: autotrain.__version__
4 |
5 | [isort]
6 | ensure_newline_before_comments = True
7 | force_grid_wrap = 0
8 | include_trailing_comma = True
9 | line_length = 119
10 | lines_after_imports = 2
11 | multi_line_output = 3
12 | use_parentheses = True
13 |
14 | [flake8]
15 | ignore = E203, E501, W503
16 | max-line-length = 119
17 | per-file-ignores =
18 | # imported but unused
19 | __init__.py: F401, E402
20 | src/autotrain/params.py: F401
21 | exclude =
22 | .git,
23 | .venv,
24 | __pycache__,
25 | dist
26 | build
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | """
3 | HuggingFace / AutoTrain Advanced
4 | """
5 | import os
6 |
7 | from setuptools import find_packages, setup
8 |
9 |
10 | DOCLINES = __doc__.split("\n")
11 |
12 | this_directory = os.path.abspath(os.path.dirname(__file__))
13 | with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f:
14 | LONG_DESCRIPTION = f.read()
15 |
16 | # get INSTALL_REQUIRES from requirements.txt
17 | INSTALL_REQUIRES = []
18 | requirements_path = os.path.join(this_directory, "requirements.txt")
19 | with open(requirements_path, encoding="utf-8") as f:
20 | for line in f:
21 | # Exclude 'bitsandbytes' if installing on macOS
22 | if "bitsandbytes" in line:
23 | line = line.strip() + " ; sys_platform == 'linux'"
24 | INSTALL_REQUIRES.append(line.strip())
25 | else:
26 | INSTALL_REQUIRES.append(line.strip())
27 |
28 | QUALITY_REQUIRE = [
29 | "black",
30 | "isort",
31 | "flake8==3.7.9",
32 | ]
33 |
34 | TESTS_REQUIRE = ["pytest"]
35 |
36 | CLIENT_REQUIRES = ["requests", "loguru"]
37 |
38 |
39 | EXTRAS_REQUIRE = {
40 | "base": INSTALL_REQUIRES,
41 | "dev": INSTALL_REQUIRES + QUALITY_REQUIRE + TESTS_REQUIRE,
42 | "quality": INSTALL_REQUIRES + QUALITY_REQUIRE,
43 | "docs": INSTALL_REQUIRES
44 | + [
45 | "recommonmark",
46 | "sphinx==3.1.2",
47 | "sphinx-markdown-tables",
48 | "sphinx-rtd-theme==0.4.3",
49 | "sphinx-copybutton",
50 | ],
51 | "client": CLIENT_REQUIRES,
52 | }
53 |
54 | setup(
55 | name="autotrain-advanced",
56 | description=DOCLINES[0],
57 | long_description=LONG_DESCRIPTION,
58 | long_description_content_type="text/markdown",
59 | author="HuggingFace Inc.",
60 | author_email="autotrain@huggingface.co",
61 | url="https://github.com/huggingface/autotrain-advanced",
62 | download_url="https://github.com/huggingface/autotrain-advanced/tags",
63 | license="Apache 2.0",
64 | package_dir={"": "src"},
65 | packages=find_packages("src"),
66 | extras_require=EXTRAS_REQUIRE,
67 | install_requires=INSTALL_REQUIRES,
68 | entry_points={"console_scripts": ["autotrain=autotrain.cli.autotrain:main"]},
69 | classifiers=[
70 | "Development Status :: 5 - Production/Stable",
71 | "Intended Audience :: Developers",
72 | "Intended Audience :: Education",
73 | "Intended Audience :: Science/Research",
74 | "License :: OSI Approved :: Apache Software License",
75 | "Operating System :: OS Independent",
76 | "Programming Language :: Python :: 3.8",
77 | "Programming Language :: Python :: 3.9",
78 | "Programming Language :: Python :: 3.10",
79 | "Programming Language :: Python :: 3.11",
80 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
81 | ],
82 | keywords="automl autonlp autotrain huggingface",
83 | data_files=[
84 | (
85 | "static",
86 | [
87 | "src/autotrain/app/static/logo.png",
88 | "src/autotrain/app/static/scripts/fetch_data_and_update_models.js",
89 | "src/autotrain/app/static/scripts/listeners.js",
90 | "src/autotrain/app/static/scripts/utils.js",
91 | "src/autotrain/app/static/scripts/poll.js",
92 | "src/autotrain/app/static/scripts/logs.js",
93 | ],
94 | ),
95 | (
96 | "templates",
97 | [
98 | "src/autotrain/app/templates/index.html",
99 | "src/autotrain/app/templates/error.html",
100 | "src/autotrain/app/templates/duplicate.html",
101 | "src/autotrain/app/templates/login.html",
102 | ],
103 | ),
104 | ],
105 | include_package_data=True,
106 | )
107 |
--------------------------------------------------------------------------------
/src/autotrain/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020-2023 The HuggingFace AutoTrain Authors
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | # pylint: enable=line-too-long
18 | import os
19 |
20 |
21 | os.environ["BITSANDBYTES_NOWELCOME"] = "1"
22 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
23 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
24 |
25 |
26 | import warnings
27 |
28 |
29 | try:
30 | import torch._dynamo
31 |
32 | torch._dynamo.config.suppress_errors = True
33 | except ImportError:
34 | pass
35 |
36 | from autotrain.logging import Logger
37 |
38 |
39 | warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")
40 | warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
41 | warnings.filterwarnings("ignore", category=UserWarning, module="peft")
42 | warnings.filterwarnings("ignore", category=UserWarning, module="accelerate")
43 | warnings.filterwarnings("ignore", category=UserWarning, module="datasets")
44 | warnings.filterwarnings("ignore", category=FutureWarning, module="accelerate")
45 | warnings.filterwarnings("ignore", category=UserWarning, module="huggingface_hub")
46 |
47 | logger = Logger().get_logger()
48 | __version__ = "0.8.37.dev0"
49 |
50 |
51 | def is_colab():
52 | try:
53 | import google.colab
54 |
55 | return True
56 | except ImportError:
57 | return False
58 |
59 |
60 | def is_unsloth_available():
61 | try:
62 | from unsloth import FastLanguageModel
63 |
64 | return True
65 | except Exception as e:
66 | logger.warning("Unsloth not available, continuing without it")
67 | logger.warning(e)
68 | return False
69 |
--------------------------------------------------------------------------------
/src/autotrain/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/app/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/app/app.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from fastapi import FastAPI, Request
4 | from fastapi.responses import RedirectResponse
5 | from fastapi.staticfiles import StaticFiles
6 |
7 | from autotrain import __version__, logger
8 | from autotrain.app.api_routes import api_router
9 | from autotrain.app.oauth import attach_oauth
10 | from autotrain.app.ui_routes import ui_router
11 |
12 |
13 | logger.info("Starting AutoTrain...")
14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15 | app = FastAPI()
16 | if "SPACE_ID" in os.environ:
17 | attach_oauth(app)
18 |
19 | app.include_router(ui_router, prefix="/ui", include_in_schema=False)
20 | app.include_router(api_router, prefix="/api")
21 | static_path = os.path.join(BASE_DIR, "static")
22 | app.mount("/static", StaticFiles(directory=static_path), name="static")
23 | logger.info(f"AutoTrain version: {__version__}")
24 | logger.info("AutoTrain started successfully")
25 |
26 |
27 | @app.get("/")
28 | async def forward_to_ui(request: Request):
29 | """
30 | Forwards the incoming request to the UI endpoint.
31 |
32 | Args:
33 | request (Request): The incoming HTTP request.
34 |
35 | Returns:
36 | RedirectResponse: A response object that redirects to the UI endpoint,
37 | including any query parameters from the original request.
38 | """
39 | query_params = request.query_params
40 | url = "/ui/"
41 | if query_params:
42 | url += f"?{query_params}"
43 | return RedirectResponse(url=url)
44 |
--------------------------------------------------------------------------------
/src/autotrain/app/db.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 |
3 |
4 | class AutoTrainDB:
5 | """
6 | A class to manage job records in a SQLite database.
7 |
8 | Attributes:
9 | -----------
10 | db_path : str
11 | The path to the SQLite database file.
12 | conn : sqlite3.Connection
13 | The SQLite database connection object.
14 | c : sqlite3.Cursor
15 | The SQLite database cursor object.
16 |
17 | Methods:
18 | --------
19 | __init__(db_path):
20 | Initializes the database connection and creates the jobs table if it does not exist.
21 |
22 | create_jobs_table():
23 | Creates the jobs table in the database if it does not exist.
24 |
25 | add_job(pid):
26 | Adds a new job with the given process ID (pid) to the jobs table.
27 |
28 | get_running_jobs():
29 | Retrieves a list of all running job process IDs (pids) from the jobs table.
30 |
31 | delete_job(pid):
32 | Deletes the job with the given process ID (pid) from the jobs table.
33 | """
34 |
35 | def __init__(self, db_path):
36 | self.db_path = db_path
37 | self.conn = sqlite3.connect(db_path)
38 | self.c = self.conn.cursor()
39 | self.create_jobs_table()
40 |
41 | def create_jobs_table(self):
42 | self.c.execute(
43 | """CREATE TABLE IF NOT EXISTS jobs
44 | (id INTEGER PRIMARY KEY, pid INTEGER)"""
45 | )
46 | self.conn.commit()
47 |
48 | def add_job(self, pid):
49 | sql = f"INSERT INTO jobs (pid) VALUES ({pid})"
50 | self.c.execute(sql)
51 | self.conn.commit()
52 |
53 | def get_running_jobs(self):
54 | self.c.execute("""SELECT pid FROM jobs""")
55 | running_pids = self.c.fetchall()
56 | running_pids = [pid[0] for pid in running_pids]
57 | return running_pids
58 |
59 | def delete_job(self, pid):
60 | sql = f"DELETE FROM jobs WHERE pid={pid}"
61 | self.c.execute(sql)
62 | self.conn.commit()
63 |
--------------------------------------------------------------------------------
/src/autotrain/app/static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/app/static/logo.png
--------------------------------------------------------------------------------
/src/autotrain/app/static/scripts/fetch_data_and_update_models.js:
--------------------------------------------------------------------------------
1 | document.addEventListener('DOMContentLoaded', function () {
2 | function fetchDataAndUpdateModels() {
3 | const taskValue = document.getElementById('task').value;
4 | const baseModelSelect = document.getElementById('base_model');
5 | const queryParams = new URLSearchParams(window.location.search);
6 | const customModelsValue = queryParams.get('custom_models');
7 | const baseModelInput = document.getElementById('base_model_input');
8 | const baseModelCheckbox = document.getElementById('base_model_checkbox');
9 |
10 | let fetchURL = `/ui/model_choices/${taskValue}`;
11 | if (customModelsValue) {
12 | fetchURL += `?custom_models=${customModelsValue}`;
13 | }
14 | baseModelSelect.innerHTML = 'Fetching models...';
15 | fetch(fetchURL)
16 | .then(response => response.json())
17 | .then(data => {
18 | const baseModelSelect = document.getElementById('base_model');
19 | baseModelCheckbox.checked = false;
20 | baseModelSelect.classList.remove('hidden');
21 | baseModelInput.classList.add('hidden');
22 | baseModelSelect.innerHTML = ''; // Clear existing options
23 | data.forEach(model => {
24 | let option = document.createElement('option');
25 | option.value = model.id; // Assuming each model has an 'id'
26 | option.textContent = model.name; // Assuming each model has a 'name'
27 | baseModelSelect.appendChild(option);
28 | });
29 | })
30 | .catch(error => console.error('Error:', error));
31 | }
32 | document.getElementById('task').addEventListener('change', fetchDataAndUpdateModels);
33 | fetchDataAndUpdateModels();
34 | });
--------------------------------------------------------------------------------
/src/autotrain/app/static/scripts/logs.js:
--------------------------------------------------------------------------------
1 | document.addEventListener('DOMContentLoaded', function () {
2 | var fetchLogsInterval;
3 |
4 | // Function to check the modal's display property and fetch logs if visible
5 | function fetchAndDisplayLogs() {
6 | var modal = document.getElementById('logs-modal');
7 | var displayStyle = window.getComputedStyle(modal).display;
8 |
9 | // Check if the modal display property is 'flex'
10 | if (displayStyle === 'flex') {
11 | fetchLogs(); // Initial fetch when the modal is opened
12 |
13 | // Clear any existing interval to avoid duplicates
14 | clearInterval(fetchLogsInterval);
15 |
16 | // Set up the interval to fetch logs every 5 seconds
17 | fetchLogsInterval = setInterval(fetchLogs, 5000);
18 | } else {
19 | // Clear the interval when the modal is not displayed as 'flex'
20 | clearInterval(fetchLogsInterval);
21 | }
22 | }
23 |
24 | // Function to fetch logs from the server
25 | function fetchLogs() {
26 | fetch('/ui/logs')
27 | .then(response => response.json())
28 | .then(data => {
29 | var logContainer = document.getElementById('logContent');
30 | logContainer.innerHTML = ''; // Clear previous logs
31 |
32 | // Handling the case when logs are only available in local mode or no logs available
33 | if (typeof data.logs === 'string') {
34 | logContainer.textContent = data.logs;
35 | } else {
36 | // Assuming data.logs is an array of log entries
37 | data.logs.forEach(log => {
38 | if (log.trim().length > 0) {
39 | var p = document.createElement('p');
40 | p.textContent = log;
41 | logContainer.appendChild(p); // Appends logs in order received
42 | }
43 | });
44 | }
45 | })
46 | .catch(error => console.error('Error fetching logs:', error));
47 | }
48 |
49 | // Set up an observer to detect when the modal becomes visible or hidden
50 | var observer = new MutationObserver(function (mutations) {
51 | mutations.forEach(function (mutation) {
52 | if (mutation.attributeName === 'class') {
53 | fetchAndDisplayLogs();
54 | }
55 | });
56 | });
57 |
58 | var modal = document.getElementById('logs-modal');
59 | observer.observe(modal, {
60 | attributes: true //configure it to listen to attribute changes
61 | });
62 | });
--------------------------------------------------------------------------------
/src/autotrain/app/static/scripts/poll.js:
--------------------------------------------------------------------------------
1 | document.addEventListener('DOMContentLoaded', (event) => {
2 | function pollAccelerators() {
3 | const numAcceleratorsElement = document.getElementById('num_accelerators');
4 | if (autotrain_local_value === 0) {
5 | numAcceleratorsElement.innerText = 'Accelerators: Only available in local mode.';
6 | numAcceleratorsElement.style.display = 'block'; // Ensure the element is visible
7 | return;
8 | }
9 |
10 | // Send a request to the /accelerators endpoint
11 | fetch('/ui/accelerators')
12 | .then(response => response.json()) // Assuming the response is in JSON format
13 | .then(data => {
14 | // Update the paragraph with the number of accelerators
15 | document.getElementById('num_accelerators').innerText = `Accelerators: ${data.accelerators}`;
16 | })
17 | .catch(error => {
18 | console.error('Error:', error);
19 | // Update the paragraph to show an error message
20 | document.getElementById('num_accelerators').innerText = 'Accelerators: Error fetching data';
21 | });
22 | }
23 | function pollModelTrainingStatus() {
24 | // Send a request to the /is_model_training endpoint
25 |
26 | if (autotrain_local_value === 0) {
27 | const statusParagraph = document.getElementById('is_model_training');
28 | statusParagraph.innerText = 'Running jobs: Only available in local mode.';
29 | statusParagraph.style.display = 'block';
30 | return;
31 | }
32 | fetch('/ui/is_model_training')
33 | .then(response => response.json()) // Assuming the response is in JSON format
34 | .then(data => {
35 | // Construct the message to display
36 | let message = data.model_training ? 'Running job PID(s): ' + data.pids.join(', ') : 'No running jobs';
37 |
38 | // Update the paragraph with the status of model training
39 | let statusParagraph = document.getElementById('is_model_training');
40 | statusParagraph.innerText = message;
41 | let stopTrainingButton = document.getElementById('stop-training-button');
42 | let startTrainingButton = document.getElementById('start-training-button');
43 |
44 | // Change the text color based on the model training status
45 | if (data.model_training) {
46 | // Set text color to red if jobs are running
47 | statusParagraph.style.color = 'red';
48 | stopTrainingButton.style.display = 'block';
49 | startTrainingButton.style.display = 'none';
50 | } else {
51 | // Set text color to green if no jobs are running
52 | statusParagraph.style.color = 'green';
53 | stopTrainingButton.style.display = 'none';
54 | startTrainingButton.style.display = 'block';
55 | }
56 | })
57 | .catch(error => {
58 | console.error('Error:', error);
59 | // Update the paragraph to show an error message
60 | let statusParagraph = document.getElementById('is_model_training');
61 | statusParagraph.innerText = 'Error fetching training status';
62 | statusParagraph.style.color = 'red'; // Set error message color to red
63 | });
64 | }
65 |
66 | setInterval(pollAccelerators, 10000);
67 | setInterval(pollModelTrainingStatus, 5000);
68 | pollAccelerators();
69 | pollModelTrainingStatus();
70 | });
--------------------------------------------------------------------------------
/src/autotrain/app/templates/duplicate.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
16 |
17 |
18 |
19 |
20 |
21 |

22 |
23 |
24 |
25 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/src/autotrain/app/templates/error.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
16 |
17 |
18 |
19 |
20 |
21 |

22 |
23 |
24 |
25 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/autotrain/app/templates/login.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
16 |
17 |
18 |
24 |
25 |
26 |
27 |
28 |

29 |
30 |
31 |
32 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/src/autotrain/app/training_api.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | import signal
4 | import sys
5 | from contextlib import asynccontextmanager
6 |
7 | from fastapi import FastAPI
8 |
9 | from autotrain import logger
10 | from autotrain.app.db import AutoTrainDB
11 | from autotrain.app.utils import get_running_jobs, kill_process_by_pid
12 | from autotrain.utils import run_training
13 |
14 |
15 | HF_TOKEN = os.environ.get("HF_TOKEN")
16 | AUTOTRAIN_USERNAME = os.environ.get("AUTOTRAIN_USERNAME")
17 | PROJECT_NAME = os.environ.get("PROJECT_NAME")
18 | TASK_ID = int(os.environ.get("TASK_ID"))
19 | PARAMS = os.environ.get("PARAMS")
20 | DATA_PATH = os.environ.get("DATA_PATH")
21 | MODEL = os.environ.get("MODEL")
22 | DB = AutoTrainDB("autotrain.db")
23 |
24 |
25 | def graceful_exit(signum, frame):
26 | """
27 | Handles the SIGTERM signal to perform cleanup and exit the program gracefully.
28 |
29 | Args:
30 | signum (int): The signal number.
31 | frame (FrameType): The current stack frame (or None).
32 |
33 | Logs a message indicating that SIGTERM was received and then exits the program with status code 0.
34 | """
35 | logger.info("SIGTERM received. Performing cleanup...")
36 | sys.exit(0)
37 |
38 |
39 | signal.signal(signal.SIGTERM, graceful_exit)
40 |
41 |
42 | class BackgroundRunner:
43 | """
44 | A class to handle background running tasks.
45 |
46 | Methods
47 | -------
48 | run_main():
49 | Continuously checks for running jobs and shuts down the server if no jobs are found.
50 | """
51 |
52 | async def run_main(self):
53 | while True:
54 | running_jobs = get_running_jobs(DB)
55 | if not running_jobs:
56 | logger.info("No running jobs found. Shutting down the server.")
57 | kill_process_by_pid(os.getpid())
58 | await asyncio.sleep(30)
59 |
60 |
61 | runner = BackgroundRunner()
62 |
63 |
64 | @asynccontextmanager
65 | async def lifespan(app: FastAPI):
66 | """
67 | Manages the lifespan of the FastAPI application.
68 |
69 | This function is responsible for starting the training process and
70 | managing a background task runner. It logs the process ID of the
71 | training job, adds the job to the database, and ensures the background
72 | task is properly cancelled when the application shuts down.
73 |
74 | Args:
75 | app (FastAPI): The FastAPI application instance.
76 |
77 | Yields:
78 | None: This function is a generator that yields control back to the
79 | FastAPI application lifecycle.
80 | """
81 | process_pid = run_training(params=PARAMS, task_id=TASK_ID)
82 | logger.info(f"Started training with PID {process_pid}")
83 | DB.add_job(process_pid)
84 | task = asyncio.create_task(runner.run_main())
85 | yield
86 |
87 | task.cancel()
88 | try:
89 | await task
90 | except asyncio.CancelledError:
91 | logger.info("Background runner task cancelled.")
92 |
93 |
94 | api = FastAPI(lifespan=lifespan)
95 | logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}")
96 | logger.info(f"PROJECT_NAME: {PROJECT_NAME}")
97 | logger.info(f"TASK_ID: {TASK_ID}")
98 | logger.info(f"DATA_PATH: {DATA_PATH}")
99 | logger.info(f"MODEL: {MODEL}")
100 |
101 |
102 | @api.get("/")
103 | async def root():
104 | return "Your model is being trained..."
105 |
106 |
107 | @api.get("/health")
108 | async def health():
109 | return "OK"
110 |
--------------------------------------------------------------------------------
/src/autotrain/backends/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/backends/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/backends/endpoints.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | from autotrain.backends.base import BaseBackend
4 |
5 |
6 | ENDPOINTS_URL = "https://api.endpoints.huggingface.cloud/v2/endpoint/"
7 |
8 |
9 | class EndpointsRunner(BaseBackend):
10 | """
11 | EndpointsRunner is responsible for creating and managing endpoint instances.
12 |
13 | Methods
14 | -------
15 | create():
16 | Creates an endpoint instance with the specified hardware and model parameters.
17 |
18 | create() Method
19 | ---------------
20 | Creates an endpoint instance with the specified hardware and model parameters.
21 |
22 | Parameters
23 | ----------
24 | None
25 |
26 | Returns
27 | -------
28 | str
29 | The name of the created endpoint instance.
30 |
31 | Raises
32 | ------
33 | requests.exceptions.RequestException
34 | If there is an issue with the HTTP request.
35 | """
36 |
37 | def create(self):
38 | hardware = self.available_hardware[self.backend]
39 | accelerator = hardware.split("_")[2]
40 | instance_size = hardware.split("_")[3]
41 | region = hardware.split("_")[1]
42 | vendor = hardware.split("_")[0]
43 | instance_type = hardware.split("_")[4]
44 | payload = {
45 | "accountId": self.username,
46 | "compute": {
47 | "accelerator": accelerator,
48 | "instanceSize": instance_size,
49 | "instanceType": instance_type,
50 | "scaling": {"maxReplica": 1, "minReplica": 1},
51 | },
52 | "model": {
53 | "framework": "custom",
54 | "image": {
55 | "custom": {
56 | "env": {
57 | "HF_TOKEN": self.params.token,
58 | "AUTOTRAIN_USERNAME": self.username,
59 | "PROJECT_NAME": self.params.project_name,
60 | "PARAMS": self.params.model_dump_json(),
61 | "DATA_PATH": self.params.data_path,
62 | "TASK_ID": str(self.task_id),
63 | "MODEL": self.params.model,
64 | "ENDPOINT_ID": f"{self.username}/{self.params.project_name}",
65 | },
66 | "health_route": "/",
67 | "port": 7860,
68 | "url": "public.ecr.aws/z4c3o6n6/autotrain-api:latest",
69 | }
70 | },
71 | "repository": "autotrain-projects/autotrain-advanced",
72 | "revision": "main",
73 | "task": "custom",
74 | },
75 | "name": self.params.project_name,
76 | "provider": {"region": region, "vendor": vendor},
77 | "type": "protected",
78 | }
79 | headers = {"Authorization": f"Bearer {self.params.token}"}
80 | r = requests.post(
81 | ENDPOINTS_URL + self.username,
82 | json=payload,
83 | headers=headers,
84 | timeout=120,
85 | )
86 | return r.json()["name"]
87 |
--------------------------------------------------------------------------------
/src/autotrain/backends/local.py:
--------------------------------------------------------------------------------
1 | from autotrain import logger
2 | from autotrain.backends.base import BaseBackend
3 | from autotrain.utils import run_training
4 |
5 |
6 | class LocalRunner(BaseBackend):
7 | """
8 | LocalRunner is a class that inherits from BaseBackend and is responsible for managing local training tasks.
9 |
10 | Methods:
11 | create():
12 | Starts the local training process by retrieving parameters and task ID from environment variables.
13 | Logs the start of the training process.
14 | Runs the training with the specified parameters and task ID.
15 | If the `wait` attribute is False, logs the training process ID (PID).
16 | Returns the training process ID (PID).
17 | """
18 |
19 | def create(self):
20 | logger.info("Starting local training...")
21 | params = self.env_vars["PARAMS"]
22 | task_id = int(self.env_vars["TASK_ID"])
23 | training_pid = run_training(params, task_id, local=True, wait=self.wait)
24 | if not self.wait:
25 | logger.info(f"Training PID: {training_pid}")
26 | return training_pid
27 |
--------------------------------------------------------------------------------
/src/autotrain/backends/spaces.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | from huggingface_hub import HfApi
4 |
5 | from autotrain.backends.base import BaseBackend
6 | from autotrain.trainers.generic.params import GenericParams
7 |
8 |
9 | _DOCKERFILE = """
10 | FROM huggingface/autotrain-advanced:latest
11 |
12 | CMD pip uninstall -y autotrain-advanced && pip install -U autotrain-advanced && autotrain api --port 7860 --host 0.0.0.0
13 | """
14 |
15 | # format _DOCKERFILE
16 | _DOCKERFILE = _DOCKERFILE.replace("\n", " ").replace(" ", "\n").strip()
17 |
18 |
19 | class SpaceRunner(BaseBackend):
20 | """
21 | SpaceRunner is a backend class responsible for creating and managing training jobs on Hugging Face Spaces.
22 |
23 | Methods
24 | -------
25 | _create_readme():
26 | Creates a README.md file content for the space.
27 |
28 | _add_secrets(api, space_id):
29 | Adds necessary secrets to the space repository.
30 |
31 | create():
32 | Creates a new space repository, adds secrets, and uploads necessary files.
33 | """
34 |
35 | def _create_readme(self):
36 | _readme = "---\n"
37 | _readme += f"title: {self.params.project_name}\n"
38 | _readme += "emoji: 🚀\n"
39 | _readme += "colorFrom: green\n"
40 | _readme += "colorTo: indigo\n"
41 | _readme += "sdk: docker\n"
42 | _readme += "pinned: false\n"
43 | _readme += "tags:\n"
44 | _readme += "- autotrain\n"
45 | _readme += "duplicated_from: autotrain-projects/autotrain-advanced\n"
46 | _readme += "---\n"
47 | _readme = io.BytesIO(_readme.encode())
48 | return _readme
49 |
50 | def _add_secrets(self, api, space_id):
51 | if isinstance(self.params, GenericParams):
52 | for k, v in self.params.env.items():
53 | api.add_space_secret(repo_id=space_id, key=k, value=v)
54 | self.params.env = {}
55 |
56 | api.add_space_secret(repo_id=space_id, key="HF_TOKEN", value=self.params.token)
57 | api.add_space_secret(repo_id=space_id, key="AUTOTRAIN_USERNAME", value=self.username)
58 | api.add_space_secret(repo_id=space_id, key="PROJECT_NAME", value=self.params.project_name)
59 | api.add_space_secret(repo_id=space_id, key="TASK_ID", value=str(self.task_id))
60 | api.add_space_secret(repo_id=space_id, key="PARAMS", value=self.params.model_dump_json())
61 | api.add_space_secret(repo_id=space_id, key="DATA_PATH", value=self.params.data_path)
62 |
63 | if not isinstance(self.params, GenericParams):
64 | api.add_space_secret(repo_id=space_id, key="MODEL", value=self.params.model)
65 |
66 | def create(self):
67 | api = HfApi(token=self.params.token)
68 | space_id = f"{self.username}/autotrain-{self.params.project_name}"
69 | api.create_repo(
70 | repo_id=space_id,
71 | repo_type="space",
72 | space_sdk="docker",
73 | space_hardware=self.available_hardware[self.backend],
74 | private=True,
75 | )
76 | self._add_secrets(api, space_id)
77 | api.set_space_sleep_time(repo_id=space_id, sleep_time=604800)
78 | readme = self._create_readme()
79 | api.upload_file(
80 | path_or_fileobj=readme,
81 | path_in_repo="README.md",
82 | repo_id=space_id,
83 | repo_type="space",
84 | )
85 |
86 | _dockerfile = io.BytesIO(_DOCKERFILE.encode())
87 | api.upload_file(
88 | path_or_fileobj=_dockerfile,
89 | path_in_repo="Dockerfile",
90 | repo_id=space_id,
91 | repo_type="space",
92 | )
93 | return space_id
94 |
--------------------------------------------------------------------------------
/src/autotrain/cli/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from argparse import ArgumentParser
3 |
4 |
5 | class BaseAutoTrainCommand(ABC):
6 | @staticmethod
7 | @abstractmethod
8 | def register_subcommand(parser: ArgumentParser):
9 | raise NotImplementedError()
10 |
11 | @abstractmethod
12 | def run(self):
13 | raise NotImplementedError()
14 |
--------------------------------------------------------------------------------
/src/autotrain/cli/autotrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from autotrain import __version__, logger
4 | from autotrain.cli.run_api import RunAutoTrainAPICommand
5 | from autotrain.cli.run_app import RunAutoTrainAppCommand
6 | from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand
7 | from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand
8 | from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand
9 | from autotrain.cli.run_llm import RunAutoTrainLLMCommand
10 | from autotrain.cli.run_object_detection import RunAutoTrainObjectDetectionCommand
11 | from autotrain.cli.run_sent_tranformers import RunAutoTrainSentenceTransformersCommand
12 | from autotrain.cli.run_seq2seq import RunAutoTrainSeq2SeqCommand
13 | from autotrain.cli.run_setup import RunSetupCommand
14 | from autotrain.cli.run_spacerunner import RunAutoTrainSpaceRunnerCommand
15 | from autotrain.cli.run_tabular import RunAutoTrainTabularCommand
16 | from autotrain.cli.run_text_classification import RunAutoTrainTextClassificationCommand
17 | from autotrain.cli.run_text_regression import RunAutoTrainTextRegressionCommand
18 | from autotrain.cli.run_token_classification import RunAutoTrainTokenClassificationCommand
19 | from autotrain.cli.run_tools import RunAutoTrainToolsCommand
20 | from autotrain.parser import AutoTrainConfigParser
21 |
22 |
23 | def main():
24 | parser = argparse.ArgumentParser(
25 | "AutoTrain advanced CLI",
26 | usage="autotrain []",
27 | epilog="For more information about a command, run: `autotrain --help`",
28 | )
29 | parser.add_argument("--version", "-v", help="Display AutoTrain version", action="store_true")
30 | parser.add_argument("--config", help="Optional configuration file", type=str)
31 | commands_parser = parser.add_subparsers(help="commands")
32 |
33 | # Register commands
34 | RunAutoTrainAppCommand.register_subcommand(commands_parser)
35 | RunAutoTrainLLMCommand.register_subcommand(commands_parser)
36 | RunSetupCommand.register_subcommand(commands_parser)
37 | RunAutoTrainAPICommand.register_subcommand(commands_parser)
38 | RunAutoTrainTextClassificationCommand.register_subcommand(commands_parser)
39 | RunAutoTrainImageClassificationCommand.register_subcommand(commands_parser)
40 | RunAutoTrainTabularCommand.register_subcommand(commands_parser)
41 | RunAutoTrainSpaceRunnerCommand.register_subcommand(commands_parser)
42 | RunAutoTrainSeq2SeqCommand.register_subcommand(commands_parser)
43 | RunAutoTrainTokenClassificationCommand.register_subcommand(commands_parser)
44 | RunAutoTrainToolsCommand.register_subcommand(commands_parser)
45 | RunAutoTrainTextRegressionCommand.register_subcommand(commands_parser)
46 | RunAutoTrainObjectDetectionCommand.register_subcommand(commands_parser)
47 | RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser)
48 | RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser)
49 | RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser)
50 |
51 | args = parser.parse_args()
52 |
53 | if args.version:
54 | print(__version__)
55 | exit(0)
56 |
57 | if args.config:
58 | logger.info(f"Using AutoTrain configuration: {args.config}")
59 | cp = AutoTrainConfigParser(args.config)
60 | cp.run()
61 | exit(0)
62 |
63 | if not hasattr(args, "func"):
64 | parser.print_help()
65 | exit(1)
66 |
67 | command = args.func(args)
68 | command.run()
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/src/autotrain/cli/run_api.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from . import BaseAutoTrainCommand
4 |
5 |
6 | def run_api_command_factory(args):
7 | return RunAutoTrainAPICommand(
8 | args.port,
9 | args.host,
10 | args.task,
11 | )
12 |
13 |
14 | class RunAutoTrainAPICommand(BaseAutoTrainCommand):
15 | """
16 | Command to run the AutoTrain API.
17 |
18 | This command sets up and runs the AutoTrain API using the specified host and port.
19 |
20 | Methods
21 | -------
22 | register_subcommand(parser: ArgumentParser)
23 | Registers the 'api' subcommand and its arguments to the provided parser.
24 |
25 | __init__(port: int, host: str, task: str)
26 | Initializes the command with the specified port, host, and task.
27 |
28 | run()
29 | Runs the AutoTrain API using the uvicorn server.
30 | """
31 |
32 | @staticmethod
33 | def register_subcommand(parser: ArgumentParser):
34 | run_api_parser = parser.add_parser(
35 | "api",
36 | description="✨ Run AutoTrain API",
37 | )
38 | run_api_parser.add_argument(
39 | "--port",
40 | type=int,
41 | default=7860,
42 | help="Port to run the api on",
43 | required=False,
44 | )
45 | run_api_parser.add_argument(
46 | "--host",
47 | type=str,
48 | default="127.0.0.1",
49 | help="Host to run the api on",
50 | required=False,
51 | )
52 | run_api_parser.add_argument(
53 | "--task",
54 | type=str,
55 | required=False,
56 | help="Task to run",
57 | )
58 | run_api_parser.set_defaults(func=run_api_command_factory)
59 |
60 | def __init__(self, port, host, task):
61 | self.port = port
62 | self.host = host
63 | self.task = task
64 |
65 | def run(self):
66 | import uvicorn
67 |
68 | from autotrain.app.training_api import api
69 |
70 | uvicorn.run(api, host=self.host, port=self.port)
71 |
--------------------------------------------------------------------------------
/src/autotrain/cli/run_seq2seq.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from autotrain import logger
4 | from autotrain.cli.utils import get_field_info
5 | from autotrain.project import AutoTrainProject
6 | from autotrain.trainers.seq2seq.params import Seq2SeqParams
7 |
8 | from . import BaseAutoTrainCommand
9 |
10 |
11 | def run_seq2seq_command_factory(args):
12 | return RunAutoTrainSeq2SeqCommand(args)
13 |
14 |
15 | class RunAutoTrainSeq2SeqCommand(BaseAutoTrainCommand):
16 | @staticmethod
17 | def register_subcommand(parser: ArgumentParser):
18 | arg_list = get_field_info(Seq2SeqParams)
19 | arg_list = [
20 | {
21 | "arg": "--train",
22 | "help": "Command to train the model",
23 | "required": False,
24 | "action": "store_true",
25 | },
26 | {
27 | "arg": "--deploy",
28 | "help": "Command to deploy the model (limited availability)",
29 | "required": False,
30 | "action": "store_true",
31 | },
32 | {
33 | "arg": "--inference",
34 | "help": "Command to run inference (limited availability)",
35 | "required": False,
36 | "action": "store_true",
37 | },
38 | {
39 | "arg": "--backend",
40 | "help": "Backend",
41 | "required": False,
42 | "type": str,
43 | "default": "local",
44 | },
45 | ] + arg_list
46 | run_seq2seq_parser = parser.add_parser("seq2seq", description="✨ Run AutoTrain Seq2Seq")
47 | for arg in arg_list:
48 | names = [arg["arg"]] + arg.get("alias", [])
49 | if "action" in arg:
50 | run_seq2seq_parser.add_argument(
51 | *names,
52 | dest=arg["arg"].replace("--", "").replace("-", "_"),
53 | help=arg["help"],
54 | required=arg.get("required", False),
55 | action=arg.get("action"),
56 | default=arg.get("default"),
57 | )
58 | else:
59 | run_seq2seq_parser.add_argument(
60 | *names,
61 | dest=arg["arg"].replace("--", "").replace("-", "_"),
62 | help=arg["help"],
63 | required=arg.get("required", False),
64 | type=arg.get("type"),
65 | default=arg.get("default"),
66 | choices=arg.get("choices"),
67 | )
68 | run_seq2seq_parser.set_defaults(func=run_seq2seq_command_factory)
69 |
70 | def __init__(self, args):
71 | self.args = args
72 |
73 | store_true_arg_names = ["train", "deploy", "inference", "auto_find_batch_size", "push_to_hub", "peft"]
74 | for arg_name in store_true_arg_names:
75 | if getattr(self.args, arg_name) is None:
76 | setattr(self.args, arg_name, False)
77 |
78 | if self.args.train:
79 | if self.args.project_name is None:
80 | raise ValueError("Project name must be specified")
81 | if self.args.data_path is None:
82 | raise ValueError("Data path must be specified")
83 | if self.args.model is None:
84 | raise ValueError("Model must be specified")
85 | if self.args.push_to_hub:
86 | if self.args.username is None:
87 | raise ValueError("Username must be specified for push to hub")
88 | else:
89 | raise ValueError("Must specify --train, --deploy or --inference")
90 |
91 | def run(self):
92 | logger.info("Running Seq2Seq Classification")
93 | if self.args.train:
94 | params = Seq2SeqParams(**vars(self.args))
95 | project = AutoTrainProject(params=params, backend=self.args.backend, process=True)
96 | job_id = project.create()
97 | logger.info(f"Job ID: {job_id}")
98 |
--------------------------------------------------------------------------------
/src/autotrain/cli/run_setup.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from argparse import ArgumentParser
3 |
4 | from autotrain import logger
5 |
6 | from . import BaseAutoTrainCommand
7 |
8 |
9 | def run_app_command_factory(args):
10 | return RunSetupCommand(args.update_torch, args.colab)
11 |
12 |
13 | class RunSetupCommand(BaseAutoTrainCommand):
14 | @staticmethod
15 | def register_subcommand(parser: ArgumentParser):
16 | run_setup_parser = parser.add_parser(
17 | "setup",
18 | description="✨ Run AutoTrain setup",
19 | )
20 | run_setup_parser.add_argument(
21 | "--update-torch",
22 | action="store_true",
23 | help="Update PyTorch to latest version",
24 | )
25 | run_setup_parser.add_argument(
26 | "--colab",
27 | action="store_true",
28 | help="Run setup for Google Colab",
29 | )
30 | run_setup_parser.set_defaults(func=run_app_command_factory)
31 |
32 | def __init__(self, update_torch: bool, colab: bool = False):
33 | self.update_torch = update_torch
34 | self.colab = colab
35 |
36 | def run(self):
37 | if self.colab:
38 | cmd = "pip install -U xformers==0.0.24"
39 | else:
40 | cmd = "pip uninstall -y xformers"
41 | cmd = cmd.split()
42 | pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
43 | logger.info("Installing latest xformers")
44 | _, _ = pipe.communicate()
45 | logger.info("Successfully installed latest xformers")
46 |
47 | if self.update_torch:
48 | cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121"
49 | cmd = cmd.split()
50 | pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
51 | logger.info("Installing latest PyTorch")
52 | _, _ = pipe.communicate()
53 | logger.info("Successfully installed latest PyTorch")
54 |
--------------------------------------------------------------------------------
/src/autotrain/cli/run_tools.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from . import BaseAutoTrainCommand
4 |
5 |
6 | def run_tools_command_factory(args):
7 | return RunAutoTrainToolsCommand(args)
8 |
9 |
10 | class RunAutoTrainToolsCommand(BaseAutoTrainCommand):
11 | @staticmethod
12 | def register_subcommand(parser: ArgumentParser):
13 | run_app_parser = parser.add_parser("tools", help="Run AutoTrain tools")
14 | subparsers = run_app_parser.add_subparsers(title="tools", dest="tool_name")
15 |
16 | merge_llm_parser = subparsers.add_parser(
17 | "merge-llm-adapter",
18 | help="Merge LLM Adapter tool",
19 | )
20 | merge_llm_parser.add_argument(
21 | "--base-model-path",
22 | type=str,
23 | help="Base model path",
24 | )
25 | merge_llm_parser.add_argument(
26 | "--adapter-path",
27 | type=str,
28 | help="Adapter path",
29 | )
30 | merge_llm_parser.add_argument(
31 | "--token",
32 | type=str,
33 | help="Token",
34 | default=None,
35 | required=False,
36 | )
37 | merge_llm_parser.add_argument(
38 | "--pad-to-multiple-of",
39 | type=int,
40 | help="Pad to multiple of",
41 | default=None,
42 | required=False,
43 | )
44 | merge_llm_parser.add_argument(
45 | "--output-folder",
46 | type=str,
47 | help="Output folder",
48 | required=False,
49 | default=None,
50 | )
51 | merge_llm_parser.add_argument(
52 | "--push-to-hub",
53 | action="store_true",
54 | help="Push to Hugging Face Hub",
55 | required=False,
56 | )
57 | merge_llm_parser.set_defaults(func=run_tools_command_factory, merge_llm_adapter=True)
58 |
59 | convert_to_kohya_parser = subparsers.add_parser("convert_to_kohya", help="Convert to Kohya tool")
60 | convert_to_kohya_parser.add_argument(
61 | "--input-path",
62 | type=str,
63 | help="Input path",
64 | )
65 | convert_to_kohya_parser.add_argument(
66 | "--output-path",
67 | type=str,
68 | help="Output path",
69 | )
70 | convert_to_kohya_parser.set_defaults(func=run_tools_command_factory, convert_to_kohya=True)
71 |
72 | def __init__(self, args):
73 | self.args = args
74 |
75 | def run(self):
76 | if getattr(self.args, "merge_llm_adapter", False):
77 | self.run_merge_llm_adapter()
78 | if getattr(self.args, "convert_to_kohya", False):
79 | self.run_convert_to_kohya()
80 |
81 | def run_merge_llm_adapter(self):
82 | from autotrain.tools.merge_adapter import merge_llm_adapter
83 |
84 | merge_llm_adapter(
85 | base_model_path=self.args.base_model_path,
86 | adapter_path=self.args.adapter_path,
87 | token=self.args.token,
88 | output_folder=self.args.output_folder,
89 | pad_to_multiple_of=self.args.pad_to_multiple_of,
90 | push_to_hub=self.args.push_to_hub,
91 | )
92 |
93 | def run_convert_to_kohya(self):
94 | from autotrain.tools.convert_to_kohya import convert_to_kohya
95 |
96 | convert_to_kohya(
97 | input_path=self.args.input_path,
98 | output_path=self.args.output_path,
99 | )
100 |
--------------------------------------------------------------------------------
/src/autotrain/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | HF_API = os.getenv("HF_API", "https://huggingface.co")
5 |
--------------------------------------------------------------------------------
/src/autotrain/logging.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from dataclasses import dataclass
3 |
4 | from loguru import logger
5 |
6 |
7 | IS_ACCELERATE_AVAILABLE = False
8 |
9 | try:
10 | from accelerate.state import PartialState
11 |
12 | IS_ACCELERATE_AVAILABLE = True
13 | except ImportError:
14 | pass
15 |
16 |
17 | @dataclass
18 | class Logger:
19 | """
20 | A custom logger class that sets up and manages logging configuration.
21 |
22 | Methods
23 | -------
24 | __post_init__():
25 | Initializes the logger with a specific format and sets up the logger.
26 |
27 | _should_log(record):
28 | Determines if a log record should be logged based on the process state.
29 |
30 | setup_logger():
31 | Configures the logger to output to stdout with the specified format and filter.
32 |
33 | get_logger():
34 | Returns the configured logger instance.
35 | """
36 |
37 | def __post_init__(self):
38 | self.log_format = (
39 | "{level: <8} | "
40 | "{time:YYYY-MM-DD HH:mm:ss} | "
41 | "{name}:{function}:{line} - "
42 | "{message}"
43 | )
44 | self.logger = logger
45 | self.setup_logger()
46 |
47 | def _should_log(self, record):
48 | if not IS_ACCELERATE_AVAILABLE:
49 | return None
50 | return PartialState().is_main_process
51 |
52 | def setup_logger(self):
53 | self.logger.remove()
54 | self.logger.add(
55 | sys.stdout,
56 | format=self.log_format,
57 | filter=lambda x: self._should_log(x) if IS_ACCELERATE_AVAILABLE else None,
58 | )
59 |
60 | def get_logger(self):
61 | return self.logger
62 |
--------------------------------------------------------------------------------
/src/autotrain/params.py:
--------------------------------------------------------------------------------
1 | from autotrain.trainers.clm.params import LLMTrainingParams
2 | from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
3 | from autotrain.trainers.image_classification.params import ImageClassificationParams
4 | from autotrain.trainers.image_regression.params import ImageRegressionParams
5 | from autotrain.trainers.object_detection.params import ObjectDetectionParams
6 | from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
7 | from autotrain.trainers.seq2seq.params import Seq2SeqParams
8 | from autotrain.trainers.tabular.params import TabularParams
9 | from autotrain.trainers.text_classification.params import TextClassificationParams
10 | from autotrain.trainers.text_regression.params import TextRegressionParams
11 | from autotrain.trainers.token_classification.params import TokenClassificationParams
12 | from autotrain.trainers.vlm.params import VLMTrainingParams
13 |
--------------------------------------------------------------------------------
/src/autotrain/preprocessor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/preprocessor/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/tasks.py:
--------------------------------------------------------------------------------
1 | NLP_TASKS = {
2 | "text_binary_classification": 1,
3 | "text_multi_class_classification": 2,
4 | "text_token_classification": 4,
5 | "text_extractive_question_answering": 5,
6 | "text_summarization": 8,
7 | "text_single_column_regression": 10,
8 | "speech_recognition": 11,
9 | "natural_language_inference": 22,
10 | "lm_training": 9,
11 | "seq2seq": 28, # 27 is reserved for generic training
12 | "sentence_transformers": 30,
13 | "vlm": 31,
14 | }
15 |
16 | VISION_TASKS = {
17 | "image_binary_classification": 17,
18 | "image_multi_class_classification": 18,
19 | "image_single_column_regression": 24,
20 | "image_object_detection": 29,
21 | }
22 |
23 | TABULAR_TASKS = {
24 | "tabular_binary_classification": 13,
25 | "tabular_multi_class_classification": 14,
26 | "tabular_multi_label_classification": 15,
27 | "tabular_single_column_regression": 16,
28 | "tabular": 26,
29 | }
30 |
31 |
32 | TASKS = {
33 | **NLP_TASKS,
34 | **VISION_TASKS,
35 | **TABULAR_TASKS,
36 | }
37 |
--------------------------------------------------------------------------------
/src/autotrain/tests/test_cli.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/tests/test_cli.py
--------------------------------------------------------------------------------
/src/autotrain/tests/test_dummy.py:
--------------------------------------------------------------------------------
1 | def test_dummy():
2 | assert 1 + 1 == 2
3 |
--------------------------------------------------------------------------------
/src/autotrain/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/tools/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/tools/convert_to_kohya.py:
--------------------------------------------------------------------------------
1 | from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
2 | from safetensors.torch import load_file, save_file
3 |
4 | from autotrain import logger
5 |
6 |
7 | def convert_to_kohya(input_path, output_path):
8 | """
9 | Converts a Lora state dictionary to a Kohya state dictionary and saves it to the specified output path.
10 |
11 | Args:
12 | input_path (str): The file path to the input Lora state dictionary.
13 | output_path (str): The file path where the converted Kohya state dictionary will be saved.
14 |
15 | Returns:
16 | None
17 | """
18 | logger.info(f"Converting Lora state dict from {input_path} to Kohya state dict at {output_path}")
19 | lora_state_dict = load_file(input_path)
20 | peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
21 | kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
22 | save_file(kohya_state_dict, output_path)
23 | logger.info(f"Kohya state dict saved at {output_path}")
24 |
--------------------------------------------------------------------------------
/src/autotrain/tools/merge_adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from peft import PeftModel
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 |
5 | from autotrain import logger
6 | from autotrain.trainers.common import ALLOW_REMOTE_CODE
7 |
8 |
9 | def merge_llm_adapter(
10 | base_model_path, adapter_path, token, output_folder=None, pad_to_multiple_of=None, push_to_hub=False
11 | ):
12 | """
13 | Merges a language model adapter into a base model and optionally saves or pushes the merged model.
14 |
15 | Args:
16 | base_model_path (str): Path to the base model.
17 | adapter_path (str): Path to the adapter model.
18 | token (str): Authentication token for accessing the models.
19 | output_folder (str, optional): Directory to save the merged model. Defaults to None.
20 | pad_to_multiple_of (int, optional): If specified, pad the token embeddings to a multiple of this value. Defaults to None.
21 | push_to_hub (bool, optional): If True, push the merged model to the Hugging Face Hub. Defaults to False.
22 |
23 | Raises:
24 | ValueError: If neither `output_folder` nor `push_to_hub` is specified.
25 |
26 | Returns:
27 | None
28 | """
29 | if output_folder is None and push_to_hub is False:
30 | raise ValueError("You must specify either --output_folder or --push_to_hub")
31 |
32 | logger.info("Loading adapter...")
33 | base_model = AutoModelForCausalLM.from_pretrained(
34 | base_model_path,
35 | torch_dtype=torch.float16,
36 | low_cpu_mem_usage=True,
37 | trust_remote_code=ALLOW_REMOTE_CODE,
38 | token=token,
39 | )
40 |
41 | tokenizer = AutoTokenizer.from_pretrained(
42 | adapter_path,
43 | trust_remote_code=ALLOW_REMOTE_CODE,
44 | token=token,
45 | )
46 | if pad_to_multiple_of:
47 | base_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=pad_to_multiple_of)
48 | else:
49 | base_model.resize_token_embeddings(len(tokenizer))
50 |
51 | model = PeftModel.from_pretrained(
52 | base_model,
53 | adapter_path,
54 | token=token,
55 | )
56 | model = model.merge_and_unload()
57 |
58 | if output_folder is not None:
59 | logger.info("Saving target model...")
60 | model.save_pretrained(output_folder)
61 | tokenizer.save_pretrained(output_folder)
62 | logger.info(f"Model saved to {output_folder}")
63 |
64 | if push_to_hub:
65 | logger.info("Pushing model to Hugging Face Hub...")
66 | model.push_to_hub(adapter_path)
67 | tokenizer.push_to_hub(adapter_path)
68 | logger.info(f"Model pushed to Hugging Face Hub as {adapter_path}")
69 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/clm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/clm/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/clm/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | from autotrain.trainers.clm.params import LLMTrainingParams
5 | from autotrain.trainers.common import monitor
6 |
7 |
8 | def parse_args():
9 | # get training_config.json from the end user
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--training_config", type=str, required=True)
12 | return parser.parse_args()
13 |
14 |
15 | @monitor
16 | def train(config):
17 | if isinstance(config, dict):
18 | config = LLMTrainingParams(**config)
19 |
20 | if config.trainer == "default":
21 | from autotrain.trainers.clm.train_clm_default import train as train_default
22 |
23 | train_default(config)
24 |
25 | elif config.trainer == "sft":
26 | from autotrain.trainers.clm.train_clm_sft import train as train_sft
27 |
28 | train_sft(config)
29 |
30 | elif config.trainer == "reward":
31 | from autotrain.trainers.clm.train_clm_reward import train as train_reward
32 |
33 | train_reward(config)
34 |
35 | elif config.trainer == "dpo":
36 | from autotrain.trainers.clm.train_clm_dpo import train as train_dpo
37 |
38 | train_dpo(config)
39 |
40 | elif config.trainer == "orpo":
41 | from autotrain.trainers.clm.train_clm_orpo import train as train_orpo
42 |
43 | train_orpo(config)
44 |
45 | else:
46 | raise ValueError(f"trainer `{config.trainer}` not supported")
47 |
48 |
49 | if __name__ == "__main__":
50 | _args = parse_args()
51 | training_config = json.load(open(_args.training_config))
52 | _config = LLMTrainingParams(**training_config)
53 | train(_config)
54 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/clm/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from peft import set_peft_model_state_dict
5 | from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
6 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
7 |
8 |
9 | class SavePeftModelCallback(TrainerCallback):
10 | def on_save(
11 | self,
12 | args: TrainingArguments,
13 | state: TrainerState,
14 | control: TrainerControl,
15 | **kwargs,
16 | ):
17 | checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
18 |
19 | kwargs["model"].save_pretrained(checkpoint_folder)
20 |
21 | pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
22 | torch.save({}, pytorch_model_path)
23 | return control
24 |
25 |
26 | class LoadBestPeftModelCallback(TrainerCallback):
27 | def on_train_end(
28 | self,
29 | args: TrainingArguments,
30 | state: TrainerState,
31 | control: TrainerControl,
32 | **kwargs,
33 | ):
34 | print(f"Loading best peft model from {state.best_model_checkpoint} (score: {state.best_metric}).")
35 | best_model_path = os.path.join(state.best_model_checkpoint, "adapter_model.bin")
36 | adapters_weights = torch.load(best_model_path)
37 | model = kwargs["model"]
38 | set_peft_model_state_dict(model, adapters_weights)
39 | return control
40 |
41 |
42 | class SaveDeepSpeedPeftModelCallback(TrainerCallback):
43 | def __init__(self, trainer, save_steps=500):
44 | self.trainer = trainer
45 | self.save_steps = save_steps
46 |
47 | def on_step_end(
48 | self,
49 | args: TrainingArguments,
50 | state: TrainerState,
51 | control: TrainerControl,
52 | **kwargs,
53 | ):
54 | if (state.global_step + 1) % self.save_steps == 0:
55 | self.trainer.accelerator.wait_for_everyone()
56 | state_dict = self.trainer.accelerator.get_state_dict(self.trainer.deepspeed)
57 | unwrapped_model = self.trainer.accelerator.unwrap_model(self.trainer.deepspeed)
58 | if self.trainer.accelerator.is_main_process:
59 | unwrapped_model.save_pretrained(args.output_dir, state_dict=state_dict)
60 | self.trainer.accelerator.wait_for_everyone()
61 | return control
62 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/clm/train_clm_orpo.py:
--------------------------------------------------------------------------------
1 | from peft import LoraConfig
2 | from transformers.trainer_callback import PrinterCallback
3 | from trl import ORPOConfig, ORPOTrainer
4 |
5 | from autotrain import logger
6 | from autotrain.trainers.clm import utils
7 | from autotrain.trainers.clm.params import LLMTrainingParams
8 |
9 |
10 | def train(config):
11 | logger.info("Starting ORPO training...")
12 | if isinstance(config, dict):
13 | config = LLMTrainingParams(**config)
14 | train_data, valid_data = utils.process_input_data(config)
15 | tokenizer = utils.get_tokenizer(config)
16 | train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)
17 |
18 | logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
19 | training_args = utils.configure_training_args(config, logging_steps)
20 | config = utils.configure_block_size(config, tokenizer)
21 |
22 | training_args["max_length"] = config.block_size
23 | training_args["max_prompt_length"] = config.max_prompt_length
24 | training_args["max_completion_length"] = config.max_completion_length
25 | args = ORPOConfig(**training_args)
26 |
27 | model = utils.get_model(config, tokenizer)
28 |
29 | if config.peft:
30 | peft_config = LoraConfig(
31 | r=config.lora_r,
32 | lora_alpha=config.lora_alpha,
33 | lora_dropout=config.lora_dropout,
34 | bias="none",
35 | task_type="CAUSAL_LM",
36 | target_modules=utils.get_target_modules(config),
37 | )
38 |
39 | logger.info("creating trainer")
40 | callbacks = utils.get_callbacks(config)
41 | trainer_args = dict(
42 | args=args,
43 | model=model,
44 | callbacks=callbacks,
45 | )
46 |
47 | trainer = ORPOTrainer(
48 | **trainer_args,
49 | train_dataset=train_data,
50 | eval_dataset=valid_data if config.valid_split is not None else None,
51 | processing_class=tokenizer,
52 | peft_config=peft_config if config.peft else None,
53 | )
54 |
55 | trainer.remove_callback(PrinterCallback)
56 | trainer.train()
57 | utils.post_training_steps(config, trainer)
58 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/clm/train_clm_sft.py:
--------------------------------------------------------------------------------
1 | from peft import LoraConfig
2 | from transformers.trainer_callback import PrinterCallback
3 | from trl import SFTConfig, SFTTrainer
4 |
5 | from autotrain import logger
6 | from autotrain.trainers.clm import utils
7 | from autotrain.trainers.clm.params import LLMTrainingParams
8 |
9 |
10 | def train(config):
11 | logger.info("Starting SFT training...")
12 | if isinstance(config, dict):
13 | config = LLMTrainingParams(**config)
14 | train_data, valid_data = utils.process_input_data(config)
15 | tokenizer = utils.get_tokenizer(config)
16 | train_data, valid_data = utils.process_data_with_chat_template(config, tokenizer, train_data, valid_data)
17 |
18 | logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
19 | training_args = utils.configure_training_args(config, logging_steps)
20 | config = utils.configure_block_size(config, tokenizer)
21 |
22 | training_args["dataset_text_field"] = config.text_column
23 | training_args["max_seq_length"] = config.block_size
24 | training_args["packing"] = True
25 | args = SFTConfig(**training_args)
26 |
27 | model = utils.get_model(config, tokenizer)
28 |
29 | if config.peft:
30 | peft_config = LoraConfig(
31 | r=config.lora_r,
32 | lora_alpha=config.lora_alpha,
33 | lora_dropout=config.lora_dropout,
34 | bias="none",
35 | task_type="CAUSAL_LM",
36 | target_modules=utils.get_target_modules(config),
37 | )
38 |
39 | logger.info("creating trainer")
40 | callbacks = utils.get_callbacks(config)
41 | trainer_args = dict(
42 | args=args,
43 | model=model,
44 | callbacks=callbacks,
45 | )
46 | trainer = SFTTrainer(
47 | **trainer_args,
48 | train_dataset=train_data,
49 | eval_dataset=valid_data if config.valid_split is not None else None,
50 | peft_config=peft_config if config.peft else None,
51 | processing_class=tokenizer,
52 | )
53 |
54 | trainer.remove_callback(PrinterCallback)
55 | trainer.train()
56 | utils.post_training_steps(config, trainer)
57 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/extractive_question_answering/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/extractive_question_answering/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/generic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/generic/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/generic/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | from autotrain import logger
5 | from autotrain.trainers.common import monitor, pause_space
6 | from autotrain.trainers.generic import utils
7 | from autotrain.trainers.generic.params import GenericParams
8 |
9 |
10 | def parse_args():
11 | # get training_config.json from the end user
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--config", type=str, required=True)
14 | return parser.parse_args()
15 |
16 |
17 | @monitor
18 | def run(config):
19 | """
20 | Executes a series of operations based on the provided configuration.
21 |
22 | This function performs the following steps:
23 | 1. Converts the configuration dictionary to a GenericParams object if necessary.
24 | 2. Downloads the data repository specified in the configuration.
25 | 3. Uninstalls any existing requirements specified in the configuration.
26 | 4. Installs the necessary requirements specified in the configuration.
27 | 5. Runs a command specified in the configuration.
28 | 6. Pauses the space as specified in the configuration.
29 |
30 | Args:
31 | config (dict or GenericParams): The configuration for the operations to be performed.
32 | """
33 | if isinstance(config, dict):
34 | config = GenericParams(**config)
35 |
36 | # download the data repo
37 | logger.info("Downloading data repo...")
38 | utils.pull_dataset_repo(config)
39 |
40 | logger.info("Unintalling requirements...")
41 | utils.uninstall_requirements(config)
42 |
43 | # install the requirements
44 | logger.info("Installing requirements...")
45 | utils.install_requirements(config)
46 |
47 | # run the command
48 | logger.info("Running command...")
49 | utils.run_command(config)
50 |
51 | pause_space(config)
52 |
53 |
54 | if __name__ == "__main__":
55 | args = parse_args()
56 | _config = json.load(open(args.config))
57 | _config = GenericParams(**_config)
58 | run(_config)
59 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/generic/params.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | from pydantic import Field
4 |
5 | from autotrain.trainers.common import AutoTrainParams
6 |
7 |
8 | class GenericParams(AutoTrainParams):
9 | """
10 | GenericParams is a class that holds configuration parameters for an AutoTrain SpaceRunner project.
11 |
12 | Attributes:
13 | username (str): The username for your Hugging Face account.
14 | project_name (str): The name of the project.
15 | data_path (str): The file path to the dataset.
16 | token (str): The authentication token for accessing Hugging Face Hub.
17 | script_path (str): The file path to the script to be executed. Path to script.py.
18 | env (Optional[Dict[str, str]]): A dictionary of environment variables to be set.
19 | args (Optional[Dict[str, str]]): A dictionary of arguments to be passed to the script.
20 | """
21 |
22 | username: str = Field(
23 | None, title="Hugging Face Username", description="The username for your Hugging Face account."
24 | )
25 | project_name: str = Field("project-name", title="Project Name", description="The name of the project.")
26 | data_path: str = Field(None, title="Data Path", description="The file path to the dataset.")
27 | token: str = Field(None, title="Hub Token", description="The authentication token for accessing Hugging Face Hub.")
28 | script_path: str = Field(
29 | None, title="Script Path", description="The file path to the script to be executed. Path to script.py"
30 | )
31 | env: Optional[Dict[str, str]] = Field(
32 | None, title="Environment Variables", description="A dictionary of environment variables to be set."
33 | )
34 | args: Optional[Dict[str, str]] = Field(
35 | None, title="Arguments", description="A dictionary of arguments to be passed to the script."
36 | )
37 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/image_classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/image_classification/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/image_classification/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class ImageClassificationDataset:
6 | """
7 | A custom dataset class for image classification tasks.
8 |
9 | Args:
10 | data (list): A list of data samples, where each sample is a dictionary containing image and target information.
11 | transforms (callable): A function/transform that takes in an image and returns a transformed version.
12 | config (object): A configuration object containing the column names for images and targets.
13 |
14 | Attributes:
15 | data (list): The dataset containing image and target information.
16 | transforms (callable): The transformation function to be applied to the images.
17 | config (object): The configuration object with image and target column names.
18 |
19 | Methods:
20 | __len__(): Returns the number of samples in the dataset.
21 | __getitem__(item): Retrieves the image and target at the specified index, applies transformations, and returns them as tensors.
22 |
23 | Example:
24 | dataset = ImageClassificationDataset(data, transforms, config)
25 | image, target = dataset[0]
26 | """
27 |
28 | def __init__(self, data, transforms, config):
29 | self.data = data
30 | self.transforms = transforms
31 | self.config = config
32 |
33 | def __len__(self):
34 | return len(self.data)
35 |
36 | def __getitem__(self, item):
37 | image = self.data[item][self.config.image_column]
38 | target = int(self.data[item][self.config.target_column])
39 |
40 | image = self.transforms(image=np.array(image.convert("RGB")))["image"]
41 | image = np.transpose(image, (2, 0, 1)).astype(np.float32)
42 |
43 | return {
44 | "pixel_values": torch.tensor(image, dtype=torch.float),
45 | "labels": torch.tensor(target, dtype=torch.long),
46 | }
47 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/image_regression/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/image_regression/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/image_regression/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class ImageRegressionDataset:
6 | """
7 | A dataset class for image regression tasks.
8 |
9 | Args:
10 | data (list): A list of data points where each data point is a dictionary containing image and target information.
11 | transforms (callable): A function/transform that takes in an image and returns a transformed version.
12 | config (object): A configuration object that contains the column names for images and targets.
13 |
14 | Attributes:
15 | data (list): The input data.
16 | transforms (callable): The transformation function.
17 | config (object): The configuration object.
18 |
19 | Methods:
20 | __len__(): Returns the number of data points in the dataset.
21 | __getitem__(item): Returns a dictionary containing the transformed image and the target value for the given index.
22 | """
23 |
24 | def __init__(self, data, transforms, config):
25 | self.data = data
26 | self.transforms = transforms
27 | self.config = config
28 |
29 | def __len__(self):
30 | return len(self.data)
31 |
32 | def __getitem__(self, item):
33 | image = self.data[item][self.config.image_column]
34 | target = self.data[item][self.config.target_column]
35 |
36 | image = self.transforms(image=np.array(image.convert("RGB")))["image"]
37 | image = np.transpose(image, (2, 0, 1)).astype(np.float32)
38 |
39 | return {
40 | "pixel_values": torch.tensor(image, dtype=torch.float),
41 | "labels": torch.tensor(target, dtype=torch.float),
42 | }
43 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/object_detection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/object_detection/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/object_detection/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class ObjectDetectionDataset:
5 | """
6 | A dataset class for object detection tasks.
7 |
8 | Args:
9 | data (list): A list of data entries where each entry is a dictionary containing image and object information.
10 | transforms (callable): A function or transform to apply to the images and bounding boxes.
11 | image_processor (callable): A function or processor to convert images and annotations into the desired format.
12 | config (object): A configuration object containing column names for images and objects.
13 |
14 | Attributes:
15 | data (list): The dataset containing image and object information.
16 | transforms (callable): The transform function to apply to the images and bounding boxes.
17 | image_processor (callable): The processor to convert images and annotations into the desired format.
18 | config (object): The configuration object with column names for images and objects.
19 |
20 | Methods:
21 | __len__(): Returns the number of items in the dataset.
22 | __getitem__(item): Retrieves and processes the image and annotations for the given index.
23 |
24 | Example:
25 | dataset = ObjectDetectionDataset(data, transforms, image_processor, config)
26 | image_data = dataset[0]
27 | """
28 |
29 | def __init__(self, data, transforms, image_processor, config):
30 | self.data = data
31 | self.transforms = transforms
32 | self.image_processor = image_processor
33 | self.config = config
34 |
35 | def __len__(self):
36 | return len(self.data)
37 |
38 | def __getitem__(self, item):
39 | image = self.data[item][self.config.image_column]
40 | objects = self.data[item][self.config.objects_column]
41 | output = self.transforms(
42 | image=np.array(image.convert("RGB")), bboxes=objects["bbox"], category=objects["category"]
43 | )
44 | image = output["image"]
45 | annotations = []
46 | for j in range(len(output["bboxes"])):
47 | annotations.append(
48 | {
49 | "image_id": str(item),
50 | "category_id": output["category"][j],
51 | "iscrowd": 0,
52 | "area": objects["bbox"][j][2] * objects["bbox"][j][3], # [x, y, w, h
53 | "bbox": output["bboxes"][j],
54 | }
55 | )
56 | annotations = {"annotations": annotations, "image_id": str(item)}
57 | result = self.image_processor(images=image, annotations=annotations, return_tensors="pt")
58 | result["pixel_values"] = result["pixel_values"][0]
59 | result["labels"] = result["labels"][0]
60 | return result
61 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/sent_transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/sent_transformers/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/seq2seq/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/seq2seq/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/seq2seq/dataset.py:
--------------------------------------------------------------------------------
1 | class Seq2SeqDataset:
2 | """
3 | A dataset class for sequence-to-sequence tasks.
4 |
5 | Args:
6 | data (list): The dataset containing input and target sequences.
7 | tokenizer (PreTrainedTokenizer): The tokenizer to process the text data.
8 | config (object): Configuration object containing dataset parameters.
9 |
10 | Attributes:
11 | data (list): The dataset containing input and target sequences.
12 | tokenizer (PreTrainedTokenizer): The tokenizer to process the text data.
13 | config (object): Configuration object containing dataset parameters.
14 | max_len_input (int): Maximum length for input sequences.
15 | max_len_target (int): Maximum length for target sequences.
16 |
17 | Methods:
18 | __len__(): Returns the number of samples in the dataset.
19 | __getitem__(item): Returns the tokenized input and target sequences for a given index.
20 | """
21 |
22 | def __init__(self, data, tokenizer, config):
23 | self.data = data
24 | self.tokenizer = tokenizer
25 | self.config = config
26 | self.max_len_input = self.config.max_seq_length
27 | self.max_len_target = self.config.max_target_length
28 |
29 | def __len__(self):
30 | return len(self.data)
31 |
32 | def __getitem__(self, item):
33 | text = str(self.data[item][self.config.text_column])
34 | target = str(self.data[item][self.config.target_column])
35 |
36 | model_inputs = self.tokenizer(text, max_length=self.max_len_input, truncation=True)
37 |
38 | labels = self.tokenizer(text_target=target, max_length=self.max_len_target, truncation=True)
39 |
40 | model_inputs["labels"] = labels["input_ids"]
41 | return model_inputs
42 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/seq2seq/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import evaluate
4 | import nltk
5 | import numpy as np
6 |
7 |
8 | ROUGE_METRIC = evaluate.load("rouge")
9 |
10 | MODEL_CARD = """
11 | ---
12 | library_name: transformers
13 | tags:
14 | - autotrain
15 | - text2text-generation{base_model}
16 | widget:
17 | - text: "I love AutoTrain"{dataset_tag}
18 | ---
19 |
20 | # Model Trained Using AutoTrain
21 |
22 | - Problem type: Seq2Seq
23 |
24 | ## Validation Metrics
25 | {validation_metrics}
26 | """
27 |
28 |
29 | def _seq2seq_metrics(pred, tokenizer):
30 | """
31 | Compute sequence-to-sequence metrics for predictions and labels.
32 |
33 | Args:
34 | pred (tuple): A tuple containing predictions and labels.
35 | Predictions and labels are expected to be token IDs.
36 | tokenizer (PreTrainedTokenizer): The tokenizer used for decoding the predictions and labels.
37 |
38 | Returns:
39 | dict: A dictionary containing the computed ROUGE metrics and the average length of the generated sequences.
40 | The keys are the metric names and the values are the corresponding scores rounded to four decimal places.
41 | """
42 | predictions, labels = pred
43 | decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
44 |
45 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
46 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
47 |
48 | decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
49 | decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
50 |
51 | result = ROUGE_METRIC.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
52 | result = {key: value * 100 for key, value in result.items()}
53 |
54 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
55 | result["gen_len"] = np.mean(prediction_lens)
56 |
57 | return {k: round(v, 4) for k, v in result.items()}
58 |
59 |
60 | def create_model_card(config, trainer):
61 | """
62 | Generates a model card string based on the provided configuration and trainer.
63 |
64 | Args:
65 | config (object): Configuration object containing the following attributes:
66 | - valid_split (optional): If not None, the function will include evaluation scores.
67 | - data_path (str): Path to the dataset.
68 | - project_name (str): Name of the project.
69 | - model (str): Path or identifier of the model.
70 | trainer (object): Trainer object with an `evaluate` method that returns evaluation metrics.
71 |
72 | Returns:
73 | str: A formatted model card string containing dataset information, validation metrics, and base model details.
74 | """
75 | if config.valid_split is not None:
76 | eval_scores = trainer.evaluate()
77 | eval_scores = [f"{k[len('eval_'):]}: {v}" for k, v in eval_scores.items()]
78 | eval_scores = "\n\n".join(eval_scores)
79 |
80 | else:
81 | eval_scores = "No validation metrics available"
82 |
83 | if config.data_path == f"{config.project_name}/autotrain-data" or os.path.isdir(config.data_path):
84 | dataset_tag = ""
85 | else:
86 | dataset_tag = f"\ndatasets:\n- {config.data_path}"
87 |
88 | if os.path.isdir(config.model):
89 | base_model = ""
90 | else:
91 | base_model = f"\nbase_model: {config.model}"
92 |
93 | model_card = MODEL_CARD.format(
94 | dataset_tag=dataset_tag,
95 | validation_metrics=eval_scores,
96 | base_model=base_model,
97 | )
98 | return model_card
99 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/tabular/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/tabular/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/tabular/params.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | from pydantic import Field
4 |
5 | from autotrain.trainers.common import AutoTrainParams
6 |
7 |
8 | class TabularParams(AutoTrainParams):
9 | """
10 | TabularParams is a configuration class for tabular data training parameters.
11 |
12 | Attributes:
13 | data_path (str): Path to the dataset.
14 | model (str): Name of the model to use. Default is "xgboost".
15 | username (Optional[str]): Hugging Face Username.
16 | seed (int): Random seed for reproducibility. Default is 42.
17 | train_split (str): Name of the training data split. Default is "train".
18 | valid_split (Optional[str]): Name of the validation data split.
19 | project_name (str): Name of the output directory. Default is "project-name".
20 | token (Optional[str]): Hub Token for authentication.
21 | push_to_hub (bool): Whether to push the model to the hub. Default is False.
22 | id_column (str): Name of the ID column. Default is "id".
23 | target_columns (Union[List[str], str]): Target column(s) in the dataset. Default is ["target"].
24 | categorical_columns (Optional[List[str]]): List of categorical columns.
25 | numerical_columns (Optional[List[str]]): List of numerical columns.
26 | task (str): Type of task (e.g., "classification"). Default is "classification".
27 | num_trials (int): Number of trials for hyperparameter optimization. Default is 10.
28 | time_limit (int): Time limit for training in seconds. Default is 600.
29 | categorical_imputer (Optional[str]): Imputer strategy for categorical columns.
30 | numerical_imputer (Optional[str]): Imputer strategy for numerical columns.
31 | numeric_scaler (Optional[str]): Scaler strategy for numerical columns.
32 | """
33 |
34 | data_path: str = Field(None, title="Data path")
35 | model: str = Field("xgboost", title="Model name")
36 | username: Optional[str] = Field(None, title="Hugging Face Username")
37 | seed: int = Field(42, title="Seed")
38 | train_split: str = Field("train", title="Train split")
39 | valid_split: Optional[str] = Field(None, title="Validation split")
40 | project_name: str = Field("project-name", title="Output directory")
41 | token: Optional[str] = Field(None, title="Hub Token")
42 | push_to_hub: bool = Field(False, title="Push to hub")
43 | id_column: str = Field("id", title="ID column")
44 | target_columns: Union[List[str], str] = Field(["target"], title="Target column(s)")
45 | categorical_columns: Optional[List[str]] = Field(None, title="Categorical columns")
46 | numerical_columns: Optional[List[str]] = Field(None, title="Numerical columns")
47 | task: str = Field("classification", title="Task")
48 | num_trials: int = Field(10, title="Number of trials")
49 | time_limit: int = Field(600, title="Time limit")
50 | categorical_imputer: Optional[str] = Field(None, title="Categorical imputer")
51 | numerical_imputer: Optional[str] = Field(None, title="Numerical imputer")
52 | numeric_scaler: Optional[str] = Field(None, title="Numeric scaler")
53 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/text_classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/text_classification/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/text_classification/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TextClassificationDataset:
5 | """
6 | A dataset class for text classification tasks.
7 |
8 | Args:
9 | data (list): The dataset containing text and target columns.
10 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data.
11 | config (object): Configuration object containing dataset parameters.
12 |
13 | Attributes:
14 | data (list): The dataset containing text and target columns.
15 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data.
16 | config (object): Configuration object containing dataset parameters.
17 | text_column (str): The name of the column containing text data.
18 | target_column (str): The name of the column containing target labels.
19 |
20 | Methods:
21 | __len__(): Returns the number of samples in the dataset.
22 | __getitem__(item): Returns a dictionary containing tokenized input ids, attention mask, token type ids (if available), and target labels for the given item index.
23 | """
24 |
25 | def __init__(self, data, tokenizer, config):
26 | self.data = data
27 | self.tokenizer = tokenizer
28 | self.config = config
29 | self.text_column = self.config.text_column
30 | self.target_column = self.config.target_column
31 |
32 | def __len__(self):
33 | return len(self.data)
34 |
35 | def __getitem__(self, item):
36 | text = str(self.data[item][self.text_column])
37 | target = self.data[item][self.target_column]
38 | target = int(target)
39 | inputs = self.tokenizer(
40 | text,
41 | max_length=self.config.max_seq_length,
42 | padding="max_length",
43 | truncation=True,
44 | )
45 |
46 | ids = inputs["input_ids"]
47 | mask = inputs["attention_mask"]
48 |
49 | if "token_type_ids" in inputs:
50 | token_type_ids = inputs["token_type_ids"]
51 | else:
52 | token_type_ids = None
53 |
54 | if token_type_ids is not None:
55 | return {
56 | "input_ids": torch.tensor(ids, dtype=torch.long),
57 | "attention_mask": torch.tensor(mask, dtype=torch.long),
58 | "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
59 | "labels": torch.tensor(target, dtype=torch.long),
60 | }
61 | return {
62 | "input_ids": torch.tensor(ids, dtype=torch.long),
63 | "attention_mask": torch.tensor(mask, dtype=torch.long),
64 | "labels": torch.tensor(target, dtype=torch.long),
65 | }
66 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/text_regression/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/text_regression/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/text_regression/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TextRegressionDataset:
5 | """
6 | A custom dataset class for text regression tasks for AutoTrain.
7 |
8 | Args:
9 | data (list of dict): The dataset containing text and target values.
10 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data.
11 | config (object): Configuration object containing dataset parameters.
12 |
13 | Attributes:
14 | data (list of dict): The dataset containing text and target values.
15 | tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data.
16 | config (object): Configuration object containing dataset parameters.
17 | text_column (str): The column name for text data in the dataset.
18 | target_column (str): The column name for target values in the dataset.
19 | max_len (int): The maximum sequence length for tokenized inputs.
20 |
21 | Methods:
22 | __len__(): Returns the number of samples in the dataset.
23 | __getitem__(item): Returns a dictionary containing tokenized inputs and target value for a given index.
24 | """
25 |
26 | def __init__(self, data, tokenizer, config):
27 | self.data = data
28 | self.tokenizer = tokenizer
29 | self.config = config
30 | self.text_column = self.config.text_column
31 | self.target_column = self.config.target_column
32 | self.max_len = self.config.max_seq_length
33 |
34 | def __len__(self):
35 | return len(self.data)
36 |
37 | def __getitem__(self, item):
38 | text = str(self.data[item][self.text_column])
39 | target = float(self.data[item][self.target_column])
40 | inputs = self.tokenizer(
41 | text,
42 | max_length=self.max_len,
43 | padding="max_length",
44 | truncation=True,
45 | )
46 |
47 | ids = inputs["input_ids"]
48 | mask = inputs["attention_mask"]
49 |
50 | if "token_type_ids" in inputs:
51 | token_type_ids = inputs["token_type_ids"]
52 | else:
53 | token_type_ids = None
54 |
55 | if token_type_ids is not None:
56 | return {
57 | "input_ids": torch.tensor(ids, dtype=torch.long),
58 | "attention_mask": torch.tensor(mask, dtype=torch.long),
59 | "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
60 | "labels": torch.tensor(target, dtype=torch.float),
61 | }
62 | return {
63 | "input_ids": torch.tensor(ids, dtype=torch.long),
64 | "attention_mask": torch.tensor(mask, dtype=torch.long),
65 | "labels": torch.tensor(target, dtype=torch.float),
66 | }
67 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/token_classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/token_classification/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/token_classification/dataset.py:
--------------------------------------------------------------------------------
1 | class TokenClassificationDataset:
2 | """
3 | A dataset class for token classification tasks.
4 |
5 | Args:
6 | data (Dataset): The dataset containing the text and tags.
7 | tokenizer (PreTrainedTokenizer): The tokenizer to be used for tokenizing the text.
8 | config (Config): Configuration object containing necessary parameters.
9 |
10 | Attributes:
11 | data (Dataset): The dataset containing the text and tags.
12 | tokenizer (PreTrainedTokenizer): The tokenizer to be used for tokenizing the text.
13 | config (Config): Configuration object containing necessary parameters.
14 |
15 | Methods:
16 | __len__():
17 | Returns the number of samples in the dataset.
18 |
19 | __getitem__(item):
20 | Retrieves a tokenized sample and its corresponding labels.
21 |
22 | Args:
23 | item (int): The index of the sample to retrieve.
24 |
25 | Returns:
26 | dict: A dictionary containing tokenized text and corresponding labels.
27 | """
28 |
29 | def __init__(self, data, tokenizer, config):
30 | self.data = data
31 | self.tokenizer = tokenizer
32 | self.config = config
33 |
34 | def __len__(self):
35 | return len(self.data)
36 |
37 | def __getitem__(self, item):
38 | text = self.data[item][self.config.tokens_column]
39 | tags = self.data[item][self.config.tags_column]
40 |
41 | label_list = self.data.features[self.config.tags_column].feature.names
42 | label_to_id = {i: i for i in range(len(label_list))}
43 |
44 | tokenized_text = self.tokenizer(
45 | text,
46 | max_length=self.config.max_seq_length,
47 | padding="max_length",
48 | truncation=True,
49 | is_split_into_words=True,
50 | )
51 |
52 | word_ids = tokenized_text.word_ids(batch_index=0)
53 | previous_word_idx = None
54 | label_ids = []
55 | for word_idx in word_ids:
56 | if word_idx is None:
57 | label_ids.append(-100)
58 | elif word_idx != previous_word_idx:
59 | label_ids.append(label_to_id[tags[word_idx]])
60 | else:
61 | label_ids.append(label_to_id[tags[word_idx]])
62 | previous_word_idx = word_idx
63 |
64 | tokenized_text["labels"] = label_ids
65 | return tokenized_text
66 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/token_classification/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from seqeval import metrics
5 |
6 |
7 | MODEL_CARD = """
8 | ---
9 | library_name: transformers
10 | tags:
11 | - autotrain
12 | - token-classification{base_model}
13 | widget:
14 | - text: "I love AutoTrain"{dataset_tag}
15 | ---
16 |
17 | # Model Trained Using AutoTrain
18 |
19 | - Problem type: Token Classification
20 |
21 | ## Validation Metrics
22 | {validation_metrics}
23 | """
24 |
25 |
26 | def token_classification_metrics(pred, label_list):
27 | """
28 | Compute token classification metrics including precision, recall, F1 score, and accuracy.
29 |
30 | Args:
31 | pred (tuple): A tuple containing predictions and labels.
32 | Predictions should be a 3D array (batch_size, sequence_length, num_labels).
33 | Labels should be a 2D array (batch_size, sequence_length).
34 | label_list (list): A list of label names corresponding to the indices used in predictions and labels.
35 |
36 | Returns:
37 | dict: A dictionary containing the following metrics:
38 | - "precision": Precision score of the token classification.
39 | - "recall": Recall score of the token classification.
40 | - "f1": F1 score of the token classification.
41 | - "accuracy": Accuracy score of the token classification.
42 | """
43 | predictions, labels = pred
44 | predictions = np.argmax(predictions, axis=2)
45 |
46 | true_predictions = [
47 | [label_list[predi] for (predi, lbl) in zip(prediction, label) if lbl != -100]
48 | for prediction, label in zip(predictions, labels)
49 | ]
50 | true_labels = [
51 | [label_list[lbl] for (predi, lbl) in zip(prediction, label) if lbl != -100]
52 | for prediction, label in zip(predictions, labels)
53 | ]
54 |
55 | results = {
56 | "precision": metrics.precision_score(true_labels, true_predictions),
57 | "recall": metrics.recall_score(true_labels, true_predictions),
58 | "f1": metrics.f1_score(true_labels, true_predictions),
59 | "accuracy": metrics.accuracy_score(true_labels, true_predictions),
60 | }
61 | return results
62 |
63 |
64 | def create_model_card(config, trainer):
65 | """
66 | Generates a model card string based on the provided configuration and trainer.
67 |
68 | Args:
69 | config (object): Configuration object containing model and dataset information.
70 | trainer (object): Trainer object used to evaluate the model.
71 |
72 | Returns:
73 | str: A formatted model card string with dataset tags, validation metrics, and base model information.
74 | """
75 | if config.valid_split is not None:
76 | eval_scores = trainer.evaluate()
77 | valid_metrics = ["eval_loss", "eval_precision", "eval_recall", "eval_f1", "eval_accuracy"]
78 | eval_scores = [f"{k[len('eval_'):]}: {v}" for k, v in eval_scores.items() if k in valid_metrics]
79 | eval_scores = "\n\n".join(eval_scores)
80 | else:
81 | eval_scores = "No validation metrics available"
82 |
83 | if config.data_path == f"{config.project_name}/autotrain-data" or os.path.isdir(config.data_path):
84 | dataset_tag = ""
85 | else:
86 | dataset_tag = f"\ndatasets:\n- {config.data_path}"
87 |
88 | if os.path.isdir(config.model):
89 | base_model = ""
90 | else:
91 | base_model = f"\nbase_model: {config.model}"
92 |
93 | model_card = MODEL_CARD.format(
94 | dataset_tag=dataset_tag,
95 | validation_metrics=eval_scores,
96 | base_model=base_model,
97 | )
98 | return model_card
99 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/vlm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/vlm/__init__.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/vlm/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | from autotrain.trainers.common import monitor
5 | from autotrain.trainers.vlm import utils
6 | from autotrain.trainers.vlm.params import VLMTrainingParams
7 |
8 |
9 | def parse_args():
10 | # get training_config.json from the end user
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--training_config", type=str, required=True)
13 | return parser.parse_args()
14 |
15 |
16 | @monitor
17 | def train(config):
18 | if isinstance(config, dict):
19 | config = VLMTrainingParams(**config)
20 |
21 | if not utils.check_model_support(config):
22 | raise ValueError(f"model `{config.model}` not supported")
23 |
24 | if config.trainer in ("vqa", "captioning"):
25 | from autotrain.trainers.vlm.train_vlm_generic import train as train_generic
26 |
27 | train_generic(config)
28 |
29 | else:
30 | raise ValueError(f"trainer `{config.trainer}` not supported")
31 |
32 |
33 | if __name__ == "__main__":
34 | _args = parse_args()
35 | training_config = json.load(open(_args.training_config))
36 | _config = VLMTrainingParams(**training_config)
37 | train(_config)
38 |
--------------------------------------------------------------------------------
/src/autotrain/trainers/vlm/dataset.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/src/autotrain/trainers/vlm/dataset.py
--------------------------------------------------------------------------------
/src/autotrain/trainers/vlm/train_vlm_generic.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from datasets import load_dataset, load_from_disk
4 | from transformers import AutoProcessor, Trainer, TrainingArguments
5 | from transformers.trainer_callback import PrinterCallback
6 |
7 | from autotrain import logger
8 | from autotrain.trainers.common import ALLOW_REMOTE_CODE
9 | from autotrain.trainers.vlm import utils
10 |
11 |
12 | def collate_fn(examples, config, processor):
13 | prompts = ["answer " + example[config.prompt_text_column] for example in examples]
14 | labels = [example[config.text_column] for example in examples]
15 | images = [example[config.image_column].convert("RGB") for example in examples]
16 | tokens = processor(
17 | text=prompts,
18 | images=images,
19 | suffix=labels,
20 | return_tensors="pt",
21 | padding="longest",
22 | tokenize_newline_separately=False,
23 | )
24 | return tokens
25 |
26 |
27 | def train(config):
28 | valid_data = None
29 | if config.data_path == f"{config.project_name}/autotrain-data":
30 | train_data = load_from_disk(config.data_path)[config.train_split]
31 | else:
32 | if ":" in config.train_split:
33 | dataset_config_name, split = config.train_split.split(":")
34 | train_data = load_dataset(
35 | config.data_path,
36 | name=dataset_config_name,
37 | split=split,
38 | token=config.token,
39 | )
40 | else:
41 | train_data = load_dataset(
42 | config.data_path,
43 | split=config.train_split,
44 | token=config.token,
45 | )
46 |
47 | if config.valid_split is not None:
48 | if config.data_path == f"{config.project_name}/autotrain-data":
49 | valid_data = load_from_disk(config.data_path)[config.valid_split]
50 | else:
51 | if ":" in config.valid_split:
52 | dataset_config_name, split = config.valid_split.split(":")
53 | valid_data = load_dataset(
54 | config.data_path,
55 | name=dataset_config_name,
56 | split=split,
57 | token=config.token,
58 | )
59 | else:
60 | valid_data = load_dataset(
61 | config.data_path,
62 | split=config.valid_split,
63 | token=config.token,
64 | )
65 |
66 | logger.info(f"Train data: {train_data}")
67 | logger.info(f"Valid data: {valid_data}")
68 |
69 | if config.trainer == "captioning":
70 | config.prompt_text_column = "caption"
71 |
72 | processor = AutoProcessor.from_pretrained(config.model, token=config.token, trust_remote_code=ALLOW_REMOTE_CODE)
73 |
74 | logging_steps = utils.configure_logging_steps(config, train_data, valid_data)
75 | training_args = utils.configure_training_args(config, logging_steps)
76 |
77 | args = TrainingArguments(**training_args)
78 | model = utils.get_model(config)
79 |
80 | logger.info("creating trainer")
81 | callbacks = utils.get_callbacks(config)
82 | trainer_args = dict(
83 | args=args,
84 | model=model,
85 | callbacks=callbacks,
86 | )
87 |
88 | col_fn = partial(collate_fn, config=config, processor=processor)
89 |
90 | trainer = Trainer(
91 | **trainer_args,
92 | train_dataset=train_data,
93 | eval_dataset=valid_data if valid_data is not None else None,
94 | data_collator=col_fn,
95 | )
96 | trainer.remove_callback(PrinterCallback)
97 | trainer.train()
98 | utils.post_training_steps(config, trainer)
99 |
--------------------------------------------------------------------------------
/src/autotrain/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import subprocess
4 |
5 | from autotrain.commands import launch_command
6 | from autotrain.trainers.clm.params import LLMTrainingParams
7 | from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams
8 | from autotrain.trainers.generic.params import GenericParams
9 | from autotrain.trainers.image_classification.params import ImageClassificationParams
10 | from autotrain.trainers.image_regression.params import ImageRegressionParams
11 | from autotrain.trainers.object_detection.params import ObjectDetectionParams
12 | from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
13 | from autotrain.trainers.seq2seq.params import Seq2SeqParams
14 | from autotrain.trainers.tabular.params import TabularParams
15 | from autotrain.trainers.text_classification.params import TextClassificationParams
16 | from autotrain.trainers.text_regression.params import TextRegressionParams
17 | from autotrain.trainers.token_classification.params import TokenClassificationParams
18 | from autotrain.trainers.vlm.params import VLMTrainingParams
19 |
20 |
21 | ALLOW_REMOTE_CODE = os.environ.get("ALLOW_REMOTE_CODE", "true").lower() == "true"
22 |
23 |
24 | def run_training(params, task_id, local=False, wait=False):
25 | """
26 | Run the training process based on the provided parameters and task ID.
27 |
28 | Args:
29 | params (str): JSON string of the parameters required for training.
30 | task_id (int): Identifier for the type of task to be performed.
31 | local (bool, optional): Flag to indicate if the training should be run locally. Defaults to False.
32 | wait (bool, optional): Flag to indicate if the function should wait for the process to complete. Defaults to False.
33 |
34 | Returns:
35 | int: Process ID of the launched training process.
36 |
37 | Raises:
38 | NotImplementedError: If the task_id does not match any of the predefined tasks.
39 | """
40 | params = json.loads(params)
41 | if isinstance(params, str):
42 | params = json.loads(params)
43 | if task_id == 9:
44 | params = LLMTrainingParams(**params)
45 | elif task_id == 28:
46 | params = Seq2SeqParams(**params)
47 | elif task_id in (1, 2):
48 | params = TextClassificationParams(**params)
49 | elif task_id in (13, 14, 15, 16, 26):
50 | params = TabularParams(**params)
51 | elif task_id == 27:
52 | params = GenericParams(**params)
53 | elif task_id == 18:
54 | params = ImageClassificationParams(**params)
55 | elif task_id == 4:
56 | params = TokenClassificationParams(**params)
57 | elif task_id == 10:
58 | params = TextRegressionParams(**params)
59 | elif task_id == 29:
60 | params = ObjectDetectionParams(**params)
61 | elif task_id == 30:
62 | params = SentenceTransformersParams(**params)
63 | elif task_id == 24:
64 | params = ImageRegressionParams(**params)
65 | elif task_id == 31:
66 | params = VLMTrainingParams(**params)
67 | elif task_id == 5:
68 | params = ExtractiveQuestionAnsweringParams(**params)
69 | else:
70 | raise NotImplementedError
71 |
72 | params.save(output_dir=params.project_name)
73 | cmd = launch_command(params=params)
74 | cmd = [str(c) for c in cmd]
75 | env = os.environ.copy()
76 | process = subprocess.Popen(cmd, env=env)
77 | if wait:
78 | process.wait()
79 | return process.pid
80 |
--------------------------------------------------------------------------------
/static/autotrain_homepage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_homepage.png
--------------------------------------------------------------------------------
/static/autotrain_model_choice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_model_choice.png
--------------------------------------------------------------------------------
/static/autotrain_space.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_space.png
--------------------------------------------------------------------------------
/static/autotrain_text_classification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/autotrain_text_classification.png
--------------------------------------------------------------------------------
/static/cost.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/cost.png
--------------------------------------------------------------------------------
/static/dreambooth1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/dreambooth1.jpeg
--------------------------------------------------------------------------------
/static/dreambooth2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/dreambooth2.png
--------------------------------------------------------------------------------
/static/duplicate_space.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/duplicate_space.png
--------------------------------------------------------------------------------
/static/ext_qa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/ext_qa.png
--------------------------------------------------------------------------------
/static/hub_model_choice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/hub_model_choice.png
--------------------------------------------------------------------------------
/static/image_classification_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/image_classification_1.png
--------------------------------------------------------------------------------
/static/img_reg_ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/img_reg_ui.png
--------------------------------------------------------------------------------
/static/llm_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_1.png
--------------------------------------------------------------------------------
/static/llm_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_2.png
--------------------------------------------------------------------------------
/static/llm_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_3.png
--------------------------------------------------------------------------------
/static/llm_orpo_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/llm_orpo_example.png
--------------------------------------------------------------------------------
/static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/logo.png
--------------------------------------------------------------------------------
/static/model_choice_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/model_choice_1.png
--------------------------------------------------------------------------------
/static/param_choice_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/param_choice_1.png
--------------------------------------------------------------------------------
/static/param_choice_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/param_choice_2.png
--------------------------------------------------------------------------------
/static/space_template_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_1.png
--------------------------------------------------------------------------------
/static/space_template_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_2.png
--------------------------------------------------------------------------------
/static/space_template_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_3.png
--------------------------------------------------------------------------------
/static/space_template_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_4.png
--------------------------------------------------------------------------------
/static/space_template_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/space_template_5.png
--------------------------------------------------------------------------------
/static/text_classification_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/text_classification_1.png
--------------------------------------------------------------------------------
/static/ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/autotrain-advanced/b5c98fbe3aab61101e9c8f7f6c64407cbb68e400/static/ui.png
--------------------------------------------------------------------------------