├── .data_tracking └── data_info.txt ├── .flake8 ├── .github └── workflows │ └── test_cli.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── app ├── __init__.py ├── app.py ├── assets │ ├── Report.pdf │ └── bootstrap.css ├── callbacks.py ├── index.py └── layouts.py ├── data ├── README.md └── exp_data.csv.gz ├── notebooks ├── README.md └── active_learning_data.ipynb ├── pyproject.toml ├── requirements.txt ├── scripts ├── __init__.py ├── annotator.py ├── config.py ├── data.py ├── download_data.py └── train.py └── tests ├── README.md └── annotation ├── __init__.py └── test_annotator.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = venv 3 | ignore = E501, W503, E226 4 | max-line-length = 79 5 | 6 | # E501: Line too long 7 | # W503: Line break occurred before binary operator 8 | # E226: Missing white space around arithmetic operator 9 | -------------------------------------------------------------------------------- /.github/workflows/test_cli.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: CI 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the main branch 8 | pull_request: 9 | branches: [ main ] 10 | 11 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 12 | jobs: 13 | # This workflow contains a single job called "build" 14 | build: 15 | # The type of runner that the job will run on 16 | runs-on: ubuntu-latest 17 | 18 | # Steps represent a sequence of tasks that will be executed as part of the job 19 | steps: 20 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 21 | - uses: actions/checkout@v2 22 | 23 | # Runs a single command using the runners shell 24 | - name: Test the cli 25 | run: | 26 | pip install -r requirements.txt 27 | pytest -v 28 | 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .pytest_cache 3 | .pyc 4 | runs 5 | .vscode 6 | wandb 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.4.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: check-yaml 9 | - id: check-added-large-files 10 | args: ['--maxkb=1000'] 11 | - id: check-ast 12 | - id: check-json 13 | - id: check-merge-conflict 14 | - id: detect-private-key 15 | - repo: https://github.com/psf/black 16 | rev: 20.8b1 17 | hooks: 18 | - id: black 19 | args: [] 20 | files: . 21 | - repo: https://gitlab.com/PyCQA/flake8 22 | rev: 3.9.0 23 | hooks: 24 | - id: flake8 25 | - repo: https://github.com/PyCQA/isort 26 | rev: 5.8.0 27 | hooks: 28 | - id: isort 29 | args: [] 30 | files: . 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Abinaya Mahendiran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Styling 2 | .PHONY: style 3 | style: 4 | black . 5 | flake8 6 | isort . 7 | 8 | # Tests 9 | .PHONY: test 10 | test: 11 | pytest --version 12 | pytest --cov scripts --cov-report html 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Active Learning in NLP 3 | 4 | The aim of this project is to identify if active learning can help in building better models with less good quality data in the NLP domain. 5 | The project is undertaken as part of the Full Stack Deep Learning Course, 2021. The code is not production ready yet and it is in experimental stage. 6 | 7 | ## Initial Plan 8 | The initial plan is to build an end-to-end project that contains the following components, 9 | - Active learning (using custom code or available library) 10 | - Multi-class classification model using Transformers 11 | - Experiment tracking using MLflow or wandb 12 | - Labeling using labelstudio 13 | - App using streamlit/Dash 14 | - Explainability for NLP models - stretch goal 15 | - Unit testing using pytest/unittest 16 | - CI/CD using Github Actions 17 | 18 | ## Components completed 19 | Currently, the following components are present (some of which are still in not 100% complete), 20 | - Active learning is performed using uncertainity based sampling (random, least confidence, entropy based) 21 | - Multi-class classification of news articles is done using Simpletransformers library 22 | - Experiments are tracked using wandb 23 | - CLI based annotation tool 24 | - GUI based Dash app for annotation (functionalities are not completed yet) 25 | - Pytest and coverage (doesn't cover all the code yet) 26 | - CI/CD using Github Actions (in progress) 27 | - Explainability of NLP models (will be done in the future) 28 | - Expose API using FastAPI/Flask (will be done in the future) 29 | 30 | ## Authors 31 | - [@AbinayaM02](https://github.com/AbinayaM02) 32 | - [@datafool](https://github.com/datafool) 33 | 34 | ## Demo 35 | 36 | Insert gif or link to demo (will be provided later) 37 | 38 | ## Documentation 39 | 40 | [Documentation] (will be updated later) 41 | 42 | 43 | ## Environment Variables 44 | 45 | To run this project, you will need to add the following path to your environment variable, 46 | ``` 47 | export PYTHONPATH="${PYTHONPATH}: 48 | ``` 49 | ## Running the CLI tool 50 | 51 | To run the CLI tool, use the following command 52 | ``` 53 | python scripts/annotator.py \ 54 | \ 55 | \ 56 | 57 | ``` 58 | 59 | ## Running the GUI annotation tool 60 | 61 | To run the GUI tool, use the following command 62 | ``` 63 | python app/index.py 64 | ``` 65 | 66 | ## Train model 67 | 68 | To train the model on the news corpus (by directly downloading it from HuggingFace Datasets), do the following 69 | ``` 70 | python scripts/download_data.py 71 | python scripts/train.py 72 | ``` 73 | If you've downloaded the data from the Kaggle competition, then use the following commands, 74 | ``` 75 | # To prepare the data 76 | python scripts/data.py 77 | 78 | # To train the model 79 | jupyter notebook 80 | ``` 81 | 82 | ## Running Tests 83 | 84 | To run tests, run the following command 85 | ``` 86 | make test 87 | ``` 88 | 89 | ## Running Styling 90 | 91 | To run the styling on this project, use the following command 92 | ``` 93 | make style 94 | ``` 95 | 96 | ## Roadmap 97 | - Fix the issues with annotation app (Dash) 98 | - Add testcases for all the modules 99 | - Add features to train and inference using Dash app 100 | - Fix Github Actions 101 | - Dockerize the application 102 | - Add documentation 103 | 104 | 105 | ## Screenshots 106 | 107 | ### CLI 108 | ![CLI tool](https://user-images.githubusercontent.com/28945722/118386233-ad2e1d80-b633-11eb-8b1a-326c03e398b8.png) 109 | 110 | ### Dash Tool 111 | ![Dash : Home](https://user-images.githubusercontent.com/28945722/118368098-fabc7300-b5be-11eb-8774-da6dcceab501.png) 112 | ![Dash : Annotation details](https://user-images.githubusercontent.com/28945722/118368130-20e21300-b5bf-11eb-893e-756693583463.png) 113 | ![Dash : Annotation](https://user-images.githubusercontent.com/28945722/118386173-2e38e500-b633-11eb-90e9-7a2453b448e8.png) 114 | 115 | 116 | ## Acknowledgement 117 | 118 | - [@katherinepeterson](https://www.github.com/katherinepeterson) for development and design of README.so. 119 | - [@GokuMohandas](https://github.com/GokuMohandas) for the course on MLOps (https://github.com/GokuMohandas/MadeWithML). 120 | 121 | ## License 122 | 123 | [MIT](https://choosealicense.com/licenses/mit/) 124 | 125 | 126 | ## Appendix 127 | 128 | Any additional information goes here 129 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AbinayaM02/Active_Learning_in_NLP/9c6bb281c5508a6117fa3c829c54562b3d58253a/app/__init__.py -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | import dash 2 | import dash_bootstrap_components as dbc 3 | 4 | # Define the app 5 | app = dash.Dash(__name__, external_stylesheets=[dbc.themes.YETI], suppress_callback_exceptions=True) 6 | server = app.server -------------------------------------------------------------------------------- /app/assets/Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AbinayaM02/Active_Learning_in_NLP/9c6bb281c5508a6117fa3c829c54562b3d58253a/app/assets/Report.pdf -------------------------------------------------------------------------------- /app/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import json 4 | 5 | # import pandas as pd 6 | from dash.dependencies import Input, Output, State 7 | from layouts import (annotate_data_dir, 8 | annotation_layout, 9 | instruction_example_tabs, 10 | instruction_tab_content, 11 | example_tab_content, 12 | tagging_layout, 13 | popup_layout, 14 | stat_layout) 15 | from scripts import annotator 16 | from app import app 17 | 18 | # Callback for slider 19 | @app.callback( 20 | Output('slider-output-container', 'children'), 21 | Input('selected-samples', 'value') 22 | ) 23 | def update_output(value): 24 | return 25 | 26 | # Callback for popup 27 | @app.callback(Output('confirm', 'displayed'), 28 | Input('choose-data', 'value')) 29 | def display_confirm(value): 30 | if value == None: 31 | return True 32 | return False 33 | 34 | # Callback for submit 35 | @app.callback( 36 | [ 37 | Output("selected-data-method", "children"), 38 | Output("annotate-text", "data"), 39 | ], 40 | [ 41 | Input("submit-val", "n_clicks"), 42 | Input('confirm', 'submit_n_clicks'), 43 | Input("choose-data", "value"), 44 | Input("choose-annotate-method", "value"), 45 | Input("selected-samples", "value"), 46 | ], 47 | ) 48 | def get_data_for_annotation(n_clicks, submit_n_clicks, data, method, sample_size): 49 | if submit_n_clicks == 0: 50 | return [True, True] 51 | if n_clicks: 52 | data_path = os.path.join(annotate_data_dir, data) 53 | df_annotate, _ = annotator.get_data(data_path, method, sample_size) 54 | df_annotate["sampling_method"] = method 55 | return [True, df_annotate.to_json(orient="columns")] 56 | 57 | # Callback for tabs 58 | @app.callback(Output('instruction-example-tab', 'children'), 59 | Input('tabs-content', 'value')) 60 | def render_content(tab): 61 | if tab == 'Annotation Instructions': 62 | return instruction_tab_content 63 | elif tab == 'Annotation Examples': 64 | return example_tab_content 65 | 66 | # Callback for annotation_layout 67 | @app.callback( 68 | [ 69 | Output("title-body", "children"), 70 | Output("description-body", "children"), 71 | ], 72 | [ 73 | Input("annotate-next", "n_clicks"), 74 | Input("annotate-prev", "n_clicks"), 75 | Input("news-class", 'value') 76 | ], 77 | State("annotate-text", "data") 78 | ) 79 | def sample_data(n_clicks_next, n_clicks_back, news_type, df_json): 80 | df = pd.DataFrame.from_dict(json.loads(df_json)) 81 | if n_clicks_back <= 0 or n_clicks_next <= 0: 82 | df_index = df["idx"].values[0] 83 | title = df["text"].values[0] 84 | description = df["title"].values[0] 85 | if n_clicks_next: 86 | df_index = df["idx"].values[n_clicks_next] 87 | title = df["text"].values[n_clicks_next] 88 | description = df["title"].values[n_clicks_next] 89 | if n_clicks_back: 90 | n_clicks_next = n_clicks_next - 1 91 | df_index = df["idx"].values[n_clicks_next] 92 | title = df["text"].values[n_clicks_next] 93 | description = df["title"].values[n_clicks_next] 94 | df.loc[df["idx"] == df_index, "annotated_labels"]= news_type - 1 95 | print(n_clicks_next, n_clicks_back) 96 | print(df) 97 | return [f"Title:\n {title}", f"Description:\n {description}"] 98 | 99 | -------------------------------------------------------------------------------- /app/index.py: -------------------------------------------------------------------------------- 1 | import dash_bootstrap_components as dbc 2 | import dash_core_components as dcc 3 | import dash_html_components as html 4 | from dash.dependencies import Input, Output 5 | from layouts import ( 6 | annotation_layout, 7 | instruction_example_tabs, 8 | report_layout, 9 | sidebar, 10 | sidebar_content, 11 | stat_layout, 12 | tagging_layout, 13 | ) 14 | import callbacks 15 | from app import app 16 | 17 | # Define app layout 18 | app.layout = dbc.Container( 19 | [ 20 | dcc.Location(id="url", refresh=False), 21 | # html.Div(id="page-content") 22 | sidebar, 23 | sidebar_content, 24 | dcc.Store(id="annotate-text", storage_type="local"), 25 | ], 26 | fluid=True, 27 | ) 28 | 29 | 30 | @app.callback(Output("page-content", "children"), Input("url", "pathname")) 31 | def display_page(pathname): 32 | if pathname == "/home": 33 | return instruction_example_tabs 34 | elif pathname == "/annotate": 35 | return annotation_layout 36 | elif pathname == "/annotate_info": 37 | return tagging_layout 38 | elif pathname == "/annotate-stat": 39 | return stat_layout 40 | elif pathname == "/report": 41 | return report_layout 42 | return dbc.Jumbotron( 43 | [ 44 | html.H1("404: Not found", className="text-danger"), 45 | html.Hr(), 46 | html.P(f"The pathname {pathname} was not recognised..."), 47 | ] 48 | ) 49 | 50 | 51 | if __name__ == "__main__": 52 | app.run_server( 53 | debug=True, 54 | ) 55 | -------------------------------------------------------------------------------- /app/layouts.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import dash_core_components as dcc 4 | import dash_bootstrap_components as dbc 5 | import dash_html_components as html 6 | 7 | 8 | colors = {"background": "#000000", "text": "#7FDBFF"} 9 | class_map = {1: "World News", 2: "Sports", 3: "Business", 4: "Sci/Tech", 0: "Not Sure"} 10 | annotate_data_dir = Path(__file__).resolve().parents[1] / "data/output/20210421" 11 | 12 | data_file_dropdown = dcc.Dropdown( 13 | id="choose-data", 14 | options=[ 15 | {"label": i, "value": i} for i in os.listdir(annotate_data_dir) 16 | ], 17 | style={"width": "75%"}, 18 | 19 | ) 20 | 21 | annotation_method_dropdown = dcc.Dropdown( 22 | id="choose-annotate-method", 23 | options=[ 24 | {"label": "Random Sampling", "value": "random"}, 25 | {"label": "Least Confidence Sampling", "value": "least",}, 26 | {"label": "Margin Sampling", "value": "margin"}, 27 | {"label": "Entropy Base Sampling", "value": "entropy"}, 28 | ], 29 | value="random", 30 | style={"width": "75%"}, 31 | ) 32 | 33 | sample_slider = dcc.Slider( 34 | id='selected-samples', 35 | min=0, 36 | max=1000, 37 | step=10, 38 | marks={i: '{}'.format(i) for i in range(1001) if i%100 == 0}, 39 | value=50, 40 | updatemode='drag', 41 | tooltip={'always_visible': False} 42 | ) 43 | 44 | popup_layout = html.Div( 45 | [ 46 | dcc.ConfirmDialog( 47 | id='confirm', 48 | message='No file is chosen! Do you want to continue?', 49 | ), 50 | html.Div(id='output-confirm') 51 | ] 52 | ) 53 | 54 | tagging_layout = html.Div( 55 | [ 56 | dbc.CardHeader( 57 | children="annotate.it", 58 | style={"textAlign": "center"}, 59 | ), 60 | popup_layout, 61 | dbc.Card( 62 | [ 63 | html.Div( 64 | [ 65 | html.Br(), 66 | html.Br(), 67 | ] 68 | ), 69 | dbc.Row( 70 | [ 71 | dbc.Col(dbc.FormGroup([dbc.Label("Choose data to annotate"), data_file_dropdown]), width={"size": 5, "order": "first", "offset": 1}), 72 | dbc.Col(dbc.FormGroup([dbc.Label("Choose annotation method"), annotation_method_dropdown]), width={"size": 5, "order": "last", "offset": 1}), 73 | ], 74 | justify="around" 75 | ), 76 | html.Div( 77 | [ 78 | html.Br(), 79 | html.Br(), 80 | dbc.Label("Select number of samples"), 81 | sample_slider, 82 | html.Div(id="slider-output-container"), 83 | ] 84 | ), 85 | html.Div( 86 | [ 87 | html.Br(), 88 | html.Br(), 89 | dbc.Row( 90 | [ 91 | dbc.Col(dbc.Button("Submit", id="submit-val", n_clicks=0, 92 | color="primary", size="md", className="mt-auto", href="/annotate"), 93 | width=4), 94 | ], 95 | justify="around", 96 | ), 97 | ], 98 | style={'textAlign':'center', 99 | 'margin':'auto'} 100 | ), 101 | ] 102 | ), 103 | html.H1(id="selected-data-method") 104 | ], 105 | className="m-4 px-2", 106 | ) 107 | 108 | instruction_tab_content = dbc.Card( 109 | children=[ 110 | dbc.CardBody( 111 | [ 112 | html.P( 113 | """In this exercise we will be labeling news into one of the below four categories. 114 | In case, news is ambigous, please choose option `Not Sure`""" 115 | ), 116 | html.Ol( 117 | [ 118 | html.Li(html.B("World News")), 119 | html.Li(html.B("Sports")), 120 | html.Li(html.B("Business")), 121 | html.Li(html.B("Sci/Tech")), 122 | html.Li(html.B("Not sure")), 123 | ] 124 | ), 125 | ], 126 | className="mt-3", 127 | ), 128 | ], 129 | ) 130 | 131 | example_tab_content = dbc.Card( 132 | children=[ 133 | dbc.CardBody( 134 | [ 135 | html.P( 136 | """Examples for each of the categories are shown below:""" 137 | ), 138 | html.Ol( 139 | [ 140 | html.Li(html.B("World News")), 141 | html.B("Title: "), 142 | html.P( 143 | """ White House Proposes Cuts in Salmon Areas (AP) """ 144 | ), 145 | html.B("Description: "), 146 | html.P( 147 | """ 148 | AP - The Bush administration Tuesday proposed large cuts in 149 | federally designated areas in the Northwest and California meant 150 | to aid the recovery of threatened or endangered salmon. 151 | Protection would focus instead on rivers where the fish now thrive. 152 | """ 153 | ), 154 | html.Li(html.B("Sports News")), 155 | html.B("Title: "), 156 | html.P( 157 | """ Wannstedt Steps Down as Dolphins Coach """ 158 | ), 159 | html.B("Description: "), 160 | html.P( 161 | """ 162 | DAVIE, Fla. (Sports Network) - Dave Wannstedt resigned Tuesday as head 163 | coach of the Miami Dolphins after the team sunk to an NFL-worst 1-8 record. 164 | Defensive coordinator Jim Bates will take over as interim coach for 165 | the remainder of the season. 166 | """ 167 | ), 168 | html.Li(html.B("Business News")), 169 | html.B("Title: "), 170 | html.P( 171 | """ Credit Issuers Shares Dented by Kerry Plan """ 172 | ), 173 | html.B("Description: "), 174 | html.P( 175 | """ 176 | Financial companies were under scrutiny Friday after Sen. 177 | John Kerry vowed to push for legislation that would curb credit card fees 178 | and protect homebuyers from unfair lending practices. 179 | """ 180 | ), 181 | html.Li(html.B("Sci/Tech News")), 182 | html.B("Title: "), 183 | html.P( 184 | """ Titan on Tuesday """ 185 | ), 186 | html.B("Description: "), 187 | html.P( 188 | """ 189 | On Tuesday, October 26, the Cassini spacecraft will approach Saturn #39;s 190 | largest moon, Titan. Cassini will fly by Titan at a distance of 1,200 kilometers 191 | (745 miles) above the surface, nearly 300 times closer than the first Cassini 192 | flyby of Titan on July 3. 193 | """ 194 | ), 195 | ] 196 | ), 197 | ], 198 | className="mt-3", 199 | ), 200 | ] 201 | ) 202 | 203 | instruction_example_tabs = html.Div( 204 | [ 205 | dbc.CardHeader( 206 | children="Annotation Instructions and Examples", 207 | style={"textAlign": "center"}, 208 | ), 209 | dbc.Tabs( 210 | [ 211 | dbc.Tab(instruction_tab_content, label="Annotation Instructions", tab_id="tab-instruction"), 212 | dbc.Tab(example_tab_content, label="Annotation Examples", tab_id="tab-example"), 213 | ], 214 | id="tabs-content", 215 | ), 216 | html.Div(id='instruction-example-tab'), 217 | ] 218 | ) 219 | 220 | annotation_layout = html.Div( 221 | [ 222 | dbc.CardHeader( 223 | children="Annotate.it", 224 | style={"textAlign": "center"}, 225 | ), 226 | dbc.Card( 227 | [ 228 | dbc.CardBody( 229 | [ 230 | html.H5("Title:\n", className="card-title"), 231 | html.P("IPL 2021 is being played in India"), 232 | 233 | ], 234 | id="title-body" 235 | ), 236 | dbc.CardBody( 237 | [ 238 | html.H5("Description:\n", className="card-title"), 239 | html.P("IPL 2021 is being played in India"), 240 | 241 | ], 242 | id="description-body" 243 | ), 244 | dbc.CardBody( 245 | [ 246 | dbc.Label("Please choose one of the five categories that describes the text well!", 247 | align="start", size="md"), 248 | dbc.RadioItems( 249 | id="news-class", 250 | options=[ 251 | {"label": "World News", "value": 1}, 252 | {"label": "Sports", "value": 2}, 253 | {"label": "Business", "value": 3}, 254 | {"label": "Sci/Tech", "value": 4}, 255 | {"label": "Not Sure", "value": 0}, 256 | ], 257 | value=0, 258 | inline=True, 259 | style={"textAlign": "center"}, 260 | ), 261 | ] 262 | ), 263 | html.Div(id="annotate-data"), 264 | html.Br(), 265 | html.Br(), 266 | html.Br(), 267 | html.Div( 268 | [ 269 | dcc.Link(dbc.Button("Back", id="annotate-prev", n_clicks=0, color="secondary", size="md"), href="/annotate"), 270 | dcc.Link(dbc.Button("Save", id="save-link", n_clicks=0, color="primary", size="md"), href="/annotate"), 271 | dcc.Link(dbc.Button("Next", id="annotate-next", n_clicks=0, color="secondary", size="md"), href="/annotate"), 272 | ], 273 | style={"textAlign": 'center'} 274 | ), 275 | ], 276 | ), 277 | ] 278 | ) 279 | 280 | stat_layout = html.Div( 281 | [ 282 | dbc.CardHeader( 283 | children="Annotate.it : Statistics", 284 | style={"textAlign": "center"}, 285 | ), 286 | dbc.Card( 287 | [ 288 | dbc.CardBody( 289 | [ 290 | dcc.Graph(id="stat-graph") 291 | ] 292 | ) 293 | ] 294 | ), 295 | ], 296 | id="stat-container" 297 | ) 298 | 299 | report_layout = html.Div( 300 | [ 301 | dbc.CardHeader( 302 | 'Active learning in NLP', 303 | style={"textAlign": "center"} 304 | ), 305 | dbc.CardBody( 306 | html.Iframe( 307 | src=os.path.join("assets", "sample.pdf"), 308 | style={"width": "910px", "height": "700px"} 309 | ), 310 | ) 311 | ] 312 | ) 313 | # styling the sidebar 314 | SIDEBAR_STYLE = { 315 | "position": "fixed", 316 | "top": 0, 317 | "left": 0, 318 | "bottom": 0, 319 | "width": "16rem", 320 | "padding": "2rem 1rem", 321 | "background-color": "#f8f9fa", 322 | } 323 | 324 | # padding for the page content 325 | CONTENT_STYLE = { 326 | "margin-left": "18rem", 327 | "margin-right": "2rem", 328 | "padding": "2rem 1rem", 329 | } 330 | 331 | sidebar = html.Div( 332 | [ 333 | html.H4("Active Learning", className="display-6"), 334 | html.Hr(), 335 | dbc.Nav( 336 | [ 337 | dbc.NavLink("Home", href="/home", active="exact"), 338 | dbc.NavLink("Annotate", href="/annotate_info", active="exact"), 339 | dbc.NavLink("Annotation Statistics", href="/annotate-stat", active="exact"), 340 | dbc.NavLink("Project Report", href="/report", active="exact") 341 | ], 342 | vertical=True, 343 | pills=True, 344 | ), 345 | ], 346 | style=SIDEBAR_STYLE, 347 | ) 348 | 349 | sidebar_content = html.Div(id="page-content", children=[], style=CONTENT_STYLE) 350 | 351 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Add all the data files in this folder. 2 | 3 | Please run download_data.py to download the data and store it in the data folder automatically before running the training script. 4 | -------------------------------------------------------------------------------- /data/exp_data.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AbinayaM02/Active_Learning_in_NLP/9c6bb281c5508a6117fa3c829c54562b3d58253a/data/exp_data.csv.gz -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | Add notebooks in this folder. 2 | -------------------------------------------------------------------------------- /notebooks/active_learning_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "active_learning_data.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "machine_shape": "hm" 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU", 19 | "widgets": { 20 | "application/vnd.jupyter.widget-state+json": { 21 | "e9b2650e436a4a3d9e0aa4cbb7c3c679": { 22 | "model_module": "@jupyter-widgets/controls", 23 | "model_name": "HBoxModel", 24 | "state": { 25 | "_view_name": "HBoxView", 26 | "_dom_classes": [], 27 | "_model_name": "HBoxModel", 28 | "_view_module": "@jupyter-widgets/controls", 29 | "_model_module_version": "1.5.0", 30 | "_view_count": null, 31 | "_view_module_version": "1.5.0", 32 | "box_style": "", 33 | "layout": "IPY_MODEL_bc17c094006a48d29d6e2cc708a35d0b", 34 | "_model_module": "@jupyter-widgets/controls", 35 | "children": [ 36 | "IPY_MODEL_cde8d171084446e58e8a064e01a7a501", 37 | "IPY_MODEL_b7d5e68eda2c492a9e80525b69d6dd21", 38 | "IPY_MODEL_0de1f93062784a38be5e1016bbf29c05" 39 | ] 40 | } 41 | }, 42 | "bc17c094006a48d29d6e2cc708a35d0b": { 43 | "model_module": "@jupyter-widgets/base", 44 | "model_name": "LayoutModel", 45 | "state": { 46 | "_view_name": "LayoutView", 47 | "grid_template_rows": null, 48 | "right": null, 49 | "justify_content": null, 50 | "_view_module": "@jupyter-widgets/base", 51 | "overflow": null, 52 | "_model_module_version": "1.2.0", 53 | "_view_count": null, 54 | "flex_flow": null, 55 | "width": null, 56 | "min_width": null, 57 | "border": null, 58 | "align_items": null, 59 | "bottom": null, 60 | "_model_module": "@jupyter-widgets/base", 61 | "top": null, 62 | "grid_column": null, 63 | "overflow_y": null, 64 | "overflow_x": null, 65 | "grid_auto_flow": null, 66 | "grid_area": null, 67 | "grid_template_columns": null, 68 | "flex": null, 69 | "_model_name": "LayoutModel", 70 | "justify_items": null, 71 | "grid_row": null, 72 | "max_height": null, 73 | "align_content": null, 74 | "visibility": null, 75 | "align_self": null, 76 | "height": null, 77 | "min_height": null, 78 | "padding": null, 79 | "grid_auto_rows": null, 80 | "grid_gap": null, 81 | "max_width": null, 82 | "order": null, 83 | "_view_module_version": "1.2.0", 84 | "grid_template_areas": null, 85 | "object_position": null, 86 | "object_fit": null, 87 | "grid_auto_columns": null, 88 | "margin": null, 89 | "display": null, 90 | "left": null 91 | } 92 | }, 93 | "cde8d171084446e58e8a064e01a7a501": { 94 | "model_module": "@jupyter-widgets/controls", 95 | "model_name": "HTMLModel", 96 | "state": { 97 | "_view_name": "HTMLView", 98 | "style": "IPY_MODEL_25b6388694ff46eca2c9dbcf49b14ac7", 99 | "_dom_classes": [], 100 | "description": "", 101 | "_model_name": "HTMLModel", 102 | "placeholder": "​", 103 | "_view_module": "@jupyter-widgets/controls", 104 | "_model_module_version": "1.5.0", 105 | "value": " 0%", 106 | "_view_count": null, 107 | "_view_module_version": "1.5.0", 108 | "description_tooltip": null, 109 | "_model_module": "@jupyter-widgets/controls", 110 | "layout": "IPY_MODEL_1cf29525e90c45f3bda46b4d4ea0e9d3" 111 | } 112 | }, 113 | "b7d5e68eda2c492a9e80525b69d6dd21": { 114 | "model_module": "@jupyter-widgets/controls", 115 | "model_name": "FloatProgressModel", 116 | "state": { 117 | "_view_name": "ProgressView", 118 | "style": "IPY_MODEL_2910f4ee601241408647ff845c2713d4", 119 | "_dom_classes": [], 120 | "description": "", 121 | "_model_name": "FloatProgressModel", 122 | "bar_style": "danger", 123 | "max": 67200, 124 | "_view_module": "@jupyter-widgets/controls", 125 | "_model_module_version": "1.5.0", 126 | "value": 135, 127 | "_view_count": null, 128 | "_view_module_version": "1.5.0", 129 | "orientation": "horizontal", 130 | "min": 0, 131 | "description_tooltip": null, 132 | "_model_module": "@jupyter-widgets/controls", 133 | "layout": "IPY_MODEL_0a817039c2cf4f23be507e0b0b7df501" 134 | } 135 | }, 136 | "0de1f93062784a38be5e1016bbf29c05": { 137 | "model_module": "@jupyter-widgets/controls", 138 | "model_name": "HTMLModel", 139 | "state": { 140 | "_view_name": "HTMLView", 141 | "style": "IPY_MODEL_6af6b17b562141f59a0b4fae42b8c1b4", 142 | "_dom_classes": [], 143 | "description": "", 144 | "_model_name": "HTMLModel", 145 | "placeholder": "​", 146 | "_view_module": "@jupyter-widgets/controls", 147 | "_model_module_version": "1.5.0", 148 | "value": " 135/67200 [00:18<1:39:35, 11.22it/s]", 149 | "_view_count": null, 150 | "_view_module_version": "1.5.0", 151 | "description_tooltip": null, 152 | "_model_module": "@jupyter-widgets/controls", 153 | "layout": "IPY_MODEL_1b4aeeac358841a4884e02ac756f8a3c" 154 | } 155 | }, 156 | "25b6388694ff46eca2c9dbcf49b14ac7": { 157 | "model_module": "@jupyter-widgets/controls", 158 | "model_name": "DescriptionStyleModel", 159 | "state": { 160 | "_view_name": "StyleView", 161 | "_model_name": "DescriptionStyleModel", 162 | "description_width": "", 163 | "_view_module": "@jupyter-widgets/base", 164 | "_model_module_version": "1.5.0", 165 | "_view_count": null, 166 | "_view_module_version": "1.2.0", 167 | "_model_module": "@jupyter-widgets/controls" 168 | } 169 | }, 170 | "1cf29525e90c45f3bda46b4d4ea0e9d3": { 171 | "model_module": "@jupyter-widgets/base", 172 | "model_name": "LayoutModel", 173 | "state": { 174 | "_view_name": "LayoutView", 175 | "grid_template_rows": null, 176 | "right": null, 177 | "justify_content": null, 178 | "_view_module": "@jupyter-widgets/base", 179 | "overflow": null, 180 | "_model_module_version": "1.2.0", 181 | "_view_count": null, 182 | "flex_flow": null, 183 | "width": null, 184 | "min_width": null, 185 | "border": null, 186 | "align_items": null, 187 | "bottom": null, 188 | "_model_module": "@jupyter-widgets/base", 189 | "top": null, 190 | "grid_column": null, 191 | "overflow_y": null, 192 | "overflow_x": null, 193 | "grid_auto_flow": null, 194 | "grid_area": null, 195 | "grid_template_columns": null, 196 | "flex": null, 197 | "_model_name": "LayoutModel", 198 | "justify_items": null, 199 | "grid_row": null, 200 | "max_height": null, 201 | "align_content": null, 202 | "visibility": null, 203 | "align_self": null, 204 | "height": null, 205 | "min_height": null, 206 | "padding": null, 207 | "grid_auto_rows": null, 208 | "grid_gap": null, 209 | "max_width": null, 210 | "order": null, 211 | "_view_module_version": "1.2.0", 212 | "grid_template_areas": null, 213 | "object_position": null, 214 | "object_fit": null, 215 | "grid_auto_columns": null, 216 | "margin": null, 217 | "display": null, 218 | "left": null 219 | } 220 | }, 221 | "2910f4ee601241408647ff845c2713d4": { 222 | "model_module": "@jupyter-widgets/controls", 223 | "model_name": "ProgressStyleModel", 224 | "state": { 225 | "_view_name": "StyleView", 226 | "_model_name": "ProgressStyleModel", 227 | "description_width": "", 228 | "_view_module": "@jupyter-widgets/base", 229 | "_model_module_version": "1.5.0", 230 | "_view_count": null, 231 | "_view_module_version": "1.2.0", 232 | "bar_color": null, 233 | "_model_module": "@jupyter-widgets/controls" 234 | } 235 | }, 236 | "0a817039c2cf4f23be507e0b0b7df501": { 237 | "model_module": "@jupyter-widgets/base", 238 | "model_name": "LayoutModel", 239 | "state": { 240 | "_view_name": "LayoutView", 241 | "grid_template_rows": null, 242 | "right": null, 243 | "justify_content": null, 244 | "_view_module": "@jupyter-widgets/base", 245 | "overflow": null, 246 | "_model_module_version": "1.2.0", 247 | "_view_count": null, 248 | "flex_flow": null, 249 | "width": null, 250 | "min_width": null, 251 | "border": null, 252 | "align_items": null, 253 | "bottom": null, 254 | "_model_module": "@jupyter-widgets/base", 255 | "top": null, 256 | "grid_column": null, 257 | "overflow_y": null, 258 | "overflow_x": null, 259 | "grid_auto_flow": null, 260 | "grid_area": null, 261 | "grid_template_columns": null, 262 | "flex": null, 263 | "_model_name": "LayoutModel", 264 | "justify_items": null, 265 | "grid_row": null, 266 | "max_height": null, 267 | "align_content": null, 268 | "visibility": null, 269 | "align_self": null, 270 | "height": null, 271 | "min_height": null, 272 | "padding": null, 273 | "grid_auto_rows": null, 274 | "grid_gap": null, 275 | "max_width": null, 276 | "order": null, 277 | "_view_module_version": "1.2.0", 278 | "grid_template_areas": null, 279 | "object_position": null, 280 | "object_fit": null, 281 | "grid_auto_columns": null, 282 | "margin": null, 283 | "display": null, 284 | "left": null 285 | } 286 | }, 287 | "6af6b17b562141f59a0b4fae42b8c1b4": { 288 | "model_module": "@jupyter-widgets/controls", 289 | "model_name": "DescriptionStyleModel", 290 | "state": { 291 | "_view_name": "StyleView", 292 | "_model_name": "DescriptionStyleModel", 293 | "description_width": "", 294 | "_view_module": "@jupyter-widgets/base", 295 | "_model_module_version": "1.5.0", 296 | "_view_count": null, 297 | "_view_module_version": "1.2.0", 298 | "_model_module": "@jupyter-widgets/controls" 299 | } 300 | }, 301 | "1b4aeeac358841a4884e02ac756f8a3c": { 302 | "model_module": "@jupyter-widgets/base", 303 | "model_name": "LayoutModel", 304 | "state": { 305 | "_view_name": "LayoutView", 306 | "grid_template_rows": null, 307 | "right": null, 308 | "justify_content": null, 309 | "_view_module": "@jupyter-widgets/base", 310 | "overflow": null, 311 | "_model_module_version": "1.2.0", 312 | "_view_count": null, 313 | "flex_flow": null, 314 | "width": null, 315 | "min_width": null, 316 | "border": null, 317 | "align_items": null, 318 | "bottom": null, 319 | "_model_module": "@jupyter-widgets/base", 320 | "top": null, 321 | "grid_column": null, 322 | "overflow_y": null, 323 | "overflow_x": null, 324 | "grid_auto_flow": null, 325 | "grid_area": null, 326 | "grid_template_columns": null, 327 | "flex": null, 328 | "_model_name": "LayoutModel", 329 | "justify_items": null, 330 | "grid_row": null, 331 | "max_height": null, 332 | "align_content": null, 333 | "visibility": null, 334 | "align_self": null, 335 | "height": null, 336 | "min_height": null, 337 | "padding": null, 338 | "grid_auto_rows": null, 339 | "grid_gap": null, 340 | "max_width": null, 341 | "order": null, 342 | "_view_module_version": "1.2.0", 343 | "grid_template_areas": null, 344 | "object_position": null, 345 | "object_fit": null, 346 | "grid_auto_columns": null, 347 | "margin": null, 348 | "display": null, 349 | "left": null 350 | } 351 | }, 352 | "87b7ef958f914aab8e6213d3f2176300": { 353 | "model_module": "@jupyter-widgets/controls", 354 | "model_name": "HBoxModel", 355 | "state": { 356 | "_view_name": "HBoxView", 357 | "_dom_classes": [], 358 | "_model_name": "HBoxModel", 359 | "_view_module": "@jupyter-widgets/controls", 360 | "_model_module_version": "1.5.0", 361 | "_view_count": null, 362 | "_view_module_version": "1.5.0", 363 | "box_style": "", 364 | "layout": "IPY_MODEL_b206209e6ef7409aae13bb41baf7c7bc", 365 | "_model_module": "@jupyter-widgets/controls", 366 | "children": [ 367 | "IPY_MODEL_4c42bed0b1524e699d58e8785e4c1c3c", 368 | "IPY_MODEL_8881de3a8e1f44bab6fe8c01f11cc7a7", 369 | "IPY_MODEL_fb536d3431d9479eb912eb0c113662a4" 370 | ] 371 | } 372 | }, 373 | "b206209e6ef7409aae13bb41baf7c7bc": { 374 | "model_module": "@jupyter-widgets/base", 375 | "model_name": "LayoutModel", 376 | "state": { 377 | "_view_name": "LayoutView", 378 | "grid_template_rows": null, 379 | "right": null, 380 | "justify_content": null, 381 | "_view_module": "@jupyter-widgets/base", 382 | "overflow": null, 383 | "_model_module_version": "1.2.0", 384 | "_view_count": null, 385 | "flex_flow": null, 386 | "width": null, 387 | "min_width": null, 388 | "border": null, 389 | "align_items": null, 390 | "bottom": null, 391 | "_model_module": "@jupyter-widgets/base", 392 | "top": null, 393 | "grid_column": null, 394 | "overflow_y": null, 395 | "overflow_x": null, 396 | "grid_auto_flow": null, 397 | "grid_area": null, 398 | "grid_template_columns": null, 399 | "flex": null, 400 | "_model_name": "LayoutModel", 401 | "justify_items": null, 402 | "grid_row": null, 403 | "max_height": null, 404 | "align_content": null, 405 | "visibility": null, 406 | "align_self": null, 407 | "height": null, 408 | "min_height": null, 409 | "padding": null, 410 | "grid_auto_rows": null, 411 | "grid_gap": null, 412 | "max_width": null, 413 | "order": null, 414 | "_view_module_version": "1.2.0", 415 | "grid_template_areas": null, 416 | "object_position": null, 417 | "object_fit": null, 418 | "grid_auto_columns": null, 419 | "margin": null, 420 | "display": null, 421 | "left": null 422 | } 423 | }, 424 | "4c42bed0b1524e699d58e8785e4c1c3c": { 425 | "model_module": "@jupyter-widgets/controls", 426 | "model_name": "HTMLModel", 427 | "state": { 428 | "_view_name": "HTMLView", 429 | "style": "IPY_MODEL_704ea72ddf774d5dba11d98ca9abf584", 430 | "_dom_classes": [], 431 | "description": "", 432 | "_model_name": "HTMLModel", 433 | "placeholder": "​", 434 | "_view_module": "@jupyter-widgets/controls", 435 | "_model_module_version": "1.5.0", 436 | "value": "100%", 437 | "_view_count": null, 438 | "_view_module_version": "1.5.0", 439 | "description_tooltip": null, 440 | "_model_module": "@jupyter-widgets/controls", 441 | "layout": "IPY_MODEL_e6d4f400022344f99053b6521e505933" 442 | } 443 | }, 444 | "8881de3a8e1f44bab6fe8c01f11cc7a7": { 445 | "model_module": "@jupyter-widgets/controls", 446 | "model_name": "FloatProgressModel", 447 | "state": { 448 | "_view_name": "ProgressView", 449 | "style": "IPY_MODEL_b6488a6d0ec646f1bd16bfdaf2ba9828", 450 | "_dom_classes": [], 451 | "description": "", 452 | "_model_name": "FloatProgressModel", 453 | "bar_style": "success", 454 | "max": 4200, 455 | "_view_module": "@jupyter-widgets/controls", 456 | "_model_module_version": "1.5.0", 457 | "value": 4200, 458 | "_view_count": null, 459 | "_view_module_version": "1.5.0", 460 | "orientation": "horizontal", 461 | "min": 0, 462 | "description_tooltip": null, 463 | "_model_module": "@jupyter-widgets/controls", 464 | "layout": "IPY_MODEL_0e8f39bc412d425e98327d74e5403854" 465 | } 466 | }, 467 | "fb536d3431d9479eb912eb0c113662a4": { 468 | "model_module": "@jupyter-widgets/controls", 469 | "model_name": "HTMLModel", 470 | "state": { 471 | "_view_name": "HTMLView", 472 | "style": "IPY_MODEL_d120778b937c4c97806532b2db11aada", 473 | "_dom_classes": [], 474 | "description": "", 475 | "_model_name": "HTMLModel", 476 | "placeholder": "​", 477 | "_view_module": "@jupyter-widgets/controls", 478 | "_model_module_version": "1.5.0", 479 | "value": " 4200/4200 [02:19<00:00, 30.17it/s]", 480 | "_view_count": null, 481 | "_view_module_version": "1.5.0", 482 | "description_tooltip": null, 483 | "_model_module": "@jupyter-widgets/controls", 484 | "layout": "IPY_MODEL_c6e4edff105640baa81a42c97c78b578" 485 | } 486 | }, 487 | "704ea72ddf774d5dba11d98ca9abf584": { 488 | "model_module": "@jupyter-widgets/controls", 489 | "model_name": "DescriptionStyleModel", 490 | "state": { 491 | "_view_name": "StyleView", 492 | "_model_name": "DescriptionStyleModel", 493 | "description_width": "", 494 | "_view_module": "@jupyter-widgets/base", 495 | "_model_module_version": "1.5.0", 496 | "_view_count": null, 497 | "_view_module_version": "1.2.0", 498 | "_model_module": "@jupyter-widgets/controls" 499 | } 500 | }, 501 | "e6d4f400022344f99053b6521e505933": { 502 | "model_module": "@jupyter-widgets/base", 503 | "model_name": "LayoutModel", 504 | "state": { 505 | "_view_name": "LayoutView", 506 | "grid_template_rows": null, 507 | "right": null, 508 | "justify_content": null, 509 | "_view_module": "@jupyter-widgets/base", 510 | "overflow": null, 511 | "_model_module_version": "1.2.0", 512 | "_view_count": null, 513 | "flex_flow": null, 514 | "width": null, 515 | "min_width": null, 516 | "border": null, 517 | "align_items": null, 518 | "bottom": null, 519 | "_model_module": "@jupyter-widgets/base", 520 | "top": null, 521 | "grid_column": null, 522 | "overflow_y": null, 523 | "overflow_x": null, 524 | "grid_auto_flow": null, 525 | "grid_area": null, 526 | "grid_template_columns": null, 527 | "flex": null, 528 | "_model_name": "LayoutModel", 529 | "justify_items": null, 530 | "grid_row": null, 531 | "max_height": null, 532 | "align_content": null, 533 | "visibility": null, 534 | "align_self": null, 535 | "height": null, 536 | "min_height": null, 537 | "padding": null, 538 | "grid_auto_rows": null, 539 | "grid_gap": null, 540 | "max_width": null, 541 | "order": null, 542 | "_view_module_version": "1.2.0", 543 | "grid_template_areas": null, 544 | "object_position": null, 545 | "object_fit": null, 546 | "grid_auto_columns": null, 547 | "margin": null, 548 | "display": null, 549 | "left": null 550 | } 551 | }, 552 | "b6488a6d0ec646f1bd16bfdaf2ba9828": { 553 | "model_module": "@jupyter-widgets/controls", 554 | "model_name": "ProgressStyleModel", 555 | "state": { 556 | "_view_name": "StyleView", 557 | "_model_name": "ProgressStyleModel", 558 | "description_width": "", 559 | "_view_module": "@jupyter-widgets/base", 560 | "_model_module_version": "1.5.0", 561 | "_view_count": null, 562 | "_view_module_version": "1.2.0", 563 | "bar_color": null, 564 | "_model_module": "@jupyter-widgets/controls" 565 | } 566 | }, 567 | "0e8f39bc412d425e98327d74e5403854": { 568 | "model_module": "@jupyter-widgets/base", 569 | "model_name": "LayoutModel", 570 | "state": { 571 | "_view_name": "LayoutView", 572 | "grid_template_rows": null, 573 | "right": null, 574 | "justify_content": null, 575 | "_view_module": "@jupyter-widgets/base", 576 | "overflow": null, 577 | "_model_module_version": "1.2.0", 578 | "_view_count": null, 579 | "flex_flow": null, 580 | "width": null, 581 | "min_width": null, 582 | "border": null, 583 | "align_items": null, 584 | "bottom": null, 585 | "_model_module": "@jupyter-widgets/base", 586 | "top": null, 587 | "grid_column": null, 588 | "overflow_y": null, 589 | "overflow_x": null, 590 | "grid_auto_flow": null, 591 | "grid_area": null, 592 | "grid_template_columns": null, 593 | "flex": null, 594 | "_model_name": "LayoutModel", 595 | "justify_items": null, 596 | "grid_row": null, 597 | "max_height": null, 598 | "align_content": null, 599 | "visibility": null, 600 | "align_self": null, 601 | "height": null, 602 | "min_height": null, 603 | "padding": null, 604 | "grid_auto_rows": null, 605 | "grid_gap": null, 606 | "max_width": null, 607 | "order": null, 608 | "_view_module_version": "1.2.0", 609 | "grid_template_areas": null, 610 | "object_position": null, 611 | "object_fit": null, 612 | "grid_auto_columns": null, 613 | "margin": null, 614 | "display": null, 615 | "left": null 616 | } 617 | }, 618 | "d120778b937c4c97806532b2db11aada": { 619 | "model_module": "@jupyter-widgets/controls", 620 | "model_name": "DescriptionStyleModel", 621 | "state": { 622 | "_view_name": "StyleView", 623 | "_model_name": "DescriptionStyleModel", 624 | "description_width": "", 625 | "_view_module": "@jupyter-widgets/base", 626 | "_model_module_version": "1.5.0", 627 | "_view_count": null, 628 | "_view_module_version": "1.2.0", 629 | "_model_module": "@jupyter-widgets/controls" 630 | } 631 | }, 632 | "c6e4edff105640baa81a42c97c78b578": { 633 | "model_module": "@jupyter-widgets/base", 634 | "model_name": "LayoutModel", 635 | "state": { 636 | "_view_name": "LayoutView", 637 | "grid_template_rows": null, 638 | "right": null, 639 | "justify_content": null, 640 | "_view_module": "@jupyter-widgets/base", 641 | "overflow": null, 642 | "_model_module_version": "1.2.0", 643 | "_view_count": null, 644 | "flex_flow": null, 645 | "width": null, 646 | "min_width": null, 647 | "border": null, 648 | "align_items": null, 649 | "bottom": null, 650 | "_model_module": "@jupyter-widgets/base", 651 | "top": null, 652 | "grid_column": null, 653 | "overflow_y": null, 654 | "overflow_x": null, 655 | "grid_auto_flow": null, 656 | "grid_area": null, 657 | "grid_template_columns": null, 658 | "flex": null, 659 | "_model_name": "LayoutModel", 660 | "justify_items": null, 661 | "grid_row": null, 662 | "max_height": null, 663 | "align_content": null, 664 | "visibility": null, 665 | "align_self": null, 666 | "height": null, 667 | "min_height": null, 668 | "padding": null, 669 | "grid_auto_rows": null, 670 | "grid_gap": null, 671 | "max_width": null, 672 | "order": null, 673 | "_view_module_version": "1.2.0", 674 | "grid_template_areas": null, 675 | "object_position": null, 676 | "object_fit": null, 677 | "grid_auto_columns": null, 678 | "margin": null, 679 | "display": null, 680 | "left": null 681 | } 682 | } 683 | } 684 | } 685 | }, 686 | "cells": [ 687 | { 688 | "cell_type": "markdown", 689 | "metadata": { 690 | "id": "3HoAB1ZDp1hU" 691 | }, 692 | "source": [ 693 | "# Introduction" 694 | ] 695 | }, 696 | { 697 | "cell_type": "markdown", 698 | "metadata": { 699 | "id": "_mec9E23pnhC" 700 | }, 701 | "source": [ 702 | "This notebook can be used to run experiments that trains RoBERTa model on different split of AG News Corpus dataset. The main aim is to train a model that will classify the news articles into one of the four categories,\n", 703 | "1. World News\n", 704 | "2. Sports News\n", 705 | "3. Business News\n", 706 | "4. Science / Technology News" 707 | ] 708 | }, 709 | { 710 | "cell_type": "markdown", 711 | "metadata": { 712 | "id": "8n8UqVfQqeuu" 713 | }, 714 | "source": [ 715 | "# Libraries Needed" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "metadata": { 721 | "id": "VjML8q1V8Bnv" 722 | }, 723 | "source": [ 724 | "# !pip install simpletransformers" 725 | ], 726 | "execution_count": null, 727 | "outputs": [] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "metadata": { 732 | "colab": { 733 | "base_uri": "https://localhost:8080/" 734 | }, 735 | "id": "naRwwMZLN8XO", 736 | "outputId": "7e3ffef3-59c5-4090-cdd6-67725d7461c5" 737 | }, 738 | "source": [ 739 | "from google.colab import drive\n", 740 | "drive.mount('/content/drive')" 741 | ], 742 | "execution_count": null, 743 | "outputs": [ 744 | { 745 | "output_type": "stream", 746 | "text": [ 747 | "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" 748 | ], 749 | "name": "stdout" 750 | } 751 | ] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "metadata": { 756 | "id": "GOolGKUQOK71" 757 | }, 758 | "source": [ 759 | "import pandas as pd\n", 760 | "import numpy as np\n", 761 | "import torch\n", 762 | "\n", 763 | "# from simpletransformers.classification import ClassificationModel, ClassificationArgs\n", 764 | "# import pandas as pd\n", 765 | "import logging\n", 766 | "\n", 767 | "\n", 768 | "logging.basicConfig(level=logging.INFO)\n", 769 | "transformers_logger = logging.getLogger(\"transformers\")\n", 770 | "transformers_logger.setLevel(logging.WARNING)\n", 771 | "from IPython.display import display\n", 772 | "from sklearn.metrics import accuracy_score\n", 773 | "import os\n", 774 | "import torch\n", 775 | "import torch.nn as nn\n", 776 | "# import wandb\n", 777 | "import json" 778 | ], 779 | "execution_count": null, 780 | "outputs": [] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "metadata": { 785 | "colab": { 786 | "base_uri": "https://localhost:8080/" 787 | }, 788 | "id": "putIoVOxw0vj", 789 | "outputId": "19374cc0-07af-4064-c84b-4aef8a922502" 790 | }, 791 | "source": [ 792 | "torch.cuda.is_available()" 793 | ], 794 | "execution_count": null, 795 | "outputs": [ 796 | { 797 | "output_type": "execute_result", 798 | "data": { 799 | "text/plain": [ 800 | "True" 801 | ] 802 | }, 803 | "metadata": { 804 | "tags": [] 805 | }, 806 | "execution_count": 4 807 | } 808 | ] 809 | }, 810 | { 811 | "cell_type": "markdown", 812 | "metadata": { 813 | "id": "ZCPO4NzPqrJ-" 814 | }, 815 | "source": [ 816 | "# Data" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "metadata": { 822 | "id": "IG7wPkwlOZwd" 823 | }, 824 | "source": [ 825 | "df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/train.csv\", index_col=False)" 826 | ], 827 | "execution_count": null, 828 | "outputs": [] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "metadata": { 833 | "id": "gMnWSvy1h-n8" 834 | }, 835 | "source": [ 836 | "test_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/test.csv\", index_col=False)" 837 | ], 838 | "execution_count": null, 839 | "outputs": [] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "metadata": { 844 | "colab": { 845 | "base_uri": "https://localhost:8080/", 846 | "height": 204 847 | }, 848 | "id": "JdMavsXeOlas", 849 | "outputId": "843bc5bd-d465-4930-c5f8-a80fa8ed75d9" 850 | }, 851 | "source": [ 852 | "df.head()" 853 | ], 854 | "execution_count": null, 855 | "outputs": [ 856 | { 857 | "output_type": "execute_result", 858 | "data": { 859 | "text/html": [ 860 | "
\n", 861 | "\n", 874 | "\n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | "
Class IndexTitleDescription
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...
13Carlyle Looks Toward Commercial Aerospace (Reu...Reuters - Private investment firm Carlyle Grou...
23Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...
33Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\\f...
43Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...
\n", 916 | "
" 917 | ], 918 | "text/plain": [ 919 | " Class Index ... Description\n", 920 | "0 3 ... Reuters - Short-sellers, Wall Street's dwindli...\n", 921 | "1 3 ... Reuters - Private investment firm Carlyle Grou...\n", 922 | "2 3 ... Reuters - Soaring crude prices plus worries\\ab...\n", 923 | "3 3 ... Reuters - Authorities have halted oil export\\f...\n", 924 | "4 3 ... AFP - Tearaway world oil prices, toppling reco...\n", 925 | "\n", 926 | "[5 rows x 3 columns]" 927 | ] 928 | }, 929 | "metadata": { 930 | "tags": [] 931 | }, 932 | "execution_count": 5 933 | } 934 | ] 935 | }, 936 | { 937 | "cell_type": "markdown", 938 | "metadata": { 939 | "id": "lx71wm_bLh5d" 940 | }, 941 | "source": [ 942 | "1-World\n", 943 | "\n", 944 | "2-Sports\n", 945 | "\n", 946 | "3-Business \n", 947 | "\n", 948 | "4-Sci/Tech" 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "metadata": { 954 | "colab": { 955 | "base_uri": "https://localhost:8080/" 956 | }, 957 | "id": "NdqkPRoUxVZR", 958 | "outputId": "8af2a544-64fd-4221-982c-d830c065377f" 959 | }, 960 | "source": [ 961 | "df['text'] = df['Title'].str.lower() + \" \" + df['Description'].str.lower()\n", 962 | "df['labels'] = df['Class Index'] - 1\n", 963 | "\n", 964 | "test_df['text'] = test_df['Title'].str.lower() + \" \" + test_df['Description'].str.lower()\n", 965 | "test_df['labels'] = test_df['Class Index'] - 1" 966 | ], 967 | "execution_count": null, 968 | "outputs": [ 969 | { 970 | "output_type": "stream", 971 | "text": [ 972 | "INFO:numexpr.utils:NumExpr defaulting to 2 threads.\n" 973 | ], 974 | "name": "stderr" 975 | } 976 | ] 977 | }, 978 | { 979 | "cell_type": "code", 980 | "metadata": { 981 | "colab": { 982 | "base_uri": "https://localhost:8080/", 983 | "height": 1000 984 | }, 985 | "id": "YoYDss_RxpM2", 986 | "outputId": "19eba657-24ea-47f4-9156-faa3c374e874" 987 | }, 988 | "source": [ 989 | "display(df)\n", 990 | "display(test_df)" 991 | ], 992 | "execution_count": null, 993 | "outputs": [ 994 | { 995 | "output_type": "display_data", 996 | "data": { 997 | "text/html": [ 998 | "
\n", 999 | "\n", 1012 | "\n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | "
Class IndexTitleDescriptiontextlabels
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...wall st. bears claw back into the black (reute...2
13Carlyle Looks Toward Commercial Aerospace (Reu...Reuters - Private investment firm Carlyle Grou...carlyle looks toward commercial aerospace (reu...2
23Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...oil and economy cloud stocks' outlook (reuters...2
33Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\\f...iraq halts oil exports from main southern pipe...2
43Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...oil prices soar to all-time record, posing new...2
..................
1199951Pakistan's Musharraf Says Won't Quit as Army C...KARACHI (Reuters) - Pakistani President Perve...pakistan's musharraf says won't quit as army c...0
1199962Renteria signing a top-shelf dealRed Sox general manager Theo Epstein acknowled...renteria signing a top-shelf deal red sox gene...1
1199972Saban not going to Dolphins yetThe Miami Dolphins will put their courtship of...saban not going to dolphins yet the miami dolp...1
1199982Today's NFL gamesPITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...today's nfl games pittsburgh at ny giants time...1
1199992Nets get Carter from RaptorsINDIANAPOLIS -- All-Star Vince Carter was trad...nets get carter from raptors indianapolis -- a...1
\n", 1114 | "

120000 rows × 5 columns

\n", 1115 | "
" 1116 | ], 1117 | "text/plain": [ 1118 | " Class Index ... labels\n", 1119 | "0 3 ... 2\n", 1120 | "1 3 ... 2\n", 1121 | "2 3 ... 2\n", 1122 | "3 3 ... 2\n", 1123 | "4 3 ... 2\n", 1124 | "... ... ... ...\n", 1125 | "119995 1 ... 0\n", 1126 | "119996 2 ... 1\n", 1127 | "119997 2 ... 1\n", 1128 | "119998 2 ... 1\n", 1129 | "119999 2 ... 1\n", 1130 | "\n", 1131 | "[120000 rows x 5 columns]" 1132 | ] 1133 | }, 1134 | "metadata": { 1135 | "tags": [] 1136 | } 1137 | }, 1138 | { 1139 | "output_type": "display_data", 1140 | "data": { 1141 | "text/html": [ 1142 | "
\n", 1143 | "\n", 1156 | "\n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | " \n", 1222 | " \n", 1223 | " \n", 1224 | " \n", 1225 | " \n", 1226 | " \n", 1227 | " \n", 1228 | " \n", 1229 | " \n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | "
Class IndexTitleDescriptiontextlabels
03Fears for T N pension after talksUnions representing workers at Turner Newall...fears for t n pension after talks unions repre...2
14The Race is On: Second Private Team Sets Launc...SPACE.com - TORONTO, Canada -- A second\\team o...the race is on: second private team sets launc...3
24Ky. Company Wins Grant to Study Peptides (AP)AP - A company founded by a chemistry research...ky. company wins grant to study peptides (ap) ...3
34Prediction Unit Helps Forecast Wildfires (AP)AP - It's barely dawn when Mike Fitzpatrick st...prediction unit helps forecast wildfires (ap) ...3
44Calif. Aims to Limit Farm-Related Smog (AP)AP - Southern California's smog-fighting agenc...calif. aims to limit farm-related smog (ap) ap...3
..................
75951Around the worldUkrainian presidential candidate Viktor Yushch...around the world ukrainian presidential candid...0
75962Void is filled with ClementWith the supply of attractive pitching options...void is filled with clement with the supply of...1
75972Martinez leaves bitterLike Roger Clemens did almost exactly eight ye...martinez leaves bitter like roger clemens did ...1
759835 of arthritis patients in Singapore take Bext...SINGAPORE : Doctors in the United States have ...5 of arthritis patients in singapore take bext...2
75993EBay gets into rentalsEBay plans to buy the apartment and home renta...ebay gets into rentals ebay plans to buy the a...2
\n", 1258 | "

7600 rows × 5 columns

\n", 1259 | "
" 1260 | ], 1261 | "text/plain": [ 1262 | " Class Index ... labels\n", 1263 | "0 3 ... 2\n", 1264 | "1 4 ... 3\n", 1265 | "2 4 ... 3\n", 1266 | "3 4 ... 3\n", 1267 | "4 4 ... 3\n", 1268 | "... ... ... ...\n", 1269 | "7595 1 ... 0\n", 1270 | "7596 2 ... 1\n", 1271 | "7597 2 ... 1\n", 1272 | "7598 3 ... 2\n", 1273 | "7599 3 ... 2\n", 1274 | "\n", 1275 | "[7600 rows x 5 columns]" 1276 | ] 1277 | }, 1278 | "metadata": { 1279 | "tags": [] 1280 | } 1281 | } 1282 | ] 1283 | }, 1284 | { 1285 | "cell_type": "code", 1286 | "metadata": { 1287 | "id": "iXLji3JeVikX" 1288 | }, 1289 | "source": [ 1290 | "np.random.seed(100)\n", 1291 | "train_idx = np.random.choice(df.index, size=int(df.shape[0]*0.8), replace=False)\n", 1292 | "valid_idx = set(df.index) - set(train_idx)\n", 1293 | "\n", 1294 | "train_df = df[df.index.isin(train_idx)]\n", 1295 | "valid_df = df[df.index.isin(valid_idx)]" 1296 | ], 1297 | "execution_count": null, 1298 | "outputs": [] 1299 | }, 1300 | { 1301 | "cell_type": "code", 1302 | "metadata": { 1303 | "colab": { 1304 | "base_uri": "https://localhost:8080/" 1305 | }, 1306 | "id": "0Fm5RGxGyUOx", 1307 | "outputId": "0b775f9f-a550-4953-a7fd-3509cca2987b" 1308 | }, 1309 | "source": [ 1310 | "train_df.shape[0] + valid_df.shape[0] == df.shape[0]" 1311 | ], 1312 | "execution_count": null, 1313 | "outputs": [ 1314 | { 1315 | "output_type": "execute_result", 1316 | "data": { 1317 | "text/plain": [ 1318 | "True" 1319 | ] 1320 | }, 1321 | "metadata": { 1322 | "tags": [] 1323 | }, 1324 | "execution_count": 8 1325 | } 1326 | ] 1327 | }, 1328 | { 1329 | "cell_type": "code", 1330 | "metadata": { 1331 | "id": "UU8EXVsX8c4j" 1332 | }, 1333 | "source": [ 1334 | "os.mkdir(\"/content/drive/MyDrive/fsdl_project/data\")" 1335 | ], 1336 | "execution_count": null, 1337 | "outputs": [] 1338 | }, 1339 | { 1340 | "cell_type": "code", 1341 | "metadata": { 1342 | "id": "UYcQPYI68o5a" 1343 | }, 1344 | "source": [ 1345 | "os.mkdir(\"/content/drive/MyDrive/fsdl_project/data/baseline\")" 1346 | ], 1347 | "execution_count": null, 1348 | "outputs": [] 1349 | }, 1350 | { 1351 | "cell_type": "code", 1352 | "metadata": { 1353 | "id": "Oy6-nLIO8xEk" 1354 | }, 1355 | "source": [ 1356 | "os.mkdir(\"/content/drive/MyDrive/fsdl_project/data/active_learning\")" 1357 | ], 1358 | "execution_count": null, 1359 | "outputs": [] 1360 | }, 1361 | { 1362 | "cell_type": "code", 1363 | "metadata": { 1364 | "id": "hu3jpJ7R8jti" 1365 | }, 1366 | "source": [ 1367 | "train_df.to_csv(\"/content/drive/MyDrive/fsdl_project/data/baseline/train.csv.gz\", sep=\"|\", index=False, compression=\"gzip\")\n", 1368 | "valid_df.to_csv(\"/content/drive/MyDrive/fsdl_project/data/baseline/valid.csv.gz\", sep=\"|\", index=False, compression=\"gzip\")\n", 1369 | "test_df.to_csv(\"/content/drive/MyDrive/fsdl_project/data/baseline/test.csv.gz\", sep=\"|\", index=False, compression=\"gzip\")" 1370 | ], 1371 | "execution_count": null, 1372 | "outputs": [] 1373 | }, 1374 | { 1375 | "cell_type": "code", 1376 | "metadata": { 1377 | "id": "BeaHRTtrTGiG" 1378 | }, 1379 | "source": [ 1380 | "train_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/baseline/train.csv.gz\", sep=\"|\", index_col=False)\n", 1381 | "valid_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/baseline/valid.csv.gz\", sep=\"|\", index_col=False)\n", 1382 | "test_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/baseline/test.csv.gz\", sep=\"|\", index_col=False)" 1383 | ], 1384 | "execution_count": null, 1385 | "outputs": [] 1386 | }, 1387 | { 1388 | "cell_type": "code", 1389 | "metadata": { 1390 | "colab": { 1391 | "base_uri": "https://localhost:8080/", 1392 | "height": 204 1393 | }, 1394 | "id": "G4Tq7Gt9tDIq", 1395 | "outputId": "90b3aedb-d143-4c6d-8eea-a7dcbda456f8" 1396 | }, 1397 | "source": [ 1398 | "train_df.head()" 1399 | ], 1400 | "execution_count": null, 1401 | "outputs": [ 1402 | { 1403 | "output_type": "execute_result", 1404 | "data": { 1405 | "text/html": [ 1406 | "
\n", 1407 | "\n", 1420 | "\n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1438 | " \n", 1439 | " \n", 1440 | " \n", 1441 | " \n", 1442 | " \n", 1443 | " \n", 1444 | " \n", 1445 | " \n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | " \n", 1455 | " \n", 1456 | " \n", 1457 | " \n", 1458 | " \n", 1459 | " \n", 1460 | " \n", 1461 | " \n", 1462 | " \n", 1463 | " \n", 1464 | " \n", 1465 | " \n", 1466 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | "
Class IndexTitleDescriptiontextlabels
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...reuters - short-sellers, wall street's dwindli...2
13Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...reuters - soaring crude prices plus worries\\ab...2
23Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\\f...reuters - authorities have halted oil export\\f...2
33Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...afp - tearaway world oil prices, toppling reco...2
43Stocks End Up, But Near Year Lows (Reuters)Reuters - Stocks ended slightly higher on Frid...reuters - stocks ended slightly higher on frid...2
\n", 1474 | "
" 1475 | ], 1476 | "text/plain": [ 1477 | " Class Index ... labels\n", 1478 | "0 3 ... 2\n", 1479 | "1 3 ... 2\n", 1480 | "2 3 ... 2\n", 1481 | "3 3 ... 2\n", 1482 | "4 3 ... 2\n", 1483 | "\n", 1484 | "[5 rows x 5 columns]" 1485 | ] 1486 | }, 1487 | "metadata": { 1488 | "tags": [] 1489 | }, 1490 | "execution_count": 7 1491 | } 1492 | ] 1493 | }, 1494 | { 1495 | "cell_type": "code", 1496 | "metadata": { 1497 | "id": "FcCGWUbJTa97" 1498 | }, 1499 | "source": [ 1500 | "train_df['text'] = train_df['Description'].str.lower()\n", 1501 | "valid_df['text'] = valid_df['Description'].str.lower()\n", 1502 | "test_df['text'] = test_df['Description'].str.lower()" 1503 | ], 1504 | "execution_count": null, 1505 | "outputs": [] 1506 | }, 1507 | { 1508 | "cell_type": "markdown", 1509 | "metadata": { 1510 | "id": "75vkSUMSqxEw" 1511 | }, 1512 | "source": [ 1513 | "# Model" 1514 | ] 1515 | }, 1516 | { 1517 | "cell_type": "code", 1518 | "metadata": { 1519 | "id": "bdTIFCxhycuf" 1520 | }, 1521 | "source": [ 1522 | "## Training RoBERTa on full training set to obtain baseline accuracy on test data\n", 1523 | "\n", 1524 | "# Optional model configuration\n", 1525 | "\n", 1526 | "model_args = ClassificationArgs(num_train_epochs=5, \n", 1527 | " overwrite_output_dir= True, \n", 1528 | " train_batch_size=16,\n", 1529 | " max_seq_length=250, \n", 1530 | " wandb_project= 'active_learning_baseline_v2', \n", 1531 | " best_model_dir=\"/content/drive/MyDrive/fsdl_project/model/baseline/best_model/20210418\",\n", 1532 | " cache_dir=\"/content/drive/MyDrive/fsdl_project/cache/baseline/20210418\",\n", 1533 | " eval_batch_size=16,\n", 1534 | " evaluate_during_training=True,\n", 1535 | " evaluate_during_training_verbose=True,\n", 1536 | " manual_seed=100,\n", 1537 | " output_dir=\"content/drive/MyDrive/fsdl_project/output/baseline/20210418\",\n", 1538 | " use_early_stopping=True,\n", 1539 | " early_stopping_patience=3,\n", 1540 | " )\n", 1541 | "\n", 1542 | "\n", 1543 | "# Create a ClassificationModel\n", 1544 | "model = ClassificationModel(\n", 1545 | "\"roberta\", \"roberta-base\", args=model_args, use_cuda=True, num_labels=4,\n", 1546 | ")" 1547 | ], 1548 | "execution_count": null, 1549 | "outputs": [] 1550 | }, 1551 | { 1552 | "cell_type": "code", 1553 | "metadata": { 1554 | "id": "FvN6-Txc0OKy" 1555 | }, 1556 | "source": [ 1557 | "model.train_model(train_df=train_df, eval_df=valid_df, accuracy=accuracy_score)" 1558 | ], 1559 | "execution_count": null, 1560 | "outputs": [] 1561 | }, 1562 | { 1563 | "cell_type": "code", 1564 | "metadata": { 1565 | "id": "ZEnJVeDV3Qhv" 1566 | }, 1567 | "source": [ 1568 | "## Loading model for inference\n", 1569 | "# Create a ClassificationModel\n", 1570 | "model = ClassificationModel(\n", 1571 | "\"roberta\", \"/content/drive/MyDrive/fsdl_project\" ,\n", 1572 | ")" 1573 | ], 1574 | "execution_count": null, 1575 | "outputs": [] 1576 | }, 1577 | { 1578 | "cell_type": "code", 1579 | "metadata": { 1580 | "id": "-WP9tSaA1iqD" 1581 | }, 1582 | "source": [ 1583 | "valid_result, valid_model_outputs, valid_wrong_predictions = model.eval_model(valid_df, accuracy=accuracy_score)" 1584 | ], 1585 | "execution_count": null, 1586 | "outputs": [] 1587 | }, 1588 | { 1589 | "cell_type": "code", 1590 | "metadata": { 1591 | "id": "TrNd911FDH42" 1592 | }, 1593 | "source": [ 1594 | "train_result, train_model_outputs, train_wrong_predictions = model.eval_model(train_df, accuracy=accuracy_score)" 1595 | ], 1596 | "execution_count": null, 1597 | "outputs": [] 1598 | }, 1599 | { 1600 | "cell_type": "code", 1601 | "metadata": { 1602 | "id": "PzQzwZwL8dQr" 1603 | }, 1604 | "source": [ 1605 | "sf = nn.Softmax(dim=1)" 1606 | ], 1607 | "execution_count": null, 1608 | "outputs": [] 1609 | }, 1610 | { 1611 | "cell_type": "code", 1612 | "metadata": { 1613 | "colab": { 1614 | "base_uri": "https://localhost:8080/" 1615 | }, 1616 | "id": "cjbARyuq8Nck", 1617 | "outputId": "cc41fec1-a71f-4fac-ad7e-6063d156d848" 1618 | }, 1619 | "source": [ 1620 | "np.mean(torch.argmax(sf(torch.tensor(model_outputs)), dim=1).numpy() == valid_df['labels'].values.ravel())" 1621 | ], 1622 | "execution_count": null, 1623 | "outputs": [ 1624 | { 1625 | "output_type": "execute_result", 1626 | "data": { 1627 | "text/plain": [ 1628 | "0.941625" 1629 | ] 1630 | }, 1631 | "metadata": { 1632 | "tags": [] 1633 | }, 1634 | "execution_count": 28 1635 | } 1636 | ] 1637 | }, 1638 | { 1639 | "cell_type": "code", 1640 | "metadata": { 1641 | "id": "DnwEsTa37hdX" 1642 | }, 1643 | "source": [ 1644 | "wandb.log({'best_train_accuracy': })" 1645 | ], 1646 | "execution_count": null, 1647 | "outputs": [] 1648 | }, 1649 | { 1650 | "cell_type": "code", 1651 | "metadata": { 1652 | "colab": { 1653 | "base_uri": "https://localhost:8080/", 1654 | "height": 306 1655 | }, 1656 | "id": "iYw9JK3Z_zk2", 1657 | "outputId": "b65eb5d6-65c3-440a-bb2f-5f11e3979e1f" 1658 | }, 1659 | "source": [ 1660 | "test_df.head()" 1661 | ], 1662 | "execution_count": null, 1663 | "outputs": [ 1664 | { 1665 | "output_type": "execute_result", 1666 | "data": { 1667 | "text/html": [ 1668 | "
\n", 1669 | "\n", 1682 | "\n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | " \n", 1721 | " \n", 1722 | " \n", 1723 | " \n", 1724 | " \n", 1725 | " \n", 1726 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | "
Class IndexTitleDescriptiontextlabels
03Fears for T N pension after talksUnions representing workers at Turner Newall...fears for t n pension after talks unions repre...2
14The Race is On: Second Private Team Sets Launc...SPACE.com - TORONTO, Canada -- A second\\team o...the race is on: second private team sets launc...3
24Ky. Company Wins Grant to Study Peptides (AP)AP - A company founded by a chemistry research...ky. company wins grant to study peptides (ap) ...3
34Prediction Unit Helps Forecast Wildfires (AP)AP - It's barely dawn when Mike Fitzpatrick st...prediction unit helps forecast wildfires (ap) ...3
44Calif. Aims to Limit Farm-Related Smog (AP)AP - Southern California's smog-fighting agenc...calif. aims to limit farm-related smog (ap) ap...3
\n", 1736 | "
" 1737 | ], 1738 | "text/plain": [ 1739 | " Class Index ... labels\n", 1740 | "0 3 ... 2\n", 1741 | "1 4 ... 3\n", 1742 | "2 4 ... 3\n", 1743 | "3 4 ... 3\n", 1744 | "4 4 ... 3\n", 1745 | "\n", 1746 | "[5 rows x 5 columns]" 1747 | ] 1748 | }, 1749 | "metadata": { 1750 | "tags": [] 1751 | }, 1752 | "execution_count": 29 1753 | } 1754 | ] 1755 | }, 1756 | { 1757 | "cell_type": "code", 1758 | "metadata": { 1759 | "id": "VMfm5GpE_2hn" 1760 | }, 1761 | "source": [ 1762 | "\n", 1763 | "test_result, test_model_outputs, test_wrong_predictions = model.eval_model(test_df, accuracy = accuracy_score)" 1764 | ], 1765 | "execution_count": null, 1766 | "outputs": [] 1767 | }, 1768 | { 1769 | "cell_type": "code", 1770 | "metadata": { 1771 | "colab": { 1772 | "base_uri": "https://localhost:8080/" 1773 | }, 1774 | "id": "NSkVsMSKAlvM", 1775 | "outputId": "0956f20c-501d-4ab9-e243-b150ade44e94" 1776 | }, 1777 | "source": [ 1778 | "log_to_wandb" 1779 | ], 1780 | "execution_count": null, 1781 | "outputs": [ 1782 | { 1783 | "output_type": "execute_result", 1784 | "data": { 1785 | "text/plain": [ 1786 | "{'test_accuracy': 0.9388157894736842,\n", 1787 | " 'test_eval_loss': 0.24408445681959978,\n", 1788 | " 'test_mcc': 0.9184682069651259,\n", 1789 | " 'train_accuracy': 0.96821875,\n", 1790 | " 'train_eval_loss': 0.11901225684736952,\n", 1791 | " 'train_mcc': 0.9576493149383133,\n", 1792 | " 'valid_eval_loss': 0.22651522111527933,\n", 1793 | " 'valid_mcc': 0.9222124561368541}" 1794 | ] 1795 | }, 1796 | "metadata": { 1797 | "tags": [] 1798 | }, 1799 | "execution_count": 44 1800 | } 1801 | ] 1802 | }, 1803 | { 1804 | "cell_type": "code", 1805 | "metadata": { 1806 | "id": "2ot8BAIbII1P" 1807 | }, 1808 | "source": [ 1809 | "log_to_wandb = {f'test_{key}': item for key, item in test_result.items()}\n", 1810 | "log_to_wandb.update({f'train_{key}': item for key, item in train_result.items()})\n", 1811 | "log_to_wandb.update({f'valid_{key}': item for key, item in result.items()})" 1812 | ], 1813 | "execution_count": null, 1814 | "outputs": [] 1815 | }, 1816 | { 1817 | "cell_type": "code", 1818 | "metadata": { 1819 | "id": "lxnr63qRJnS1" 1820 | }, 1821 | "source": [ 1822 | "import json" 1823 | ], 1824 | "execution_count": null, 1825 | "outputs": [] 1826 | }, 1827 | { 1828 | "cell_type": "code", 1829 | "metadata": { 1830 | "id": "i1j--jrGJofS" 1831 | }, 1832 | "source": [ 1833 | "import os" 1834 | ], 1835 | "execution_count": null, 1836 | "outputs": [] 1837 | }, 1838 | { 1839 | "cell_type": "code", 1840 | "metadata": { 1841 | "id": "IaWAeRepJptX" 1842 | }, 1843 | "source": [ 1844 | "os.mkdir(\"/content/drive/MyDrive/fsdl_project/baseline\")" 1845 | ], 1846 | "execution_count": null, 1847 | "outputs": [] 1848 | }, 1849 | { 1850 | "cell_type": "code", 1851 | "metadata": { 1852 | "id": "HKNluWxeJvGO" 1853 | }, 1854 | "source": [ 1855 | "with open(\"/content/drive/MyDrive/fsdl_project/baseline/exp_stats.json\", 'w') as outfile:\n", 1856 | " json.dump(log_to_wandb, outfile, indent=4)" 1857 | ], 1858 | "execution_count": null, 1859 | "outputs": [] 1860 | }, 1861 | { 1862 | "cell_type": "code", 1863 | "metadata": { 1864 | "id": "whSuMKRdIPTs" 1865 | }, 1866 | "source": [ 1867 | "wandb.log(log_to_wandb)" 1868 | ], 1869 | "execution_count": null, 1870 | "outputs": [] 1871 | }, 1872 | { 1873 | "cell_type": "code", 1874 | "metadata": { 1875 | "id": "4zUIXkzWWyhx" 1876 | }, 1877 | "source": [ 1878 | "train_df = df.sample(int(df.shape[0]*0.1))" 1879 | ], 1880 | "execution_count": null, 1881 | "outputs": [] 1882 | }, 1883 | { 1884 | "cell_type": "markdown", 1885 | "metadata": { 1886 | "id": "swUNHGDM7SAk" 1887 | }, 1888 | "source": [ 1889 | "" 1890 | ] 1891 | }, 1892 | { 1893 | "cell_type": "code", 1894 | "metadata": { 1895 | "colab": { 1896 | "base_uri": "https://localhost:8080/", 1897 | "height": 204 1898 | }, 1899 | "id": "Auh0UMgoXHfX", 1900 | "outputId": "d4b66311-4725-467a-ea7f-1e1a191fa609" 1901 | }, 1902 | "source": [ 1903 | "train_df.head()" 1904 | ], 1905 | "execution_count": null, 1906 | "outputs": [ 1907 | { 1908 | "output_type": "execute_result", 1909 | "data": { 1910 | "text/html": [ 1911 | "
\n", 1912 | "\n", 1925 | "\n", 1926 | " \n", 1927 | " \n", 1928 | " \n", 1929 | " \n", 1930 | " \n", 1931 | " \n", 1932 | " \n", 1933 | " \n", 1934 | " \n", 1935 | " \n", 1936 | " \n", 1937 | " \n", 1938 | " \n", 1939 | " \n", 1940 | " \n", 1941 | " \n", 1942 | " \n", 1943 | " \n", 1944 | " \n", 1945 | " \n", 1946 | " \n", 1947 | " \n", 1948 | " \n", 1949 | " \n", 1950 | " \n", 1951 | " \n", 1952 | " \n", 1953 | " \n", 1954 | " \n", 1955 | " \n", 1956 | " \n", 1957 | " \n", 1958 | " \n", 1959 | " \n", 1960 | " \n", 1961 | " \n", 1962 | " \n", 1963 | " \n", 1964 | " \n", 1965 | " \n", 1966 | "
Class IndexTitleDescription
407601Saudi Troops, Gunmen Clash in RiyadhSaudi security forces, battling a wave of terr...
856701Rebels Kill 45 in Attacks in Iraq #39;s BaqubaInsurgent attacks and clashes killed 45 people...
1146081Arab press roundup: December 13, 2004Arab newspaper highlighted and commented on PL...
915442NCAA Wins Right to Limit TournamentsNew Mexico #39;s Mark Walters (5) is almost tr...
1131984Netflix CEO Rates Blockbuster, Amazon Threats ...Reuters - Netflix Inc chief\\executive Reed Has...
\n", 1967 | "
" 1968 | ], 1969 | "text/plain": [ 1970 | " Class Index ... Description\n", 1971 | "40760 1 ... Saudi security forces, battling a wave of terr...\n", 1972 | "85670 1 ... Insurgent attacks and clashes killed 45 people...\n", 1973 | "114608 1 ... Arab newspaper highlighted and commented on PL...\n", 1974 | "91544 2 ... New Mexico #39;s Mark Walters (5) is almost tr...\n", 1975 | "113198 4 ... Reuters - Netflix Inc chief\\executive Reed Has...\n", 1976 | "\n", 1977 | "[5 rows x 3 columns]" 1978 | ] 1979 | }, 1980 | "metadata": { 1981 | "tags": [] 1982 | }, 1983 | "execution_count": 7 1984 | } 1985 | ] 1986 | }, 1987 | { 1988 | "cell_type": "code", 1989 | "metadata": { 1990 | "id": "tNzv1VD09VTm" 1991 | }, 1992 | "source": [ 1993 | "" 1994 | ], 1995 | "execution_count": null, 1996 | "outputs": [] 1997 | }, 1998 | { 1999 | "cell_type": "markdown", 2000 | "metadata": { 2001 | "id": "82jF0DEL9W8o" 2002 | }, 2003 | "source": [ 2004 | "Training on reduced dataset for active learning" 2005 | ] 2006 | }, 2007 | { 2008 | "cell_type": "code", 2009 | "metadata": { 2010 | "id": "TKJhkEdlXLBa" 2011 | }, 2012 | "source": [ 2013 | "np.random.seed(100)\n", 2014 | "train_al_idx = np.random.choice(train_df.index, size=int(train_df.shape[0]*0.3), replace=False)\n", 2015 | "annotate_idx = list(set(train_df.index) - set(train_al_idx))\n", 2016 | "train_df_al = train_df[train_df.index.isin(train_al_idx)]\n", 2017 | "annotate_df = train_df[train_df.index.isin(annotate_idx)]" 2018 | ], 2019 | "execution_count": null, 2020 | "outputs": [] 2021 | }, 2022 | { 2023 | "cell_type": "code", 2024 | "metadata": { 2025 | "colab": { 2026 | "base_uri": "https://localhost:8080/", 2027 | "height": 51 2028 | }, 2029 | "id": "6V2yXZT_aC3f", 2030 | "outputId": "45ae184e-00bb-4700-c574-123c31304bf9" 2031 | }, 2032 | "source": [ 2033 | "display(set(annotate_idx).intersection(train_al_idx))\n", 2034 | "display(annotate_df.shape[0] + train_df_al.shape[0] == train_df.shape[0])" 2035 | ], 2036 | "execution_count": null, 2037 | "outputs": [ 2038 | { 2039 | "output_type": "display_data", 2040 | "data": { 2041 | "text/plain": [ 2042 | "set()" 2043 | ] 2044 | }, 2045 | "metadata": { 2046 | "tags": [] 2047 | } 2048 | }, 2049 | { 2050 | "output_type": "display_data", 2051 | "data": { 2052 | "text/plain": [ 2053 | "True" 2054 | ] 2055 | }, 2056 | "metadata": { 2057 | "tags": [] 2058 | } 2059 | } 2060 | ] 2061 | }, 2062 | { 2063 | "cell_type": "code", 2064 | "metadata": { 2065 | "id": "Gf1vl-hnT4Or", 2066 | "colab": { 2067 | "base_uri": "https://localhost:8080/" 2068 | }, 2069 | "outputId": "d566ed6e-e2b5-4f59-c0c4-ac6e9ebb3b5a" 2070 | }, 2071 | "source": [ 2072 | "annotate_df['idx'] = annotate_df.index" 2073 | ], 2074 | "execution_count": null, 2075 | "outputs": [ 2076 | { 2077 | "output_type": "stream", 2078 | "text": [ 2079 | "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: \n", 2080 | "A value is trying to be set on a copy of a slice from a DataFrame.\n", 2081 | "Try using .loc[row_indexer,col_indexer] = value instead\n", 2082 | "\n", 2083 | "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", 2084 | " \"\"\"Entry point for launching an IPython kernel.\n" 2085 | ], 2086 | "name": "stderr" 2087 | } 2088 | ] 2089 | }, 2090 | { 2091 | "cell_type": "code", 2092 | "metadata": { 2093 | "id": "yU3q5BVj-bwx", 2094 | "colab": { 2095 | "base_uri": "https://localhost:8080/", 2096 | "height": 266 2097 | }, 2098 | "outputId": "79ba8bd4-d623-42c8-d866-61afcce12ede" 2099 | }, 2100 | "source": [ 2101 | "annotate_df[annotate_df['text'].str.contains(\"#name?\")].sort_values(\"text\")" 2102 | ], 2103 | "execution_count": null, 2104 | "outputs": [ 2105 | { 2106 | "output_type": "execute_result", 2107 | "data": { 2108 | "text/html": [ 2109 | "
\n", 2110 | "\n", 2123 | "\n", 2124 | " \n", 2125 | " \n", 2126 | " \n", 2127 | " \n", 2128 | " \n", 2129 | " \n", 2130 | " \n", 2131 | " \n", 2132 | " \n", 2133 | " \n", 2134 | " \n", 2135 | " \n", 2136 | " \n", 2137 | " \n", 2138 | " \n", 2139 | " \n", 2140 | " \n", 2141 | " \n", 2142 | " \n", 2143 | " \n", 2144 | " \n", 2145 | " \n", 2146 | " \n", 2147 | " \n", 2148 | " \n", 2149 | " \n", 2150 | " \n", 2151 | " \n", 2152 | " \n", 2153 | " \n", 2154 | " \n", 2155 | " \n", 2156 | " \n", 2157 | " \n", 2158 | " \n", 2159 | " \n", 2160 | " \n", 2161 | " \n", 2162 | " \n", 2163 | " \n", 2164 | " \n", 2165 | " \n", 2166 | " \n", 2167 | " \n", 2168 | " \n", 2169 | " \n", 2170 | " \n", 2171 | " \n", 2172 | " \n", 2173 | " \n", 2174 | " \n", 2175 | " \n", 2176 | " \n", 2177 | " \n", 2178 | " \n", 2179 | " \n", 2180 | " \n", 2181 | " \n", 2182 | " \n", 2183 | " \n", 2184 | " \n", 2185 | " \n", 2186 | " \n", 2187 | " \n", 2188 | " \n", 2189 | " \n", 2190 | " \n", 2191 | " \n", 2192 | " \n", 2193 | " \n", 2194 | " \n", 2195 | " \n", 2196 | " \n", 2197 | " \n", 2198 | " \n", 2199 | " \n", 2200 | "
Class IndexTitleDescriptiontextlabelsidx
142452Sabres agree to terms with 2003 first-round pi...#NAME?#name?114245
182962Top of 3rd#NAME?#name?118296
191092Blues re-sign D Backman, four others#NAME?#name?119109
221402Wild re-sign D Schultz#NAME?#name?122140
231742Predators re-sign D Zidlicky#NAME?#name?123174
366652- UMPIRES: Home,Andy Fletcher; First, Tim Welk...#NAME?#name?136665
808941Lynching of agents puts Mexico focus on vigila...#NAME?#name?080894
\n", 2201 | "
" 2202 | ], 2203 | "text/plain": [ 2204 | " Class Index ... idx\n", 2205 | "14245 2 ... 14245\n", 2206 | "18296 2 ... 18296\n", 2207 | "19109 2 ... 19109\n", 2208 | "22140 2 ... 22140\n", 2209 | "23174 2 ... 23174\n", 2210 | "36665 2 ... 36665\n", 2211 | "80894 1 ... 80894\n", 2212 | "\n", 2213 | "[7 rows x 6 columns]" 2214 | ] 2215 | }, 2216 | "metadata": { 2217 | "tags": [] 2218 | }, 2219 | "execution_count": 39 2220 | } 2221 | ] 2222 | }, 2223 | { 2224 | "cell_type": "code", 2225 | "metadata": { 2226 | "id": "En1lrNP27e1E", 2227 | "colab": { 2228 | "base_uri": "https://localhost:8080/", 2229 | "height": 37 2230 | }, 2231 | "outputId": "cfa39faa-8c8a-4064-e952-e5a7fc655791" 2232 | }, 2233 | "source": [ 2234 | "annotate_df.loc[annotate_df.idx == 9566, 'text'].values[0]" 2235 | ], 2236 | "execution_count": null, 2237 | "outputs": [ 2238 | { 2239 | "output_type": "execute_result", 2240 | "data": { 2241 | "application/vnd.google.colaboratory.intrinsic+json": { 2242 | "type": "string" 2243 | }, 2244 | "text/plain": [ 2245 | "'coach joins the s p 500, and others stand to benefit from the leather in the weather.'" 2246 | ] 2247 | }, 2248 | "metadata": { 2249 | "tags": [] 2250 | }, 2251 | "execution_count": 35 2252 | } 2253 | ] 2254 | }, 2255 | { 2256 | "cell_type": "code", 2257 | "metadata": { 2258 | "id": "lIQVGP79XWCf" 2259 | }, 2260 | "source": [ 2261 | "annotate_df[annotate_df['text'].str.contains()]" 2262 | ], 2263 | "execution_count": null, 2264 | "outputs": [] 2265 | }, 2266 | { 2267 | "cell_type": "code", 2268 | "metadata": { 2269 | "id": "gqdIPvA1Xnxu", 2270 | "colab": { 2271 | "base_uri": "https://localhost:8080/", 2272 | "height": 409 2273 | }, 2274 | "outputId": "38f44f53-8f33-4e6c-f355-b974b81ea9a2" 2275 | }, 2276 | "source": [ 2277 | "## Loading all data to get accuracy scores on them, logits, and probability\n", 2278 | "\n", 2279 | "train_df_al = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/train.csv.gz\", sep=\"|\", index_col=False)\n", 2280 | "valid_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/valid.csv.gz\", sep=\"|\", index_col=False)\n", 2281 | "test_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/test.csv.gz\", sep=\"|\", index_col=False)\n", 2282 | "annotate_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/annotate.csv.gz\", sep=\"|\", index_col=False)" 2283 | ], 2284 | "execution_count": null, 2285 | "outputs": [ 2286 | { 2287 | "output_type": "error", 2288 | "ename": "FileNotFoundError", 2289 | "evalue": "ignored", 2290 | "traceback": [ 2291 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 2292 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 2293 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mvalid_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/drive/MyDrive/fsdl_project/data/active_learning/valid.csv.gz\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msep\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"|\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex_col\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtest_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/drive/MyDrive/fsdl_project/data/active_learning/test.csv.gz\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msep\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"|\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex_col\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mannotate_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/content/drive/MyDrive/fsdl_project/data/active_learning/annotate.csv.gz\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msep\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"|\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex_col\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 2294 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, squeeze, prefix, mangle_dupe_cols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, dialect, error_bad_lines, warn_bad_lines, delim_whitespace, low_memory, memory_map, float_precision)\u001b[0m\n\u001b[1;32m 686\u001b[0m )\n\u001b[1;32m 687\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 688\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_read\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath_or_buffer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 690\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 2295 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0;31m# Create the parser.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 454\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextFileReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp_or_buf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 455\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 456\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mchunksize\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2296 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 946\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"has_index_names\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 947\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 948\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 949\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 950\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2297 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m_make_engine\u001b[0;34m(self, engine)\u001b[0m\n\u001b[1;32m 1178\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_make_engine\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mengine\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"c\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1179\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mengine\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"c\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1180\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCParserWrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1181\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1182\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mengine\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"python\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2298 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/io/parsers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, src, **kwds)\u001b[0m\n\u001b[1;32m 2008\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"usecols\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0musecols\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2009\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2010\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparsers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTextReader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2011\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munnamed_cols\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munnamed_cols\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2012\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 2299 | "\u001b[0;32mpandas/_libs/parsers.pyx\u001b[0m in \u001b[0;36mpandas._libs.parsers.TextReader.__cinit__\u001b[0;34m()\u001b[0m\n", 2300 | "\u001b[0;32mpandas/_libs/parsers.pyx\u001b[0m in \u001b[0;36mpandas._libs.parsers.TextReader._setup_parser_source\u001b[0;34m()\u001b[0m\n", 2301 | "\u001b[0;32m/usr/lib/python3.7/gzip.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, mode, compresslevel, fileobj, mtime)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m'b'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfileobj\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0mfileobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmyfileobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbuiltins\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfileobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'name'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 2302 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/content/drive/MyDrive/fsdl_project/data/active_learning/annotate.csv.gz'" 2303 | ] 2304 | } 2305 | ] 2306 | }, 2307 | { 2308 | "cell_type": "code", 2309 | "metadata": { 2310 | "colab": { 2311 | "base_uri": "https://localhost:8080/", 2312 | "height": 204 2313 | }, 2314 | "id": "lrGhdaaM9_gx", 2315 | "outputId": "8dfa3e97-3eb3-408f-ffc6-e21bbf7d5965" 2316 | }, 2317 | "source": [ 2318 | "train_df_al.head()" 2319 | ], 2320 | "execution_count": null, 2321 | "outputs": [ 2322 | { 2323 | "output_type": "execute_result", 2324 | "data": { 2325 | "text/html": [ 2326 | "
\n", 2327 | "\n", 2340 | "\n", 2341 | " \n", 2342 | " \n", 2343 | " \n", 2344 | " \n", 2345 | " \n", 2346 | " \n", 2347 | " \n", 2348 | " \n", 2349 | " \n", 2350 | " \n", 2351 | " \n", 2352 | " \n", 2353 | " \n", 2354 | " \n", 2355 | " \n", 2356 | " \n", 2357 | " \n", 2358 | " \n", 2359 | " \n", 2360 | " \n", 2361 | " \n", 2362 | " \n", 2363 | " \n", 2364 | " \n", 2365 | " \n", 2366 | " \n", 2367 | " \n", 2368 | " \n", 2369 | " \n", 2370 | " \n", 2371 | " \n", 2372 | " \n", 2373 | " \n", 2374 | " \n", 2375 | " \n", 2376 | " \n", 2377 | " \n", 2378 | " \n", 2379 | " \n", 2380 | " \n", 2381 | " \n", 2382 | " \n", 2383 | " \n", 2384 | " \n", 2385 | " \n", 2386 | " \n", 2387 | " \n", 2388 | " \n", 2389 | " \n", 2390 | " \n", 2391 | " \n", 2392 | " \n", 2393 | "
Class IndexTitleDescriptiontextlabels
03Money Funds Fell in Latest Week (AP)AP - Assets of the nation's retail money marke...ap - assets of the nation's retail money marke...2
13Fed minutes show dissent over inflation (USATO...USATODAY.com - Retail sales bounced back a bit...usatoday.com - retail sales bounced back a bit...2
23Safety Net (Forbes.com)Forbes.com - After earning a PH.D. in Sociolog...forbes.com - after earning a ph.d. in sociolog...2
33No Need for OPEC to Pump More-Iran GovTEHRAN (Reuters) - OPEC can do nothing to dou...tehran (reuters) - opec can do nothing to dou...2
43Shell 'could be target for Total'Oil giant Shell could be bracing itself for a ...oil giant shell could be bracing itself for a ...2
\n", 2394 | "
" 2395 | ], 2396 | "text/plain": [ 2397 | " Class Index ... labels\n", 2398 | "0 3 ... 2\n", 2399 | "1 3 ... 2\n", 2400 | "2 3 ... 2\n", 2401 | "3 3 ... 2\n", 2402 | "4 3 ... 2\n", 2403 | "\n", 2404 | "[5 rows x 5 columns]" 2405 | ] 2406 | }, 2407 | "metadata": { 2408 | "tags": [] 2409 | }, 2410 | "execution_count": 4 2411 | } 2412 | ] 2413 | }, 2414 | { 2415 | "cell_type": "code", 2416 | "metadata": { 2417 | "colab": { 2418 | "base_uri": "https://localhost:8080/", 2419 | "height": 204 2420 | }, 2421 | "id": "Vfdz5kCSQrMT", 2422 | "outputId": "3cfae7d4-98ab-4258-e71d-9d1eba264e1f" 2423 | }, 2424 | "source": [ 2425 | "annotate_df.head()" 2426 | ], 2427 | "execution_count": null, 2428 | "outputs": [ 2429 | { 2430 | "output_type": "execute_result", 2431 | "data": { 2432 | "text/html": [ 2433 | "
\n", 2434 | "\n", 2447 | "\n", 2448 | " \n", 2449 | " \n", 2450 | " \n", 2451 | " \n", 2452 | " \n", 2453 | " \n", 2454 | " \n", 2455 | " \n", 2456 | " \n", 2457 | " \n", 2458 | " \n", 2459 | " \n", 2460 | " \n", 2461 | " \n", 2462 | " \n", 2463 | " \n", 2464 | " \n", 2465 | " \n", 2466 | " \n", 2467 | " \n", 2468 | " \n", 2469 | " \n", 2470 | " \n", 2471 | " \n", 2472 | " \n", 2473 | " \n", 2474 | " \n", 2475 | " \n", 2476 | " \n", 2477 | " \n", 2478 | " \n", 2479 | " \n", 2480 | " \n", 2481 | " \n", 2482 | " \n", 2483 | " \n", 2484 | " \n", 2485 | " \n", 2486 | " \n", 2487 | " \n", 2488 | " \n", 2489 | " \n", 2490 | " \n", 2491 | " \n", 2492 | " \n", 2493 | " \n", 2494 | " \n", 2495 | " \n", 2496 | " \n", 2497 | " \n", 2498 | " \n", 2499 | " \n", 2500 | " \n", 2501 | " \n", 2502 | " \n", 2503 | " \n", 2504 | " \n", 2505 | " \n", 2506 | "
Class IndexTitleDescriptiontextlabelsidx
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...reuters - short-sellers, wall street's dwindli...20
13Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...reuters - soaring crude prices plus worries\\ab...21
23Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\\f...reuters - authorities have halted oil export\\f...22
33Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...afp - tearaway world oil prices, toppling reco...23
43Stocks End Up, But Near Year Lows (Reuters)Reuters - Stocks ended slightly higher on Frid...reuters - stocks ended slightly higher on frid...24
\n", 2507 | "
" 2508 | ], 2509 | "text/plain": [ 2510 | " Class Index Title ... labels idx\n", 2511 | "0 3 Wall St. Bears Claw Back Into the Black (Reuters) ... 2 0\n", 2512 | "1 3 Oil and Economy Cloud Stocks' Outlook (Reuters) ... 2 1\n", 2513 | "2 3 Iraq Halts Oil Exports from Main Southern Pipe... ... 2 2\n", 2514 | "3 3 Oil prices soar to all-time record, posing new... ... 2 3\n", 2515 | "4 3 Stocks End Up, But Near Year Lows (Reuters) ... 2 4\n", 2516 | "\n", 2517 | "[5 rows x 6 columns]" 2518 | ] 2519 | }, 2520 | "metadata": { 2521 | "tags": [] 2522 | }, 2523 | "execution_count": 5 2524 | } 2525 | ] 2526 | }, 2527 | { 2528 | "cell_type": "code", 2529 | "metadata": { 2530 | "id": "uiCPrbuzQvOW" 2531 | }, 2532 | "source": [ 2533 | "## We are going to take only 1000 randomly choosen examples from above train data to train our model\n", 2534 | "train_df = train_df_al.sample(1000, random_state=100)" 2535 | ], 2536 | "execution_count": null, 2537 | "outputs": [] 2538 | }, 2539 | { 2540 | "cell_type": "code", 2541 | "metadata": { 2542 | "id": "SwRK6iaPQ8GP" 2543 | }, 2544 | "source": [ 2545 | "train_df.to_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/train_1000.csv.gz\", sep=\"|\", index=False, compression=\"gzip\")" 2546 | ], 2547 | "execution_count": null, 2548 | "outputs": [] 2549 | }, 2550 | { 2551 | "cell_type": "code", 2552 | "metadata": { 2553 | "id": "jwuSGGTR8r-F" 2554 | }, 2555 | "source": [ 2556 | "train_df_6000 = train_df_al.sample(6000, random_state=100)" 2557 | ], 2558 | "execution_count": null, 2559 | "outputs": [] 2560 | }, 2561 | { 2562 | "cell_type": "code", 2563 | "metadata": { 2564 | "id": "duT-MkmJ82wt" 2565 | }, 2566 | "source": [ 2567 | "train_df_6000.to_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/train_6000.csv.gz\", sep=\"|\", index=False, compression=\"gzip\")" 2568 | ], 2569 | "execution_count": null, 2570 | "outputs": [] 2571 | }, 2572 | { 2573 | "cell_type": "code", 2574 | "metadata": { 2575 | "id": "ICMFRUOnRkQO" 2576 | }, 2577 | "source": [ 2578 | "# train_df_al.to_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/train.csv.gz\", sep=\"|\", index=False, compression='gzip')\n", 2579 | "# annotate_df.to_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/annotate.csv.gz\", sep=\"|\", index=False, compression='gzip')\n", 2580 | "# valid_df.to_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/valid.csv.gz\", sep=\"|\", index=False, compression='gzip')\n", 2581 | "# test_df.to_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/test.csv.gz\", sep=\"|\", index=False, compression='gzip')" 2582 | ], 2583 | "execution_count": null, 2584 | "outputs": [] 2585 | }, 2586 | { 2587 | "cell_type": "code", 2588 | "metadata": { 2589 | "id": "1-NCKkt29Dfp" 2590 | }, 2591 | "source": [ 2592 | "## Training RoBERTa on 1000 training sample to obtain baseline accuracy on test data\n", 2593 | "\n", 2594 | "# Optional model configuration\n", 2595 | "\n", 2596 | "mid_model_args = ClassificationArgs(num_train_epochs=5, \n", 2597 | " overwrite_output_dir= True, \n", 2598 | " train_batch_size=16,\n", 2599 | " max_seq_length=256, \n", 2600 | " wandb_project= 'active_learning_6000_20210512', \n", 2601 | " best_model_dir=\"/content/drive/MyDrive/fsdl_project/model/active_learning/6000/best_model/20210512\",\n", 2602 | " cache_dir=\"/content/drive/MyDrive/fsdl_project/cache/active_learning/6000/20210512\",\n", 2603 | " eval_batch_size=16,\n", 2604 | " evaluate_during_training=True,\n", 2605 | " evaluate_during_training_verbose=True,\n", 2606 | " manual_seed=100,\n", 2607 | " output_dir=\"content/drive/MyDrive/fsdl_project/output/active_learning/6000/20210512\",\n", 2608 | " # no_cache = True,\n", 2609 | " use_early_stopping=True,\n", 2610 | " early_stopping_patience=3,\n", 2611 | " )\n", 2612 | "\n", 2613 | "\n", 2614 | "# Create a ClassificationModel\n", 2615 | "mid_model = ClassificationModel(\n", 2616 | "\"roberta\", \"roberta-base\", args=mid_model_args, use_cuda=True, num_labels=4,\n", 2617 | ")\n", 2618 | "\n", 2619 | "mid_model.train_model(train_df=train_df_6000, eval_df=valid_df, accuracy=accuracy_score)\n" 2620 | ], 2621 | "execution_count": null, 2622 | "outputs": [] 2623 | }, 2624 | { 2625 | "cell_type": "code", 2626 | "metadata": { 2627 | "id": "1dEnyQBm9shD" 2628 | }, 2629 | "source": [ 2630 | "test_result, test_model_outputs, test_wrong_predictions = mid_model.eval_model(test_df, accuracy = accuracy_score)\n", 2631 | "valid_result, valid_model_outputs, valid_wrong_predictions = mid_model.eval_model(valid_df, accuracy = accuracy_score)\n", 2632 | "train_result, train_model_outputs, train_wrong_predictions = mid_model.eval_model(train_df_6000, accuracy = accuracy_score)\n", 2633 | "log_to_wandb = {f'test_{key}': item for key, item in test_result.items()}\n", 2634 | "log_to_wandb.update({f'train_{key}': item for key, item in train_result.items()})\n", 2635 | "log_to_wandb.update({f'valid_{key}': item for key, item in valid_result.items()})\n", 2636 | "log_to_wandb\n" 2637 | ], 2638 | "execution_count": null, 2639 | "outputs": [] 2640 | }, 2641 | { 2642 | "cell_type": "code", 2643 | "metadata": { 2644 | "id": "5qBX9wr097ex" 2645 | }, 2646 | "source": [ 2647 | "os.mkdir(\"/content/drive/MyDrive/fsdl_project/result/active_learning/20210512\")" 2648 | ], 2649 | "execution_count": null, 2650 | "outputs": [] 2651 | }, 2652 | { 2653 | "cell_type": "code", 2654 | "metadata": { 2655 | "colab": { 2656 | "base_uri": "https://localhost:8080/" 2657 | }, 2658 | "id": "nTYmzEGbNZmn", 2659 | "outputId": "b0c78e5d-631e-4c9b-9eea-2341c7ac03c5" 2660 | }, 2661 | "source": [ 2662 | "log_to_wandb" 2663 | ], 2664 | "execution_count": null, 2665 | "outputs": [ 2666 | { 2667 | "output_type": "execute_result", 2668 | "data": { 2669 | "text/plain": [ 2670 | "{'test_accuracy': 0.9052631578947369,\n", 2671 | " 'test_eval_loss': 0.46176427139533,\n", 2672 | " 'test_mcc': 0.8736933468410781,\n", 2673 | " 'train_accuracy': 0.9883333333333333,\n", 2674 | " 'train_eval_loss': 0.046081207289981344,\n", 2675 | " 'train_mcc': 0.9844555672177369,\n", 2676 | " 'valid_accuracy': 0.907125,\n", 2677 | " 'valid_eval_loss': 0.4330419914466329,\n", 2678 | " 'valid_mcc': 0.8761829392366339}" 2679 | ] 2680 | }, 2681 | "metadata": { 2682 | "tags": [] 2683 | }, 2684 | "execution_count": 13 2685 | } 2686 | ] 2687 | }, 2688 | { 2689 | "cell_type": "code", 2690 | "metadata": { 2691 | "id": "wXiaUd7e98Lo" 2692 | }, 2693 | "source": [ 2694 | "import json\n", 2695 | "with open(\"/content/drive/MyDrive/fsdl_project/result/active_learning/20210512/initial_train_stats_6000_20210512.json\", 'w') as outfile:\n", 2696 | " json.dump(log_to_wandb, outfile, indent=4)" 2697 | ], 2698 | "execution_count": null, 2699 | "outputs": [] 2700 | }, 2701 | { 2702 | "cell_type": "code", 2703 | "metadata": { 2704 | "id": "o463wnL_Rg04" 2705 | }, 2706 | "source": [ 2707 | "# ## Training RoBERTa on 1000 training sample to obtain baseline accuracy on test data\n", 2708 | "\n", 2709 | "# # Optional model configuration\n", 2710 | "\n", 2711 | "# al_model_args = ClassificationArgs(num_train_epochs=5, \n", 2712 | "# overwrite_output_dir= True, \n", 2713 | "# train_batch_size=16,\n", 2714 | "# max_seq_length=256, \n", 2715 | "# wandb_project= 'active_learning_1000_20210510', \n", 2716 | "# best_model_dir=\"/content/drive/MyDrive/fsdl_project/model/active_learning/1000/best_model/20210510\",\n", 2717 | "# cache_dir=\"/content/drive/MyDrive/fsdl_project/cache/active_learning/1000/20210510\",\n", 2718 | "# eval_batch_size=16,\n", 2719 | "# evaluate_during_training=True,\n", 2720 | "# evaluate_during_training_verbose=True,\n", 2721 | "# manual_seed=100,\n", 2722 | "# output_dir=\"content/drive/MyDrive/fsdl_project/output/active_learning/1000/20210510\",\n", 2723 | "# # no_cache = True,\n", 2724 | "# use_early_stopping=True,\n", 2725 | "# early_stopping_patience=3,\n", 2726 | "# )\n", 2727 | "\n", 2728 | "\n", 2729 | "# # Create a ClassificationModel\n", 2730 | "# al_model = ClassificationModel(\n", 2731 | "# \"roberta\", \"roberta-base\", args=al_model_args, use_cuda=True, num_labels=4,\n", 2732 | "# )\n", 2733 | "\n", 2734 | "# al_model.train_model(train_df=train_df, eval_df=valid_df, accuracy=accuracy_score)\n" 2735 | ], 2736 | "execution_count": null, 2737 | "outputs": [] 2738 | }, 2739 | { 2740 | "cell_type": "code", 2741 | "metadata": { 2742 | "id": "BmWvDbchMXcb" 2743 | }, 2744 | "source": [ 2745 | "### Initial trained model on large training data\n", 2746 | "##-----------------------------------------------###\n", 2747 | "\n", 2748 | "\n", 2749 | "## Training RoBERTa on truncated training set to obtain baseline accuracy on test data\n", 2750 | "\n", 2751 | "# Optional model configuration\n", 2752 | "\n", 2753 | "# al_model_args = ClassificationArgs(num_train_epochs=5, \n", 2754 | "# overwrite_output_dir= True, \n", 2755 | "# train_batch_size=16,\n", 2756 | "# max_seq_length=256, \n", 2757 | "# wandb_project= 'active_learning_20210510', \n", 2758 | "# best_model_dir=\"/content/drive/MyDrive/fsdl_project/model/active_learning/best_model/20210510\",\n", 2759 | "# cache_dir=\"/content/drive/MyDrive/fsdl_project/cache/active_learning/20210510\",\n", 2760 | "# eval_batch_size=16,\n", 2761 | "# evaluate_during_training=True,\n", 2762 | "# evaluate_during_training_verbose=True,\n", 2763 | "# manual_seed=100,\n", 2764 | "# output_dir=\"content/drive/MyDrive/fsdl_project/output/active_learning/20210510\",\n", 2765 | "# # no_cache = True,\n", 2766 | "# use_early_stopping=True,\n", 2767 | "# early_stopping_patience=3,\n", 2768 | "# )\n", 2769 | "\n", 2770 | "\n", 2771 | "# # Create a ClassificationModel\n", 2772 | "# al_model = ClassificationModel(\n", 2773 | "# \"roberta\", \"roberta-base\", args=al_model_args, use_cuda=True, num_labels=4,\n", 2774 | "# )\n", 2775 | "\n", 2776 | "# al_model.train_model(train_df=train_df_al, eval_df=valid_df, accuracy=accuracy_score)\n", 2777 | "\n", 2778 | "##--------------------------------" 2779 | ], 2780 | "execution_count": null, 2781 | "outputs": [] 2782 | }, 2783 | { 2784 | "cell_type": "code", 2785 | "metadata": { 2786 | "id": "MpS9XBi-MOEU" 2787 | }, 2788 | "source": [ 2789 | "# ## loading model\n", 2790 | "al_model = ClassificationModel(\n", 2791 | "\"roberta\", al_model_args.best_model_dir\n", 2792 | ")" 2793 | ], 2794 | "execution_count": null, 2795 | "outputs": [] 2796 | }, 2797 | { 2798 | "cell_type": "code", 2799 | "metadata": { 2800 | "id": "G4I6L_ZvMuDE" 2801 | }, 2802 | "source": [ 2803 | "test_result, test_model_outputs, test_wrong_predictions = al_model.eval_model(test_df, accuracy = accuracy_score)\n", 2804 | "valid_result, valid_model_outputs, valid_wrong_predictions = al_model.eval_model(valid_df, accuracy = accuracy_score)\n", 2805 | "train_result, train_model_outputs, train_wrong_predictions = al_model.eval_model(train_df, accuracy = accuracy_score)" 2806 | ], 2807 | "execution_count": null, 2808 | "outputs": [] 2809 | }, 2810 | { 2811 | "cell_type": "code", 2812 | "metadata": { 2813 | "colab": { 2814 | "base_uri": "https://localhost:8080/" 2815 | }, 2816 | "id": "RylCMhLsUH0I", 2817 | "outputId": "0602eb33-11e7-472e-efb8-bcced32da185" 2818 | }, 2819 | "source": [ 2820 | "test_result" 2821 | ], 2822 | "execution_count": null, 2823 | "outputs": [ 2824 | { 2825 | "output_type": "execute_result", 2826 | "data": { 2827 | "text/plain": [ 2828 | "{'accuracy': 0.8688157894736842,\n", 2829 | " 'eval_loss': 0.42330945989803265,\n", 2830 | " 'mcc': 0.8268612843070937}" 2831 | ] 2832 | }, 2833 | "metadata": { 2834 | "tags": [] 2835 | }, 2836 | "execution_count": 18 2837 | } 2838 | ] 2839 | }, 2840 | { 2841 | "cell_type": "code", 2842 | "metadata": { 2843 | "id": "9424lETxQrtj" 2844 | }, 2845 | "source": [ 2846 | "log_to_wandb = {f'test_{key}': item for key, item in test_result.items()}\n", 2847 | "log_to_wandb.update({f'train_{key}': item for key, item in train_result.items()})\n", 2848 | "log_to_wandb.update({f'valid_{key}': item for key, item in valid_result.items()})" 2849 | ], 2850 | "execution_count": null, 2851 | "outputs": [] 2852 | }, 2853 | { 2854 | "cell_type": "code", 2855 | "metadata": { 2856 | "colab": { 2857 | "base_uri": "https://localhost:8080/" 2858 | }, 2859 | "id": "Y7r_xzsucYcp", 2860 | "outputId": "5a2433e8-c8e9-4973-8949-077370e3ce86" 2861 | }, 2862 | "source": [ 2863 | "log_to_wandb" 2864 | ], 2865 | "execution_count": null, 2866 | "outputs": [ 2867 | { 2868 | "output_type": "execute_result", 2869 | "data": { 2870 | "text/plain": [ 2871 | "{'test_accuracy': 0.9230263157894737,\n", 2872 | " 'test_eval_loss': 0.3959990933331612,\n", 2873 | " 'test_mcc': 0.8974095014721511,\n", 2874 | " 'train_accuracy': 0.9876736111111111,\n", 2875 | " 'train_eval_loss': 0.053259740840294396,\n", 2876 | " 'train_mcc': 0.9835782903636949,\n", 2877 | " 'valid_accuracy': 0.9295,\n", 2878 | " 'valid_eval_loss': 0.3732603291405637,\n", 2879 | " 'valid_mcc': 0.9060107576776867}" 2880 | ] 2881 | }, 2882 | "metadata": { 2883 | "tags": [] 2884 | }, 2885 | "execution_count": 10 2886 | } 2887 | ] 2888 | }, 2889 | { 2890 | "cell_type": "code", 2891 | "metadata": { 2892 | "id": "jKASciqOQ6J1" 2893 | }, 2894 | "source": [ 2895 | "# os.mkdir(\"/content/drive/MyDrive/fsdl_project/result/\")\n", 2896 | "# os.mkdir(\"/content/drive/MyDrive/fsdl_project/result/active_learning/\")\n", 2897 | "os.mkdir(\"/content/drive/MyDrive/fsdl_project/result/active_learning/20210421\")\n", 2898 | "# os.mkdir(\"/content/drive/MyDrive/fsdl_project/result/baseline\")" 2899 | ], 2900 | "execution_count": null, 2901 | "outputs": [] 2902 | }, 2903 | { 2904 | "cell_type": "code", 2905 | "metadata": { 2906 | "id": "SuAdKtnMRS7v" 2907 | }, 2908 | "source": [ 2909 | "" 2910 | ], 2911 | "execution_count": null, 2912 | "outputs": [] 2913 | }, 2914 | { 2915 | "cell_type": "code", 2916 | "metadata": { 2917 | "id": "BQnGqQTwQ5SH" 2918 | }, 2919 | "source": [ 2920 | "import json\n", 2921 | "with open(\"/content/drive/MyDrive/fsdl_project/result/active_learning/20210410/initial_train_stats_1000_20210510.json\", 'w') as outfile:\n", 2922 | " json.dump(log_to_wandb, outfile, indent=4)" 2923 | ], 2924 | "execution_count": null, 2925 | "outputs": [] 2926 | }, 2927 | { 2928 | "cell_type": "code", 2929 | "metadata": { 2930 | "id": "egPN-zZ7UBqJ" 2931 | }, 2932 | "source": [ 2933 | "## loading annotate_df\n", 2934 | "\n", 2935 | "annotate_df = pd.read_csv(\"/content/drive/MyDrive/fsdl_project/data/active_learning/annotate.csv.gz\", sep=\"|\", index_col=False)" 2936 | ], 2937 | "execution_count": null, 2938 | "outputs": [] 2939 | }, 2940 | { 2941 | "cell_type": "code", 2942 | "metadata": { 2943 | "colab": { 2944 | "base_uri": "https://localhost:8080/", 2945 | "height": 111 2946 | }, 2947 | "id": "865ZLVMLUEpD", 2948 | "outputId": "06a59624-3565-416c-f723-3e2ee0727f2e" 2949 | }, 2950 | "source": [ 2951 | "display(annotate_df.head(2))" 2952 | ], 2953 | "execution_count": null, 2954 | "outputs": [ 2955 | { 2956 | "output_type": "display_data", 2957 | "data": { 2958 | "text/html": [ 2959 | "
\n", 2960 | "\n", 2973 | "\n", 2974 | " \n", 2975 | " \n", 2976 | " \n", 2977 | " \n", 2978 | " \n", 2979 | " \n", 2980 | " \n", 2981 | " \n", 2982 | " \n", 2983 | " \n", 2984 | " \n", 2985 | " \n", 2986 | " \n", 2987 | " \n", 2988 | " \n", 2989 | " \n", 2990 | " \n", 2991 | " \n", 2992 | " \n", 2993 | " \n", 2994 | " \n", 2995 | " \n", 2996 | " \n", 2997 | " \n", 2998 | " \n", 2999 | " \n", 3000 | " \n", 3001 | " \n", 3002 | " \n", 3003 | " \n", 3004 | " \n", 3005 | "
Class IndexTitleDescriptiontextlabelsidx
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...reuters - short-sellers, wall street's dwindli...20
13Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...reuters - soaring crude prices plus worries\\ab...21
\n", 3006 | "
" 3007 | ], 3008 | "text/plain": [ 3009 | " Class Index Title ... labels idx\n", 3010 | "0 3 Wall St. Bears Claw Back Into the Black (Reuters) ... 2 0\n", 3011 | "1 3 Oil and Economy Cloud Stocks' Outlook (Reuters) ... 2 1\n", 3012 | "\n", 3013 | "[2 rows x 6 columns]" 3014 | ] 3015 | }, 3016 | "metadata": { 3017 | "tags": [] 3018 | } 3019 | } 3020 | ] 3021 | }, 3022 | { 3023 | "cell_type": "code", 3024 | "metadata": { 3025 | "id": "0ouZ7nznOy76", 3026 | "colab": { 3027 | "base_uri": "https://localhost:8080/", 3028 | "height": 1000, 3029 | "referenced_widgets": [ 3030 | "e9b2650e436a4a3d9e0aa4cbb7c3c679", 3031 | "bc17c094006a48d29d6e2cc708a35d0b", 3032 | "cde8d171084446e58e8a064e01a7a501", 3033 | "b7d5e68eda2c492a9e80525b69d6dd21", 3034 | "0de1f93062784a38be5e1016bbf29c05", 3035 | "25b6388694ff46eca2c9dbcf49b14ac7", 3036 | "1cf29525e90c45f3bda46b4d4ea0e9d3", 3037 | "2910f4ee601241408647ff845c2713d4", 3038 | "0a817039c2cf4f23be507e0b0b7df501", 3039 | "6af6b17b562141f59a0b4fae42b8c1b4", 3040 | "1b4aeeac358841a4884e02ac756f8a3c", 3041 | "87b7ef958f914aab8e6213d3f2176300", 3042 | "b206209e6ef7409aae13bb41baf7c7bc", 3043 | "4c42bed0b1524e699d58e8785e4c1c3c", 3044 | "8881de3a8e1f44bab6fe8c01f11cc7a7", 3045 | "fb536d3431d9479eb912eb0c113662a4", 3046 | "704ea72ddf774d5dba11d98ca9abf584", 3047 | "e6d4f400022344f99053b6521e505933", 3048 | "b6488a6d0ec646f1bd16bfdaf2ba9828", 3049 | "0e8f39bc412d425e98327d74e5403854", 3050 | "d120778b937c4c97806532b2db11aada", 3051 | "c6e4edff105640baa81a42c97c78b578" 3052 | ] 3053 | }, 3054 | "outputId": "11455bf6-51a0-44a4-8257-481a5a290041" 3055 | }, 3056 | "source": [ 3057 | "### Making prediction on annotation dataset and saving it to output/active_learning/20210410 for annotation\n", 3058 | "annotate_text = annotate_df['text'].values.tolist()\n", 3059 | "annotate_predictions, annotate_raw_output = al_model.predict(annotate_text)" 3060 | ], 3061 | "execution_count": null, 3062 | "outputs": [ 3063 | { 3064 | "output_type": "stream", 3065 | "text": [ 3066 | "INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n" 3067 | ], 3068 | "name": "stderr" 3069 | }, 3070 | { 3071 | "output_type": "display_data", 3072 | "data": { 3073 | "application/vnd.jupyter.widget-view+json": { 3074 | "model_id": "e9b2650e436a4a3d9e0aa4cbb7c3c679", 3075 | "version_minor": 0, 3076 | "version_major": 2 3077 | }, 3078 | "text/plain": [ 3079 | " 0%| | 0/67200 [00:00\n", 3090 | "Traceback (most recent call last):\n", 3091 | " File \"/usr/lib/python3.7/weakref.py\", line 572, in __call__\n", 3092 | " return info.func(*info.args, **(info.kwargs or {}))\n", 3093 | " File \"/usr/lib/python3.7/tempfile.py\", line 936, in _cleanup\n", 3094 | " _rmtree(name)\n", 3095 | " File \"/usr/lib/python3.7/shutil.py\", line 485, in rmtree\n", 3096 | " onerror(os.lstat, path, sys.exc_info())\n", 3097 | " File \"/usr/lib/python3.7/shutil.py\", line 483, in rmtree\n", 3098 | " orig_st = os.lstat(path)\n", 3099 | "FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmp5y2g38ta'\n", 3100 | "Exception ignored in: \n", 3101 | "Traceback (most recent call last):\n", 3102 | " File \"/usr/lib/python3.7/weakref.py\", line 572, in __call__\n", 3103 | " return info.func(*info.args, **(info.kwargs or {}))\n", 3104 | " File \"/usr/lib/python3.7/tempfile.py\", line 936, in _cleanup\n", 3105 | " _rmtree(name)\n", 3106 | " File \"/usr/lib/python3.7/shutil.py\", line 485, in rmtree\n", 3107 | " onerror(os.lstat, path, sys.exc_info())\n", 3108 | " File \"/usr/lib/python3.7/shutil.py\", line 483, in rmtree\n", 3109 | " orig_st = os.lstat(path)\n", 3110 | "FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmppdp9cwu1'\n", 3111 | "Exception ignored in: \n", 3112 | "Traceback (most recent call last):\n", 3113 | " File \"/usr/lib/python3.7/weakref.py\", line 572, in __call__\n", 3114 | " return info.func(*info.args, **(info.kwargs or {}))\n", 3115 | " File \"/usr/lib/python3.7/tempfile.py\", line 936, in _cleanup\n", 3116 | " _rmtree(name)\n", 3117 | " File \"/usr/lib/python3.7/shutil.py\", line 485, in rmtree\n", 3118 | " onerror(os.lstat, path, sys.exc_info())\n", 3119 | " File \"/usr/lib/python3.7/shutil.py\", line 483, in rmtree\n", 3120 | " orig_st = os.lstat(path)\n", 3121 | "FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmp15pve8hz'\n", 3122 | "Exception ignored in: \n", 3123 | "Traceback (most recent call last):\n", 3124 | " File \"/usr/lib/python3.7/weakref.py\", line 572, in __call__\n", 3125 | " return info.func(*info.args, **(info.kwargs or {}))\n", 3126 | " File \"/usr/lib/python3.7/tempfile.py\", line 936, in _cleanup\n", 3127 | " _rmtree(name)\n", 3128 | " File \"/usr/lib/python3.7/shutil.py\", line 485, in rmtree\n", 3129 | " onerror(os.lstat, path, sys.exc_info())\n", 3130 | " File \"/usr/lib/python3.7/shutil.py\", line 483, in rmtree\n", 3131 | " orig_st = os.lstat(path)\n", 3132 | "FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmp6o4cludb'\n" 3133 | ], 3134 | "name": "stderr" 3135 | }, 3136 | { 3137 | "output_type": "display_data", 3138 | "data": { 3139 | "application/vnd.jupyter.widget-view+json": { 3140 | "model_id": "87b7ef958f914aab8e6213d3f2176300", 3141 | "version_minor": 0, 3142 | "version_major": 2 3143 | }, 3144 | "text/plain": [ 3145 | " 0%| | 0/4200 [00:00\n", 3156 | "Traceback (most recent call last):\n", 3157 | " File \"/usr/lib/python3.7/weakref.py\", line 572, in __call__\n", 3158 | " return info.func(*info.args, **(info.kwargs or {}))\n", 3159 | " File \"/usr/lib/python3.7/tempfile.py\", line 936, in _cleanup\n", 3160 | " _rmtree(name)\n", 3161 | " File \"/usr/lib/python3.7/shutil.py\", line 485, in rmtree\n", 3162 | " onerror(os.lstat, path, sys.exc_info())\n", 3163 | " File \"/usr/lib/python3.7/shutil.py\", line 483, in rmtree\n", 3164 | " orig_st = os.lstat(path)\n", 3165 | "FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmp5y2g38ta'\n" 3166 | ], 3167 | "name": "stderr" 3168 | } 3169 | ] 3170 | }, 3171 | { 3172 | "cell_type": "code", 3173 | "metadata": { 3174 | "id": "2-IPs2d5UMu6" 3175 | }, 3176 | "source": [ 3177 | "import torch.nn as nn\n", 3178 | "import torch" 3179 | ], 3180 | "execution_count": null, 3181 | "outputs": [] 3182 | }, 3183 | { 3184 | "cell_type": "code", 3185 | "metadata": { 3186 | "id": "ilmmqBTjUQvu" 3187 | }, 3188 | "source": [ 3189 | "sfm = nn.Softmax(dim=1)" 3190 | ], 3191 | "execution_count": null, 3192 | "outputs": [] 3193 | }, 3194 | { 3195 | "cell_type": "code", 3196 | "metadata": { 3197 | "id": "uyL3r8GcSFUM" 3198 | }, 3199 | "source": [ 3200 | "annotate_raw_output_tensor = torch.from_numpy(annotate_raw_output)\n", 3201 | "annotate_class_prob = sfm(annotate_raw_output_tensor)\n", 3202 | "max_prob = torch.max(annotate_class_prob, dim=1)\n", 3203 | "annotate_class_prob = annotate_class_prob.numpy()\n", 3204 | "max_prob = max_prob.values.numpy()" 3205 | ], 3206 | "execution_count": null, 3207 | "outputs": [] 3208 | }, 3209 | { 3210 | "cell_type": "code", 3211 | "metadata": { 3212 | "colab": { 3213 | "base_uri": "https://localhost:8080/", 3214 | "height": 111 3215 | }, 3216 | "id": "0hmWfhcGdUkG", 3217 | "outputId": "92a20213-be1b-495d-cd66-33d937161a12" 3218 | }, 3219 | "source": [ 3220 | "annotate_df.head(2)" 3221 | ], 3222 | "execution_count": null, 3223 | "outputs": [ 3224 | { 3225 | "output_type": "execute_result", 3226 | "data": { 3227 | "text/html": [ 3228 | "
\n", 3229 | "\n", 3242 | "\n", 3243 | " \n", 3244 | " \n", 3245 | " \n", 3246 | " \n", 3247 | " \n", 3248 | " \n", 3249 | " \n", 3250 | " \n", 3251 | " \n", 3252 | " \n", 3253 | " \n", 3254 | " \n", 3255 | " \n", 3256 | " \n", 3257 | " \n", 3258 | " \n", 3259 | " \n", 3260 | " \n", 3261 | " \n", 3262 | " \n", 3263 | " \n", 3264 | " \n", 3265 | " \n", 3266 | " \n", 3267 | " \n", 3268 | " \n", 3269 | " \n", 3270 | " \n", 3271 | " \n", 3272 | " \n", 3273 | " \n", 3274 | "
Class IndexTitleDescriptiontextlabelsidx
03Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...reuters - short-sellers, wall street's dwindli...20
13Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...reuters - soaring crude prices plus worries\\ab...21
\n", 3275 | "
" 3276 | ], 3277 | "text/plain": [ 3278 | " Class Index Title ... labels idx\n", 3279 | "0 3 Wall St. Bears Claw Back Into the Black (Reuters) ... 2 0\n", 3280 | "1 3 Oil and Economy Cloud Stocks' Outlook (Reuters) ... 2 1\n", 3281 | "\n", 3282 | "[2 rows x 6 columns]" 3283 | ] 3284 | }, 3285 | "metadata": { 3286 | "tags": [] 3287 | }, 3288 | "execution_count": 19 3289 | } 3290 | ] 3291 | }, 3292 | { 3293 | "cell_type": "code", 3294 | "metadata": { 3295 | "id": "-j_ijy_1XA_H" 3296 | }, 3297 | "source": [ 3298 | "annotate_df_with_pred = np.hstack((\n", 3299 | " annotate_df.idx.values.reshape(-1,1),\n", 3300 | " annotate_df.Title.values.reshape(-1,1),\n", 3301 | " annotate_df.Description.values.reshape(-1,1),\n", 3302 | " annotate_df.text.values.reshape(-1,1), \n", 3303 | " annotate_raw_output,\n", 3304 | " annotate_class_prob,\n", 3305 | " max_prob.reshape(-1,1),\n", 3306 | " np.array(annotate_predictions).reshape(-1,1)\n", 3307 | " ))" 3308 | ], 3309 | "execution_count": null, 3310 | "outputs": [] 3311 | }, 3312 | { 3313 | "cell_type": "code", 3314 | "metadata": { 3315 | "id": "1x2DfEtyXSqG" 3316 | }, 3317 | "source": [ 3318 | "col_names = ['idx', \n", 3319 | " 'text', \n", 3320 | " 'title',\n", 3321 | " 'description',\n", 3322 | " 'logit_0', 'logit_1', 'logit_2', 'logit_3', \n", 3323 | " 'prob_0', 'prob_1', 'prob_2', 'prob_3',\n", 3324 | " 'max_prob',\n", 3325 | " 'label_pred'\n", 3326 | " ]" 3327 | ], 3328 | "execution_count": null, 3329 | "outputs": [] 3330 | }, 3331 | { 3332 | "cell_type": "code", 3333 | "metadata": { 3334 | "id": "CBXUZ2hdaUhr" 3335 | }, 3336 | "source": [ 3337 | "\n", 3338 | "annotate_df_with_pred = pd.DataFrame(annotate_df_with_pred, columns=col_names)" 3339 | ], 3340 | "execution_count": null, 3341 | "outputs": [] 3342 | }, 3343 | { 3344 | "cell_type": "code", 3345 | "metadata": { 3346 | "id": "RMY3DojLQ1l_" 3347 | }, 3348 | "source": [ 3349 | "annotate_df_with_pred['annotated_labels'] = ''\n", 3350 | "annotate_df_with_pred['sampling_method'] = ''" 3351 | ], 3352 | "execution_count": null, 3353 | "outputs": [] 3354 | }, 3355 | { 3356 | "cell_type": "code", 3357 | "metadata": { 3358 | "id": "wIg_fRNua0Ni" 3359 | }, 3360 | "source": [ 3361 | "" 3362 | ], 3363 | "execution_count": null, 3364 | "outputs": [] 3365 | }, 3366 | { 3367 | "cell_type": "code", 3368 | "metadata": { 3369 | "colab": { 3370 | "base_uri": "https://localhost:8080/", 3371 | "height": 289 3372 | }, 3373 | "id": "7zNfz4TAdnKr", 3374 | "outputId": "a9d993ba-cefd-40bb-8df2-7a2f8eb32e46" 3375 | }, 3376 | "source": [ 3377 | "annotate_df_with_pred.head()" 3378 | ], 3379 | "execution_count": null, 3380 | "outputs": [ 3381 | { 3382 | "output_type": "execute_result", 3383 | "data": { 3384 | "text/html": [ 3385 | "
\n", 3386 | "\n", 3399 | "\n", 3400 | " \n", 3401 | " \n", 3402 | " \n", 3403 | " \n", 3404 | " \n", 3405 | " \n", 3406 | " \n", 3407 | " \n", 3408 | " \n", 3409 | " \n", 3410 | " \n", 3411 | " \n", 3412 | " \n", 3413 | " \n", 3414 | " \n", 3415 | " \n", 3416 | " \n", 3417 | " \n", 3418 | " \n", 3419 | " \n", 3420 | " \n", 3421 | " \n", 3422 | " \n", 3423 | " \n", 3424 | " \n", 3425 | " \n", 3426 | " \n", 3427 | " \n", 3428 | " \n", 3429 | " \n", 3430 | " \n", 3431 | " \n", 3432 | " \n", 3433 | " \n", 3434 | " \n", 3435 | " \n", 3436 | " \n", 3437 | " \n", 3438 | " \n", 3439 | " \n", 3440 | " \n", 3441 | " \n", 3442 | " \n", 3443 | " \n", 3444 | " \n", 3445 | " \n", 3446 | " \n", 3447 | " \n", 3448 | " \n", 3449 | " \n", 3450 | " \n", 3451 | " \n", 3452 | " \n", 3453 | " \n", 3454 | " \n", 3455 | " \n", 3456 | " \n", 3457 | " \n", 3458 | " \n", 3459 | " \n", 3460 | " \n", 3461 | " \n", 3462 | " \n", 3463 | " \n", 3464 | " \n", 3465 | " \n", 3466 | " \n", 3467 | " \n", 3468 | " \n", 3469 | " \n", 3470 | " \n", 3471 | " \n", 3472 | " \n", 3473 | " \n", 3474 | " \n", 3475 | " \n", 3476 | " \n", 3477 | " \n", 3478 | " \n", 3479 | " \n", 3480 | " \n", 3481 | " \n", 3482 | " \n", 3483 | " \n", 3484 | " \n", 3485 | " \n", 3486 | " \n", 3487 | " \n", 3488 | " \n", 3489 | " \n", 3490 | " \n", 3491 | " \n", 3492 | " \n", 3493 | " \n", 3494 | " \n", 3495 | " \n", 3496 | " \n", 3497 | " \n", 3498 | " \n", 3499 | " \n", 3500 | " \n", 3501 | " \n", 3502 | " \n", 3503 | " \n", 3504 | " \n", 3505 | " \n", 3506 | " \n", 3507 | " \n", 3508 | " \n", 3509 | " \n", 3510 | " \n", 3511 | " \n", 3512 | " \n", 3513 | " \n", 3514 | " \n", 3515 | " \n", 3516 | " \n", 3517 | " \n", 3518 | "
idxtexttitledescriptionlogit_0logit_1logit_2logit_3prob_0prob_1prob_2prob_3max_problabel_predannotated_labelssampling_method
00Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...reuters - short-sellers, wall street's dwindli...-1.64648-1.900392.115231.566410.01435770.01113820.6177030.3568010.6177032
11Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...reuters - soaring crude prices plus worries\\ab...0.0162964-2.003913.34375-1.15430.03411790.004525010.9507740.01058280.9507742
22Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\\f...reuters - authorities have halted oil export\\f...3.80469-1.46680.199097-1.953120.9657490.004960070.02624070.003049850.9657490
33Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...afp - tearaway world oil prices, toppling reco...1.12891-2.02932.77734-1.557620.1584920.006736230.8239760.0107960.8239762
44Stocks End Up, But Near Year Lows (Reuters)Reuters - Stocks ended slightly higher on Frid...reuters - stocks ended slightly higher on frid...0.172119-1.751953.11914-1.287110.04895350.007147750.9325210.01137750.9325212
\n", 3519 | "
" 3520 | ], 3521 | "text/plain": [ 3522 | " idx ... sampling_method\n", 3523 | "0 0 ... \n", 3524 | "1 1 ... \n", 3525 | "2 2 ... \n", 3526 | "3 3 ... \n", 3527 | "4 4 ... \n", 3528 | "\n", 3529 | "[5 rows x 16 columns]" 3530 | ] 3531 | }, 3532 | "metadata": { 3533 | "tags": [] 3534 | }, 3535 | "execution_count": 31 3536 | } 3537 | ] 3538 | }, 3539 | { 3540 | "cell_type": "code", 3541 | "metadata": { 3542 | "id": "-2-oJE7qcrbU" 3543 | }, 3544 | "source": [ 3545 | "# os.mkdir(\"/content/drive/MyDrive/fsdl_project/output/active_learning\")\n", 3546 | "os.mkdir(\"/content/drive/MyDrive/fsdl_project/output/active_learning/20210510\")" 3547 | ], 3548 | "execution_count": null, 3549 | "outputs": [] 3550 | }, 3551 | { 3552 | "cell_type": "code", 3553 | "metadata": { 3554 | "id": "p6dsnjmBc7Aw" 3555 | }, 3556 | "source": [ 3557 | "annotate_df_with_pred.to_csv(\"/content/drive/MyDrive/fsdl_project/output/active_learning/20210510/annotate.csv.gz\", index=False,\n", 3558 | " compression=\"gzip\")" 3559 | ], 3560 | "execution_count": null, 3561 | "outputs": [] 3562 | }, 3563 | { 3564 | "cell_type": "code", 3565 | "metadata": { 3566 | "id": "JR5dL6pqdIXE" 3567 | }, 3568 | "source": [ 3569 | "annotate_df_with_pred.head()" 3570 | ], 3571 | "execution_count": null, 3572 | "outputs": [] 3573 | }, 3574 | { 3575 | "cell_type": "code", 3576 | "metadata": { 3577 | "colab": { 3578 | "base_uri": "https://localhost:8080/" 3579 | }, 3580 | "id": "rKZfzquRdLiT", 3581 | "outputId": "4af429d1-1177-4a31-dd9f-d3b369c0f82a" 3582 | }, 3583 | "source": [ 3584 | "log_to_wandb" 3585 | ], 3586 | "execution_count": null, 3587 | "outputs": [ 3588 | { 3589 | "output_type": "execute_result", 3590 | "data": { 3591 | "text/plain": [ 3592 | "{'test_accuracy': 0.9292105263157895,\n", 3593 | " 'test_eval_loss': 0.37604381556634936,\n", 3594 | " 'test_mcc': 0.9056558484190013,\n", 3595 | " 'train_accuracy': 0.9859375,\n", 3596 | " 'train_eval_loss': 0.06025019081414535,\n", 3597 | " 'train_mcc': 0.9812553280559401,\n", 3598 | " 'valid_accuracy': 0.9294583333333334,\n", 3599 | " 'valid_eval_loss': 0.3652915442798403,\n", 3600 | " 'valid_mcc': 0.9059680833454649}" 3601 | ] 3602 | }, 3603 | "metadata": { 3604 | "tags": [] 3605 | }, 3606 | "execution_count": 28 3607 | } 3608 | ] 3609 | }, 3610 | { 3611 | "cell_type": "code", 3612 | "metadata": { 3613 | "id": "cjOV8vD-acs0" 3614 | }, 3615 | "source": [ 3616 | "" 3617 | ], 3618 | "execution_count": null, 3619 | "outputs": [] 3620 | } 3621 | ] 3622 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.mypy] 2 | ignore_missing_imports = true 3 | 4 | # Black formatting 5 | [tool.black] 6 | line-length = 100 7 | include = '\.pyi?$' 8 | exclude = ''' 9 | /( 10 | \.eggs # exclude a few common directories in the 11 | | \.git # root of the project 12 | | \.hg 13 | | \.mypy_cache 14 | | \.tox 15 | | \.venv 16 | | _build 17 | | buck-out 18 | | build 19 | | dist 20 | | wandb 21 | )/ 22 | ''' 23 | 24 | # iSort 25 | [tool.isort] 26 | profile = "black" 27 | line_length = 79 28 | multi_line_output = 3 29 | include_trailing_comma = true 30 | skip_gitignore = true 31 | virtual_env = "venv" 32 | 33 | # Pytest 34 | [tool.pytest.ini_options] 35 | testpaths = ["tests"] 36 | python_files = "test_*.py" 37 | addopts = "--strict-markers --disable-pytest-warnings" 38 | #markers 39 | 40 | # Pytest coverage 41 | [tool.coverage.run] 42 | omit = ["app/"] 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | pandas==1.2.4 3 | simpletransformers 4 | torch==2.2.0 5 | wandb 6 | pytest-cov 7 | black==21.5b1 8 | isort==5.8.0 9 | flake8==3.9.2 10 | pre-commit==2.12.1 11 | dash==1.20.0 12 | docker==5.0.0 13 | dash-bootstrap-components==0.12.2 14 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AbinayaM02/Active_Learning_in_NLP/9c6bb281c5508a6117fa3c829c54562b3d58253a/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/annotator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | import warnings 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from scripts.config import DATA_DIR 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | def get_sampling_method_map(): 16 | return { 17 | "random": random_sampling, 18 | "least": least_confidence_sampling, 19 | "margin": margin_sampling, 20 | "entropy": entropy_sampling, 21 | } 22 | 23 | 24 | def clear(): 25 | if os.name == "nt": 26 | _ = os.system("cls") 27 | else: 28 | _ = os.system("clear") 29 | 30 | 31 | def random_sampling(raw_data: pd.DataFrame, size: int = 1000, random_seed=100): 32 | raw_data = raw_data[raw_data["annotated_labels"].isna()].copy() 33 | all_idx = raw_data["idx"].values.tolist() 34 | np.random.seed(random_seed) 35 | idx = np.random.choice(all_idx, size=size, replace=False) 36 | return idx 37 | 38 | 39 | def least_confidence_sampling(raw_data: pd.DataFrame, size: int = 1000): 40 | raw_data = raw_data.copy() 41 | raw_data = raw_data[raw_data["annotated_labels"].isna()].copy() 42 | raw_data.sort_values("max_prob", inplace=True) 43 | return raw_data.head(size).idx.values 44 | 45 | 46 | def margin_sampling(raw_data: pd.DataFrame, size: int = 1000): 47 | raw_data = raw_data[raw_data["annotated_labels"].isna()].copy() 48 | 49 | prob_cols = ["prob_0", "prob_1", "prob_2", "prob_3"] 50 | prob_sorted = np.sort(raw_data[prob_cols], axis=1) 51 | margin = prob_sorted[:, -1] - prob_sorted[:, -2] 52 | raw_data["margin"] = margin 53 | raw_data.sort_values("margin", ascending=True, inplace=True) 54 | return raw_data.head(size).idx.values 55 | 56 | 57 | def entropy_sampling(raw_data: pd.DataFrame, size: int = 1000): 58 | raw_data = raw_data[raw_data["annotated_labels"].isna()].copy() 59 | 60 | prob_cols = ["prob_0", "prob_1", "prob_2", "prob_3"] 61 | raw_data["entropy"] = -1 * np.sum(raw_data[prob_cols] * np.log(raw_data[prob_cols]), axis=1) 62 | raw_data.sort_values("entropy", ascending=False, inplace=True) 63 | return raw_data.head(size).idx.values 64 | 65 | 66 | def annotation_message(): 67 | annotation_instruction = ( 68 | "Please provide input as per the instruction given in box to annotate \n" 69 | ) 70 | annotation_instruction += "---------------------------------------\n" 71 | annotation_instruction += "| 1: World News |\n" 72 | annotation_instruction += "| 2: Sports |\n" 73 | annotation_instruction += "| 3: Business |\n" 74 | annotation_instruction += "| 4: Sci/Tech |\n" 75 | annotation_instruction += "| 0: Not Sure |\n" 76 | annotation_instruction += "---------------------------------------\n" 77 | annotation_instruction += "| save: to save the results |\n" 78 | annotation_instruction += "| f: full instruction |\n" 79 | annotation_instruction += "| u: go back to last text (undo) |\n" 80 | annotation_instruction += "---------------------------------------\n" 81 | 82 | full_instruction = "Please provide input as per the instruction given in box to see examples \n" 83 | full_instruction += "---------------------------------------------\n" 84 | full_instruction += "| u: go back to last text (undo) |\n" 85 | full_instruction += "| w: World News examples |\n" 86 | full_instruction += "| s: Sports News examples |\n" 87 | full_instruction += "| b: Business News examples |\n" 88 | full_instruction += "| t: Science/ Technology News examples |\n" 89 | full_instruction += "| f: full instruction |\n" 90 | full_instruction += "---------------------------------------------\n" 91 | full_instruction += annotation_instruction 92 | 93 | return full_instruction, annotation_instruction 94 | 95 | 96 | def get_examples(exp_data, label=0): 97 | label_data = exp_data[exp_data["labels"] == label] 98 | return label_data.head(1) 99 | 100 | 101 | def get_data(annotation_data_path, sampling_method, sample_size): 102 | 103 | df_for_annotation = pd.read_csv(annotation_data_path, index_col=False) 104 | df_for_annotation["sampling_method"].fillna("", inplace=True) 105 | sampling_method_in_data = df_for_annotation.sampling_method.unique() 106 | if len(sampling_method_in_data) > 1: 107 | if sampling_method not in sampling_method_in_data: 108 | raise ValueError( 109 | f"""Warning: data of other sampling {sampling_method_in_data} method is being used, when actual sampling method is {sampling_method}. This might corrupt data""" 110 | ) 111 | sampling_method_map = get_sampling_method_map() 112 | sample_idx = sampling_method_map[sampling_method](df_for_annotation, size=sample_size) 113 | data = df_for_annotation[df_for_annotation["idx"].isin(sample_idx)] 114 | remaining_data = df_for_annotation[~(df_for_annotation["idx"].isin(sample_idx))] 115 | return (data, remaining_data) 116 | 117 | 118 | def get_countdown(): 119 | 120 | for i in range(5, 0, -1): 121 | sys.stdout.write(str(i) + " ") 122 | sys.stdout.flush() 123 | time.sleep(1) 124 | 125 | 126 | def get_annotation(data, exp_data, sample_size): 127 | label_map = {"w": 0, "s": 1, "b": 2, "t": 3} 128 | label_desc_map = { 129 | "w": "World News examples", 130 | "s": "Sports News examples", 131 | "b": "Business News example", 132 | "t": "Science/Tech example", 133 | } 134 | data.reset_index(drop=True, inplace=True) 135 | data["annotated_labels"] = np.nan 136 | num_ex = data.shape[0] 137 | clear() 138 | print(f"Number of examples to annotate are {num_ex}") 139 | full_instruction, annotation_instruction = annotation_message() 140 | print(full_instruction) 141 | input("Press enter to start annotation >") 142 | clear() 143 | ind = 0 144 | while ind < num_ex: 145 | if ind < 0: 146 | ind = 0 147 | if ind > 0: 148 | clear() 149 | # textId = data.loc[ind, "idx"] 150 | title = data.loc[ind, "title"] 151 | description = data.loc[ind, "description"] 152 | 153 | print(annotation_instruction) 154 | print("*" * 100) 155 | print(f"{ind + 1} / {sample_size}") 156 | print(f"Title: {title}") 157 | print(f"Description: {description}") 158 | label = str(input("> ")) 159 | if label in ["0", "1", "2", "3", "4"]: 160 | data.loc[ind, "annotated_labels"] = int(label) - 1 161 | ind += 1 162 | elif label == "u": 163 | ind -= 1 if ind > 0 else 0 164 | elif label in ["w", "s", "b", "t"]: 165 | clear() 166 | exp = get_examples(exp_data, label_map.get(label, 0)) 167 | print("*" * 100) 168 | print(label_desc_map.get(label, "w")) 169 | print(f"\n Title: {exp['Title'].values[0]}") 170 | print(f"\n Description: {exp['Description'].values[0]}") 171 | print("*" * 100) 172 | print("Continue annotation [y/n]") 173 | instruction = str(input(">")) 174 | if instruction == "n": 175 | print("saving all data and exiting in") 176 | get_countdown() 177 | clear() 178 | # time.sleep(10) 179 | return data 180 | elif instruction == "y": 181 | print("continuing in") 182 | get_countdown() 183 | clear() 184 | # time.sleep(10) 185 | continue 186 | else: 187 | print("invalid response, continuing annotation in") 188 | get_countdown() 189 | clear() 190 | # time.sleep(10) 191 | elif label == "f": 192 | print(full_instruction) 193 | get_countdown() 194 | clear() 195 | elif label == "save": 196 | print("saving all data and exiting") 197 | return data 198 | print("-" * 100) 199 | print("Done with presentation set of annotation") 200 | return data 201 | 202 | 203 | def main(): 204 | parser = argparse.ArgumentParser(description="Annotate data using uncertainty sampling methods") 205 | parser.add_argument("annotation_data", help="data to be annotated") 206 | parser.add_argument( 207 | "sampling_method", 208 | default="random", 209 | help="method to use to sample data for annotation, options are `random`, `least`, `margin`, `entropy`", 210 | ) 211 | parser.add_argument("sample_size", default=100, help="Number of samples to be annotated") 212 | parser.add_argument("output_location", help="location to write annotated data") 213 | parser.add_argument("--example_data", help="data from which we pick example for each class") 214 | args = parser.parse_args() 215 | annotation_data_path = args.annotation_data 216 | sampling_method = args.sampling_method 217 | sample_size = int(args.sample_size) 218 | output_location = args.output_location 219 | example_data_path = args.example_data 220 | if example_data_path is None: 221 | example_data_path = DATA_DIR / "exp_data.csv.gz" 222 | 223 | df_for_annotation = pd.read_csv(annotation_data_path, index_col=False) 224 | df_for_annotation["sampling_method"].fillna("", inplace=True) 225 | sampling_method_in_data = df_for_annotation.sampling_method.unique() 226 | if len(sampling_method_in_data) > 1: 227 | if sampling_method not in sampling_method_in_data: 228 | raise ValueError( 229 | f"""Warning: data of other sampling {sampling_method_in_data} method is being used, when actual sampling method is {sampling_method}. This might corrupt data""" 230 | ) 231 | 232 | exp_data = pd.read_csv(example_data_path) 233 | sampling_method_map = { 234 | "random": random_sampling, 235 | "least": least_confidence_sampling, 236 | "margin": margin_sampling, 237 | "entropy": entropy_sampling, 238 | } 239 | if sampling_method not in list(sampling_method_map.keys()): 240 | raise ValueError("Sampling method has to be one of `random`, `least`, `margin`, `entropy` ") 241 | sample_idx = sampling_method_map[sampling_method](df_for_annotation, size=sample_size) 242 | 243 | data = df_for_annotation[df_for_annotation["idx"].isin(sample_idx)] 244 | reamining_data = df_for_annotation[~(df_for_annotation["idx"].isin(sample_idx))] 245 | 246 | annotated_data = get_annotation(data, exp_data, sample_size) 247 | if not os.path.exists(output_location): 248 | os.mkdir( 249 | output_location, 250 | ) 251 | annotated_data.loc[ 252 | ~(annotated_data["annotated_labels"].isna()), "sampling_method" 253 | ] = sampling_method 254 | tot_annotated = annotated_data[~(annotated_data.annotated_labels.isna())].shape[0] 255 | clear() 256 | print(f"Total annotation required {sample_size}, total annotated {tot_annotated}") 257 | annotated_data = pd.concat((annotated_data, reamining_data), axis=0) 258 | # today = datetime.today().strftime("%Y%m%d") 259 | out_file = os.path.join(output_location, f"annotated_data_{sampling_method}.csv.gz") 260 | print(f"Writing file to {out_file}") 261 | annotated_data.to_csv(out_file, index=False, compression="gzip") 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | -------------------------------------------------------------------------------- /scripts/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Apr 10 17:31:43 2021 4 | 5 | Config file to all the configurations 6 | 7 | @author: Abinaya.M02 8 | """ 9 | 10 | # Import necessary libraries 11 | import logging 12 | import logging.config 13 | import sys 14 | from pathlib import Path 15 | 16 | from rich.logging import RichHandler 17 | from simpletransformers.classification import ClassificationArgs 18 | 19 | # Directories 20 | BASE_DIR = Path(__file__).parent.parent.absolute() 21 | CONFIG_DIR = Path(BASE_DIR, "config") 22 | LOGS_DIR = Path(BASE_DIR, "logs") 23 | DATA_DIR = Path(BASE_DIR, "data") 24 | MODEL_DIR = Path(BASE_DIR, "model") 25 | STORES_DIR = Path(BASE_DIR, "stores") 26 | CACHE_DIR = Path(DATA_DIR, "cache") 27 | BEST_MODEL_DIR = Path(MODEL_DIR, "{}/best_model") 28 | OUTPUT_DIR = Path(DATA_DIR, "output") 29 | 30 | # Local stores 31 | BLOB_STORE = Path(STORES_DIR, "blob_store") 32 | FEATURE_STORE = Path(STORES_DIR, "feature_store") 33 | MODEL_REGISTRY = Path(STORES_DIR, "model_store") 34 | 35 | # Create dirs 36 | LOGS_DIR.mkdir(parents=True, exist_ok=True) 37 | DATA_DIR.mkdir(parents=True, exist_ok=True) 38 | MODEL_DIR.mkdir(parents=True, exist_ok=True) 39 | STORES_DIR.mkdir(parents=True, exist_ok=True) 40 | BLOB_STORE.mkdir(parents=True, exist_ok=True) 41 | FEATURE_STORE.mkdir(parents=True, exist_ok=True) 42 | MODEL_REGISTRY.mkdir(parents=True, exist_ok=True) 43 | 44 | # Set up logging 45 | # Logger 46 | logging_config = { 47 | "version": 1, 48 | "disable_existing_loggers": False, 49 | "formatters": { 50 | "minimal": {"format": "%(message)s"}, 51 | "detailed": { 52 | "format": "%(levelname)s %(asctime)s [%(filename)s:%(funcName)s:%(lineno)d]\n%(message)s\n" 53 | }, 54 | }, 55 | "handlers": { 56 | "console": { 57 | "class": "logging.StreamHandler", 58 | "stream": sys.stdout, 59 | "formatter": "minimal", 60 | "level": logging.DEBUG, 61 | }, 62 | "info": { 63 | "class": "logging.handlers.RotatingFileHandler", 64 | "filename": Path(LOGS_DIR, "info.log"), 65 | "maxBytes": 10485760, # 1 MB 66 | "backupCount": 10, 67 | "formatter": "detailed", 68 | "level": logging.INFO, 69 | }, 70 | "error": { 71 | "class": "logging.handlers.RotatingFileHandler", 72 | "filename": Path(LOGS_DIR, "error.log"), 73 | "maxBytes": 10485760, # 1 MB 74 | "backupCount": 10, 75 | "formatter": "detailed", 76 | "level": logging.ERROR, 77 | }, 78 | }, 79 | "loggers": { 80 | "root": { 81 | "handlers": ["console", "info", "error"], 82 | "level": logging.INFO, 83 | "propagate": True, 84 | }, 85 | }, 86 | } 87 | logging.config.dictConfig(logging_config) 88 | logger = logging.getLogger("root") 89 | logger.handlers[0] = RichHandler(markup=True) 90 | 91 | # Random seed 92 | RANDOM_SEED = 100 93 | 94 | # wandb directory (change the names as per need) 95 | WANDB_PROJ_COMPLETE_DATA = "model_complete_data" 96 | WANDB_PROJ_AL_BASELINE = "model_al_baseline" 97 | WANDB_PROJ_AL_EXP = "model_al_experiments" 98 | 99 | # Model args for the simpletransformer model 100 | # Add or modify parameters based on experiment 101 | BEST_MODEL_SPEC_DIR = str(BEST_MODEL_DIR).format(WANDB_PROJ_AL_EXP) 102 | MODEL_ARGS = ClassificationArgs( 103 | num_train_epochs=5, 104 | overwrite_output_dir=True, 105 | train_batch_size=16, 106 | max_seq_length=250, 107 | # modify based on the experiment 108 | wandb_project=WANDB_PROJ_AL_EXP, 109 | best_model_dir=BEST_MODEL_SPEC_DIR, 110 | cache_dir=str(CACHE_DIR), 111 | eval_batch_size=16, 112 | evaluate_during_training=True, 113 | evaluate_during_training_verbose=True, 114 | manual_seed=100, 115 | output_dir=str(OUTPUT_DIR), 116 | use_early_stopping=True, 117 | early_stopping_patience=3, 118 | reprocess_input_data=True, 119 | ) 120 | 121 | # Model name (roberta-base, roberta-base-uncased, etc) 122 | MODEL_NAME = "roberta" 123 | MODEL_TYPE = "roberta-base" 124 | 125 | # Labels for classification 126 | LABELS = {"0": "Not sure", "1": "World", "2": "Sports", "3": "Business", "4": "Sci/Tech"} 127 | TEST_SPLIT = 0.2 128 | -------------------------------------------------------------------------------- /scripts/data.py: -------------------------------------------------------------------------------- 1 | """ Data creation for Active Learning 2 | From original training data, get 3 | i. training data: random sample of 20% of the original training data 4 | ii. validation data: random sample of 10% of the original training data or 12.5% of data remaining after getting training data 5 | iii. data for annotation: 70% of the data 6 | """ 7 | 8 | import json 9 | import os 10 | 11 | import numpy as np 12 | import pandas as pd 13 | 14 | # import logging 15 | # import datetime 16 | 17 | __author__ = "Pawan Kumar Singh" 18 | 19 | 20 | def prepare_data(file_path, file_name, out_file_name): 21 | """Simple function to 22 | i. Read in data 23 | ii. change title and description to lower case 24 | iii. concatenate title and description to create text 25 | iv. add index[summary] 26 | 27 | Parameters 28 | ---------- 29 | file_path : [type] 30 | [description] 31 | file_name : [type] 32 | [description] 33 | out_file_name : [type] 34 | [description] 35 | """ 36 | data = pd.read_csv(os.path.join(file_path, file_name), index_col=False) 37 | data["Title"] = data["Title"].str.lower() 38 | data["Description"] = data["Description"].str.lower() 39 | data["text"] = data["Title"] + " " + data["Description"] 40 | data["labels"] = data["Class Index"] 41 | data.drop("Class Index", axis=1, inplace=True) 42 | 43 | data["id"] = data.index 44 | if "gz" not in out_file_name: 45 | out_file_name += ".gz" 46 | print(out_file_name) 47 | data.to_csv(os.path.join(file_path, out_file_name), index=False, compression="gzip") 48 | 49 | 50 | def split_data( 51 | file_path, 52 | train_file_name, 53 | test_file_name, 54 | train_size=0.2, 55 | valid_size=0.1, 56 | random_seed=100, 57 | ): 58 | """Split original training data into training data and annotation data and 59 | save them in `train` and `annotation` directory in `file_path` 60 | 61 | Parameters 62 | ---------- 63 | file_path : [type] 64 | [description] 65 | train_file_name : [type] 66 | [description] 67 | test_file_name : [type] 68 | [description] 69 | train_size : float, optional 70 | [description], by default 0.2 71 | valid_size : float, optional 72 | [description], by default 0.1 73 | random_seed : int, optional 74 | [description], by default 100 75 | """ 76 | data = pd.read_csv(os.path.join(file_path, train_file_name), index_col=False) 77 | test_data = pd.read_csv(os.path.join(file_path, test_file_name), index_col=False) 78 | train_size = int(data.shape[0] * train_size) 79 | valid_size = int(data.shape[0] * valid_size) 80 | np.random.seed(100) 81 | all_idx = data["id"] 82 | train_idx = np.random.choice(all_idx, size=train_size, replace=False) 83 | remain_idx = list(set(all_idx.tolist()) - set(train_idx.tolist())) 84 | valid_idx = np.random.choice(remain_idx, size=valid_size, replace=False) 85 | annotate_idx = list(set(remain_idx) - set(valid_idx)) 86 | 87 | data_track_dict = {} 88 | data_track_dict["train_idx"] = train_idx.tolist() 89 | data_track_dict["valid_idx"] = valid_idx.tolist() 90 | data_track_dict["annotate_idx"] = annotate_idx 91 | 92 | check_and_create_dir(".data_tracking") 93 | print(data_track_dict.keys()) 94 | with open(os.path.join(".data_tracking", "data_info.txt"), "w") as outfile: 95 | json.dump(data_track_dict, outfile, indent=4) 96 | # json.dumps(data_track_dict, ".data_tracking") 97 | 98 | train_data = data[data["id"].isin(train_idx)] 99 | valid_data = data[data["id"].isin(valid_idx)] 100 | annotate_data = data[data["id"].isin(annotate_idx)] 101 | annotate_data.drop("labels", axis=1, inplace=True) 102 | 103 | print(f"Size of train data {train_data.shape[0] / data.shape[0]}") 104 | print(f"Size of valid data {valid_data.shape[0] / data.shape[0]}") 105 | print(f"Size of annotate data {annotate_data.shape[0] / data.shape[0]}") 106 | 107 | check_and_create_dir("train_data") 108 | write_csv(train_data, "train_data", "train.csv") 109 | check_and_create_dir("test_data") 110 | write_csv(test_data, "test_data", "test.csv") 111 | check_and_create_dir("annotate_data") 112 | write_csv(annotate_data, "annotate_data", "annotate.csv") 113 | check_and_create_dir("valid_data") 114 | write_csv(valid_data, "valid_data", "valid.csv") 115 | 116 | 117 | def write_csv(data, out_path, file_name): 118 | if "gz" not in file_name: 119 | file_name += ".gz" 120 | data.to_csv(os.path.join(out_path, file_name)) 121 | 122 | 123 | def check_and_create_dir(path): 124 | if not os.path.exists(path): 125 | os.mkdir(path) 126 | -------------------------------------------------------------------------------- /scripts/download_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Apr 2 14:23:30 2021 4 | 5 | Script to download ag_news dataset 6 | 7 | @author: Abinaya Mahendiran 8 | """ 9 | 10 | from pathlib import Path 11 | 12 | from config import DATA_DIR, logger 13 | 14 | # Import necessary libraries 15 | from datasets import load_dataset 16 | 17 | 18 | # Load data and convert it to dataframe 19 | def load_data(dataset_name: str, split: str) -> object: 20 | """ 21 | Load the data from datasets library and convert to dataframe 22 | 23 | Parameters 24 | ---------- 25 | dataset_name : str 26 | name of the dataset to be downloaded. 27 | split : str 28 | type of split (train or test). 29 | Returns 30 | ------- 31 | object 32 | dataframe. 33 | 34 | """ 35 | data = load_dataset(dataset_name, split=split) 36 | logger.info(split + " dataset downloaded!") 37 | return data 38 | 39 | 40 | # Save teh data locally 41 | def save_data(path: str, dataframe: object) -> None: 42 | """ 43 | Save the dataframe to a local folder 44 | 45 | Parameters 46 | ---------- 47 | path : str 48 | path of the folder. 49 | dataframe : object 50 | dataframe object. 51 | 52 | Returns 53 | ------- 54 | None 55 | None. 56 | 57 | """ 58 | dataframe.to_csv(path) 59 | logger.info("dataset saved!") 60 | 61 | 62 | if __name__ == "__main__": 63 | train_data = load_data("ag_news", "train") 64 | save_data(Path(DATA_DIR, "train.csv"), train_data) 65 | test_data = load_data("ag_news", "test") 66 | save_data(Path(DATA_DIR, "test.csv"), test_data) 67 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Apr 2 14:23:30 2021 4 | 5 | Script to train simpletransformer model 6 | 7 | @author: Abinaya Mahendiran 8 | """ 9 | 10 | 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | # Import necessary libraries 17 | import torch 18 | import torch.nn as nn 19 | from simpletransformers.classification import ClassificationModel 20 | from sklearn.metrics import accuracy_score 21 | 22 | from scripts import config 23 | from scripts.config import logger 24 | 25 | 26 | class NewsClassification: 27 | def __init__(self): 28 | self.model_name = config.MODEL_NAME 29 | self.model_type = config.MODEL_TYPE 30 | self.train_data = pd.read_csv(Path(config.DATA_DIR, "train.csv")) 31 | self.test_data = pd.read_csv(Path(config.DATA_DIR, "test.csv")) 32 | self.cuda = torch.cuda.is_available() 33 | self.model_args = config.MODEL_ARGS 34 | self.labels = config.LABELS 35 | 36 | def preprocess_data(self, data: object, column_name: str) -> object: 37 | """ 38 | Perform preprocessing on the text data 39 | 40 | Parameters 41 | ---------- 42 | data : object 43 | dataframe. 44 | column_name : str 45 | name of the column in the dataframe 46 | 47 | Returns 48 | ------- 49 | object 50 | pre-processed dataframe. 51 | 52 | """ 53 | data.rename(columns={"Unnamed: 0": "idx"}, inplace=True) 54 | if column_name == "text": 55 | data[column_name] = data[column_name].str.lower() 56 | if column_name == "label": 57 | data[column_name] = data[column_name].apply(int) - 1 58 | data.rename(columns={"label": "labels"}, inplace=True) 59 | return data 60 | 61 | def split_data(self, data: object, random_seed: int) -> (object, object): 62 | """ 63 | Split the dataset into train and eval 64 | 65 | Parameters 66 | ---------- 67 | data : object 68 | dataframe containing training data. 69 | random_seed : int 70 | integer to set the random seed 71 | 72 | Returns 73 | ------- 74 | (object, object) 75 | train split, eval split. 76 | 77 | """ 78 | np.random.seed(random_seed) 79 | train_idx = np.random.choice( 80 | data.index, size=int(data.shape[0] * config.TEST_SPLIT), replace=False 81 | ) 82 | valid_idx = set(data.index) - set(train_idx) 83 | 84 | train_data = data[data.index.isin(train_idx)] 85 | eval_data = data[data.index.isin(valid_idx)] 86 | return (train_data, eval_data) 87 | 88 | def train(self, train_data: object, eval_data: object) -> object: 89 | """ 90 | Create and train the chosen model based on the args 91 | 92 | Parameters 93 | ---------- 94 | train_data : object 95 | train split of the train_data. 96 | eval_data : object 97 | validation split of the train_data. 98 | 99 | Returns 100 | ------- 101 | object 102 | model. 103 | 104 | """ 105 | 106 | # Create a ClassificationModel 107 | model = ClassificationModel( 108 | self.model_name, 109 | self.model_type, 110 | args=self.model_args, 111 | use_cuda=self.cuda, 112 | num_labels=len(self.labels) - 1, 113 | ) 114 | # Train the model 115 | model.train_model(train_df=train_data, eval_df=eval_data, accuracy=accuracy_score) 116 | return model 117 | 118 | def load_model(self, model_type: str) -> object: 119 | """ 120 | Load the specified model 121 | 122 | Parameters 123 | ---------- 124 | model_type : str 125 | path or model type to be loaded. 126 | 127 | Returns 128 | ------- 129 | object 130 | model. 131 | 132 | """ 133 | model = ClassificationModel( 134 | self.model_name, 135 | model_type, 136 | args=self.model_args, 137 | use_cuda=self.cuda, 138 | num_labels=len(self.labels) - 1, 139 | ) 140 | return model 141 | 142 | def format_output(self, predictions: object, raw_output: object) -> object: 143 | """ 144 | Format the output to the required format for annotation 145 | 146 | Parameters: 147 | ---------- 148 | predictions : object 149 | probabilities. 150 | raw_output : object 151 | logits. 152 | 153 | Returns: 154 | ------- 155 | object 156 | Modified dataframe in the required format 157 | 158 | """ 159 | # Convert logits to labels 160 | sfm = nn.Softmax(dim=1) 161 | raw_output_tensor = torch.from_numpy(raw_output) 162 | annotate_class_prob = sfm(raw_output_tensor) 163 | max_prob = torch.max(annotate_class_prob, dim=1) 164 | annotate_class_prob = annotate_class_prob.numpy() 165 | max_prob = max_prob.values.numpy() 166 | 167 | # Reshape the data 168 | annotate_df_with_pred = self.test_data 169 | probabilities = pd.DataFrame( 170 | annotate_class_prob, columns=["prob_0", "prob_1", "prob_2", "prob_3"] 171 | ) 172 | annotate_df_with_pred = pd.concat([annotate_df_with_pred, probabilities], axis=1) 173 | annotate_df_with_pred["max_prob"] = max_prob 174 | annotate_df_with_pred["label_pred"] = predictions 175 | annotate_df_with_pred["annotated_labels"] = "" 176 | annotate_df_with_pred["sampling_method"] = "" 177 | return annotate_df_with_pred 178 | 179 | 180 | def main(): 181 | """ 182 | Run the news classification model 183 | 184 | Returns 185 | ------- 186 | None. 187 | 188 | """ 189 | # Create classification object 190 | news_model = NewsClassification() 191 | logger.info("News classification model instantiated") 192 | 193 | # Preprocess and split data 194 | data = news_model.preprocess_data(news_model.train_data, "text") 195 | logger.info("Train data is pre-processed") 196 | train_data, eval_data = news_model.split_data(data, config.RANDOM_SEED) 197 | logger.info("Data is split") 198 | 199 | # Train model 200 | # train_model = news_model.train(train_data, eval_data) 201 | logger.info("Model is trained") 202 | 203 | # Load model from the best model directory 204 | loaded_model = news_model.load_model(config.BEST_MODEL_SPEC_DIR) 205 | logger.info("Model is loaded") 206 | 207 | # Eval model 208 | model_result, model_outputs, wrong_predictions = loaded_model.eval_model( 209 | eval_data, accuracy=accuracy_score 210 | ) 211 | logger.info("Model is evaluated") 212 | 213 | # Prediction 214 | news_model.test_data = news_model.preprocess_data(news_model.test_data, "text") 215 | predictions, raw_outputs = loaded_model.predict(news_model.test_data.text.values.tolist()) 216 | logger.info("Predictions completed") 217 | 218 | # Format output 219 | annotate_data = news_model.format_output(predictions, raw_outputs) 220 | annotate_data.to_csv(Path(config.DATA_DIR, "annotate_data.csv")) 221 | 222 | 223 | if __name__ == "__main__": 224 | main() 225 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | Add unit tests in this folder. 2 | -------------------------------------------------------------------------------- /tests/annotation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AbinayaM02/Active_Learning_in_NLP/9c6bb281c5508a6117fa3c829c54562b3d58253a/tests/annotation/__init__.py -------------------------------------------------------------------------------- /tests/annotation/test_annotator.py: -------------------------------------------------------------------------------- 1 | # import os 2 | # import sys 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | 8 | from scripts.annotator import ( # random_sampling, 9 | entropy_sampling, 10 | least_confidence_sampling, 11 | margin_sampling, 12 | ) 13 | 14 | # PACKAGE_PARENT = ".." 15 | # SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd()))) 16 | # sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) 17 | 18 | 19 | @pytest.fixture 20 | def test_data(): 21 | return pd.DataFrame( 22 | { 23 | "idx": [1, 2, 3], 24 | "prob_0": [0.1, 0, 0.25], 25 | "prob_1": [0.2, 0.1, 0.25], 26 | "prob_2": [0.2, 0.1, 0.25], 27 | "prob_3": [0.5, 0.8, 0.25], 28 | "max_prob": [0.5, 0.6, 0.2], 29 | "annotated_labels": [np.nan, np.nan, np.nan], 30 | } 31 | ) 32 | 33 | 34 | def test_least_confidence_sampling(test_data): 35 | actual_idx = least_confidence_sampling(test_data, 1) 36 | expected_idx = 3 37 | assert expected_idx == actual_idx 38 | 39 | 40 | def test_margin_sampling(test_data): 41 | actual_idx = margin_sampling(test_data, 2) 42 | expected_idx = np.array([3, 1]) 43 | assert (expected_idx == actual_idx).all() 44 | 45 | 46 | def test_entropy_sampling(test_data): 47 | actual_idx = entropy_sampling(test_data, 1) 48 | expected_idx = 3 49 | assert expected_idx == actual_idx 50 | --------------------------------------------------------------------------------