├── .DS_Store ├── .github ├── dependabot.yml └── workflows │ ├── pre-commit.yaml │ └── release.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── classifier ├── AutoModelForSeq2SeqLM │ ├── classify_and_evaluate.py │ └── flan-t5-finetuning.py ├── AutoModelForSequenceClassification │ ├── classify_and_evaluate.py │ └── flan-t5-finetuning.py └── data_loader.py ├── data ├── ecommerce_kaggle_dataset.csv ├── evaluation.png └── evaluation_classification_model.png ├── poetry.lock └── pyproject.toml /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VanekPetr/flan-t5-text-classifier/4a7d73fd5f9921794e726a197eaaaf322436a64c/.DS_Store -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" # How often to check for updates 12 | 13 | - package-ecosystem: "github-actions" 14 | directory: "/" 15 | schedule: 16 | # Check for updates to GitHub Actions every week 17 | interval: "weekly" 18 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | # Run pre-commit on all files in the repository 2 | name: pre-commit 3 | 4 | on: 5 | pull_request: 6 | push: 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: pre-commit/action@v3.0.1 14 | with: 15 | extra_args: '--verbose --all-files' 16 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Bump version 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Bump version and push tag 13 | id: tag_version 14 | uses: mathieudutour/github-tag-action@v6.2 15 | with: 16 | github_token: ${{ secrets.GITHUB_TOKEN }} 17 | - name: Create a GitHub release 18 | uses: ncipollo/release-action@v1 19 | with: 20 | tag: ${{ steps.tag_version.outputs.new_tag }} 21 | name: Release ${{ steps.tag_version.outputs.new_tag }} 22 | body: ${{ steps.tag_version.outputs.changelog }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | data/.DS_Store 162 | .DS_Store 163 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | 8 | - repo: https://github.com/psf/black 9 | rev: 24.1.1 10 | hooks: 11 | - id: black 12 | 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: 'v0.2.0' 15 | hooks: 16 | - id: ruff 17 | args: [ --fix, --exit-non-zero-on-fix ] 18 | 19 | - repo: https://github.com/asottile/pyupgrade 20 | rev: v3.15.0 21 | hooks: 22 | - id: pyupgrade 23 | 24 | - repo: https://github.com/python-jsonschema/check-jsonschema 25 | rev: 0.27.4 26 | hooks: 27 | - id: check-dependabot 28 | args: ["--verbose"] 29 | - id: check-github-workflows 30 | args: ["--verbose"] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Petr Vanek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # flan-t5 Fine-tuning for Text Classification 2 | 3 | This GitHub project is aimed at fine-tuning the flan t5 model for a text classification task using an E-commerce 4 | text dataset for 4 categories - "Electronics", "Household", "Books" and "Clothing & Accessories". 5 | 6 | 7 | ## Evaluation Statistics 8 | ### AutoModelForSequenceClassification 9 | Text Classification model can be found on [HuggingFace](https://huggingface.co/VanekPetr/flan-t5-small-ecommerce-text-classification). The model is trained on the dataset and evaluated on the test set. The evaluation metrics are as follows: 10 |

11 | 12 |

13 | 14 | The main **advantage** of this model is that together with prediction it outputs the confidence score for each class. This can be used to filter out the predictions with low confidence. 15 | 16 | ### AutoModelForSeq2Seq 17 | Text2Text Generation model can be found on [HuggingFace](https://huggingface.co/VanekPetr/flan-t5-base-ecommerce-text-classification). The model is trained on the dataset and evaluated on the test set. The evaluation metrics are as follows: 18 |

19 | 20 |

21 | 22 | ## Dataset 23 | 24 | The dataset is a classification-based E-commerce text dataset, which almost covers 80% of any E-commerce website. The dataset consists of product and description data for 4 categories. The dataset can be found [here](https://doi.org/10.5281/zenodo.3355823). 25 | 26 | ## Project Features 27 | 28 | The project employs the tokenizer of flan-t5 by Hugging Face, which helps in splitting the input text into a format that is understandable by the model. 29 | 30 | An evaluation function has been implemented for post-processing the labels and predictions, which will also handle sequence length adjustments. 31 | 32 | The project uses `Seq2SeqTrainer` and `SequenceClassification` for training the model. It also includes a helper function to preprocess the dataset. 33 | 34 | ## Usage 35 | 36 | To leverage the project you need to run the `flan-t5-finetuning.py` script which will trigger the training of the model. 37 | 38 | The 'train' function fine-tunes the flan-t5 model, trains it with the dataset, outputs the metrics, creates a model card and pushes the model to Hugging Face model hub. 39 | 40 | The preprocess function tokenizes the inputs, and also handles tokenization of the target labels. The compute_metrics function evaluates the model performance based on the F1 metric. 41 | 42 | ## Getting Started 43 | 44 | STEP 1: create and activate python virtual environment 45 | ``` bash 46 | python -m venv venv 47 | source venv/bin/activate 48 | ``` 49 | 50 | STEP 2: install requirements with [poetry](https://python-poetry.org/docs/#installing-with-the-official-installer) 51 | ``` bash 52 | poetry install -vv 53 | ``` 54 | 55 | ## Versioning 56 | 57 | We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/VanekPetr/flan-t5-text-classifier/tags). 58 | 59 | ## License 60 | 61 | This repository is licensed under [MIT](LICENSE) (c) 2023 GitHub, Inc. 62 | -------------------------------------------------------------------------------- /classifier/AutoModelForSeq2SeqLM/classify_and_evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import classification_report 3 | from tqdm.auto import tqdm 4 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 5 | 6 | from classifier.data_loader import load_dataset 7 | 8 | dataset = load_dataset() 9 | 10 | # Load model and tokenizer from the hub 11 | tokenizer = AutoTokenizer.from_pretrained( 12 | "VanekPetr/flan-t5-base-ecommerce-text-classification" 13 | ) 14 | model = AutoModelForSeq2SeqLM.from_pretrained( 15 | "VanekPetr/flan-t5-base-ecommerce-text-classification" 16 | ) 17 | model.to("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def classify(texts_to_classify: str): 21 | """Classify a batch of texts using the model.""" 22 | inputs = tokenizer( 23 | texts_to_classify, 24 | padding="max_length", 25 | truncation=True, 26 | max_length=256, 27 | return_tensors="pt", 28 | ) 29 | inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu") 30 | 31 | with torch.no_grad(): 32 | outputs = model.generate( 33 | inputs["input_ids"], 34 | attention_mask=inputs["attention_mask"], 35 | max_length=150, 36 | num_beams=2, 37 | early_stopping=True, 38 | ) 39 | 40 | predictions = [ 41 | tokenizer.decode(output, skip_special_tokens=True) for output in outputs 42 | ] 43 | return predictions 44 | 45 | 46 | def evaluate(): 47 | """Evaluate the model on the test dataset.""" 48 | predictions_list, labels_list = [], [] 49 | 50 | batch_size = 16 # Adjust batch size based GPU capacity 51 | num_batches = len(dataset["test"]) // batch_size + ( 52 | 0 if len(dataset["test"]) % batch_size == 0 else 1 53 | ) 54 | progress_bar = tqdm(total=num_batches, desc="Evaluating") 55 | 56 | for i in range(0, len(dataset["test"]), batch_size): 57 | batch_texts = dataset["test"]["text"][i : i + batch_size] 58 | batch_labels = dataset["test"]["label"][i : i + batch_size] 59 | 60 | batch_predictions = classify(batch_texts) 61 | 62 | predictions_list.extend(batch_predictions) 63 | labels_list.extend([str(label) for label in batch_labels]) 64 | 65 | progress_bar.update(1) 66 | 67 | progress_bar.close() 68 | report = classification_report(labels_list, predictions_list) 69 | print(report) 70 | 71 | 72 | if __name__ == "__main__": 73 | evaluate() 74 | -------------------------------------------------------------------------------- /classifier/AutoModelForSeq2SeqLM/flan-t5-finetuning.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import evaluate 4 | import nltk 5 | import numpy as np 6 | from datasets import Dataset, concatenate_datasets 7 | from huggingface_hub import HfFolder 8 | from nltk.tokenize import sent_tokenize 9 | from transformers import ( 10 | AutoModelForSeq2SeqLM, 11 | AutoTokenizer, 12 | DataCollatorForSeq2Seq, 13 | Seq2SeqTrainer, 14 | Seq2SeqTrainingArguments, 15 | ) 16 | 17 | from classifier.data_loader import load_dataset 18 | 19 | MODEL_ID = "google/flan-t5-base" 20 | REPOSITORY_ID = f"{MODEL_ID.split('/')[1]}-ecommerce-text-classification" 21 | 22 | # Load dataset 23 | dataset = load_dataset() 24 | 25 | # Load tokenizer of FLAN-t5 26 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) 27 | 28 | # Metric 29 | metric = evaluate.load("f1") 30 | 31 | # The maximum total input sequence length after tokenization. 32 | # Sequences longer than this will be truncated, sequences shorter will be padded. 33 | tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map( 34 | lambda x: tokenizer(x["text"], truncation=True), 35 | batched=True, 36 | remove_columns=["text", "label"], 37 | ) 38 | max_source_length = max([len(x) for x in tokenized_inputs["input_ids"]]) 39 | print(f"Max source length: {max_source_length}") 40 | 41 | # The maximum total sequence length for target text after tokenization. 42 | # Sequences longer than this will be truncated, sequences shorter will be padded." 43 | tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map( 44 | lambda x: tokenizer(x["label"], truncation=True), 45 | batched=True, 46 | remove_columns=["text", "label"], 47 | ) 48 | max_target_length = max([len(x) for x in tokenized_targets["input_ids"]]) 49 | print(f"Max target length: {max_target_length}") 50 | 51 | # Define training args 52 | training_args = Seq2SeqTrainingArguments( 53 | output_dir=REPOSITORY_ID, 54 | per_device_train_batch_size=8, 55 | per_device_eval_batch_size=8, 56 | predict_with_generate=True, 57 | fp16=False, # Overflows with fp16 58 | learning_rate=3e-4, 59 | num_train_epochs=2, 60 | logging_dir=f"{REPOSITORY_ID}/logs", # logging & evaluation strategies 61 | logging_strategy="epoch", 62 | evaluation_strategy="no", 63 | save_strategy="epoch", 64 | save_total_limit=2, 65 | load_best_model_at_end=False, 66 | report_to="tensorboard", 67 | push_to_hub=True, 68 | hub_strategy="every_save", 69 | hub_model_id=REPOSITORY_ID, 70 | hub_token=HfFolder.get_token(), 71 | ) 72 | 73 | 74 | def preprocess_function(sample: Dataset, padding: str = "max_length") -> dict: 75 | """Preprocess the dataset.""" 76 | 77 | # add prefix to the input for t5 78 | inputs = [item for item in sample["text"]] 79 | 80 | # tokenize inputs 81 | model_inputs = tokenizer( 82 | inputs, max_length=max_source_length, padding=padding, truncation=True 83 | ) 84 | 85 | # Tokenize targets with the `text_target` keyword argument 86 | labels = tokenizer( 87 | text_target=sample["label"], 88 | max_length=max_target_length, 89 | padding=padding, 90 | truncation=True, 91 | ) 92 | 93 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 94 | # padding in the loss. 95 | if padding == "max_length": 96 | labels["input_ids"] = [ 97 | [(la if la != tokenizer.pad_token_id else -100) for la in label] 98 | for label in labels["input_ids"] 99 | ] 100 | 101 | model_inputs["labels"] = labels["input_ids"] 102 | return model_inputs 103 | 104 | 105 | def postprocess_text( 106 | preds: List[str], labels: List[str] 107 | ) -> Tuple[List[str], List[str]]: 108 | """helper function to postprocess text""" 109 | preds = [pred.strip() for pred in preds] 110 | labels = [label.strip() for label in labels] 111 | 112 | # rougeLSum expects newline after each sentence 113 | preds = ["\n".join(sent_tokenize(pred)) for pred in preds] 114 | labels = ["\n".join(sent_tokenize(label)) for label in labels] 115 | 116 | return preds, labels 117 | 118 | 119 | def compute_metrics(eval_preds): 120 | preds, labels = eval_preds 121 | if isinstance(preds, tuple): 122 | preds = preds[0] 123 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 124 | # Replace -100 in the labels as we can't decode them. 125 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 126 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 127 | 128 | # Some simple post-processing 129 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 130 | 131 | result = metric.compute( 132 | predictions=decoded_preds, references=decoded_labels, average="macro" 133 | ) 134 | result = {k: round(v * 100, 4) for k, v in result.items()} 135 | prediction_lens = [ 136 | np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds 137 | ] 138 | result["gen_len"] = np.mean(prediction_lens) 139 | return result 140 | 141 | 142 | def train() -> None: 143 | """Train the model.""" 144 | 145 | tokenized_dataset = dataset.map( 146 | preprocess_function, batched=True, remove_columns=["text", "label"] 147 | ) 148 | print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}") 149 | 150 | # load model from the hub 151 | model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) 152 | 153 | nltk.download("punkt") 154 | 155 | # we want to ignore tokenizer pad token in the loss 156 | label_pad_token_id = -100 157 | # Data collator 158 | data_collator = DataCollatorForSeq2Seq( 159 | tokenizer, 160 | model=model, 161 | label_pad_token_id=label_pad_token_id, 162 | pad_to_multiple_of=8, 163 | ) 164 | 165 | # Create Trainer instance 166 | trainer = Seq2SeqTrainer( 167 | model=model, 168 | args=training_args, 169 | data_collator=data_collator, 170 | train_dataset=tokenized_dataset["train"], 171 | eval_dataset=tokenized_dataset["test"], 172 | compute_metrics=compute_metrics, 173 | ) 174 | 175 | # TRAIN 176 | trainer.train() 177 | 178 | # SAVE 179 | tokenizer.save_pretrained(REPOSITORY_ID) 180 | trainer.create_model_card() 181 | trainer.push_to_hub() 182 | 183 | 184 | if __name__ == "__main__": 185 | train() 186 | -------------------------------------------------------------------------------- /classifier/AutoModelForSequenceClassification/classify_and_evaluate.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import List, Tuple 3 | 4 | import torch 5 | from loguru import logger 6 | from sklearn.metrics import classification_report 7 | from tqdm.auto import tqdm 8 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 9 | 10 | from classifier.data_loader import id2label, load_dataset 11 | 12 | dataset = load_dataset("AutoModelForSequenceClassification") 13 | 14 | # Load the model and tokenizer 15 | MODEL_ID = "VanekPetr/flan-t5-small-ecommerce-text-classification" 16 | 17 | model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) 18 | model.to("cuda") if torch.cuda.is_available() else model.to("cpu") 19 | 20 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) 21 | 22 | 23 | def classify(texts_to_classify: List[str]) -> List[Tuple[str, float]]: 24 | """Classify a list of texts using the model.""" 25 | 26 | # Tokenize all texts in the batch 27 | start = time() 28 | inputs = tokenizer( 29 | texts_to_classify, 30 | return_tensors="pt", 31 | max_length=512, 32 | truncation=True, 33 | padding=True, 34 | ) 35 | inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu") 36 | 37 | # Get predictions 38 | with torch.no_grad(): 39 | outputs = model(**inputs) 40 | logger.debug( 41 | f"Classification of {len(texts_to_classify)} examples took {time() - start} seconds" 42 | ) 43 | 44 | # Process the outputs to get the probability distribution 45 | logits = outputs.logits 46 | probs = torch.nn.functional.softmax(logits, dim=-1) 47 | 48 | # Get the top class and the corresponding probability (certainty) for each input text 49 | confidences, predicted_classes = torch.max(probs, dim=1) 50 | predicted_classes = ( 51 | predicted_classes.cpu().numpy() 52 | ) # Move to CPU for numpy conversion if needed 53 | confidences = confidences.cpu().numpy() # Same here 54 | 55 | # Map predicted class IDs to labels 56 | predicted_labels = [id2label[class_id] for class_id in predicted_classes] 57 | 58 | # Zip together the predicted labels and confidences and convert to a list of tuples 59 | return list(zip(predicted_labels, confidences)) 60 | 61 | 62 | def evaluate(): 63 | """Evaluate the model on the test dataset.""" 64 | predictions_list, labels_list = [], [] 65 | 66 | batch_size = 16 # Adjust batch size based GPU capacity 67 | num_batches = len(dataset["test"]) // batch_size + ( 68 | 0 if len(dataset["test"]) % batch_size == 0 else 1 69 | ) 70 | progress_bar = tqdm(total=num_batches, desc="Evaluating") 71 | 72 | for i in range(0, len(dataset["test"]), batch_size): 73 | batch_texts = dataset["test"]["text"][i : i + batch_size] 74 | batch_labels = dataset["test"]["label"][i : i + batch_size] 75 | 76 | batch_predictions = classify(batch_texts) 77 | 78 | predictions_list.extend(batch_predictions) 79 | labels_list.extend([id2label[label_id] for label_id in batch_labels]) 80 | 81 | progress_bar.update(1) 82 | 83 | progress_bar.close() 84 | report = classification_report(labels_list, [pair[0] for pair in predictions_list]) 85 | print(report) 86 | 87 | 88 | if __name__ == "__main__": 89 | evaluate() 90 | -------------------------------------------------------------------------------- /classifier/AutoModelForSequenceClassification/flan-t5-finetuning.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import numpy as np 3 | from huggingface_hub import HfFolder 4 | from sklearn.metrics import precision_recall_fscore_support 5 | from transformers import ( 6 | AutoConfig, 7 | AutoModelForSequenceClassification, 8 | AutoTokenizer, 9 | Trainer, 10 | TrainingArguments, 11 | ) 12 | 13 | from classifier.data_loader import id2label, label2id, load_dataset 14 | 15 | MODEL_ID = "google/flan-t5-small" 16 | REPOSITORY_ID = f"{MODEL_ID.split('/')[1]}-ecommerce-text-classification" 17 | 18 | config = AutoConfig.from_pretrained( 19 | MODEL_ID, num_labels=len(label2id), id2label=id2label, label2id=label2id 20 | ) 21 | model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, config=config) 22 | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) 23 | 24 | training_args = TrainingArguments( 25 | num_train_epochs=2, 26 | output_dir=REPOSITORY_ID, 27 | logging_strategy="steps", 28 | logging_steps=100, 29 | report_to="tensorboard", 30 | per_device_train_batch_size=8, 31 | per_device_eval_batch_size=8, 32 | fp16=False, # Overflows with fp16 33 | learning_rate=3e-4, 34 | save_strategy="epoch", 35 | save_total_limit=2, 36 | load_best_model_at_end=False, 37 | push_to_hub=True, 38 | hub_strategy="every_save", 39 | hub_model_id=REPOSITORY_ID, 40 | hub_token=HfFolder.get_token(), 41 | ) 42 | 43 | 44 | def tokenize_function(examples) -> dict: 45 | """Tokenize the text column in the dataset""" 46 | return tokenizer(examples["text"], padding="max_length", truncation=True) 47 | 48 | 49 | def compute_metrics(eval_pred) -> dict: 50 | """Compute metrics for evaluation""" 51 | logits, labels = eval_pred 52 | if isinstance( 53 | logits, tuple 54 | ): # if the model also returns hidden_states or attentions 55 | logits = logits[0] 56 | predictions = np.argmax(logits, axis=-1) 57 | precision, recall, f1, _ = precision_recall_fscore_support( 58 | labels, predictions, average="binary" 59 | ) 60 | return {"precision": precision, "recall": recall, "f1": f1} 61 | 62 | 63 | def train() -> None: 64 | """ 65 | Train the model and save it to the Hugging Face Hub. 66 | """ 67 | dataset = load_dataset("AutoModelForSequenceClassification") 68 | tokenized_datasets = dataset.map(tokenize_function, batched=True) 69 | 70 | nltk.download("punkt") 71 | 72 | trainer = Trainer( 73 | model=model, 74 | args=training_args, 75 | train_dataset=tokenized_datasets["train"], 76 | eval_dataset=tokenized_datasets["test"], 77 | compute_metrics=compute_metrics, 78 | ) 79 | 80 | # TRAIN 81 | trainer.train() 82 | 83 | # SAVE AND EVALUATE 84 | tokenizer.save_pretrained(REPOSITORY_ID) 85 | trainer.create_model_card() 86 | trainer.push_to_hub() 87 | print(trainer.evaluate()) 88 | 89 | 90 | if __name__ == "__main__": 91 | train() 92 | -------------------------------------------------------------------------------- /classifier/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from datasets import Dataset 5 | 6 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 7 | 8 | label2id = {"Books": 0, "Clothing & Accessories": 1, "Electronics": 2, "Household": 3} 9 | id2label = {id: label for label, id in label2id.items()} 10 | 11 | 12 | def load_dataset(model_type: str = "") -> Dataset: 13 | """Load dataset.""" 14 | dataset_ecommerce_pandas = pd.read_csv( 15 | ROOT_DIR + "/data/ecommerce_kaggle_dataset.csv", 16 | header=None, 17 | names=["label", "text"], 18 | ) 19 | 20 | dataset_ecommerce_pandas["label"] = dataset_ecommerce_pandas["label"].astype(str) 21 | if model_type == "AutoModelForSequenceClassification": 22 | # Convert labels to integers 23 | dataset_ecommerce_pandas["label"] = dataset_ecommerce_pandas["label"].map( 24 | label2id 25 | ) 26 | 27 | dataset_ecommerce_pandas["text"] = dataset_ecommerce_pandas["text"].astype(str) 28 | dataset = Dataset.from_pandas(dataset_ecommerce_pandas) 29 | dataset = dataset.shuffle(seed=42) 30 | dataset = dataset.train_test_split(test_size=0.2) 31 | 32 | return dataset 33 | 34 | 35 | if __name__ == "__main__": 36 | print(load_dataset()) 37 | -------------------------------------------------------------------------------- /data/evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VanekPetr/flan-t5-text-classifier/4a7d73fd5f9921794e726a197eaaaf322436a64c/data/evaluation.png -------------------------------------------------------------------------------- /data/evaluation_classification_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VanekPetr/flan-t5-text-classifier/4a7d73fd5f9921794e726a197eaaaf322436a64c/data/evaluation_classification_model.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "flan-t5-text-classifier" 3 | version = "0.0.0" 4 | description = "Flan t5 model for a text classification." 5 | authors = ["Petr Vanek"] 6 | readme = "README.md" 7 | repository = "https://github.com/VanekPetr/flan-t5-text-classifier" 8 | 9 | [tool.poetry.dependencies] 10 | python = ">=3.9,<3.13" 11 | pandas = "2.2.2" 12 | pre-commit = "3.8.0" 13 | loguru = "0.7.2" 14 | datasets = "2.20.0" 15 | transformers = "4.43.3" 16 | scikit-learn = "1.5.1" 17 | tqdm = "4.66.4" 18 | torch = "2.4.0" 19 | evaluate = "0.4.2" 20 | nltk = "3.8.1" 21 | accelerate = "0.33.0" 22 | tensorboardX = "2.6.2.2" 23 | 24 | [tool.poetry.group.test.dependencies] 25 | pytest = "*" 26 | pytest-cov = "*" 27 | pre-commit = "*" 28 | 29 | [build-system] 30 | requires = ["poetry>=1.6.0"] 31 | build-backend = "poetry.core.masonry.api" 32 | 33 | [tool.ruff] 34 | lint.select = ["E", "F", "I"] 35 | line-length = 120 36 | target-version = "py310" 37 | exclude = [ 38 | "*__init__.py" 39 | ] 40 | --------------------------------------------------------------------------------