├── .dockerignore ├── .gitignore ├── .idea ├── .gitignore ├── LLM_Data_Annotation.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml └── vcs.xml ├── Architecture.png ├── Dockerfile ├── LICENSE ├── README.md ├── annotation ├── annotate_davinci.py ├── annotate_gpt35.py ├── cleanlab_label_issues.py └── data_versioning.py ├── app.py ├── data ├── original │ └── train.csv └── unannotated │ ├── unannotated_200.csv │ ├── unannotated_50.csv │ ├── unannotated_70.csv │ └── unannotated_sentiment_dataset.csv ├── frontend.py ├── frontend_resources ├── cleanlab_processing_style.html └── train_model_style.html ├── main.py ├── models └── train_bert.py ├── openai ├── key_mock.txt └── organization_mock.txt └── requirements.txt /.dockerignore: -------------------------------------------------------------------------------- 1 | # Virtual environment 2 | venv/ 3 | 4 | # Machine Learning artifacts and logs 5 | mlruns/ 6 | 7 | # Editor and IDE directories 8 | .idea/ 9 | 10 | # Python generated cache 11 | __pycache__/ 12 | 13 | # Temporary or intermediary data files 14 | data/annotated/*.csv 15 | data/filtered/*.csv 16 | data/merged/*.csv 17 | data/trainsets/*.csv 18 | data/testsets/*.csv 19 | data/cleaned/*.csv 20 | 21 | # Model weights 22 | models/**/*.pt 23 | 24 | # Additional files to exclude 25 | *.log 26 | *.pyc 27 | *.pyo 28 | *.egg-info 29 | *.dist-info 30 | .DS_Store 31 | Thumbs.db 32 | *.git/ 33 | *.gitignore -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | key.txt 2 | organization.txt 3 | venv/ 4 | mlruns/ 5 | .idea/ 6 | __pycache__/ 7 | data/annotated/*.csv 8 | data/filtered/*.csv 9 | data/merged/*.csv 10 | data/trainsets/*.csv 11 | data/testsets/*.csv 12 | data/cleaned/*.csv 13 | models/**/*.pt 14 | .ssh/ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | # Zeppelin ignored files 10 | /ZeppelinRemoteNotebooks/ 11 | -------------------------------------------------------------------------------- /.idea/LLM_Data_Annotation.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 17 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saran9991/llm-data-annotation/8b7b2539c213bb27cc2acd140feabfb5fb16cd90/Architecture.png -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as the parent image 2 | FROM python:3.9-slim 3 | 4 | # Set environment variables 5 | ENV PYTHONDONTWRITEBYTECODE 1 6 | ENV PYTHONUNBUFFERED 1 7 | 8 | # Set the working directory inside the container 9 | WORKDIR /app 10 | 11 | # Install system dependencies 12 | RUN apt-get update \ 13 | && apt-get install -y --no-install-recommends gcc libpq-dev \ 14 | && apt-get clean \ 15 | && rm -rf /var/lib/apt/lists/* 16 | 17 | # Install Python dependencies 18 | COPY requirements.txt /app/ 19 | RUN pip install --upgrade pip \ 20 | && pip install -r requirements.txt 21 | 22 | # Copy the current directory contents into the container 23 | COPY . /app/ 24 | 25 | # Run uvicorn for FastAPI when the container launches 26 | CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sarandeep Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Large Language Models for Efficient Data Annotation and Model Fine-Tuning with Iterative Active Learning 2 | 3 | 4 | This framework combines human expertise with the efficiency of Large Language Models (LLMs) like OpenAI's GPT-3.5 to simplify dataset annotation and model improvement. The iterative approach ensures the continuous improvement of data quality, and consequently, the performance of models fine-tuned using this data. This not only saves time but also enables the creation of customized LLMs that leverage both human annotators and LLM-based precision. 5 |

Architecture

6 |

7 | Architecture 8 |

9 | 10 | ## Features 11 | 12 | 1. **Dataset Uploading and Annotation** 13 | - Upload CSV datasets. 14 | - Leverage GPT-3.5 to automatically annotate datasets. 15 | - Preview the annotations, highlighting low-confidence score rows. 16 | 17 | 2. **Manual Annotation Corrections** 18 | - Display the annotated dataset for user-based corrections. 19 | - User can update labels for specific rows. 20 | 21 | 3. **CleanLab: Confident Learning Approach** 22 | - Utilizes confident learning to identify and rectify label issues. 23 | - Automatically displays rows with potential label errors for user-based corrections. 24 | 25 | 4. **Data Versioning and Saving** 26 | - Merge user corrections with the annotated dataset. 27 | - Advanced data versioning ensures unique dataset versions are saved for every update. 28 | 29 | 5. **Model Training** 30 | - Train a BERT model on the cleaned dataset. 31 | - Track and reproduce model versions seamlessly using [MLflow](https://mlflow.org/). 32 | 33 | ## Setup 34 | 35 | ### Prerequisites 36 | 37 | 1. Install the required packages: 38 | ```bash 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | ### Running the Tool 43 | 44 | 1. **Start the FastAPI backend**: 45 | ```bash 46 | uvicorn app:app --reload 47 | ``` 48 | 49 | 2. **Run the Streamlit app**: 50 | ```bash 51 | streamlit run frontend.py 52 | ``` 53 | 54 | 3. **Launch MLflow UI**: 55 | To view models, metrics, and registered models, you can access the MLflow UI with the following command: 56 | ```bash 57 | mlflow ui 58 | ``` 59 | 60 | 4. **Access the provided links in your web browser**: 61 | - For the main application, access the Streamlit link. 62 | - For MLflow's tracking interface, by default, you can navigate to `http://127.0.0.1:5000`. 63 | 64 | 5. **Follow the on-screen prompts** to upload, annotate, correct, and train on your dataset. 65 | 66 | ## Why Confident Learning? 67 | 68 | [Confident learning](https://arxiv.org/abs/1911.00068) has emerged as a groundbreaking technique in supervised learning and weak-supervision. It aims at characterizing label noise, finding label errors, and learning efficiently with noisy labels. By pruning noisy data and ranking examples to train with confidence, this method ensures a clean and reliable dataset, enhancing the overall model performance. 69 | 70 | ## License 71 | 72 | This project is open-sourced under the [MIT License](LICENSE). 73 | 74 | --- 75 | 76 | -------------------------------------------------------------------------------- /annotation/annotate_davinci.py: -------------------------------------------------------------------------------- 1 | import openai 2 | 3 | with open('openai/organization.txt', 'r') as file: 4 | openai.organization = file.read().strip() 5 | 6 | with open('openai/key.txt', 'r') as file: 7 | openai.api_key = file.read().strip() 8 | 9 | accumulated_tokens = 0 10 | accumulated_cost = 0 11 | cost_per_token = 0.0035 / 1000 12 | 13 | def analyze_davinci(text): 14 | global accumulated_tokens 15 | global accumulated_cost 16 | 17 | prompt = f"Sentiment analysis for the following text in a single word: positive, neutral, negative: \"{text}\"" 18 | 19 | response = openai.Completion.create( 20 | engine="text-davinci-003", 21 | prompt=prompt, 22 | max_tokens=10, 23 | temperature=0 24 | ) 25 | 26 | total_tokens_used = response['usage']['total_tokens'] 27 | print(f"Total tokens used for this call: {total_tokens_used}") 28 | 29 | call_cost = total_tokens_used * cost_per_token 30 | accumulated_cost += call_cost 31 | accumulated_tokens += total_tokens_used 32 | print(f"Cost for this call: {call_cost}") 33 | print(f"Accumulated tokens so far: {accumulated_tokens}") 34 | print(f"Accumulated cost so far: {accumulated_cost}\n") 35 | 36 | response_text = response.choices[0].text.strip().lower() 37 | 38 | return response_text 39 | -------------------------------------------------------------------------------- /annotation/annotate_gpt35.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from tenacity import retry, stop_after_attempt, wait_random_exponential 3 | from typing import Tuple, List 4 | 5 | def read_api_information(file_path: str) -> str: 6 | """ 7 | Read and return API information from a given file. 8 | 9 | Parameters: 10 | - file_path (str): Path to the file containing API information. 11 | 12 | Returns: 13 | - str: API information read from the file. 14 | """ 15 | with open(file_path, 'r') as file: 16 | return file.read().strip() 17 | 18 | openai.organization = read_api_information('openai/organization.txt') 19 | openai.api_key = read_api_information('openai/key.txt') 20 | 21 | accumulated_tokens = 0 22 | accumulated_cost = 0 23 | cost_per_token = 0.0035 / 1000 24 | index = 0 25 | 26 | def calculate_cost(total_tokens_used: int) -> Tuple[float, List[str]]: 27 | """ 28 | Calculate the cost and return logs. 29 | 30 | Parameters: 31 | - total_tokens_used (int): Total tokens used for an API call. 32 | 33 | Returns: 34 | - Tuple[float, List[str]]: A tuple containing the call cost and a list of logs. 35 | """ 36 | global accumulated_tokens, accumulated_cost, index 37 | 38 | call_cost = total_tokens_used * cost_per_token 39 | accumulated_cost += call_cost 40 | accumulated_tokens += total_tokens_used 41 | index += 1 42 | 43 | logs = [ 44 | f"Total tokens used for this call: {total_tokens_used}", 45 | f"Index: {index}", 46 | f"Cost for this call: {call_cost}", 47 | f"Accumulated tokens so far: {accumulated_tokens}", 48 | f"Accumulated cost so far: {accumulated_cost}\n" 49 | ] 50 | for log in logs: 51 | print(log) 52 | 53 | return call_cost, logs 54 | 55 | @retry(wait=wait_random_exponential(max=2), stop=stop_after_attempt(2)) 56 | def analyze_gpt35(text: str) -> Tuple[str, float, List[str]]: 57 | """ 58 | Analyze text and classify its sentiment using GPT-3.5. 59 | 60 | Parameters: 61 | - text (str): The text to be analyzed for sentiment. 62 | 63 | Returns: 64 | - Tuple[str, float, List[str]]: A tuple containing the primary sentiment classification, 65 | confidence score, and a list of logs. 66 | """ 67 | messages = [ 68 | {"role": "system", "content": "Your task is to analyze text and classify its sentiment as either 'positive', 'negative', or 'neutral' in a single word."}, 69 | {"role": "user", "content": f"Classify the sentiment of: '{text}'."} 70 | ] 71 | 72 | response = openai.ChatCompletion.create( 73 | model="gpt-3.5-turbo", 74 | messages=messages, 75 | max_tokens=3, 76 | n=3, 77 | temperature=0.5 78 | ) 79 | 80 | total_tokens_used = response['usage']['total_tokens'] 81 | _, logs = calculate_cost(total_tokens_used) 82 | 83 | response_texts = [choice.message.content.strip().lower() for choice in response.choices] 84 | primary_response = response_texts[0] 85 | confidence_score = response_texts.count(primary_response) / 3 86 | 87 | return primary_response, confidence_score, logs 88 | -------------------------------------------------------------------------------- /annotation/cleanlab_label_issues.py: -------------------------------------------------------------------------------- 1 | from cleanlab.classification import CleanLearning 2 | from sklearn.preprocessing import LabelEncoder 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | import copy 6 | from typing import Any, List, Union 7 | 8 | 9 | def find_label_issues(clf: Any, data_path: str) -> pd.DataFrame: 10 | """ 11 | Find label issues using CleanLearning with the given classifier and data. 12 | 13 | Parameters: 14 | - clf (Any): Classifier model. 15 | - data_path (str): Path to the data CSV file. 16 | 17 | Returns: 18 | - pd.DataFrame: DataFrame containing top 20 rows with label issues. 19 | """ 20 | data = pd.read_csv(data_path, encoding='unicode_escape') 21 | raw_texts, raw_labels = data["text"].values, data["predicted_labels"].values 22 | raw_train_texts, raw_test_texts, raw_train_labels, raw_test_labels = train_test_split(raw_texts, raw_labels, 23 | test_size=0.2) 24 | 25 | cv_n_folds = 3 26 | model_copy = copy.deepcopy(clf) 27 | cl = CleanLearning(model_copy, cv_n_folds=cv_n_folds) 28 | 29 | encoder = LabelEncoder() 30 | encoder.fit(raw_train_labels) 31 | train_labels = encoder.transform(raw_train_labels) 32 | test_labels = encoder.transform(raw_test_labels) 33 | 34 | label_issues = cl.find_label_issues(X=raw_train_texts, labels=train_labels) 35 | lowest_quality_labels = label_issues["label_quality"].argsort().to_numpy() 36 | 37 | top_20_error_rows = get_dataframe_by_index(lowest_quality_labels[:20], raw_train_texts, raw_train_labels, encoder, 38 | label_issues) 39 | return top_20_error_rows 40 | 41 | 42 | def get_dataframe_by_index(index: List[int], raw_train_texts: List[str], raw_train_labels: List[str], 43 | encoder: LabelEncoder, label_issues: Union[dict, pd.DataFrame]) -> pd.DataFrame: 44 | """ 45 | Create a DataFrame containing selected rows based on the given index. 46 | 47 | Parameters: 48 | - index (List[int]): List of indices to select rows. 49 | - raw_train_texts (List[str]): List of training texts. 50 | - raw_train_labels (List[str]): List of training labels. 51 | - encoder (LabelEncoder): Label encoder. 52 | - label_issues (Union[dict, pd.DataFrame]): Label issues information. 53 | 54 | Returns: 55 | - pd.DataFrame: DataFrame containing selected rows. 56 | """ 57 | df = pd.DataFrame( 58 | { 59 | "text": raw_train_texts, 60 | "given_label": raw_train_labels, 61 | "predicted_label": encoder.inverse_transform(label_issues["predicted_label"]), 62 | "quality": label_issues["label_quality"] 63 | } 64 | ) 65 | 66 | return df.iloc[index] 67 | -------------------------------------------------------------------------------- /annotation/data_versioning.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_next_version(directory: str, prefix: str = 'annotated_') -> int: 4 | """ 5 | Get the next available version number for a file with a given prefix in a directory. 6 | 7 | Parameters: 8 | - directory (str): Directory path where the files are located. 9 | - prefix (str): Prefix for the filename (default is 'annotated_'). 10 | 11 | Returns: 12 | - int: The next available version number. 13 | """ 14 | existing_files = os.listdir(directory) 15 | versions = [int(f.split('_')[-1].split('.')[0]) for f in existing_files if f.startswith(prefix)] 16 | if versions: 17 | return max(versions) + 1 18 | else: 19 | return 1 -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, UploadFile, File 2 | import pandas as pd 3 | from annotation.annotate_gpt35 import analyze_gpt35 4 | from annotation.data_versioning import get_next_version 5 | from pathlib import Path 6 | from typing import List, Dict, Union, Tuple 7 | 8 | app = FastAPI() 9 | 10 | 11 | def annotate_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str]]: 12 | """ 13 | Annotate the dataframe along with confidence scores. 14 | 15 | Parameters: 16 | - df (pd.DataFrame): Input DataFrame containing text data. 17 | 18 | Returns: 19 | - Tuple[pd.DataFrame, List[str]]: A tuple containing the annotated DataFrame and a list of logs. 20 | """ 21 | sentiments, confidence_scores, all_logs = [], [], [] 22 | 23 | for _, row in df.iterrows(): 24 | text = row['text'] 25 | sentiment, confidence, logs = analyze_gpt35(text) 26 | sentiments.append(sentiment) 27 | confidence_scores.append(confidence) 28 | all_logs.extend(logs) 29 | 30 | df['predicted_labels'] = sentiments 31 | df['confidence_scores'] = confidence_scores 32 | return df, all_logs 33 | 34 | 35 | def save_dataframe(path: str, prefix: str, df: pd.DataFrame) -> str: 36 | """ 37 | Save dataframe to a versioned CSV file and return the file path. 38 | 39 | Parameters: 40 | - path (str): Directory path where the CSV file will be saved. 41 | - prefix (str): Prefix for the CSV file name. 42 | - df (pd.DataFrame): DataFrame to be saved. 43 | 44 | Returns: 45 | - str: File path of the saved CSV file. 46 | """ 47 | version = get_next_version(path, prefix) 48 | file_path = Path(path) / f"{prefix}{version}.csv" 49 | df.to_csv(file_path, index=False) 50 | return str(file_path) 51 | 52 | 53 | @app.post("/annotate_dataset/") 54 | async def annotate_dataset(file: UploadFile = File(...)) -> Dict[str, Union[str, List[str]]]: 55 | """ 56 | Annotate a dataset with text data and return results. 57 | 58 | Parameters: 59 | - file (UploadFile): Uploaded CSV file containing text data. 60 | 61 | Returns: 62 | - Dict[str, Union[str, List[str]]]: A dictionary containing the status, file paths, and logs. 63 | """ 64 | df = pd.read_csv(file.file) 65 | 66 | # Annotate the dataframe 67 | df, all_logs = annotate_dataframe(df) 68 | 69 | # Save the annotated dataframe 70 | annotated_path = save_dataframe("data/annotated", 'annotated_', df) 71 | 72 | # Filter the dataset where confidence_scores < 1 and save 73 | filtered_dataset = df[df['confidence_scores'] < 1] 74 | filtered_path = save_dataframe("data/filtered", 'filtered_', filtered_dataset) 75 | 76 | return { 77 | "status": "success", 78 | "path": annotated_path, 79 | "filtered_path": filtered_path, 80 | "logs": all_logs 81 | } 82 | -------------------------------------------------------------------------------- /data/original/train.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saran9991/llm-data-annotation/8b7b2539c213bb27cc2acd140feabfb5fb16cd90/data/original/train.csv -------------------------------------------------------------------------------- /data/unannotated/unannotated_200.csv: -------------------------------------------------------------------------------- 1 | ,text 2 | 0," I`d have responded, if I were going" 3 | 1, Sooo SAD I will miss you here in San Diego!!! 4 | 2,my boss is bullying me... 5 | 3, what interview! leave me alone 6 | 4," Sons of ****, why couldn`t they put them on the releases we already bought" 7 | 5,http://www.dothebouncy.com/smf - some shameless plugging for the best Rangers forum on earth 8 | 6,2am feedings for the baby are fun when he is all smiles and coos 9 | 7,Soooo high 10 | 8, Both of you 11 | 9, Journey!? Wow... u just became cooler. hehe... (is that possible!?) 12 | 10," as much as i love to be hopeful, i reckon the chances are minimal =P i`m never gonna get my cake and stuff" 13 | 11,I really really like the song Love Story by Taylor Swift 14 | 12,My Sharpie is running DANGERously low on ink 15 | 13,i want to go to music tonight but i lost my voice. 16 | 14,test test from the LG enV2 17 | 15,"Uh oh, I am sunburned" 18 | 16," S`ok, trying to plot alternatives as we speak *sigh*" 19 | 17,"i`ve been sick for the past few days and thus, my hair looks wierd. if i didnt have a hat on it would look... http://tinyurl.com/mnf4kw" 20 | 18,is back home now gonna miss every one 21 | 19,Hes just not that into you 22 | 20," oh Marly, I`m so sorry!! I hope you find her soon!! <3 <3" 23 | 21,Playing Ghost Online is really interesting. The new updates are Kirin pet and Metamorph for third job. Can`t wait to have a dragon pet 24 | 22,is cleaning the house for her family who is comming later today.. 25 | 23,gotta restart my computer .. I thought Win7 was supposed to put an end to the constant rebootiness 26 | 24,SEe waT I Mean bOuT FoLL0w fRiiDaYs... It`S cALLed LoSe f0LloWeRs FridAy... smH 27 | 25,"the free fillin` app on my ipod is fun, im addicted" 28 | 26, I`m sorry. 29 | 27,On the way to Malaysia...no internet access to Twit 30 | 28,juss came backk from Berkeleyy ; omg its madd fun out there havent been out there in a minute . whassqoodd ? 31 | 29,Went to sleep and there is a power cut in Noida Power back up not working too 32 | 30,I`m going home now. Have you seen my new twitter design? Quite....heavenly isn`****? 33 | 31,i hope unni will make the audition . fighting dahye unni ! 34 | 32, If it is any consolation I got my BMI tested hahaha it says I am obesed well so much for being unhappy for about 10 minutes. 35 | 33, That`s very funny. Cute kids. 36 | 34," Ahhh, I slept through the game. I`m gonna try my best to watch tomorrow though. I hope we play Army." 37 | 35,"Thats it, its the end. Tears for Fears vs Eric Prydz, DJ Hero http://bit.ly/2Hpbg4" 38 | 36,Born and raised in NYC and living in Texas for the past 10 years! I still miss NY 39 | 37,"just in case you wonder, we are really busy today and this coming with with adding tons of new blogs and updates stay tuned" 40 | 38,i`m soooooo sleeeeepy!!! the last day o` school was today....sniffle.... 41 | 39,"A little happy for the wine jeje ok it`sm my free time so who cares, jaja i love this day" 42 | 40," Car not happy, big big dent in boot! Hoping theyre not going to write it off, crossing fingers and waiting" 43 | 41,im an avid fan of **** magazine and i love your magazines 44 | 42,MAYDAY?! 45 | 43,"RATT ROCKED NASHVILLE TONITE..ONE THING SUCKED, NO ENCORE! LIKE IN THE 80`S THEY STILL HAVE A FUN SHOW. PEARCY HAS THAT HOTT BAD BOY LOOK" 46 | 44, I love to! But I`m only available from 5pm. and where dear? Would love to help convert her vids.� 47 | 45,The girl in the hair salon asked me 'Shall I trim your eyebrows!' How old do I feel? 48 | 46,egh blah and boooooooooooo i dunno wanna go to work HANGOVERS SUCKKKKKK Im a drunk mess! 49 | 47,:visiting my friendster and facebook 50 | 48,"i donbt like to peel prawns, i also dont like going shopping, running out of money and crawling round the car looking for more" 51 | 49, which case? I got a new one last week and I`m not thrilled at all with mine. 52 | 50, Then you should check out http://twittersucks.com and connect with other tweeple who hate twitter 53 | 51," also bored at school, its my third freelesson( freistunde )" 54 | 52, hm... Both of us I guess... 55 | 53, it is ****...u have dissappointed me that past few days 56 | 54, romance zero is funny 57 | 55, I`d rather do the early run..but I am a morning runner 58 | 56,Bah a coworker ran into work late and her bag smacked into my knee it really hurts now 59 | 57,will be back later. http://plurk.com/p/rp3k7 60 | 58,Aw. Torn ace of hearts #Hunchback 61 | 59, what fun are you speaking of? 62 | 60,"i lost all my friends, i`m alone and sleepy..i wanna go home" 63 | 61, haha yes 64 | 62,I give in to easily 65 | 63,What better way to spoil mum than to let her kick back and relax over a nice meal and a bottle of her favorite wine? Our wine was a red 66 | 64,mannnn..... _ got an iphone!!! im jealous.... http://bit.ly/NgnaR 67 | 65,Is at a photoshoot. 68 | 66, He`s awesome... Have you worked with him before? He`s a good friend. 69 | 67,Yay playing a show tonight! Boo it`s gonna soggy and I`m at work right before playing 70 | 68,Chilliin 71 | 69," If you know such agent, do let me know" 72 | 70,I still smell of smoke #kitchenfire 73 | 71, a celtics-lakers rematch sounds better don`t you think? lol 74 | 72,"Anyone have an extra Keane ticket? I promise to buy you a drink and take rad pics for your FB / Blog / Flickr., etc" 75 | 73,"'you can ride one, you can catch one, but its not summer til you pop open one' ?" 76 | 74, she is good! so gor-juz yea i kno i asked her yesterday when we were at tha hospital if she talked to u and she said no 77 | 75,OK - I`m out of here for now. Just popped in to say Hi and check on things. I`ll probably head to the guttah later on tonight 78 | 76,"WOW, i AM REALLY MiSSiN THE FAM(iLY) TODAY. BADDD." 79 | 77, My sources say no 80 | 78,I am sooo tired 81 | 79," Hey, you change your twitter account, and you didn`t even tell me..." 82 | 80, THANK YYYYYYYYYOOOOOOOOOOUUUUU! 83 | 81, lucky kid...i so wanna see loserville pity im in oz.... 84 | 82,fell asleep waiting for my ride! 85 | 83,Sick. With a flu like thing. 86 | 84,"Still no reply from about my SimFinger problem So no iRape parody video until I get a response, sorry guys" 87 | 85,Happy Star Wars day everyone! and Enjoy the holiday (UK) 88 | 86, Miles from you I`m in Essex so give me plenty of warning so I can arrive in time to get at least one of those free beers. 89 | 87,"His snoring is so annoying n it keeps me from sleeping (like right now, lol) but I honestly wud miss it if it eva left I love him." 90 | 88,i miss you bby wish you were going tomorrow to make me do good. 91 | 89, Well what im working on isn`t QUITE ready to post about publicly (still beta testing) but its a cool new script I coded 92 | 90, SWEEEEET - San Fran is awesome!!!! Love it there 93 | 91,_Mounce yes and it lasts way past my bedtime! 94 | 92, Hi how are you doing ??? *just joined twitter...* 95 | 93,waiting for sleeping pills to kick in... gonna be so tired at work tomorrow 96 | 94,eating ice cream and then getting ready for graduation. 97 | 95,Happy Mothers day to all you Mums out there 98 | 96," CASEY`S GONE?!?! BUT WHY?! So, she piddled a little on the carpet. She`s prolly freaked cause it`s new. Can we get her back?" 99 | 97, hemp cloth is marvelous but unfortunately no 100 | 98,Gonna read a story bout adam lambert online then bed. Nighty night 101 | 99, We saw that in none 3D - the baddie`s the best 102 | 100,4am. And Im on the beach. Pretty 103 | 101," Certainly not Cheers than, huh?" 104 | 102,"1 week post my'horrible, traumatic jumping cholla accident.'-cholla`s next dirty trick:pieces are starting to emerge from my hand! Ouch!" 105 | 103,i realy wanted to go out cause its so nice but everybodys busy 106 | 104, Awesome. I`m down in Ocean Beach (if you know where that is.) By the way. 'YourBiggestFan' I`m a re-al big fan of you-rs. 107 | 105,at least I get to watch over time Let`s go Pens!! 108 | 106, cool i wear black most of the time when i go out 109 | 107, haha I do not know how to work blip apart from the obvious! thanks for reblipping my song have a nice day **** 110 | 108, have a safe trip joshy poo.......you`ll knock them dead at your speech 111 | 109," woof, I wish I was allowed to go" 112 | 110,if u have a friendster add me!!!!!!!!! my email adress add me loco_crime_1st.com add me leave some comment 113 | 111, has tickets.......? 114 | 112,"Thank you, Afrin Nasal Spray! Also, I got a giant teacup tonight!" 115 | 113, ACSM. it`s unfathomable. i think the other one .. and the .. is one that should be kept to the comfort of our bedrooms. yes? 116 | 114,"Aww, I love my daddy! He works 7 days a week almost all day and still tries to go to SF with all of us" 117 | 115,So many tests todayyy I don`t feel confident about anyy. 118 | 116,graduation is done im a little sad.. anyone want to hang out??? 119 | 117, hahaa your awesomee ! 120 | 118,holy smokes! star trek was freaking awesomeeeee 121 | 119,"I hate Fallout 3 it keeps making me jump, I`m also low on health, money, ammo and food don`t worry I`ll get through it." 122 | 120," I had it! On my itunes, but then I lost all my songs." 123 | 121,What`s with the gloomy weather? The sun must be too tired to come out and play heading to victoria gardens for some impulse buys haha 124 | 122,"Not looking forward to next week: Maths, Geography, English and French exams, totalling 7 hours" 125 | 123," Poor you Get outside and sleep in the garden, the sun will do you good. But don`t forget suncream!!!" 126 | 124,not well 127 | 125, Not a prob hun 128 | 126,"at dads, watching some mtv and am going on sims2 in a minutee" 129 | 127, Absolutely 130 | 128, what`s the matter chickadee? 131 | 129, hey mia! totally adore your music. when will your cd be out? 132 | 130,"Shopping. Cleaning. BMFing. Webcam chatting with nephews. Nothing spesh, but a good bank holiday Monday nonetheless" 133 | 131, =O you need to ask him something? Lmao I love him too 134 | 132, those splinters look very painful...but you were being very heroic saving mr. Pickle 135 | 133, why are you sad? 136 | 134, Nice to see you tweeting! It`s Sunday 10th May and we`re celebrating Mother`s Day here today. So be nice to yer Mom 137 | 135,decided 2 trans frm relaxed 2 natural hair but i wish my whole head looked like my roots. Age of the instant gratification.... 138 | 136, Namaskar & Namaste r both the same. Marathi people say Namaskar! its a marathi word.... should i ? ...naaaah ! 139 | 137," Congrats! I cuss like that in a matter of minutes, But didn`t know until now there is a reward for it." 140 | 138,Humous and Dorito`s.... Oh yes 141 | 139,"missed all the awesome weather, because she was in a movie!" 142 | 140,Today is going to be a normal day for I hope. We had a group of pilots from a large airline come in last night so it was too much drink 143 | 141,"These kids are terrible! If I was in Good Evans, I`d call Childline" 144 | 142,"Unfortunatley, AerLingus no longer fly to Copenhagen so we`re have to fly Ryanair to Billund and drive up to Copenhagen one of the days!" 145 | 143," What`s sad is that I actually had to google that term. That sucks, tho." 146 | 144,Hate fighting 147 | 145," I watched that too!!! I didnt want her to win, but she put up a good fight..lol" 148 | 146,Car-warmed Sprite tastes like sore throat 149 | 147,Just came 11th in cross country and beat dumbo 150 | 148,Candle wax is very enjoyable. 151 | 149," She`s unassuming and unpretentious. She`s just, as. I suppose that`s why she`s so endearing--because we can relate to her" 152 | 150,tomorrow valeria`s lunch!!! going to get my hair done but im arraving late got my cousins babtizm or whatever you spell it 153 | 151,goooooddd morning tweets!! week three OF my workout. did i mention i got my new glasses yesterday?!? 154 | 152, me too. I hate my computer so much.. 155 | 153, fine! Going to do my big walk today 20 or so miles 156 | 154,I WANT RED CRUISERS!! i don`t like the other ones. LMFAO! 157 | 155,Mmmmmmmm... ? it in the morning 158 | 156, me neither 159 | 157,"Has about 10 hours work to do, on a Sunday. Boo. I will find time for a two hour lunchbreak though. Yeah" 160 | 158,Bugger. forgot I still have washing in my machine 161 | 159,"_Laurie sending love, blessings & healing thoughts to you & family peace" 162 | 160,My back hurts...really bad 163 | 161," ah yes, I know that feeling" 164 | 162,Night of the cookers with my dad 165 | 163,My modem has been offline for a week now... God bless the 3g network. Tim just left... Again!! May schedule has been brutal 166 | 164, Nope I am in Coquitlam 167 | 165, Had parent teacher thing yesterday!! So boring going to skl on saturday!! lol 168 | 166, #lichfield #tweetup sounds like fun Hope to see you and everyone else there! 169 | 167,Big booming thunder storm almost here. Maybe we can all go home early??? Ah... probably not. 170 | 168,Few Bevvies 2day in twn..great on a day off!! 171 | 169,first night in myers. just not the same w/out lydia! but i`m actually excited about this summer! 172 | 170, good morning 173 | 171, its the best show EVER! 174 | 172,URL in previous post (to timer job) should be http://bit.ly/a4Fdb. I`d removed space which messed up URL. ^ES 175 | 173,i think iv hurt my tooth and eilish and cassie are having a drawing competiton to draw cookies and pineapples haha :L . 176 | 174, I want to know when the auditions are Mander! Text or...reply please! 177 | 175,or even NOOOOO NOT THE SECRET NAMEREBECCA PLEASE 178 | 176, I miss my neice can`t wait to see her bad n grown ****! Lol 179 | 177,i need to get my computer fixed 180 | 178,really hopes her car`s illness is not terminal... 181 | 179,"All the cool people I want to find for following today are #English, and I guess the English don`t tweet." 182 | 180, no sir...i woulda put honey...but i don`t have any 183 | 181,who watched X-men origins: wolverine? i totally loved it! haha 184 | 182, I VOTED!!! do u have a personal myspace? i keep talking to fakes i <3 you. u helped me thru the hrdest time of my life! (: x 185 | 183, I`m sad that I missed you guys last night! 186 | 184,Finally got a call for marriage counseling 3 days late.... 187 | 185, ok then 188 | 186,_420 why baby? 189 | 187,today was the last day of high school for me and i ended up going home sick! ... stupid dead rats 190 | 188,We`re having an impromptu pool party... Except I don`t know how to swim so I can`t get in 191 | 189,lost my tooth 2day whilst i was eating gum...oww 192 | 190,happy 1 year! <3 193 | 191,"Oh, I HELLA forgot to say my official good morning Like to hear it? Here it go! Goooooooooooood Morrrrrrrrning Twitterville! Lol" 194 | 192, *phew* Will make a note in case anyone else runs into the same issue� 195 | 193, WHAT ABOUT ME ?? I VOTE EVERY DAY FOR YOU !!!!! 196 | 194,I`m starving!! This diet is killing me but I can`t eat after 8pm 197 | 195, i talk to you 198 | 196,im soo bored...im deffo missing my music channels 199 | 197, nite nite bday girl have fun at concert 200 | 198,"Had nicotine replacement patch on for 4 hours. So far, so good, but I did sleep for most of those 4 hours. Getting a bit twitchy now" 201 | 199,_Sanderson What`s with Twatter lately? Either I can`t get on or the replies don`t turn up! 202 | -------------------------------------------------------------------------------- /data/unannotated/unannotated_50.csv: -------------------------------------------------------------------------------- 1 | ,text 2 | 0," I`d have responded, if I were going" 3 | 1, Sooo SAD I will miss you here in San Diego!!! 4 | 2,my boss is bullying me... 5 | 3, what interview! leave me alone 6 | 4," Sons of ****, why couldn`t they put them on the releases we already bought" 7 | 5,http://www.dothebouncy.com/smf - some shameless plugging for the best Rangers forum on earth 8 | 6,2am feedings for the baby are fun when he is all smiles and coos 9 | 7,Soooo high 10 | 8, Both of you 11 | 9, Journey!? Wow... u just became cooler. hehe... (is that possible!?) 12 | 10," as much as i love to be hopeful, i reckon the chances are minimal =P i`m never gonna get my cake and stuff" 13 | 11,I really really like the song Love Story by Taylor Swift 14 | 12,My Sharpie is running DANGERously low on ink 15 | 13,i want to go to music tonight but i lost my voice. 16 | 14,test test from the LG enV2 17 | 15,"Uh oh, I am sunburned" 18 | 16," S`ok, trying to plot alternatives as we speak *sigh*" 19 | 17,"i`ve been sick for the past few days and thus, my hair looks wierd. if i didnt have a hat on it would look... http://tinyurl.com/mnf4kw" 20 | 18,is back home now gonna miss every one 21 | 19,Hes just not that into you 22 | 20," oh Marly, I`m so sorry!! I hope you find her soon!! <3 <3" 23 | 21,Playing Ghost Online is really interesting. The new updates are Kirin pet and Metamorph for third job. Can`t wait to have a dragon pet 24 | 22,is cleaning the house for her family who is comming later today.. 25 | 23,gotta restart my computer .. I thought Win7 was supposed to put an end to the constant rebootiness 26 | 24,SEe waT I Mean bOuT FoLL0w fRiiDaYs... It`S cALLed LoSe f0LloWeRs FridAy... smH 27 | 25,"the free fillin` app on my ipod is fun, im addicted" 28 | 26, I`m sorry. 29 | 27,On the way to Malaysia...no internet access to Twit 30 | 28,juss came backk from Berkeleyy ; omg its madd fun out there havent been out there in a minute . whassqoodd ? 31 | 29,Went to sleep and there is a power cut in Noida Power back up not working too 32 | 30,I`m going home now. Have you seen my new twitter design? Quite....heavenly isn`****? 33 | 31,i hope unni will make the audition . fighting dahye unni ! 34 | 32, If it is any consolation I got my BMI tested hahaha it says I am obesed well so much for being unhappy for about 10 minutes. 35 | 33, That`s very funny. Cute kids. 36 | 34," Ahhh, I slept through the game. I`m gonna try my best to watch tomorrow though. I hope we play Army." 37 | 35,"Thats it, its the end. Tears for Fears vs Eric Prydz, DJ Hero http://bit.ly/2Hpbg4" 38 | 36,Born and raised in NYC and living in Texas for the past 10 years! I still miss NY 39 | 37,"just in case you wonder, we are really busy today and this coming with with adding tons of new blogs and updates stay tuned" 40 | 38,i`m soooooo sleeeeepy!!! the last day o` school was today....sniffle.... 41 | 39,"A little happy for the wine jeje ok it`sm my free time so who cares, jaja i love this day" 42 | 40," Car not happy, big big dent in boot! Hoping theyre not going to write it off, crossing fingers and waiting" 43 | 41,im an avid fan of **** magazine and i love your magazines 44 | 42,MAYDAY?! 45 | 43,"RATT ROCKED NASHVILLE TONITE..ONE THING SUCKED, NO ENCORE! LIKE IN THE 80`S THEY STILL HAVE A FUN SHOW. PEARCY HAS THAT HOTT BAD BOY LOOK" 46 | 44, I love to! But I`m only available from 5pm. and where dear? Would love to help convert her vids.� 47 | 45,The girl in the hair salon asked me 'Shall I trim your eyebrows!' How old do I feel? 48 | 46,egh blah and boooooooooooo i dunno wanna go to work HANGOVERS SUCKKKKKK Im a drunk mess! 49 | 47,:visiting my friendster and facebook 50 | 48,"i donbt like to peel prawns, i also dont like going shopping, running out of money and crawling round the car looking for more" 51 | 49, which case? I got a new one last week and I`m not thrilled at all with mine. 52 | -------------------------------------------------------------------------------- /data/unannotated/unannotated_70.csv: -------------------------------------------------------------------------------- 1 | ,text 2 | 0," I`d have responded, if I were going" 3 | 1, Sooo SAD I will miss you here in San Diego!!! 4 | 2,my boss is bullying me... 5 | 3, what interview! leave me alone 6 | 4," Sons of ****, why couldn`t they put them on the releases we already bought" 7 | 5,http://www.dothebouncy.com/smf - some shameless plugging for the best Rangers forum on earth 8 | 6,2am feedings for the baby are fun when he is all smiles and coos 9 | 7,Soooo high 10 | 8, Both of you 11 | 9, Journey!? Wow... u just became cooler. hehe... (is that possible!?) 12 | 10," as much as i love to be hopeful, i reckon the chances are minimal =P i`m never gonna get my cake and stuff" 13 | 11,I really really like the song Love Story by Taylor Swift 14 | 12,My Sharpie is running DANGERously low on ink 15 | 13,i want to go to music tonight but i lost my voice. 16 | 14,test test from the LG enV2 17 | 15,"Uh oh, I am sunburned" 18 | 16," S`ok, trying to plot alternatives as we speak *sigh*" 19 | 17,"i`ve been sick for the past few days and thus, my hair looks wierd. if i didnt have a hat on it would look... http://tinyurl.com/mnf4kw" 20 | 18,is back home now gonna miss every one 21 | 19,Hes just not that into you 22 | 20," oh Marly, I`m so sorry!! I hope you find her soon!! <3 <3" 23 | 21,Playing Ghost Online is really interesting. The new updates are Kirin pet and Metamorph for third job. Can`t wait to have a dragon pet 24 | 22,is cleaning the house for her family who is comming later today.. 25 | 23,gotta restart my computer .. I thought Win7 was supposed to put an end to the constant rebootiness 26 | 24,SEe waT I Mean bOuT FoLL0w fRiiDaYs... It`S cALLed LoSe f0LloWeRs FridAy... smH 27 | 25,"the free fillin` app on my ipod is fun, im addicted" 28 | 26, I`m sorry. 29 | 27,On the way to Malaysia...no internet access to Twit 30 | 28,juss came backk from Berkeleyy ; omg its madd fun out there havent been out there in a minute . whassqoodd ? 31 | 29,Went to sleep and there is a power cut in Noida Power back up not working too 32 | 30,I`m going home now. Have you seen my new twitter design? Quite....heavenly isn`****? 33 | 31,i hope unni will make the audition . fighting dahye unni ! 34 | 32, If it is any consolation I got my BMI tested hahaha it says I am obesed well so much for being unhappy for about 10 minutes. 35 | 33, That`s very funny. Cute kids. 36 | 34," Ahhh, I slept through the game. I`m gonna try my best to watch tomorrow though. I hope we play Army." 37 | 35,"Thats it, its the end. Tears for Fears vs Eric Prydz, DJ Hero http://bit.ly/2Hpbg4" 38 | 36,Born and raised in NYC and living in Texas for the past 10 years! I still miss NY 39 | 37,"just in case you wonder, we are really busy today and this coming with with adding tons of new blogs and updates stay tuned" 40 | 38,i`m soooooo sleeeeepy!!! the last day o` school was today....sniffle.... 41 | 39,"A little happy for the wine jeje ok it`sm my free time so who cares, jaja i love this day" 42 | 40," Car not happy, big big dent in boot! Hoping theyre not going to write it off, crossing fingers and waiting" 43 | 41,im an avid fan of **** magazine and i love your magazines 44 | 42,MAYDAY?! 45 | 43,"RATT ROCKED NASHVILLE TONITE..ONE THING SUCKED, NO ENCORE! LIKE IN THE 80`S THEY STILL HAVE A FUN SHOW. PEARCY HAS THAT HOTT BAD BOY LOOK" 46 | 44, I love to! But I`m only available from 5pm. and where dear? Would love to help convert her vids.� 47 | 45,The girl in the hair salon asked me 'Shall I trim your eyebrows!' How old do I feel? 48 | 46,egh blah and boooooooooooo i dunno wanna go to work HANGOVERS SUCKKKKKK Im a drunk mess! 49 | 47,:visiting my friendster and facebook 50 | 48,"i donbt like to peel prawns, i also dont like going shopping, running out of money and crawling round the car looking for more" 51 | 49, which case? I got a new one last week and I`m not thrilled at all with mine. 52 | 50, Then you should check out http://twittersucks.com and connect with other tweeple who hate twitter 53 | 51," also bored at school, its my third freelesson( freistunde )" 54 | 52, hm... Both of us I guess... 55 | 53, it is ****...u have dissappointed me that past few days 56 | 54, romance zero is funny 57 | 55, I`d rather do the early run..but I am a morning runner 58 | 56,Bah a coworker ran into work late and her bag smacked into my knee it really hurts now 59 | 57,will be back later. http://plurk.com/p/rp3k7 60 | 58,Aw. Torn ace of hearts #Hunchback 61 | 59, what fun are you speaking of? 62 | 60,"i lost all my friends, i`m alone and sleepy..i wanna go home" 63 | 61, haha yes 64 | 62,I give in to easily 65 | 63,What better way to spoil mum than to let her kick back and relax over a nice meal and a bottle of her favorite wine? Our wine was a red 66 | 64,mannnn..... _ got an iphone!!! im jealous.... http://bit.ly/NgnaR 67 | 65,Is at a photoshoot. 68 | 66, He`s awesome... Have you worked with him before? He`s a good friend. 69 | 67,Yay playing a show tonight! Boo it`s gonna soggy and I`m at work right before playing 70 | 68,Chilliin 71 | 69," If you know such agent, do let me know" 72 | -------------------------------------------------------------------------------- /frontend.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from pathlib import Path 3 | import streamlit as st 4 | import pandas as pd 5 | from models.train_bert import train_bert 6 | import torch 7 | from sklearn.model_selection import train_test_split 8 | from annotation.data_versioning import get_next_version 9 | from annotation.cleanlab_label_issues import find_label_issues 10 | 11 | # Setting Streamlit page config 12 | st.set_page_config(page_title="Data Annotation", page_icon="🚀", layout="wide") 13 | 14 | # Inline CSS 15 | st.markdown( 16 | """ 17 | 27 | """, 28 | unsafe_allow_html=True 29 | ) 30 | 31 | st.title("LLM Seminar Data Annotation ✏️") 32 | st.write("An interactive tool to annotate your dataset, preview annotations, and save changes.") 33 | 34 | # Session state initialization 35 | session_keys = ["iteration", "initial_training", "top20_status", "stop_iterations", "display_top_20", "next_iteration"] 36 | for key in session_keys: 37 | if key not in st.session_state: 38 | st.session_state[key] = 1 if key == "iteration" else False 39 | 40 | # Helper functions 41 | def cleanlab_style() -> str: 42 | """Loads cleanlab processing style from the frontend resources.""" 43 | with open('frontend_resources/cleanlab_processing_style.html', 'r') as file: 44 | return file.read() 45 | 46 | def train_model_style(epoch_value: int) -> str: 47 | """Returns the train model style string with formatted epoch input.""" 48 | with open("frontend_resources/train_model_style.html", "r") as f: 49 | content = f.read() 50 | return content.format(epoch_input=epoch_value) 51 | 52 | 53 | uploaded_file = st.file_uploader("Choose a dataset (CSV)", type="csv") 54 | 55 | if uploaded_file: 56 | # This button is responsible for annotating the dataset using GPT 3.5 57 | if st.button("Annotate"): 58 | with st.spinner('Annotating rows using GPT 3.5...'): 59 | files = {'file': uploaded_file.getvalue()} 60 | response = requests.post("http://127.0.0.1:8000/annotate_dataset/", files=files) 61 | 62 | if response.json()["status"] == "success": 63 | st.success(f"Dataset annotated successfully! Saved to {response.json()['path']}") 64 | st.session_state.annotated_path = response.json()["path"] 65 | st.session_state.dataset = pd.read_csv(uploaded_file).drop(columns='Unnamed: 0') 66 | 67 | filtered_dataset_url = response.json()['path'].replace('annotated', 'filtered') 68 | st.session_state.filtered_dataset = pd.read_csv(filtered_dataset_url).drop(columns='Unnamed: 0') 69 | low_confidence_rows = len( 70 | st.session_state.filtered_dataset[st.session_state.filtered_dataset['confidence_scores'] < 1]) 71 | 72 | st.markdown( 73 | f"
{low_confidence_rows} rows " 75 | f"with annotation confidence less than 1, please annotate these manually
", 76 | unsafe_allow_html=True 77 | ) 78 | 79 | else: 80 | st.error("Failed to annotate dataset.") 81 | 82 | if "filtered_dataset" in st.session_state: 83 | # Shows the rows with low confidence values 84 | st.dataframe(st.session_state.filtered_dataset, use_container_width=True) 85 | 86 | row_options = list(st.session_state.filtered_dataset.index) 87 | row_selection = st.selectbox("Edit label for row:", options=row_options) 88 | label_options = ["negative", "neutral", "positive"] 89 | new_label = st.selectbox("Select new label:", options=label_options) 90 | 91 | col1, col2, col3 = st.columns([1, 2, 1]) 92 | with col1: 93 | # Human annotating the GPT 3.5 annotated dataset 94 | if st.button("Update Label"): 95 | st.session_state.filtered_dataset.loc[row_selection, "predicted_labels"] = new_label 96 | 97 | st.markdown( 98 | """ 99 | 113 | """, 114 | unsafe_allow_html=True 115 | ) 116 | # Merge the human annotated low confidence rows with GPT 3.5 Annotated Dataset 117 | if st.button("Merge and Save", key="customSaveButton"): 118 | try: 119 | if 'dataset' not in st.session_state: 120 | st.warning("No original dataset available for merging.") 121 | 122 | annotated_dataset = pd.read_csv(st.session_state.annotated_path).drop(columns='Unnamed: 0') 123 | merged_dataset = annotated_dataset.merge(st.session_state.filtered_dataset[['text', 'predicted_labels']], 124 | on='text', how='left') 125 | merged_dataset['predicted_labels'] = merged_dataset['predicted_labels_y'].combine_first( 126 | merged_dataset['predicted_labels_x']) 127 | 128 | merged_dataset = merged_dataset.drop(columns=['predicted_labels_x', 'predicted_labels_y']) 129 | # get_next_version provides for rudimentary data versioning 130 | # The GPT 3.5 + Human annotated dataset ( merged dataset ) is saved under data/merged 131 | version = get_next_version("data/merged", 'merged_') 132 | save_path = Path("data/merged") / f"merged_{version}.csv" 133 | merged_dataset.reset_index(drop=True) 134 | merged_dataset.to_csv(save_path, index=False) 135 | st.session_state.save_path = str(save_path) 136 | 137 | st.success(f"Dataset saved successfully at {save_path}") 138 | st.session_state.merged_successful = True 139 | # To get more consistent test results across models, we use a static test set 140 | st.success('Allocating 20% of the rows as a hold-out test set') 141 | 142 | train_data, test_data = train_test_split(merged_dataset, test_size=0.2) 143 | 144 | # Train data is saved under data/trainsets 145 | train_version = get_next_version("data/trainsets", 'train_') 146 | train_save_path = Path("data/trainsets") / f"train_{train_version}.csv" 147 | train_data.reset_index(drop=True) 148 | train_data.to_csv(train_save_path, index=False) 149 | 150 | # Test data is saved under data/testsets 151 | test_version = get_next_version("data/testsets", 'test_') 152 | test_save_path = Path("data/testsets") / f"test_{test_version}.csv" 153 | test_data.reset_index(drop=True) 154 | test_data.to_csv(test_save_path, index=False) 155 | # Logging these to st.session_state for later use 156 | st.session_state.test_set_path = test_save_path 157 | st.session_state.train_save_path = train_save_path 158 | 159 | except Exception as e: 160 | st.error(f"An error occurred: {e}") 161 | # After the GPT 3.5 + Human annotated dataset has been saved, train BERT model on it 162 | if st.session_state.get('merged_successful'): 163 | st.write("----") 164 | st.session_state.experiment_name = st.text_input("Enter the experiment name:", 165 | value="llm_seminar_data_annotation") 166 | 167 | epoch_input = int(st.text_input("Enter the number of Epochs for BERT Training:", value="1")) 168 | model_name_inp = st.text_input("Enter the model name:", value="bert_sentiment_gpt35") 169 | # Training BERT on the GPT 3.5 + Human annotated dataset for n epochs 170 | if st.button("Train Model"): 171 | training_message = train_model_style(epoch_input) 172 | st.markdown(training_message, unsafe_allow_html=True) 173 | progress_bar = st.empty() 174 | 175 | # This method is responsible for showing the model training progress bar 176 | def update_progress(current_epoch, total_epochs): 177 | progress = current_epoch / total_epochs 178 | progress_bar.progress(progress) 179 | 180 | 181 | if not hasattr(st.session_state, 'save_path'): 182 | st.warning("No dataset available for training. Please upload, annotate, and then merge first.") 183 | else: 184 | # Training BERT 185 | model_path, val_acc, model = train_bert(model_path=f"models/{model_name_inp}.pt", 186 | train_data_path=st.session_state.train_save_path, 187 | test_data_path=st.session_state.test_set_path, 188 | experiment_name=st.session_state.experiment_name, 189 | epoch_input=epoch_input, 190 | model_name_inp=model_name_inp, 191 | progress_callback=update_progress 192 | ) 193 | 194 | st.success(f"Model trained successfully and saved at {model_path}", icon='✅') 195 | st.write(f"Current Model's trained Validation Accuracy: {val_acc:.2f}") 196 | st.session_state.model_path = model_path 197 | st.session_state.initial_model = model 198 | st.session_state.initial_training = True 199 | 200 | if st.session_state.get('stop_iterations', False): 201 | st.stop() 202 | # After the initial BERT Training on the GPT 3.5 + Human annotated dataset, we use CleanLab to find label issues 203 | # This is done in an iterative manner to enhance the quality of the final dataset and the model test results 204 | if st.session_state.get('initial_training') and not getattr(st.session_state, 'stop_iterations', False): 205 | heading_style = cleanlab_style() 206 | st.markdown(heading_style, unsafe_allow_html=True) 207 | st.write("----") 208 | st.subheader(f"Iteration: {st.session_state.iteration}") 209 | # If it's the first iteration, we take the BERT model trained on the initial Human + GPT dataset 210 | if st.session_state.iteration == 1 and st.session_state.initial_training: 211 | st.session_state.current_model = st.session_state.initial_model 212 | st.session_state.current_data_path = st.session_state.train_save_path 213 | print('Iteration 1 data_path:', st.session_state.current_data_path) 214 | 215 | # Else we choose the previous iteration's dataset and model 216 | else: 217 | # 218 | model_path = f"models/model_cleanlab_{st.session_state.iteration - 1}.pt" 219 | loaded_model = torch.load(model_path) 220 | st.session_state.current_model = loaded_model 221 | st.session_state.current_data_path = f"data/cleaned/cleaned_{st.session_state.iteration - 1}.csv" 222 | print('Iteration:', st.session_state.iteration, ' data_path:', st.session_state.current_data_path) 223 | 224 | # Button to find label issues 225 | if st.button("Find Label Issues", key="find_issues"): 226 | st.session_state.top_20 = find_label_issues(st.session_state.current_model, 227 | st.session_state.current_data_path) 228 | st.success('These are the top 20 labels in the dataset with lowest label quality:') 229 | st.session_state.display_top_20 = True 230 | st.session_state.top20_status = True 231 | 232 | if st.session_state.display_top_20: 233 | # If the CleanLab process was a success, we display 20 rows with the lowest label scores 234 | st.dataframe(st.session_state.top_20, use_container_width=True) 235 | 236 | if st.session_state.top20_status: 237 | # Providing an interface for the human to annotate these rows with label issues 238 | st.subheader("Label Issues for Annotation") 239 | row_options = list(st.session_state.top_20.index) 240 | row_selection = st.selectbox("Edit label for row:", options=row_options, key="row_selection") 241 | label_options = ["negative", "neutral", "positive"] 242 | new_label = st.selectbox("Select new label:", options=label_options, key="new_label_selection") 243 | 244 | col1, col2, col3 = st.columns([1, 2, 1]) 245 | with col1: 246 | if st.button("Update Label", key='update_iterative_button'): 247 | st.session_state.top_20.loc[row_selection, 'predicted_labels'] = new_label 248 | st.session_state.display_top_20 = True 249 | 250 | # We finally merge these annotated rows with the dataset 251 | if st.button("Merge and Save Cleaned Data", key="merge_clean"): 252 | original_data = pd.read_csv(st.session_state.current_data_path) 253 | 254 | merged_dataset = original_data.merge(st.session_state.top_20[['text', 'predicted_labels']], 255 | on='text', how='left') 256 | 257 | merged_dataset['predicted_labels'] = merged_dataset['predicted_labels_y'].combine_first( 258 | merged_dataset['predicted_labels_x']) 259 | 260 | merged_dataset = merged_dataset.drop(columns=['predicted_labels_x', 'predicted_labels_y']) 261 | 262 | # The cleaned dataset is saved under data/cleaned/cleaned_i , where i denotes the iteration number 263 | save_cleaned_path = f"data/cleaned/cleaned_{st.session_state.iteration}.csv" 264 | merged_dataset.to_csv(save_cleaned_path, index=False) 265 | st.success(f"Cleaned data saved at: {save_cleaned_path}") 266 | st.session_state.save_cleaned_path = save_cleaned_path 267 | setattr(st.session_state, f'data_cleaning_{st.session_state.iteration}', True) 268 | 269 | # Once the cleaned dataset has been saved, we train BERT on it and get evaluation metrics 270 | if getattr(st.session_state, f'data_cleaning_{st.session_state.iteration}', False): 271 | st.write("----") 272 | epoch_input = int( 273 | st.text_input("Enter the number of Epochs for BERT Training:", value="1", key='ep_cl_inp')) 274 | model_name_inp = st.text_input("Enter the model name:", value="bert_sentiment_cleanlab", key='mname_cl_inp') 275 | if st.button("Train Model on Cleaned Data"): 276 | training_message = train_model_style(epoch_input) 277 | st.markdown(training_message, unsafe_allow_html=True) 278 | 279 | progress_bar = st.empty() 280 | 281 | # This method is responsible for showing the model training progress bar 282 | def update_progress(current_epoch, total_epochs): 283 | progress = current_epoch / total_epochs 284 | progress_bar.progress(progress) 285 | 286 | 287 | if not hasattr(st.session_state, 'save_cleaned_path'): 288 | st.warning("No dataset available for training. Please upload, annotate, and then merge first.") 289 | else: 290 | # Training BERT on cleaned dataset for iteration i 291 | model_path, val_acc, model = train_bert( 292 | model_path=f"models/model_cleanlab_{st.session_state.iteration}.pt", 293 | train_data_path=st.session_state.save_cleaned_path, 294 | test_data_path=st.session_state.test_set_path, 295 | experiment_name=st.session_state.experiment_name, 296 | epoch_input=epoch_input, 297 | model_name_inp=model_name_inp, 298 | progress_callback=update_progress) 299 | st.success(f"Model trained successfully and saved at {model_path}", icon='✅') 300 | st.write(f"Model's trained Validation Accuracy on Cleaned Data: {val_acc:.2f}") 301 | st.session_state.model_path = model_path 302 | st.session_state.current_model = model 303 | st.session_state.bert_clean_training = True 304 | setattr(st.session_state, f'iteration_{st.session_state.iteration}', True) 305 | 306 | # The iteration number is increased if data cleaning for previous iteration is completed 307 | # and if the entire iteration has processed 308 | if (getattr(st.session_state, f'iteration_{st.session_state.iteration}', False) 309 | and getattr(st.session_state, f'data_cleaning_{st.session_state.iteration}', False)): 310 | st.session_state.iteration += 1 311 | st.session_state.top20_status = False 312 | setattr(st.session_state, f'iteration_{st.session_state.iteration}', False) 313 | 314 | col_next, col_stop = st.columns(2) 315 | 316 | with col_next: 317 | if st.button("Next Iteration"): 318 | st.write("----") 319 | st.session_state.display_top_20 = False 320 | st.session_state.next_iteration = True 321 | 322 | with col_stop: 323 | st.session_state.display_top_20 = False 324 | st.session_state.next_iteration = False 325 | if st.button('Stop Iterative CleanLab processing'): 326 | st.write("----") 327 | st.session_state.stop_iterations = True 328 | -------------------------------------------------------------------------------- /frontend_resources/cleanlab_processing_style.html: -------------------------------------------------------------------------------- 1 | 36 |
37 |
CleanLab Processing
38 |
39 |
40 | -------------------------------------------------------------------------------- /frontend_resources/train_model_style.html: -------------------------------------------------------------------------------- 1 | 24 |
Model is Training for {epoch_input} epochs...
25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # This script is for annotating and BERT model fitting in the backend. 2 | # It does not involve any user interface (UI) operations. 3 | 4 | from annotation.annotate_gpt35 import analyze_gpt35 5 | import pandas as pd 6 | from sklearn.metrics import accuracy_score 7 | from models.train_bert import train_bert 8 | from annotation.data_versioning import get_next_version 9 | import os 10 | 11 | def annotate_dataset(filepath: str) -> str: 12 | """ 13 | Annotate a dataset using GPT-3.5 and save the annotated dataset. 14 | 15 | Parameters: 16 | - filepath (str): Path to the unannotated dataset CSV file. 17 | 18 | Returns: 19 | - str: Path to the annotated dataset CSV file. 20 | """ 21 | unannotated = pd.read_csv(filepath, encoding='unicode_escape', index_col=[0]) 22 | original_dataset = pd.read_csv('data/original/train.csv',encoding='unicode_escape') 23 | 24 | sentiments_and_scores = unannotated['text'].apply(analyze_gpt35) 25 | unannotated['predicted_labels'] = [x[0] for x in sentiments_and_scores] 26 | unannotated['confidence_score'] = [x[1] for x in sentiments_and_scores] 27 | 28 | unannotated = unannotated[unannotated['predicted_labels'].isin(['positive', 'negative', 'neutral'])] 29 | 30 | version = get_next_version('data/annotated', 'annotated_') 31 | annotated_file_path = os.path.join('data', 'annotated', f"annotated_{version}.csv") 32 | unannotated.to_csv(annotated_file_path) 33 | print('Annotated dataset head: ', unannotated.head()) 34 | 35 | accuracy = accuracy_score( 36 | original_dataset.loc[list(set(unannotated.index) & set(original_dataset.index)), 'sentiment'].values, 37 | unannotated.loc[list(set(unannotated.index) & set(original_dataset.index)), 'predicted_labels'].values 38 | ) 39 | print(f"Accuracy of GPT 3.5's annotations: {accuracy}") 40 | return annotated_file_path 41 | 42 | if __name__ == "__main__": 43 | annotate_dataset('data/unannotated/unannotated_50.csv') 44 | train_bert('models/bert_sentiment_gpt35_1000_model3.pt', 'data/annotated/gpt35_conf_scores_1000_preproc.csv') 45 | -------------------------------------------------------------------------------- /models/train_bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from transformers import BertTokenizerFast, BertForSequenceClassification 4 | from torch.optim import AdamW 5 | from sklearn.metrics import classification_report 6 | import pandas as pd 7 | from tqdm import tqdm 8 | import torch.nn as nn 9 | import numpy as np 10 | import mlflow 11 | import mlflow.pytorch 12 | from mlflow.tracking import MlflowClient 13 | from sklearn.base import BaseEstimator 14 | from sklearn.preprocessing import LabelEncoder 15 | from typing import Dict, Tuple, Optional, List 16 | from pathlib import Path 17 | from torch.optim.lr_scheduler import StepLR 18 | 19 | #Custom class for BERT to align it with scikit-learn 20 | class BertSentimentClassifier(BaseEstimator): 21 | def __init__(self, model_path: str = 'bert-base-uncased', device: Optional[torch.device] = None, epochs: int = 1): 22 | """ 23 | Initialize the BertSentimentClassifier with the given parameters. 24 | 25 | Parameters: 26 | - model_path (str): The path to the BERT model or the model name (default is 'bert-base-uncased'). 27 | - device (Optional[torch.device]): The device to use for model training (default is 'cuda' if available, else 'cpu'). 28 | - epochs (int): The number of training epochs (default is 1). 29 | """ 30 | self.model_path = model_path 31 | self.tokenizer = BertTokenizerFast.from_pretrained(model_path) 32 | self.model = BertForSequenceClassification.from_pretrained(model_path, num_labels=3) 33 | self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | self.model.to(self.device) 35 | self.max_len = 128 36 | self.epochs = epochs 37 | 38 | def fit(self, X, y, progress_callback=None) -> None: 39 | """ 40 | Train the model using the given data. 41 | 42 | Parameters: 43 | - X: Input data. 44 | - y: Target labels. 45 | - progress_callback: A callback function to track training progress on UI (default is None). 46 | """ 47 | self.classes_ = np.unique(y) 48 | 49 | train_data = CustomDataset(X, y, self.tokenizer, max_len=self.max_len) 50 | train_loader = DataLoader(train_data, batch_size=16, shuffle=True) 51 | 52 | optimizer = AdamW(self.model.parameters(), lr=2e-5) 53 | 54 | for epoch in range(self.epochs): 55 | train_acc, train_loss = train_epoch(self.model, train_loader, optimizer, self.device) 56 | print(f'Epoch {epoch + 1}/{self.epochs} - Train loss: {train_loss}, accuracy: {train_acc}') 57 | 58 | if progress_callback: 59 | progress_callback(epoch + 1, self.epochs) 60 | 61 | 62 | def predict(self, X) -> np.array: 63 | """ 64 | Predict the class labels for the given data. 65 | 66 | Parameters: 67 | - X: Input data. 68 | 69 | Returns: 70 | - np.array: Predicted class labels. 71 | """ 72 | X_list = X.tolist() 73 | encoding = self.tokenizer.batch_encode_plus( 74 | X_list, 75 | add_special_tokens=True, 76 | max_length=128, 77 | return_token_type_ids=False, 78 | padding='max_length', 79 | return_attention_mask=True, 80 | return_tensors='pt', 81 | ) 82 | 83 | input_ids = encoding['input_ids'].to(self.device) 84 | attention_mask = encoding['attention_mask'].to(self.device) 85 | 86 | with torch.no_grad(): 87 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 88 | _, preds = torch.max(outputs.logits, dim=1) 89 | 90 | return self.classes_[preds.cpu().numpy()] 91 | 92 | def predict_proba(self, X) -> np.array: 93 | """ 94 | Predict the class probabilities for the given data. 95 | 96 | Parameters: 97 | - X: Input data. 98 | 99 | Returns: 100 | - np.array: Predicted class probabilities. 101 | """ 102 | X_list = X.tolist() # Convert to list 103 | encoding = self.tokenizer.batch_encode_plus( 104 | X_list, # Updated this line 105 | add_special_tokens=True, 106 | max_length=128, 107 | return_token_type_ids=False, 108 | padding='max_length', 109 | return_attention_mask=True, 110 | return_tensors='pt', 111 | ) 112 | 113 | input_ids = encoding['input_ids'].to(self.device) 114 | attention_mask = encoding['attention_mask'].to(self.device) 115 | 116 | with torch.no_grad(): 117 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 118 | 119 | # Convert logits to probabilities 120 | probs = torch.nn.functional.softmax(outputs.logits, dim=1) 121 | 122 | return probs.cpu().numpy() 123 | 124 | def score(self, X, y) -> float: 125 | """ 126 | Compute the accuracy of the model on the given test data and labels. 127 | 128 | Parameters: 129 | - X: Input data. 130 | - y: Target labels. 131 | 132 | Returns: 133 | - float: Accuracy score. 134 | """ 135 | y_pred = self.predict(X) 136 | accuracy = (y_pred == y).mean() 137 | return accuracy 138 | 139 | def set_model_weights(self, state_dict: dict) -> None: 140 | """ 141 | Set the model weights from the given state dictionary. 142 | 143 | Parameters: 144 | - state_dict (dict): The state dictionary containing model weights. 145 | """ 146 | self.model.load_state_dict(state_dict) 147 | 148 | def state_dict(self) -> dict: 149 | """ 150 | Get the model's state dictionary. 151 | 152 | Returns: 153 | - dict: Model's state dictionary. 154 | """ 155 | return self.model.state_dict() 156 | 157 | def load_state_dict(self, state_dict: dict) -> None: 158 | """ 159 | Load the model weights from the given state dictionary. 160 | 161 | Parameters: 162 | - state_dict (dict): The state dictionary containing model weights. 163 | """ 164 | return self.model.load_state_dict(state_dict) 165 | 166 | 167 | class CustomDataset(Dataset): 168 | def __init__(self, texts: List[str], targets: List[int], tokenizer, max_len: int): 169 | """ 170 | Initialize the CustomDataset with texts, targets, tokenizer, and a maximum sequence length. 171 | 172 | Parameters: 173 | - texts (List[str]): A list of input texts. 174 | - targets (List[int]): A list of target labels or scores associated with the texts. 175 | - tokenizer: An NLP tokenizer. 176 | - max_len (int): The maximum sequence length to which the input texts should be truncated or padded. 177 | """ 178 | self.texts = texts 179 | self.targets = targets 180 | self.tokenizer = tokenizer 181 | self.max_len = max_len 182 | 183 | def __len__(self) -> int: 184 | """ 185 | Return the number of items in the dataset. 186 | 187 | Returns: 188 | - int: The number of items in the dataset. 189 | """ 190 | return len(self.texts) 191 | 192 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 193 | """ 194 | Retrieve and tokenize the text at the given index and return it with the associated target. 195 | 196 | Parameters: 197 | - idx (int): The index of the item to retrieve. 198 | 199 | Returns: 200 | - dict: A dictionary containing the following elements: 201 | - 'text' (str): The original text. 202 | - 'input_ids' (torch.Tensor): The tokenized and encoded input text as a flattened tensor. 203 | - 'attention_mask' (torch.Tensor): The attention mask indicating which tokens are part of the input text (flattened tensor). 204 | - 'targets' (torch.Tensor): The target label or score as a PyTorch tensor with dtype=torch.long, indicating it's for classification tasks. 205 | """ 206 | text = str(self.texts[idx]) 207 | target = self.targets[idx] 208 | 209 | encoding = self.tokenizer.encode_plus( 210 | text, 211 | add_special_tokens=True, 212 | max_length=self.max_len, 213 | return_token_type_ids=False, 214 | padding='max_length', 215 | return_attention_mask=True, 216 | return_tensors='pt', 217 | ) 218 | 219 | return { 220 | 'text': text, 221 | 'input_ids': encoding['input_ids'].flatten(), 222 | 'attention_mask': encoding['attention_mask'].flatten(), 223 | 'targets': torch.tensor(target, dtype=torch.long) 224 | } 225 | 226 | def train_epoch(model, data_loader, optimizer, device, scheduler: Optional[StepLR] = None) -> Tuple[float, float]: 227 | """ 228 | Trains the model for one epoch. 229 | 230 | Parameters: 231 | - model (nn.Module): The PyTorch model to be trained. 232 | - data_loader (DataLoader): DataLoader providing the training data. 233 | - optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. 234 | - device (torch.device): The device (e.g., 'cuda' or 'cpu') where the model and data should be loaded. 235 | - scheduler (Optional[torch.optim.lr_scheduler._LRScheduler]): An optional learning rate scheduler. 236 | 237 | Returns: 238 | - Tuple containing: 239 | 1. Training accuracy. 240 | 2. Average training loss. 241 | """ 242 | model = model.train() 243 | losses = [] 244 | correct_predictions = 0 245 | 246 | for d in tqdm(data_loader): 247 | input_ids = d["input_ids"].to(device) 248 | attention_mask = d["attention_mask"].to(device) 249 | targets = d["targets"].to(device) 250 | 251 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=targets) 252 | loss = outputs.loss 253 | logits = outputs.logits 254 | 255 | _, preds = torch.max(logits, dim=1) 256 | correct_predictions += torch.sum(preds == targets) 257 | 258 | losses.append(loss.item()) 259 | 260 | loss.backward() 261 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 262 | optimizer.step() 263 | if scheduler: 264 | scheduler.step() 265 | optimizer.zero_grad() 266 | 267 | return correct_predictions.double() / len(data_loader.dataset), np.mean(losses) 268 | 269 | def eval_model(model: nn.Module, 270 | data_loader: DataLoader, 271 | device: torch.device, 272 | sentiments: Dict[str, int]) -> Tuple[float, str]: 273 | """ 274 | Evaluate the model on a dataset. 275 | 276 | Parameters: 277 | - model: The model to be evaluated. 278 | - data_loader: DataLoader providing the evaluation data. 279 | - device: Device (e.g., 'cuda' or 'cpu') where the model and data should be loaded. 280 | - sentiments: Dictionary of sentiment classes. 281 | 282 | Returns: 283 | - Tuple containing: 284 | 1. Evaluation accuracy. 285 | 2. Classification report string. 286 | """ 287 | model = model.eval() 288 | 289 | correct_predictions = 0 290 | predictions = [] 291 | real_values = [] 292 | 293 | with torch.no_grad(): 294 | for d in tqdm(data_loader): 295 | input_ids = d["input_ids"].to(device) 296 | attention_mask = d["attention_mask"].to(device) 297 | targets = d["targets"].to(device) 298 | 299 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=targets) 300 | _, preds = torch.max(outputs.logits, dim=1) 301 | 302 | predictions.extend(preds) 303 | real_values.extend(targets) 304 | correct_predictions += torch.sum(preds == targets) 305 | 306 | predictions = torch.stack(predictions).cpu() 307 | real_values = torch.stack(real_values).cpu() 308 | return correct_predictions.double() / len(data_loader.dataset), classification_report(real_values, predictions, target_names=sentiments.keys()) 309 | 310 | 311 | def train_bert(model_path: str, train_data_path: str, test_data_path: str, experiment_name: str, 312 | epoch_input: int, model_name_inp: str, progress_callback=None): 313 | """ 314 | Train a BERT-based sentiment classifier and log metrics using MLflow. 315 | 316 | Parameters: 317 | - model_path (str): The path where the trained model will be saved. 318 | - train_data_path (str): Path to the training data CSV file. 319 | - test_data_path (str): Path to the testing data CSV file. 320 | - experiment_name (str): Name of the MLflow experiment. 321 | - epoch_input (int): Number of training epochs. 322 | - model_name_inp (str): Name for the registered MLflow model. 323 | - progress_callback: A callback function to track training progress (default is None). 324 | 325 | Returns: 326 | - Tuple containing: 327 | 1. Path to the saved model. 328 | 2. Accuracy on the test data. 329 | 3. Trained BERTSentimentClassifier instance. 330 | """ 331 | EXPERIMENT_NAME = experiment_name 332 | client = MlflowClient() 333 | experiment_id = client.get_experiment_by_name(EXPERIMENT_NAME) 334 | if experiment_id is None: 335 | experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) 336 | else: 337 | experiment_id = experiment_id.experiment_id 338 | model_name = model_name_inp 339 | 340 | with mlflow.start_run(experiment_id=experiment_id): 341 | 342 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 343 | 344 | mlflow.log_param("epochs", epoch_input) 345 | mlflow.log_param("model_name", model_name) 346 | 347 | #sentiments = {'positive': 0, 'neutral': 1, 'negative': 2} 348 | train_data = pd.read_csv(train_data_path, encoding='unicode_escape') 349 | train_data = train_data.dropna() 350 | test_data = pd.read_csv(test_data_path, encoding='unicode_escape') 351 | test_data = test_data.dropna() 352 | 353 | raw_train_texts, raw_train_labels = train_data["text"].values, train_data["predicted_labels"].values 354 | raw_test_texts, raw_test_labels = test_data["text"].values, test_data["predicted_labels"].values 355 | 356 | # Label encoding the labels 357 | encoder = LabelEncoder() 358 | encoder.fit(raw_train_labels) 359 | train_labels = encoder.transform(raw_train_labels) 360 | test_labels = encoder.transform(raw_test_labels) 361 | 362 | # Define classifier and fit 363 | clf = BertSentimentClassifier(epochs= epoch_input) 364 | clf.fit(raw_train_texts, train_labels, progress_callback) 365 | initial_accuracy = clf.score(raw_test_texts, test_labels) 366 | print(f'Accuracy: {initial_accuracy:.4f}') 367 | 368 | mlflow.log_metric("accuracy", initial_accuracy) 369 | mlflow.pytorch.log_model(clf.model, "model") 370 | 371 | mlflow.register_model( 372 | model_uri=f"runs:/{mlflow.active_run().info.run_id}/model", 373 | name=model_name 374 | ) 375 | torch.save(clf, model_path) 376 | mlflow.end_run() 377 | return model_path, initial_accuracy, clf 378 | 379 | -------------------------------------------------------------------------------- /openai/key_mock.txt: -------------------------------------------------------------------------------- 1 | sk-key -------------------------------------------------------------------------------- /openai/organization_mock.txt: -------------------------------------------------------------------------------- 1 | org-id -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cleanlab==2.4.0 2 | fastapi==0.101.0 3 | mlflow==2.5.0 4 | numpy==1.25.1 5 | pandas==2.0.3 6 | Requests==2.31.0 7 | scikit_learn==1.3.0 8 | streamlit==1.25.0 9 | tenacity==8.2.2 10 | torch==2.0.1+cu118 11 | tqdm==4.65.0 12 | transformers==4.31.0 13 | --------------------------------------------------------------------------------