├── .DS_Store
├── .github
├── dependabot.yml
└── workflows
│ ├── pre-commit.yaml
│ └── release.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── classifier
├── AutoModelForSeq2SeqLM
│ ├── classify_and_evaluate.py
│ └── flan-t5-finetuning.py
├── AutoModelForSequenceClassification
│ ├── classify_and_evaluate.py
│ └── flan-t5-finetuning.py
└── data_loader.py
├── data
├── ecommerce_kaggle_dataset.csv
├── evaluation.png
└── evaluation_classification_model.png
├── poetry.lock
└── pyproject.toml
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VanekPetr/flan-t5-text-classifier/4a7d73fd5f9921794e726a197eaaaf322436a64c/.DS_Store
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "pip" # See documentation for possible values
9 | directory: "/" # Location of package manifests
10 | schedule:
11 | interval: "weekly" # How often to check for updates
12 |
13 | - package-ecosystem: "github-actions"
14 | directory: "/"
15 | schedule:
16 | # Check for updates to GitHub Actions every week
17 | interval: "weekly"
18 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yaml:
--------------------------------------------------------------------------------
1 | # Run pre-commit on all files in the repository
2 | name: pre-commit
3 |
4 | on:
5 | pull_request:
6 | push:
7 |
8 | jobs:
9 | pre-commit:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: pre-commit/action@v3.0.1
14 | with:
15 | extra_args: '--verbose --all-files'
16 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | name: Bump version
2 | on:
3 | push:
4 | branches:
5 | - main
6 |
7 | jobs:
8 | build:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Bump version and push tag
13 | id: tag_version
14 | uses: mathieudutour/github-tag-action@v6.2
15 | with:
16 | github_token: ${{ secrets.GITHUB_TOKEN }}
17 | - name: Create a GitHub release
18 | uses: ncipollo/release-action@v1
19 | with:
20 | tag: ${{ steps.tag_version.outputs.new_tag }}
21 | name: Release ${{ steps.tag_version.outputs.new_tag }}
22 | body: ${{ steps.tag_version.outputs.changelog }}
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 | data/.DS_Store
162 | .DS_Store
163 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.5.0
4 | hooks:
5 | - id: end-of-file-fixer
6 | - id: trailing-whitespace
7 |
8 | - repo: https://github.com/psf/black
9 | rev: 24.1.1
10 | hooks:
11 | - id: black
12 |
13 | - repo: https://github.com/astral-sh/ruff-pre-commit
14 | rev: 'v0.2.0'
15 | hooks:
16 | - id: ruff
17 | args: [ --fix, --exit-non-zero-on-fix ]
18 |
19 | - repo: https://github.com/asottile/pyupgrade
20 | rev: v3.15.0
21 | hooks:
22 | - id: pyupgrade
23 |
24 | - repo: https://github.com/python-jsonschema/check-jsonschema
25 | rev: 0.27.4
26 | hooks:
27 | - id: check-dependabot
28 | args: ["--verbose"]
29 | - id: check-github-workflows
30 | args: ["--verbose"]
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Petr Vanek
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # flan-t5 Fine-tuning for Text Classification
2 |
3 | This GitHub project is aimed at fine-tuning the flan t5 model for a text classification task using an E-commerce
4 | text dataset for 4 categories - "Electronics", "Household", "Books" and "Clothing & Accessories".
5 |
6 |
7 | ## Evaluation Statistics
8 | ### AutoModelForSequenceClassification
9 | Text Classification model can be found on [HuggingFace](https://huggingface.co/VanekPetr/flan-t5-small-ecommerce-text-classification). The model is trained on the dataset and evaluated on the test set. The evaluation metrics are as follows:
10 |
11 |
12 |
13 |
14 | The main **advantage** of this model is that together with prediction it outputs the confidence score for each class. This can be used to filter out the predictions with low confidence.
15 |
16 | ### AutoModelForSeq2Seq
17 | Text2Text Generation model can be found on [HuggingFace](https://huggingface.co/VanekPetr/flan-t5-base-ecommerce-text-classification). The model is trained on the dataset and evaluated on the test set. The evaluation metrics are as follows:
18 |
19 |
20 |
21 |
22 | ## Dataset
23 |
24 | The dataset is a classification-based E-commerce text dataset, which almost covers 80% of any E-commerce website. The dataset consists of product and description data for 4 categories. The dataset can be found [here](https://doi.org/10.5281/zenodo.3355823).
25 |
26 | ## Project Features
27 |
28 | The project employs the tokenizer of flan-t5 by Hugging Face, which helps in splitting the input text into a format that is understandable by the model.
29 |
30 | An evaluation function has been implemented for post-processing the labels and predictions, which will also handle sequence length adjustments.
31 |
32 | The project uses `Seq2SeqTrainer` and `SequenceClassification` for training the model. It also includes a helper function to preprocess the dataset.
33 |
34 | ## Usage
35 |
36 | To leverage the project you need to run the `flan-t5-finetuning.py` script which will trigger the training of the model.
37 |
38 | The 'train' function fine-tunes the flan-t5 model, trains it with the dataset, outputs the metrics, creates a model card and pushes the model to Hugging Face model hub.
39 |
40 | The preprocess function tokenizes the inputs, and also handles tokenization of the target labels. The compute_metrics function evaluates the model performance based on the F1 metric.
41 |
42 | ## Getting Started
43 |
44 | STEP 1: create and activate python virtual environment
45 | ``` bash
46 | python -m venv venv
47 | source venv/bin/activate
48 | ```
49 |
50 | STEP 2: install requirements with [poetry](https://python-poetry.org/docs/#installing-with-the-official-installer)
51 | ``` bash
52 | poetry install -vv
53 | ```
54 |
55 | ## Versioning
56 |
57 | We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/VanekPetr/flan-t5-text-classifier/tags).
58 |
59 | ## License
60 |
61 | This repository is licensed under [MIT](LICENSE) (c) 2023 GitHub, Inc.
62 |
--------------------------------------------------------------------------------
/classifier/AutoModelForSeq2SeqLM/classify_and_evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sklearn.metrics import classification_report
3 | from tqdm.auto import tqdm
4 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5 |
6 | from classifier.data_loader import load_dataset
7 |
8 | dataset = load_dataset()
9 |
10 | # Load model and tokenizer from the hub
11 | tokenizer = AutoTokenizer.from_pretrained(
12 | "VanekPetr/flan-t5-base-ecommerce-text-classification"
13 | )
14 | model = AutoModelForSeq2SeqLM.from_pretrained(
15 | "VanekPetr/flan-t5-base-ecommerce-text-classification"
16 | )
17 | model.to("cuda" if torch.cuda.is_available() else "cpu")
18 |
19 |
20 | def classify(texts_to_classify: str):
21 | """Classify a batch of texts using the model."""
22 | inputs = tokenizer(
23 | texts_to_classify,
24 | padding="max_length",
25 | truncation=True,
26 | max_length=256,
27 | return_tensors="pt",
28 | )
29 | inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
30 |
31 | with torch.no_grad():
32 | outputs = model.generate(
33 | inputs["input_ids"],
34 | attention_mask=inputs["attention_mask"],
35 | max_length=150,
36 | num_beams=2,
37 | early_stopping=True,
38 | )
39 |
40 | predictions = [
41 | tokenizer.decode(output, skip_special_tokens=True) for output in outputs
42 | ]
43 | return predictions
44 |
45 |
46 | def evaluate():
47 | """Evaluate the model on the test dataset."""
48 | predictions_list, labels_list = [], []
49 |
50 | batch_size = 16 # Adjust batch size based GPU capacity
51 | num_batches = len(dataset["test"]) // batch_size + (
52 | 0 if len(dataset["test"]) % batch_size == 0 else 1
53 | )
54 | progress_bar = tqdm(total=num_batches, desc="Evaluating")
55 |
56 | for i in range(0, len(dataset["test"]), batch_size):
57 | batch_texts = dataset["test"]["text"][i : i + batch_size]
58 | batch_labels = dataset["test"]["label"][i : i + batch_size]
59 |
60 | batch_predictions = classify(batch_texts)
61 |
62 | predictions_list.extend(batch_predictions)
63 | labels_list.extend([str(label) for label in batch_labels])
64 |
65 | progress_bar.update(1)
66 |
67 | progress_bar.close()
68 | report = classification_report(labels_list, predictions_list)
69 | print(report)
70 |
71 |
72 | if __name__ == "__main__":
73 | evaluate()
74 |
--------------------------------------------------------------------------------
/classifier/AutoModelForSeq2SeqLM/flan-t5-finetuning.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import evaluate
4 | import nltk
5 | import numpy as np
6 | from datasets import Dataset, concatenate_datasets
7 | from huggingface_hub import HfFolder
8 | from nltk.tokenize import sent_tokenize
9 | from transformers import (
10 | AutoModelForSeq2SeqLM,
11 | AutoTokenizer,
12 | DataCollatorForSeq2Seq,
13 | Seq2SeqTrainer,
14 | Seq2SeqTrainingArguments,
15 | )
16 |
17 | from classifier.data_loader import load_dataset
18 |
19 | MODEL_ID = "google/flan-t5-base"
20 | REPOSITORY_ID = f"{MODEL_ID.split('/')[1]}-ecommerce-text-classification"
21 |
22 | # Load dataset
23 | dataset = load_dataset()
24 |
25 | # Load tokenizer of FLAN-t5
26 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
27 |
28 | # Metric
29 | metric = evaluate.load("f1")
30 |
31 | # The maximum total input sequence length after tokenization.
32 | # Sequences longer than this will be truncated, sequences shorter will be padded.
33 | tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(
34 | lambda x: tokenizer(x["text"], truncation=True),
35 | batched=True,
36 | remove_columns=["text", "label"],
37 | )
38 | max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]])
39 | print(f"Max source length: {max_source_length}")
40 |
41 | # The maximum total sequence length for target text after tokenization.
42 | # Sequences longer than this will be truncated, sequences shorter will be padded."
43 | tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map(
44 | lambda x: tokenizer(x["label"], truncation=True),
45 | batched=True,
46 | remove_columns=["text", "label"],
47 | )
48 | max_target_length = max([len(x) for x in tokenized_targets["input_ids"]])
49 | print(f"Max target length: {max_target_length}")
50 |
51 | # Define training args
52 | training_args = Seq2SeqTrainingArguments(
53 | output_dir=REPOSITORY_ID,
54 | per_device_train_batch_size=8,
55 | per_device_eval_batch_size=8,
56 | predict_with_generate=True,
57 | fp16=False, # Overflows with fp16
58 | learning_rate=3e-4,
59 | num_train_epochs=2,
60 | logging_dir=f"{REPOSITORY_ID}/logs", # logging & evaluation strategies
61 | logging_strategy="epoch",
62 | evaluation_strategy="no",
63 | save_strategy="epoch",
64 | save_total_limit=2,
65 | load_best_model_at_end=False,
66 | report_to="tensorboard",
67 | push_to_hub=True,
68 | hub_strategy="every_save",
69 | hub_model_id=REPOSITORY_ID,
70 | hub_token=HfFolder.get_token(),
71 | )
72 |
73 |
74 | def preprocess_function(sample: Dataset, padding: str = "max_length") -> dict:
75 | """Preprocess the dataset."""
76 |
77 | # add prefix to the input for t5
78 | inputs = [item for item in sample["text"]]
79 |
80 | # tokenize inputs
81 | model_inputs = tokenizer(
82 | inputs, max_length=max_source_length, padding=padding, truncation=True
83 | )
84 |
85 | # Tokenize targets with the `text_target` keyword argument
86 | labels = tokenizer(
87 | text_target=sample["label"],
88 | max_length=max_target_length,
89 | padding=padding,
90 | truncation=True,
91 | )
92 |
93 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
94 | # padding in the loss.
95 | if padding == "max_length":
96 | labels["input_ids"] = [
97 | [(la if la != tokenizer.pad_token_id else -100) for la in label]
98 | for label in labels["input_ids"]
99 | ]
100 |
101 | model_inputs["labels"] = labels["input_ids"]
102 | return model_inputs
103 |
104 |
105 | def postprocess_text(
106 | preds: List[str], labels: List[str]
107 | ) -> Tuple[List[str], List[str]]:
108 | """helper function to postprocess text"""
109 | preds = [pred.strip() for pred in preds]
110 | labels = [label.strip() for label in labels]
111 |
112 | # rougeLSum expects newline after each sentence
113 | preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
114 | labels = ["\n".join(sent_tokenize(label)) for label in labels]
115 |
116 | return preds, labels
117 |
118 |
119 | def compute_metrics(eval_preds):
120 | preds, labels = eval_preds
121 | if isinstance(preds, tuple):
122 | preds = preds[0]
123 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
124 | # Replace -100 in the labels as we can't decode them.
125 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
126 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
127 |
128 | # Some simple post-processing
129 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
130 |
131 | result = metric.compute(
132 | predictions=decoded_preds, references=decoded_labels, average="macro"
133 | )
134 | result = {k: round(v * 100, 4) for k, v in result.items()}
135 | prediction_lens = [
136 | np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
137 | ]
138 | result["gen_len"] = np.mean(prediction_lens)
139 | return result
140 |
141 |
142 | def train() -> None:
143 | """Train the model."""
144 |
145 | tokenized_dataset = dataset.map(
146 | preprocess_function, batched=True, remove_columns=["text", "label"]
147 | )
148 | print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")
149 |
150 | # load model from the hub
151 | model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
152 |
153 | nltk.download("punkt")
154 |
155 | # we want to ignore tokenizer pad token in the loss
156 | label_pad_token_id = -100
157 | # Data collator
158 | data_collator = DataCollatorForSeq2Seq(
159 | tokenizer,
160 | model=model,
161 | label_pad_token_id=label_pad_token_id,
162 | pad_to_multiple_of=8,
163 | )
164 |
165 | # Create Trainer instance
166 | trainer = Seq2SeqTrainer(
167 | model=model,
168 | args=training_args,
169 | data_collator=data_collator,
170 | train_dataset=tokenized_dataset["train"],
171 | eval_dataset=tokenized_dataset["test"],
172 | compute_metrics=compute_metrics,
173 | )
174 |
175 | # TRAIN
176 | trainer.train()
177 |
178 | # SAVE
179 | tokenizer.save_pretrained(REPOSITORY_ID)
180 | trainer.create_model_card()
181 | trainer.push_to_hub()
182 |
183 |
184 | if __name__ == "__main__":
185 | train()
186 |
--------------------------------------------------------------------------------
/classifier/AutoModelForSequenceClassification/classify_and_evaluate.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | from typing import List, Tuple
3 |
4 | import torch
5 | from loguru import logger
6 | from sklearn.metrics import classification_report
7 | from tqdm.auto import tqdm
8 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
9 |
10 | from classifier.data_loader import id2label, load_dataset
11 |
12 | dataset = load_dataset("AutoModelForSequenceClassification")
13 |
14 | # Load the model and tokenizer
15 | MODEL_ID = "VanekPetr/flan-t5-small-ecommerce-text-classification"
16 |
17 | model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
18 | model.to("cuda") if torch.cuda.is_available() else model.to("cpu")
19 |
20 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21 |
22 |
23 | def classify(texts_to_classify: List[str]) -> List[Tuple[str, float]]:
24 | """Classify a list of texts using the model."""
25 |
26 | # Tokenize all texts in the batch
27 | start = time()
28 | inputs = tokenizer(
29 | texts_to_classify,
30 | return_tensors="pt",
31 | max_length=512,
32 | truncation=True,
33 | padding=True,
34 | )
35 | inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
36 |
37 | # Get predictions
38 | with torch.no_grad():
39 | outputs = model(**inputs)
40 | logger.debug(
41 | f"Classification of {len(texts_to_classify)} examples took {time() - start} seconds"
42 | )
43 |
44 | # Process the outputs to get the probability distribution
45 | logits = outputs.logits
46 | probs = torch.nn.functional.softmax(logits, dim=-1)
47 |
48 | # Get the top class and the corresponding probability (certainty) for each input text
49 | confidences, predicted_classes = torch.max(probs, dim=1)
50 | predicted_classes = (
51 | predicted_classes.cpu().numpy()
52 | ) # Move to CPU for numpy conversion if needed
53 | confidences = confidences.cpu().numpy() # Same here
54 |
55 | # Map predicted class IDs to labels
56 | predicted_labels = [id2label[class_id] for class_id in predicted_classes]
57 |
58 | # Zip together the predicted labels and confidences and convert to a list of tuples
59 | return list(zip(predicted_labels, confidences))
60 |
61 |
62 | def evaluate():
63 | """Evaluate the model on the test dataset."""
64 | predictions_list, labels_list = [], []
65 |
66 | batch_size = 16 # Adjust batch size based GPU capacity
67 | num_batches = len(dataset["test"]) // batch_size + (
68 | 0 if len(dataset["test"]) % batch_size == 0 else 1
69 | )
70 | progress_bar = tqdm(total=num_batches, desc="Evaluating")
71 |
72 | for i in range(0, len(dataset["test"]), batch_size):
73 | batch_texts = dataset["test"]["text"][i : i + batch_size]
74 | batch_labels = dataset["test"]["label"][i : i + batch_size]
75 |
76 | batch_predictions = classify(batch_texts)
77 |
78 | predictions_list.extend(batch_predictions)
79 | labels_list.extend([id2label[label_id] for label_id in batch_labels])
80 |
81 | progress_bar.update(1)
82 |
83 | progress_bar.close()
84 | report = classification_report(labels_list, [pair[0] for pair in predictions_list])
85 | print(report)
86 |
87 |
88 | if __name__ == "__main__":
89 | evaluate()
90 |
--------------------------------------------------------------------------------
/classifier/AutoModelForSequenceClassification/flan-t5-finetuning.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | import numpy as np
3 | from huggingface_hub import HfFolder
4 | from sklearn.metrics import precision_recall_fscore_support
5 | from transformers import (
6 | AutoConfig,
7 | AutoModelForSequenceClassification,
8 | AutoTokenizer,
9 | Trainer,
10 | TrainingArguments,
11 | )
12 |
13 | from classifier.data_loader import id2label, label2id, load_dataset
14 |
15 | MODEL_ID = "google/flan-t5-small"
16 | REPOSITORY_ID = f"{MODEL_ID.split('/')[1]}-ecommerce-text-classification"
17 |
18 | config = AutoConfig.from_pretrained(
19 | MODEL_ID, num_labels=len(label2id), id2label=id2label, label2id=label2id
20 | )
21 | model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, config=config)
22 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23 |
24 | training_args = TrainingArguments(
25 | num_train_epochs=2,
26 | output_dir=REPOSITORY_ID,
27 | logging_strategy="steps",
28 | logging_steps=100,
29 | report_to="tensorboard",
30 | per_device_train_batch_size=8,
31 | per_device_eval_batch_size=8,
32 | fp16=False, # Overflows with fp16
33 | learning_rate=3e-4,
34 | save_strategy="epoch",
35 | save_total_limit=2,
36 | load_best_model_at_end=False,
37 | push_to_hub=True,
38 | hub_strategy="every_save",
39 | hub_model_id=REPOSITORY_ID,
40 | hub_token=HfFolder.get_token(),
41 | )
42 |
43 |
44 | def tokenize_function(examples) -> dict:
45 | """Tokenize the text column in the dataset"""
46 | return tokenizer(examples["text"], padding="max_length", truncation=True)
47 |
48 |
49 | def compute_metrics(eval_pred) -> dict:
50 | """Compute metrics for evaluation"""
51 | logits, labels = eval_pred
52 | if isinstance(
53 | logits, tuple
54 | ): # if the model also returns hidden_states or attentions
55 | logits = logits[0]
56 | predictions = np.argmax(logits, axis=-1)
57 | precision, recall, f1, _ = precision_recall_fscore_support(
58 | labels, predictions, average="binary"
59 | )
60 | return {"precision": precision, "recall": recall, "f1": f1}
61 |
62 |
63 | def train() -> None:
64 | """
65 | Train the model and save it to the Hugging Face Hub.
66 | """
67 | dataset = load_dataset("AutoModelForSequenceClassification")
68 | tokenized_datasets = dataset.map(tokenize_function, batched=True)
69 |
70 | nltk.download("punkt")
71 |
72 | trainer = Trainer(
73 | model=model,
74 | args=training_args,
75 | train_dataset=tokenized_datasets["train"],
76 | eval_dataset=tokenized_datasets["test"],
77 | compute_metrics=compute_metrics,
78 | )
79 |
80 | # TRAIN
81 | trainer.train()
82 |
83 | # SAVE AND EVALUATE
84 | tokenizer.save_pretrained(REPOSITORY_ID)
85 | trainer.create_model_card()
86 | trainer.push_to_hub()
87 | print(trainer.evaluate())
88 |
89 |
90 | if __name__ == "__main__":
91 | train()
92 |
--------------------------------------------------------------------------------
/classifier/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | from datasets import Dataset
5 |
6 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7 |
8 | label2id = {"Books": 0, "Clothing & Accessories": 1, "Electronics": 2, "Household": 3}
9 | id2label = {id: label for label, id in label2id.items()}
10 |
11 |
12 | def load_dataset(model_type: str = "") -> Dataset:
13 | """Load dataset."""
14 | dataset_ecommerce_pandas = pd.read_csv(
15 | ROOT_DIR + "/data/ecommerce_kaggle_dataset.csv",
16 | header=None,
17 | names=["label", "text"],
18 | )
19 |
20 | dataset_ecommerce_pandas["label"] = dataset_ecommerce_pandas["label"].astype(str)
21 | if model_type == "AutoModelForSequenceClassification":
22 | # Convert labels to integers
23 | dataset_ecommerce_pandas["label"] = dataset_ecommerce_pandas["label"].map(
24 | label2id
25 | )
26 |
27 | dataset_ecommerce_pandas["text"] = dataset_ecommerce_pandas["text"].astype(str)
28 | dataset = Dataset.from_pandas(dataset_ecommerce_pandas)
29 | dataset = dataset.shuffle(seed=42)
30 | dataset = dataset.train_test_split(test_size=0.2)
31 |
32 | return dataset
33 |
34 |
35 | if __name__ == "__main__":
36 | print(load_dataset())
37 |
--------------------------------------------------------------------------------
/data/evaluation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VanekPetr/flan-t5-text-classifier/4a7d73fd5f9921794e726a197eaaaf322436a64c/data/evaluation.png
--------------------------------------------------------------------------------
/data/evaluation_classification_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VanekPetr/flan-t5-text-classifier/4a7d73fd5f9921794e726a197eaaaf322436a64c/data/evaluation_classification_model.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "flan-t5-text-classifier"
3 | version = "0.0.0"
4 | description = "Flan t5 model for a text classification."
5 | authors = ["Petr Vanek"]
6 | readme = "README.md"
7 | repository = "https://github.com/VanekPetr/flan-t5-text-classifier"
8 |
9 | [tool.poetry.dependencies]
10 | python = ">=3.9,<3.13"
11 | pandas = "2.2.2"
12 | pre-commit = "3.8.0"
13 | loguru = "0.7.2"
14 | datasets = "2.20.0"
15 | transformers = "4.43.3"
16 | scikit-learn = "1.5.1"
17 | tqdm = "4.66.4"
18 | torch = "2.4.0"
19 | evaluate = "0.4.2"
20 | nltk = "3.8.1"
21 | accelerate = "0.33.0"
22 | tensorboardX = "2.6.2.2"
23 |
24 | [tool.poetry.group.test.dependencies]
25 | pytest = "*"
26 | pytest-cov = "*"
27 | pre-commit = "*"
28 |
29 | [build-system]
30 | requires = ["poetry>=1.6.0"]
31 | build-backend = "poetry.core.masonry.api"
32 |
33 | [tool.ruff]
34 | lint.select = ["E", "F", "I"]
35 | line-length = 120
36 | target-version = "py310"
37 | exclude = [
38 | "*__init__.py"
39 | ]
40 |
--------------------------------------------------------------------------------