├── stabilizer ├── __init__.py ├── dataset.py ├── trainer.py ├── reproducibility.py ├── model.py ├── reinitialize.py └── llrd.py ├── MANIFEST.in ├── requirements.txt ├── .flake8 ├── pyproject.toml ├── tests ├── test_llrd.py └── test_reinitialize.py ├── setup.py ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── README.md ├── main.py ├── LICENSE └── examples ├── disaster-tweets-stabilizer.ipynb └── cola.ipynb /stabilizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include requirements.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | pandas==1.3.1 3 | scikit-learn==0.24.2 4 | transformers==4.10.0 -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401 3 | max-line-length = 79 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 -------------------------------------------------------------------------------- /stabilizer/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class TextLabelDataset(Dataset): 5 | def __init__(self, text_excerpts, labels): 6 | self.text_excerpts = text_excerpts 7 | self.labels = labels 8 | 9 | def __len__(self): 10 | return len(self.text_excerpts) 11 | 12 | def __getitem__(self, idx): 13 | sample = {"text_excerpt": self.text_excerpts[idx], "label": self.labels[idx]} 14 | return sample 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Example configuration for Black. 2 | 3 | # NOTE: you have to use single-quoted strings in TOML for regular expressions. 4 | # It's the equivalent of r-strings in Python. Multiline strings are treated as 5 | # verbose regular expressions by Black. Use [ ] to denote a significant space 6 | # character. 7 | 8 | [tool.black] 9 | line-length = 119 10 | target-version = ['py37', 'py38'] 11 | include = '\.pyi?$' 12 | exclude = ''' 13 | /( 14 | \.eggs 15 | | \.git 16 | | \.hg 17 | | \.mypy_cache 18 | | \.tox 19 | | \.venv 20 | | _build 21 | | buck-out 22 | | build 23 | | dist 24 | )/ 25 | ''' -------------------------------------------------------------------------------- /stabilizer/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def train_step(model, inputs, targets, loss_fn, optimizer, scheduler): 5 | model.train() 6 | # forward pass 7 | predictions = model(inputs) 8 | loss = loss_fn(predictions, targets) 9 | # backpropagation 10 | optimizer.zero_grad() 11 | loss.backward() 12 | optimizer.step() 13 | scheduler.step() 14 | return {"loss": loss} 15 | 16 | 17 | def evaluate_step(model, inputs, targets, loss_fn): 18 | model.eval() 19 | # forward pass 20 | with torch.no_grad(): 21 | predictions = model(inputs) 22 | loss = loss_fn(predictions, targets) 23 | return {"loss": loss, "targets": targets, "predictions": predictions} 24 | -------------------------------------------------------------------------------- /stabilizer/reproducibility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def seed_torch(seed: int): 8 | """ 9 | More information on Pytorch reproducibility can be found here: Pytorch reproducibility: https://pytorch.org/docs/stable/notes/randomness.html 10 | Args: 11 | seed (int): desired seed 12 | """ 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | torch.backends.cudnn.benchmark = False 16 | torch.backends.cudnn.deterministic = True 17 | 18 | 19 | def seed_everything(seed: int): 20 | """ 21 | Call this function at the begining of your script to ensure reproducibility 22 | Args: 23 | seed (int): desired seed 24 | """ 25 | random.seed(seed) 26 | os.environ["PYTHONASSEED"] = str(seed) 27 | np.random.seed(seed) 28 | seed_torch() 29 | -------------------------------------------------------------------------------- /stabilizer/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from stabilizer.reproducibility import seed_torch 3 | 4 | 5 | class PoolerClassifier(nn.Module): 6 | def __init__( 7 | self, 8 | transformer, 9 | transformer_output_size, 10 | transformer_output_dropout_prob, 11 | num_classes, 12 | task_specific_layer_seed=None, 13 | ): 14 | super(PoolerClassifier, self).__init__() 15 | self.transformer = transformer 16 | self.transformer_output_size = transformer_output_size 17 | self.transformer_output_dropout = nn.Dropout(p=transformer_output_dropout_prob) 18 | if task_specific_layer_seed is not None: 19 | seed_torch(task_specific_layer_seed) 20 | self.task_specific_layer = nn.Linear(self.transformer_output_size, num_classes) 21 | 22 | def forward(self, inputs): 23 | transformer_outputs = self.transformer(**inputs) 24 | vectors = self.transformer_output_dropout(transformer_outputs.pooler_output) 25 | logits = self.task_specific_layer(vectors) 26 | return logits 27 | -------------------------------------------------------------------------------- /tests/test_llrd.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from transformers import AutoModel 3 | from stabilizer.model import PoolerClassifier 4 | from stabilizer.llrd import get_optimizer_parameters_with_llrd 5 | from transformers import logging 6 | 7 | logging.set_verbosity_error() 8 | 9 | 10 | class TestLlrd(unittest.TestCase): 11 | def test_get_optimizer_parameters_with_llrd(self): 12 | transformer = AutoModel.from_pretrained( 13 | pretrained_model_name_or_path="bert-base-uncased", 14 | hidden_dropout_prob=0.1, 15 | attention_probs_dropout_prob=0.1, 16 | ) 17 | model = PoolerClassifier( 18 | transformer=transformer, 19 | transformer_output_size=transformer.config.hidden_size, 20 | transformer_output_dropout_prob=0.1, 21 | num_classes=1, 22 | ) 23 | optimizer_parameters = get_optimizer_parameters_with_llrd( 24 | model=model, peak_lr=2e-5, multiplicative_factor=0.95 25 | ) 26 | self.assertEqual(len(optimizer_parameters), 14) 27 | 28 | 29 | if __name__ == "__main__": 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | 4 | current_dir = os.path.dirname(os.path.abspath("__file__")) 5 | 6 | 7 | # Get the long description from the README file 8 | with open(os.path.join(current_dir, "README.md"), encoding="utf-8") as f: 9 | long_description = f.read() 10 | 11 | 12 | # What packages are required for this module to be executed? 13 | try: 14 | with open(os.path.join(current_dir, "requirements.txt"), encoding="utf-8") as f: 15 | required = f.read().split("\n") 16 | except FileNotFoundError: 17 | required = [] 18 | 19 | setuptools.setup( 20 | name="stabilizer", 21 | version="1.0.2", 22 | author="flowerpot-ai", 23 | author_email="vignesh.sbaskaran@gmail.com", 24 | description="Stabilize and achieve excellent performance with transformers", 25 | long_description_content_type="text/markdown", 26 | long_description=long_description, 27 | url="https://github.com/flowerpot-ai/stabilizer", 28 | packages=setuptools.find_packages(), 29 | classifiers=[ 30 | "Programming Language :: Python :: 3", 31 | "License :: OSI Approved :: MIT License", 32 | "Operating System :: OS Independent", 33 | ], 34 | python_requires=">=3.7", 35 | install_requires=required, 36 | include_package_data=True, 37 | ) 38 | -------------------------------------------------------------------------------- /stabilizer/reinitialize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import transformers 3 | from torch import nn 4 | 5 | AUTOENCODINGMODELS = [ 6 | "Bert", 7 | "Roberta", 8 | "DistilBert", 9 | "Albert", 10 | "XLMRoberta", 11 | "BertModel", 12 | ] 13 | 14 | 15 | def reinit_autoencoder_model(encoder, reinit_num_layers=0): 16 | """reinitialize autoencoder model layers""" 17 | 18 | if not any(k in encoder.config.architectures[0] for k in AUTOENCODINGMODELS): 19 | logging.ERROR("Model not Autoencoding based") 20 | return encoder 21 | 22 | if reinit_num_layers: 23 | for layer in encoder.layer[-reinit_num_layers:]: 24 | 25 | for module in layer.modules(): 26 | 27 | if isinstance(module, nn.Linear): 28 | module.weight.data.normal_(mean=0.0, std=encoder.config.initializer_range) 29 | if module.bias is not None: 30 | module.bias.data.zero_() 31 | elif isinstance(module, nn.Embedding): 32 | module.weight.data.normal_(mean=0.0, std=encoder.config.initializer_range) 33 | if module.padding_idx is not None: 34 | module.weight.data[module.padding_idx].zero_() 35 | elif isinstance(module, nn.LayerNorm): 36 | module.bias.data.zero_() 37 | module.weight.data.fill_(1.0) 38 | 39 | return encoder 40 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Stabilizer 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.7,3.8] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 20 | uses: actions/setup-python@v1.1.1 21 | env : 22 | ACTIONS_ALLOW_UNSECURE_COMMANDS : true 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Cache pip 26 | uses: actions/cache@v1 27 | with: 28 | path: ~/.cache/pip # This path is specific to Ubuntu 29 | # Look to see if there is a cache hit for the corresponding requirements file 30 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 31 | restore-keys: | 32 | ${{ runner.os }}-pip- 33 | ${{ runner.os }}- 34 | # You can test your matrix by printing the current Python version 35 | - name: Display Python version 36 | run: python -c "import sys; print(sys.version)" 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install -r requirements.txt 41 | pip install black flake8 mypy pytest hypothesis 42 | - name: Run black 43 | run: 44 | black --check . 45 | - name: Run flake8 46 | run: flake8 47 | - name: tests 48 | run: python -m unittest tests/*.py 49 | -------------------------------------------------------------------------------- /tests/test_reinitialize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | from torch import nn as nn 4 | from transformers import AutoModel 5 | from stabilizer.model import PoolerClassifier 6 | from stabilizer.reinitialize import reinit_autoencoder_model 7 | from transformers import logging 8 | 9 | logging.set_verbosity_error() 10 | 11 | 12 | class TestReinitialize(unittest.TestCase): 13 | def test_reinit_autoencoder_model(self): 14 | transformer = AutoModel.from_pretrained( 15 | pretrained_model_name_or_path="bert-base-uncased", 16 | hidden_dropout_prob=0.1, 17 | attention_probs_dropout_prob=0.1, 18 | ) 19 | transformer.encoder = reinit_autoencoder_model(transformer.encoder, reinit_num_layers=1) 20 | model = PoolerClassifier( 21 | transformer=transformer, 22 | transformer_output_size=transformer.config.hidden_size, 23 | transformer_output_dropout_prob=0.1, 24 | num_classes=1, 25 | ) 26 | 27 | vals = [] 28 | for module in model.transformer.encoder.layer[-1].modules(): 29 | 30 | if isinstance(module, nn.Linear): 31 | k = torch.isclose(module.weight.data.mean(), torch.tensor(0.0), atol=1e-4) 32 | v = torch.isclose( 33 | module.weight.data.std().detach().cpu(), 34 | torch.tensor(0.02), 35 | atol=1e-4, 36 | ) 37 | vals.append(k.cpu().numpy()) 38 | vals.append(v.cpu().numpy()) 39 | vals.append(all(module.bias.data.detach().cpu() == torch.zeros(module.bias.data.shape))) 40 | 41 | self.assertEqual(all(vals), True) 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /stabilizer/llrd.py: -------------------------------------------------------------------------------- 1 | def get_optimizer_parameters_with_llrd(model, peak_lr, multiplicative_factor): 2 | num_encoder_layers = len(model.transformer.encoder.layer) 3 | # Task specific layer gets the peak_lr 4 | tsl_parameters = [ 5 | { 6 | "params": [param for name, param in model.task_specific_layer.named_parameters()], 7 | "param_names": [name for name, param in model.task_specific_layer.named_parameters()], 8 | "lr": peak_lr, 9 | "name": "tsl_param_group", 10 | } 11 | ] 12 | 13 | # Starting from the last encoder layer each encoder layers get a lr defined by 14 | # current_layer_lr = prev_layer_lr * multiplicative_factor 15 | # the last encoder layer lr = peak_lr * multiplicative_factor 16 | encoder_parameters = [ 17 | { 18 | "params": [param for name, param in layer.named_parameters()], 19 | "param_names": [name for name, param in layer.named_parameters()], 20 | "lr": peak_lr * (multiplicative_factor ** (num_encoder_layers - layer_num)), 21 | "name": f"transformer.encoder.layer.{layer_num}", 22 | } 23 | for layer_num, layer in enumerate(model.transformer.encoder.layer) 24 | ] 25 | 26 | # Embedding layer gets embedding layer lr = first encoder layer lr * multiplicative_factor 27 | embedding_parameters = [ 28 | { 29 | "params": [param for name, param in model.transformer.embeddings.named_parameters()], 30 | "param_names": [name for name, param in model.transformer.embeddings.named_parameters()], 31 | "lr": peak_lr * (multiplicative_factor ** (num_encoder_layers + 1)), 32 | "name": "embeddings_param_group", 33 | } 34 | ] 35 | return tsl_parameters + encoder_parameters + embedding_parameters 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VScode 132 | vscode/* 133 | .vscode/settings.json 134 | .vscode/tasks.json 135 | .vscode/launch.json 136 | .vscode/extensions.json 137 | *.code-workspace 138 | 139 | # Data files 140 | data 141 | 142 | # Model files 143 | models 144 | 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stabilizer 2 | ### Stabilize and achieve excellent performance with transformers. 3 | The stabilizer library offer solutions to tackle one of the biggest challenges that comes along with training State of the art Transformer models, **Unstable training** 4 | 5 | ## Unstable training 6 | Unstable training is the phenomenon in which training large transformer models with trivial changes such as changing the random seed drastically changes the performance of the model. Here is a screenshot of finetuning the CoLA dataset from GLUE tasks with two different random seeds applied only to the dropout of the transformer model. 7 | 8 | ![dropout_random_seed](https://i.ibb.co/jyx3tLT/baseline-dropout-seed.png) 9 | 10 | ## Installation 11 | `pip install stabilizer` 12 | 13 | 14 | ## Techniques currently implemented in this library 15 | 1. Reinitialization 16 | 2. Layerwise Learning Rate Decay 17 | 18 | 19 | ### Reinitialization 20 | Reinitialize the last `n` layers of the transformer encoder. This technique works well because we reinitialize the task specific parameters that the pretrained models have learnt specific to the pretraining task. 21 | ```python 22 | from stabilizer.reinitialize import reinit_autoencoder_model 23 | from transformers import AutoModel 24 | 25 | transformer = AutoModel.from_pretrained( 26 | pretrained_model_name_or_path="bert-base-uncased", 27 | hidden_dropout_prob=0.1, 28 | attention_probs_dropout_prob=0.1, 29 | ) 30 | transformer.encoder = reinit_autoencoder_model( 31 | transformer.encoder, reinit_num_layers=1 32 | ) 33 | ``` 34 | Here is the result of the same model but reinitialized last 4 layers applied on the CoLA dataset. You can see that the model has converged to almost the same performance with reinitialization. 35 | ![reinit_random_seed](https://i.ibb.co/1MyPbfG/reinit-dropout-seed.png) 36 | 37 | 38 | 39 | ### Layerwise Learning Rate Decay 40 | Apply layerwise learning rate to the transformer layers. Starting from the task specific layer every layer before it gets an exponentially decreasing learning rate. 41 | 42 | 43 | ```python 44 | 45 | from stabilizer.llrd import get_optimizer_parameters_with_llrd 46 | from stabilizer.model import PoolerClassifier 47 | 48 | from transformers import AdamW, AutoModel 49 | 50 | 51 | transformer = AutoModel.from_pretrained( 52 | pretrained_model_name_or_path=config["pretrained_model_name_or_path"], 53 | hidden_dropout_prob=config["dropout_prob"], 54 | attention_probs_dropout_prob=config["dropout_prob"], 55 | ) 56 | 57 | model = PoolerClassifier( 58 | transformer=transformer, 59 | transformer_output_size=transformer.config.hidden_size, 60 | transformer_output_dropout_prob=config["dropout_prob"], 61 | num_classes=config["num_classes"], 62 | task_specific_layer_seed=config["layer_initialization_seed"], 63 | ) 64 | 65 | model_parameters = get_optimizer_parameters_with_llrd( 66 | model=model, 67 | peak_lr=config["lr"], 68 | multiplicative_factor=config["multiplicative_factor"], 69 | ) 70 | optimizer = AdamW(params=model_parameters, lr=config["lr"]) 71 | 72 | 73 | ``` 74 | 75 | Here is the result of the same model but with LLRD applied on the CoLA dataset. Here you can see that the model has diverged quite a lot by applying LLRD. Therefore as we discussed earlier their is no universal remedy yet but some techniques work well on some datasets 76 | ![llrd_random_seed](https://i.ibb.co/jkLJSP0/llrd-dropout-seed.png) 77 | 78 | ## Conbtributing to this repository 79 | There are three possible ways by which people can contribute to this repository: 80 | 1. Adding more techniques for stabilizing the training 81 | 2. Write example scripts to train new datasets with the already implemented stabilization techniques 82 | 3. Adding tensorflow/jax equivalent codes for the existing Pytorch code 83 | 84 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import logging 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from sklearn.metrics import matthews_corrcoef 11 | 12 | from stabilizer.model import PoolerClassifier 13 | from stabilizer.dataset import TextLabelDataset 14 | from stabilizer.reproducibility import seed_torch 15 | from stabilizer.trainer import train_step, evaluate_step 16 | from stabilizer.reinitialize import reinit_autoencoder_model 17 | from stabilizer.llrd import get_optimizer_parameters_with_llrd 18 | 19 | from transformers import get_scheduler, AdamW, AutoModel, AutoTokenizer 20 | 21 | logging.basicConfig( 22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 23 | datefmt="%m/%d/%Y %H:%M:%S", 24 | level=logging.INFO, 25 | ) 26 | logger = logging.getLogger(__name__) 27 | 28 | """ 29 | config = {'train_data_path': 'data/glue/cola/train.jsonl', 30 | 'valid_data_path': 'data/glue/cola/valid.jsonl', 31 | 'batch_size': 32, 32 | 'pretrained_tokenizer_name_or_path': 'models/bert-base-uncased', 33 | 'pretrained_model_name_or_path': 'models/bert-base-uncased', 34 | 'device_name': 'cpu', 35 | 'dropout_prob': 0.1, 36 | 'num_classes': 1, 37 | 'lr': 2e-5, 38 | 'num_epochs': 3, 39 | 'validate_every_n_iteration': 10, 40 | 'dataloader_seed': 41, 41 | 'layer_initialization_seed': 1000, 42 | 'dropout_seed': 1234, 43 | 'reinit_encoder': True, 44 | 'reinit_num_layers': 4, 45 | 'apply_llrd': True, 46 | 'multiplicative_factor': 0.95} 47 | 48 | python main.py --train_data_path data/glue/cola/train.jsonl --valid_data_path data/glue/cola/valid.jsonl --batch_size 32 \ 49 | --pretrained_tokenizer_name_or_path models/bert-base-uncased --pretrained_model_name_or_path models/bert-base-uncased \ 50 | --device_name cuda --dropout_prob 0.1 --num_classes 1 --lr 2e-5 --num_epochs 3 --validate_every_n_iteration 10 \ 51 | --dataloader_seed 41 --layer_initialization_seed 1000 --dropout_seed 1234 --reinit_encoder True --reinit_num_layers 4 \ 52 | --apply_llrd True --multiplicative_factor 0.95 53 | """ 54 | 55 | 56 | def parse_args(): 57 | parser = argparse.ArgumentParser(description="Train a Pooler classifier on CoLA dataset") 58 | parser.add_argument("--train_data_path", type=str, help="Path of the training data file") 59 | parser.add_argument("--valid_data_path", type=str, help="Path of the validation data file") 60 | parser.add_argument( 61 | "--pretrained_tokenizer_name_or_path", 62 | type=str, 63 | help="Path of the directory that contains the tokenizer", 64 | ) 65 | parser.add_argument( 66 | "--pretrained_model_name_or_path", 67 | type=str, 68 | help="Path of the directory that contains the pretrained model", 69 | ) 70 | parser.add_argument("--batch_size", type=int, help="Batch size") 71 | parser.add_argument( 72 | "--device_name", 73 | type=str, 74 | choices=("cpu", "cuda"), 75 | help="Device to train the algorithm", 76 | ) 77 | parser.add_argument("--dropout_prob", type=float, help="Value of dropout") 78 | parser.add_argument( 79 | "--num_classes", 80 | type=int, 81 | help="Num of output classes for the classification task", 82 | ) 83 | parser.add_argument("--lr", type=float, help="Learning rate") 84 | parser.add_argument("--num_epochs", type=int, help="Number of training epochs") 85 | parser.add_argument("--validate_every_n_iteration", type=int, help="How often to validate") 86 | parser.add_argument("--dropout_seed", type=int, default=random.randint(a=0, b=10000)) 87 | parser.add_argument("--layer_initialization_seed", type=int, default=random.randint(a=0, b=10000)) 88 | parser.add_argument("--dataloader_seed", type=int, default=random.randint(a=0, b=10000)) 89 | parser.add_argument( 90 | "--reinit_encoder", 91 | type=bool, 92 | help="Should be the transformer encoder be reinitialized", 93 | ) 94 | parser.add_argument( 95 | "--reinit_num_layers", 96 | type=int, 97 | help="Number of transformer encoder layers to be reinitialized", 98 | ) 99 | parser.add_argument( 100 | "--apply_llrd", 101 | type=bool, 102 | help="Should apply Layerwise learning rate decay to the model parameters during optimization", 103 | ) 104 | parser.add_argument( 105 | "--multiplicative_factor", 106 | type=float, 107 | help="Factor with which the learning rate should decrease for successive layers", 108 | ) 109 | args = parser.parse_args() 110 | return args 111 | 112 | 113 | def post_process_targets(targets): 114 | targets = targets.type(torch.int) 115 | targets = targets.cpu().detach().numpy().reshape(-1) 116 | return targets 117 | 118 | 119 | def post_process_predictions(predictions): 120 | predictions = torch.sigmoid(predictions) 121 | predictions = (predictions >= 0.5).type(torch.int) 122 | predictions = predictions.cpu().detach().numpy().reshape(-1) 123 | return predictions 124 | 125 | 126 | def compute_matthews_corrcoef(targets, predictions): 127 | if len(np.unique(predictions)) > 1 and len(np.unique(targets)) > 1: 128 | score = matthews_corrcoef(y_true=targets, y_pred=predictions) 129 | else: 130 | score = 0.0 131 | return score 132 | 133 | 134 | def run(): 135 | # Read configuration 136 | config = parse_args().__dict__ 137 | 138 | # Read training data 139 | train_data = pd.read_json(path_or_buf=config["train_data_path"], lines=True).set_index("idx") 140 | valid_data = pd.read_json(path_or_buf=config["valid_data_path"], lines=True).set_index("idx") 141 | 142 | # Prepate data to create dataset 143 | train_text_excerpts = train_data["text"].tolist() 144 | valid_text_excerpts = valid_data["text"].tolist() 145 | train_labels = torch.from_numpy(train_data["label"].to_numpy().reshape(-1, 1)).type(torch.float32) 146 | valid_labels = torch.from_numpy(valid_data["label"].to_numpy().reshape(-1, 1)).type(torch.float32) 147 | 148 | # Create Dataset 149 | train_dataset = TextLabelDataset(text_excerpts=train_text_excerpts, labels=train_labels) 150 | valid_dataset = TextLabelDataset(text_excerpts=valid_text_excerpts, labels=valid_labels) 151 | 152 | # Create DataLoader 153 | generator = torch.Generator(device="cpu") 154 | _ = generator.manual_seed(config["dataloader_seed"]) 155 | train_dataloader = DataLoader( 156 | dataset=train_dataset, 157 | batch_size=config["batch_size"], 158 | shuffle=True, 159 | generator=generator, 160 | ) 161 | valid_dataloader = DataLoader( 162 | dataset=valid_dataset, 163 | batch_size=config["batch_size"], 164 | shuffle=False, 165 | generator=generator, 166 | ) 167 | 168 | # Create tokenizer and model 169 | tokenizer = AutoTokenizer.from_pretrained(config["pretrained_tokenizer_name_or_path"]) 170 | transformer = AutoModel.from_pretrained( 171 | pretrained_model_name_or_path=config["pretrained_model_name_or_path"], 172 | hidden_dropout_prob=config["dropout_prob"], 173 | attention_probs_dropout_prob=config["dropout_prob"], 174 | ) 175 | 176 | # Reinitialize 177 | if config["reinit_encoder"]: 178 | seed_torch(config["layer_initialization_seed"]) 179 | transformer.encoder = reinit_autoencoder_model( 180 | transformer.encoder, reinit_num_layers=config["reinit_num_layers"] 181 | ) 182 | 183 | model = PoolerClassifier( 184 | transformer=transformer, 185 | transformer_output_size=transformer.config.hidden_size, 186 | transformer_output_dropout_prob=config["dropout_prob"], 187 | num_classes=config["num_classes"], 188 | task_specific_layer_seed=config["layer_initialization_seed"], 189 | ) 190 | 191 | device = torch.device(config["device_name"]) 192 | _ = model.to(device) 193 | 194 | # Define loss 195 | loss_fn = nn.BCEWithLogitsLoss() 196 | 197 | # Create optimizer 198 | if config["apply_llrd"]: 199 | model_parameters = get_optimizer_parameters_with_llrd( 200 | model=model, 201 | peak_lr=config["lr"], 202 | multiplicative_factor=config["multiplicative_factor"], 203 | ) 204 | else: 205 | model_parameters = model.parameters() 206 | optimizer = AdamW(params=model_parameters, lr=config["lr"]) 207 | 208 | # Create scheduler 209 | num_training_steps = config["num_epochs"] * len(train_dataloader) 210 | num_warmup_steps = num_training_steps // 10 211 | logger.info(f"Number of training steps: {num_training_steps}") 212 | logger.info(f"Number of warmup steps: {num_warmup_steps}") 213 | 214 | scheduler = get_scheduler( 215 | name="linear", 216 | optimizer=optimizer, 217 | num_warmup_steps=num_warmup_steps, 218 | num_training_steps=num_training_steps, 219 | ) 220 | 221 | # Add dropout seed 222 | seed_torch(config["dropout_seed"]) 223 | 224 | # Start training 225 | iteration_num = 0 226 | for epoch in range(config["num_epochs"]): 227 | for batch in train_dataloader: 228 | batch_inputs = tokenizer( 229 | text=batch["text_excerpt"], 230 | padding=True, 231 | truncation=True, 232 | return_tensors="pt", 233 | ).to(device) 234 | batch_targets = batch["label"].to(device) 235 | train_outputs = train_step( 236 | model=model, 237 | inputs=batch_inputs, 238 | targets=batch_targets, 239 | loss_fn=loss_fn, 240 | optimizer=optimizer, 241 | scheduler=scheduler, 242 | ) 243 | if iteration_num % config["validate_every_n_iteration"] == 0: 244 | valid_targets, valid_predictions = [], [] 245 | for batch in valid_dataloader: 246 | batch_inputs = tokenizer( 247 | text=batch["text_excerpt"], 248 | padding=True, 249 | truncation=True, 250 | return_tensors="pt", 251 | ).to(device) 252 | batch_targets = batch["label"].to(device) 253 | valid_outputs = evaluate_step( 254 | model=model, 255 | inputs=batch_inputs, 256 | targets=batch_targets, 257 | loss_fn=loss_fn, 258 | ) 259 | valid_targets.extend(valid_outputs["targets"]) 260 | valid_predictions.extend(valid_outputs["predictions"]) 261 | valid_targets = torch.vstack(valid_targets) 262 | valid_predictions = torch.vstack(valid_predictions) 263 | valid_loss = loss_fn(valid_predictions, valid_targets) 264 | valid_targets = post_process_targets(valid_targets) 265 | valid_predictions = post_process_predictions(valid_predictions) 266 | valid_score = compute_matthews_corrcoef(targets=valid_targets, predictions=valid_predictions) 267 | logger.info(f"Iteration num: {iteration_num}, Train loss: {train_outputs['loss']}") 268 | logger.info(f"Iteration num: {iteration_num}, Valid loss: {valid_loss}, Valid score: {valid_score}") 269 | iteration_num += 1 270 | 271 | 272 | if __name__ == "__main__": 273 | run() 274 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /examples/disaster-tweets-stabilizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "## Imports" 7 | ], 8 | "metadata": {} 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "source": [ 14 | "from stabilizer.llrd import get_optimizer_parameters_with_llrd\n", 15 | "from stabilizer.reinitialize import reinit_autoencoder_model\n", 16 | "from stabilizer.trainer import train_step,evaluate_step\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "from transformers import AutoModel,AutoTokenizer\n", 19 | "from stabilizer.dataset import TextLabelDataset\n", 20 | "from torch.utils.data import DataLoader,Dataset\n", 21 | "from stabilizer.model import PoolerClassifier\n", 22 | "from transformers import get_scheduler\n", 23 | "from sklearn.metrics import f1_score\n", 24 | "from torch.optim import AdamW\n", 25 | "from torch import nn\n", 26 | "import pandas as pd\n", 27 | "import numpy as np\n", 28 | "import random\n", 29 | "import torch\n" 30 | ], 31 | "outputs": [ 32 | { 33 | "output_type": "stream", 34 | "name": "stderr", 35 | "text": [ 36 | "2021-09-27 09:48:57.887343: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0\n" 37 | ] 38 | } 39 | ], 40 | "metadata": { 41 | "execution": { 42 | "iopub.execute_input": "2021-09-27T09:48:55.519700Z", 43 | "iopub.status.busy": "2021-09-27T09:48:55.519062Z", 44 | "iopub.status.idle": "2021-09-27T09:49:02.035181Z", 45 | "shell.execute_reply": "2021-09-27T09:49:02.034101Z", 46 | "shell.execute_reply.started": "2021-09-27T07:20:23.599516Z" 47 | }, 48 | "papermill": { 49 | "duration": 6.54162, 50 | "end_time": "2021-09-27T09:49:02.035359", 51 | "exception": false, 52 | "start_time": "2021-09-27T09:48:55.493739", 53 | "status": "completed" 54 | }, 55 | "tags": [] 56 | } 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "source": [ 61 | "## Setting SEED" 62 | ], 63 | "metadata": {} 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "source": [ 69 | "\n", 70 | "def seed_everything(seed):\n", 71 | " random.seed(seed)\n", 72 | " os.environ['PYTHONASSEED'] = str(seed)\n", 73 | " np.random.seed(seed)\n", 74 | " torch.manual_seed(seed)\n", 75 | " torch.cuda.manual_seed(seed)\n", 76 | " torch.cuda.manual_seed_all(seed)\n", 77 | " torch.backends.cudnn.deterministic = True\n", 78 | " torch.backends.cudnn.benchmark = True" 79 | ], 80 | "outputs": [], 81 | "metadata": { 82 | "execution": { 83 | "iopub.execute_input": "2021-09-27T09:49:02.071943Z", 84 | "iopub.status.busy": "2021-09-27T09:49:02.069729Z", 85 | "iopub.status.idle": "2021-09-27T09:49:02.072706Z", 86 | "shell.execute_reply": "2021-09-27T09:49:02.073215Z", 87 | "shell.execute_reply.started": "2021-09-27T07:20:28.823549Z" 88 | }, 89 | "papermill": { 90 | "duration": 0.023611, 91 | "end_time": "2021-09-27T09:49:02.073361", 92 | "exception": false, 93 | "start_time": "2021-09-27T09:49:02.049750", 94 | "status": "completed" 95 | }, 96 | "tags": [] 97 | } 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "source": [ 102 | "## Training configurations" 103 | ], 104 | "metadata": {} 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "source": [ 110 | "config = {\n", 111 | " 'pretrained_model':'roberta-base',\n", 112 | " 'num_classes':1,\n", 113 | " 'batch_size':32,\n", 114 | " 'device_name':torch.device('cuda'),\n", 115 | " 'lr':1e-5,\n", 116 | " 'mutliplicative_lr':0.95,\n", 117 | " 'llrd':False,\n", 118 | " 'reinit':2,\n", 119 | " 'epochs':3,\n", 120 | " 'valid_every':15,\n", 121 | " 'scheduler':'linear',\n", 122 | " 'seed':1000\n", 123 | "}\n", 124 | "\n", 125 | "seed_everything(config['seed'])" 126 | ], 127 | "outputs": [], 128 | "metadata": { 129 | "execution": { 130 | "iopub.execute_input": "2021-09-27T09:49:02.113687Z", 131 | "iopub.status.busy": "2021-09-27T09:49:02.112553Z", 132 | "iopub.status.idle": "2021-09-27T09:49:02.536091Z", 133 | "shell.execute_reply": "2021-09-27T09:49:02.535419Z", 134 | "shell.execute_reply.started": "2021-09-27T07:20:31.309831Z" 135 | }, 136 | "papermill": { 137 | "duration": 0.448079, 138 | "end_time": "2021-09-27T09:49:02.536387", 139 | "exception": true, 140 | "start_time": "2021-09-27T09:49:02.088308", 141 | "status": "failed" 142 | }, 143 | "tags": [] 144 | } 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "source": [ 149 | "## Read data" 150 | ], 151 | "metadata": {} 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "source": [ 157 | "train_df = pd.read_csv(\"../input/nlp-getting-started/train.csv\")\n", 158 | "test_df = pd.read_csv(\"../input/nlp-getting-started/test.csv\")" 159 | ], 160 | "outputs": [], 161 | "metadata": { 162 | "execution": { 163 | "iopub.execute_input": "2021-09-27T07:20:45.899662Z", 164 | "iopub.status.busy": "2021-09-27T07:20:45.899198Z", 165 | "iopub.status.idle": "2021-09-27T07:20:45.99456Z", 166 | "shell.execute_reply": "2021-09-27T07:20:45.99378Z", 167 | "shell.execute_reply.started": "2021-09-27T07:20:45.899619Z" 168 | }, 169 | "papermill": { 170 | "duration": null, 171 | "end_time": null, 172 | "exception": null, 173 | "start_time": null, 174 | "status": "pending" 175 | }, 176 | "tags": [] 177 | } 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "source": [ 183 | "print(f\"train dataset has {train_df.shape[0]} samples\")\n", 184 | "train_df.head(5)" 185 | ], 186 | "outputs": [], 187 | "metadata": { 188 | "execution": { 189 | "iopub.execute_input": "2021-09-27T07:20:46.000825Z", 190 | "iopub.status.busy": "2021-09-27T07:20:45.998807Z", 191 | "iopub.status.idle": "2021-09-27T07:20:46.03393Z", 192 | "shell.execute_reply": "2021-09-27T07:20:46.033253Z", 193 | "shell.execute_reply.started": "2021-09-27T07:20:46.000787Z" 194 | }, 195 | "papermill": { 196 | "duration": null, 197 | "end_time": null, 198 | "exception": null, 199 | "start_time": null, 200 | "status": "pending" 201 | }, 202 | "tags": [] 203 | } 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "source": [ 208 | "## Data preparation" 209 | ], 210 | "metadata": {} 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "source": [ 216 | "train,valid = train_test_split(train_df,test_size=0.2,stratify=train_df['target'])" 217 | ], 218 | "outputs": [], 219 | "metadata": { 220 | "execution": { 221 | "iopub.execute_input": "2021-09-27T07:20:46.040371Z", 222 | "iopub.status.busy": "2021-09-27T07:20:46.038407Z", 223 | "iopub.status.idle": "2021-09-27T07:20:46.064552Z", 224 | "shell.execute_reply": "2021-09-27T07:20:46.06369Z", 225 | "shell.execute_reply.started": "2021-09-27T07:20:46.040334Z" 226 | }, 227 | "papermill": { 228 | "duration": null, 229 | "end_time": null, 230 | "exception": null, 231 | "start_time": null, 232 | "status": "pending" 233 | }, 234 | "tags": [] 235 | } 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "source": [ 241 | "# Prepate data to create dataset\n", 242 | "train_tweets = train_df['text'].tolist()\n", 243 | "valid_tweets = train_df['text'].tolist()\n", 244 | "train_targets = torch.from_numpy(train_df['target'].to_numpy().reshape(-1, 1)).type(torch.float32)\n", 245 | "valid_targets = torch.from_numpy(train_df['target'].to_numpy().reshape(-1, 1)).type(torch.float32)" 246 | ], 247 | "outputs": [], 248 | "metadata": { 249 | "execution": { 250 | "iopub.execute_input": "2021-09-27T07:20:46.070923Z", 251 | "iopub.status.busy": "2021-09-27T07:20:46.068631Z", 252 | "iopub.status.idle": "2021-09-27T07:20:46.081818Z", 253 | "shell.execute_reply": "2021-09-27T07:20:46.080783Z", 254 | "shell.execute_reply.started": "2021-09-27T07:20:46.070887Z" 255 | }, 256 | "papermill": { 257 | "duration": null, 258 | "end_time": null, 259 | "exception": null, 260 | "start_time": null, 261 | "status": "pending" 262 | }, 263 | "tags": [] 264 | } 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "source": [ 270 | "# Create Dataset and DataLoader\n", 271 | "train_dataset = TextLabelDataset(text_excerpts=train_tweets, labels=train_targets)\n", 272 | "valid_dataset = TextLabelDataset(text_excerpts=valid_tweets, labels=valid_targets)\n", 273 | "\n", 274 | "train_dataloader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True)\n", 275 | "valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=config['batch_size'], shuffle=False)\n" 276 | ], 277 | "outputs": [], 278 | "metadata": { 279 | "execution": { 280 | "iopub.execute_input": "2021-09-27T07:20:46.090211Z", 281 | "iopub.status.busy": "2021-09-27T07:20:46.087395Z", 282 | "iopub.status.idle": "2021-09-27T07:20:46.098858Z", 283 | "shell.execute_reply": "2021-09-27T07:20:46.097786Z", 284 | "shell.execute_reply.started": "2021-09-27T07:20:46.090162Z" 285 | }, 286 | "papermill": { 287 | "duration": null, 288 | "end_time": null, 289 | "exception": null, 290 | "start_time": null, 291 | "status": "pending" 292 | }, 293 | "tags": [] 294 | } 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "source": [ 299 | "## Loss function and competition metric" 300 | ], 301 | "metadata": {} 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "source": [ 307 | "loss_fn = nn.BCEWithLogitsLoss()\n", 308 | "metric = f1_score\n" 309 | ], 310 | "outputs": [], 311 | "metadata": { 312 | "execution": { 313 | "iopub.execute_input": "2021-09-27T07:20:46.107595Z", 314 | "iopub.status.busy": "2021-09-27T07:20:46.104697Z", 315 | "iopub.status.idle": "2021-09-27T07:20:46.11443Z", 316 | "shell.execute_reply": "2021-09-27T07:20:46.11349Z", 317 | "shell.execute_reply.started": "2021-09-27T07:20:46.107554Z" 318 | }, 319 | "papermill": { 320 | "duration": null, 321 | "end_time": null, 322 | "exception": null, 323 | "start_time": null, 324 | "status": "pending" 325 | }, 326 | "tags": [] 327 | } 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "source": [ 332 | "## Train and Evaluate" 333 | ], 334 | "metadata": {} 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "source": [ 340 | "def train_and_eval(train_dataloader,valid_dataloader):\n", 341 | " \n", 342 | " tokenizer = AutoTokenizer.from_pretrained(config['pretrained_model'])\n", 343 | " transformer = AutoModel.from_pretrained(config['pretrained_model'])\n", 344 | " model = PoolerClassifier(transformer=transformer,\n", 345 | " transformer_output_size=transformer.config.hidden_size,\n", 346 | " transformer_output_dropout_prob=transformer.config.hidden_dropout_prob,\n", 347 | " num_classes=config['num_classes']\n", 348 | " )\n", 349 | " device = torch.device(config['device_name'])\n", 350 | " _ = model.to(device)\n", 351 | " \n", 352 | " \n", 353 | " if config['llrd']:\n", 354 | " parameters = get_optimizer_parameters_with_llrd(model,config['lr'],config['multiplicative_lr'])\n", 355 | " \n", 356 | " else:\n", 357 | " no_decay = ['bias','layerNorm.weight']\n", 358 | " parameters = [{'params':[p for n,p in model.named_parameters() if not any(k in n for k in no_decay)],\n", 359 | " 'weight_decay':0.01,'lr':config['lr']},\n", 360 | " {'params':[p for n,p in model.named_parameters() if any(k in n for k in no_decay)],\n", 361 | " 'weight_decay':0.00,'lr':config['lr']}]\n", 362 | " \n", 363 | " if config['reinit']:\n", 364 | " model = reinit_autoencoder_model(model,config['reinit'])\n", 365 | " \n", 366 | " optimizer = AdamW(parameters)\n", 367 | "\n", 368 | " num_training_steps = config['epochs'] * len(train_dataloader)\n", 369 | " scheduler = get_scheduler(name=config['scheduler'],num_training_steps=num_training_steps,\n", 370 | " num_warmup_steps=int(0.1*num_training_steps),optimizer=optimizer)\n", 371 | " num_iter=0\n", 372 | " for epoch in range(config['epochs']):\n", 373 | " train_f1,train_loss = 0.0,0.0\n", 374 | " for batch in train_dataloader:\n", 375 | " inputs = tokenizer(batch['text_excerpt'],padding=True, truncation=True,return_tensors='pt').to(config['device_name'])\n", 376 | " targets = batch['label'].to(config['device_name'])\n", 377 | " train_outputs = train_step(model=model, inputs=inputs, targets=targets, loss_fn=loss_fn, optimizer=optimizer,\n", 378 | " scheduler=scheduler)\n", 379 | " #train_f1 += metric(targets.detach().cpu().numpy(),train_outputs['predictions'].detach().cpu().numpy())\n", 380 | " train_loss += train_outputs['loss']\n", 381 | " if num_iter % config['valid_every'] == 0:\n", 382 | " valid_f1,valid_loss = 0.0,0.0\n", 383 | " for valid_batch in valid_dataloader:\n", 384 | " inputs = tokenizer(valid_batch['text_excerpt'],padding=True, truncation=True,return_tensors='pt').to(config['device_name'])\n", 385 | " targets = valid_batch['label'].to(config['device_name'])\n", 386 | " valid_outputs = evaluate_step(model=model, inputs=inputs, targets=targets, loss_fn=loss_fn)\n", 387 | " predictions = (nn.Sigmoid()(valid_outputs['predictions'])).detach().cpu().numpy()\n", 388 | " \n", 389 | " valid_f1 += metric(targets.detach().cpu().numpy(),np.round(predictions),average='micro')\n", 390 | " valid_loss += valid_outputs['loss']\n", 391 | " print(\"validation f1 score\",valid_f1/len(valid_dataloader))\n", 392 | " print(\"validation loss\",valid_loss.item()/len(valid_dataloader))\n", 393 | " num_iter+=1\n", 394 | "\n", 395 | " \n", 396 | " print(f\"Train epoch {epoch} loss {train_loss.item()/len(train_dataloader)}\")\n", 397 | " #print(f\"Train epoch {epoch} f1 score {train_f1/len(train_dataloader)}\") " 398 | ], 399 | "outputs": [], 400 | "metadata": { 401 | "execution": { 402 | "iopub.execute_input": "2021-09-27T07:20:46.122399Z", 403 | "iopub.status.busy": "2021-09-27T07:20:46.120121Z", 404 | "iopub.status.idle": "2021-09-27T07:20:46.151613Z", 405 | "shell.execute_reply": "2021-09-27T07:20:46.15077Z", 406 | "shell.execute_reply.started": "2021-09-27T07:20:46.122364Z" 407 | }, 408 | "papermill": { 409 | "duration": null, 410 | "end_time": null, 411 | "exception": null, 412 | "start_time": null, 413 | "status": "pending" 414 | }, 415 | "tags": [] 416 | } 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "source": [ 422 | "train_and_eval(train_dataloader,valid_dataloader)" 423 | ], 424 | "outputs": [], 425 | "metadata": { 426 | "execution": { 427 | "iopub.execute_input": "2021-09-27T07:20:46.160515Z", 428 | "iopub.status.busy": "2021-09-27T07:20:46.1574Z" 429 | }, 430 | "papermill": { 431 | "duration": null, 432 | "end_time": null, 433 | "exception": null, 434 | "start_time": null, 435 | "status": "pending" 436 | }, 437 | "tags": [] 438 | } 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "source": [], 444 | "outputs": [], 445 | "metadata": { 446 | "papermill": { 447 | "duration": null, 448 | "end_time": null, 449 | "exception": null, 450 | "start_time": null, 451 | "status": "pending" 452 | }, 453 | "tags": [] 454 | } 455 | } 456 | ], 457 | "metadata": { 458 | "kernelspec": { 459 | "display_name": "Python 3", 460 | "language": "python", 461 | "name": "python3" 462 | }, 463 | "language_info": { 464 | "codemirror_mode": { 465 | "name": "ipython", 466 | "version": 3 467 | }, 468 | "file_extension": ".py", 469 | "mimetype": "text/x-python", 470 | "name": "python", 471 | "nbconvert_exporter": "python", 472 | "pygments_lexer": "ipython3", 473 | "version": "3.7.10" 474 | }, 475 | "papermill": { 476 | "default_parameters": {}, 477 | "duration": 162.769315, 478 | "end_time": "2021-09-27T09:49:05.713195", 479 | "environment_variables": {}, 480 | "exception": true, 481 | "input_path": "__notebook__.ipynb", 482 | "output_path": "__notebook__.ipynb", 483 | "parameters": {}, 484 | "start_time": "2021-09-27T09:46:22.943880", 485 | "version": "2.3.3" 486 | } 487 | }, 488 | "nbformat": 4, 489 | "nbformat_minor": 5 490 | } -------------------------------------------------------------------------------- /examples/cola.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "4711a347-ce50-4a05-a829-18ef64f17f15", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# Replace this with pip install stabilizer\n", 11 | "import sys\n", 12 | "sys.path.insert(0, '..')" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "6031bb52-3b5a-41ca-94c1-8adf03117872", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import torch\n", 23 | "import logging\n", 24 | "import numpy as np\n", 25 | "import pandas as pd\n", 26 | "\n", 27 | "from torch import nn\n", 28 | "from torch.utils.data import DataLoader\n", 29 | "from sklearn.metrics import matthews_corrcoef\n", 30 | "\n", 31 | "from stabilizer.model import PoolerClassifier\n", 32 | "from stabilizer.dataset import TextLabelDataset\n", 33 | "from stabilizer.trainer import train_step, evaluate_step\n", 34 | "\n", 35 | "from transformers import get_scheduler, AdamW, AutoModel, AutoTokenizer" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "id": "1836f0fe-1995-410e-83e3-c54061ce3aa6", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "logging.basicConfig(\n", 46 | " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n", 47 | " datefmt=\"%m/%d/%Y %H:%M:%S\",\n", 48 | " level=logging.INFO,\n", 49 | ")\n", 50 | "logger = logging.getLogger(__name__)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "id": "cab9eb42-2a5f-43c9-ac47-dad2cfbb0184", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def post_process_targets(targets):\n", 61 | " targets = targets.type(torch.int)\n", 62 | " targets = targets.cpu().detach().numpy().reshape(-1)\n", 63 | " return targets\n", 64 | "\n", 65 | "\n", 66 | "def post_process_predictions(predictions):\n", 67 | " predictions = torch.sigmoid(predictions)\n", 68 | " predictions = (predictions >= 0.5).type(torch.int)\n", 69 | " predictions = predictions.cpu().detach().numpy().reshape(-1)\n", 70 | " return predictions\n", 71 | "\n", 72 | "\n", 73 | "def compute_matthews_corrcoef(targets, predictions):\n", 74 | " if len(np.unique(predictions)) > 1 and len(np.unique(targets)) > 1:\n", 75 | " score = matthews_corrcoef(y_true=targets, y_pred=predictions)\n", 76 | " else:\n", 77 | " score = 0.0\n", 78 | " return score" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 5, 84 | "id": "fac70052-c951-47cb-8bb7-8aeb07d34bfd", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "config = {'train_data_path': '../data/glue/cola/train.jsonl',\n", 89 | " 'valid_data_path': '../data/glue/cola/valid.jsonl',\n", 90 | " 'batch_size': 32,\n", 91 | " 'pretrained_tokenizer_name_or_path': '../models/bert-base-uncased/',\n", 92 | " 'pretrained_model_name_or_path': '../models/bert-base-uncased/',\n", 93 | " 'device_name': 'cpu',\n", 94 | " 'dropout_prob': 0.1,\n", 95 | " 'num_classes': 1,\n", 96 | " 'lr': 2e-5,\n", 97 | " 'num_epochs': 3,\n", 98 | " 'validate_every_n_iteration': 10}" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 6, 104 | "id": "fb4e29f2-94f4-4f2f-b00c-de359c9797cc", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "# Read training data\n", 109 | "train_data = pd.read_json(path_or_buf=config['train_data_path'], lines=True).set_index('idx')\n", 110 | "valid_data = pd.read_json(path_or_buf=config['valid_data_path'], lines=True).set_index('idx')" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 7, 116 | "id": "c3f58b6b-4a23-4da3-854c-23815f42a3d3", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# Show a small snippet and Give a small explanation of the data" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 8, 126 | "id": "3972aa0f-3dbb-46c0-b90b-d958a2b10530", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# Prepate data to create dataset\n", 131 | "train_text_excerpts = train_data['text'].tolist()\n", 132 | "valid_text_excerpts = valid_data['text'].tolist()\n", 133 | "train_labels = torch.from_numpy(train_data['label'].to_numpy().reshape(-1, 1)).type(torch.float32)\n", 134 | "valid_labels = torch.from_numpy(valid_data['label'].to_numpy().reshape(-1, 1)).type(torch.float32)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 9, 140 | "id": "6f232b62-0368-445b-a987-48636baf082d", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# Create Dataset and DataLoader\n", 145 | "train_dataset = TextLabelDataset(text_excerpts=train_text_excerpts, labels=train_labels)\n", 146 | "valid_dataset = TextLabelDataset(text_excerpts=valid_text_excerpts, labels=valid_labels)\n", 147 | "\n", 148 | "train_dataloader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True)\n", 149 | "valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=config['batch_size'], shuffle=False)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 10, 155 | "id": "92fe18da-411a-4b4c-8cfb-01a1afa070ee", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "# Create tokenizer and model\n", 160 | "tokenizer = AutoTokenizer.from_pretrained(config['pretrained_tokenizer_name_or_path'])\n", 161 | "transformer = AutoModel.from_pretrained(pretrained_model_name_or_path=config['pretrained_model_name_or_path'],\n", 162 | " hidden_dropout_prob=config['dropout_prob'],\n", 163 | " attention_probs_dropout_prob=config['dropout_prob'])\n", 164 | "model = PoolerClassifier(transformer=transformer,\n", 165 | " transformer_output_size=transformer.config.hidden_size,\n", 166 | " transformer_output_dropout_prob=config['dropout_prob'],\n", 167 | " num_classes=config['num_classes'])\n", 168 | "device = torch.device(config['device_name'])\n", 169 | "_ = model.to(device)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 11, 175 | "id": "2e5c40d6-11c8-463d-a9b7-494b8ba573e9", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "# Define loss\n", 180 | "loss_fn = nn.BCEWithLogitsLoss()\n", 181 | "\n", 182 | "# Create optimizer and scheduler\n", 183 | "model_parameters = model.parameters()\n", 184 | "optimizer = AdamW(params=model_parameters, lr=config['lr'])" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 12, 190 | "id": "7498e99d-979a-4467-b0bf-ade36bd39535", 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "name": "stderr", 195 | "output_type": "stream", 196 | "text": [ 197 | "09/18/2021 14:09:22 - INFO - __main__ - Number of training steps: 804\n", 198 | "09/18/2021 14:09:22 - INFO - __main__ - Number of warmup steps: 80\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "num_training_steps = config['num_epochs'] * len(train_dataloader)\n", 204 | "num_warmup_steps = num_training_steps // 10\n", 205 | "logger.info(f'Number of training steps: {num_training_steps}')\n", 206 | "logger.info(f'Number of warmup steps: {num_warmup_steps}')\n", 207 | "\n", 208 | "scheduler = get_scheduler(name='linear',\n", 209 | " optimizer=optimizer,\n", 210 | " num_warmup_steps=num_warmup_steps,\n", 211 | " num_training_steps=num_training_steps)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 13, 217 | "id": "90803ea3-fb0f-449b-938e-c010bdca46e7", 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 225 | "To disable this warning, you can either:\n", 226 | "\t- Avoid using `tokenizers` before the fork if possible\n", 227 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 228 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 229 | "To disable this warning, you can either:\n", 230 | "\t- Avoid using `tokenizers` before the fork if possible\n", 231 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 232 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 233 | "To disable this warning, you can either:\n", 234 | "\t- Avoid using `tokenizers` before the fork if possible\n", 235 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 236 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 237 | "To disable this warning, you can either:\n", 238 | "\t- Avoid using `tokenizers` before the fork if possible\n", 239 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 240 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 241 | "To disable this warning, you can either:\n", 242 | "\t- Avoid using `tokenizers` before the fork if possible\n", 243 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 244 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 245 | "To disable this warning, you can either:\n", 246 | "\t- Avoid using `tokenizers` before the fork if possible\n", 247 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 248 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 249 | "To disable this warning, you can either:\n", 250 | "\t- Avoid using `tokenizers` before the fork if possible\n", 251 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 252 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 253 | "To disable this warning, you can either:\n", 254 | "\t- Avoid using `tokenizers` before the fork if possible\n", 255 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 256 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 257 | "To disable this warning, you can either:\n", 258 | "\t- Avoid using `tokenizers` before the fork if possible\n", 259 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 260 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 261 | "To disable this warning, you can either:\n", 262 | "\t- Avoid using `tokenizers` before the fork if possible\n", 263 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" 264 | ] 265 | }, 266 | { 267 | "name": "stderr", 268 | "output_type": "stream", 269 | "text": [ 270 | "09/18/2021 14:09:58 - INFO - __main__ - Iteration num: 0, Train loss: 0.9233994483947754\n", 271 | "09/18/2021 14:09:58 - INFO - __main__ - Iteration num: 0, Valid loss: 0.8915260434150696, Valid score: -0.018148342420931135\n", 272 | "09/18/2021 14:11:06 - INFO - __main__ - Iteration num: 10, Train loss: 0.7844147682189941\n", 273 | "09/18/2021 14:11:06 - INFO - __main__ - Iteration num: 10, Valid loss: 0.8481836318969727, Valid score: -0.018148342420931135\n", 274 | "09/18/2021 14:12:29 - INFO - __main__ - Iteration num: 20, Train loss: 0.7572133541107178\n", 275 | "09/18/2021 14:12:29 - INFO - __main__ - Iteration num: 20, Valid loss: 0.7248084545135498, Valid score: 0.02069796861945717\n", 276 | "09/18/2021 14:13:36 - INFO - __main__ - Iteration num: 30, Train loss: 0.6081583499908447\n", 277 | "09/18/2021 14:13:36 - INFO - __main__ - Iteration num: 30, Valid loss: 0.6268794536590576, Valid score: 0.00286100001416597\n", 278 | "09/18/2021 14:14:41 - INFO - __main__ - Iteration num: 40, Train loss: 0.5490421056747437\n", 279 | "09/18/2021 14:14:41 - INFO - __main__ - Iteration num: 40, Valid loss: 0.6158878207206726, Valid score: 0.0\n", 280 | "09/18/2021 14:15:47 - INFO - __main__ - Iteration num: 50, Train loss: 0.6187964677810669\n", 281 | "09/18/2021 14:15:47 - INFO - __main__ - Iteration num: 50, Valid loss: 0.6144923567771912, Valid score: 0.0\n", 282 | "09/18/2021 14:16:54 - INFO - __main__ - Iteration num: 60, Train loss: 0.6215643882751465\n", 283 | "09/18/2021 14:16:54 - INFO - __main__ - Iteration num: 60, Valid loss: 0.6129640936851501, Valid score: 0.0\n", 284 | "09/18/2021 14:18:00 - INFO - __main__ - Iteration num: 70, Train loss: 0.6671042442321777\n", 285 | "09/18/2021 14:18:00 - INFO - __main__ - Iteration num: 70, Valid loss: 0.6125308871269226, Valid score: 0.0\n", 286 | "09/18/2021 14:19:06 - INFO - __main__ - Iteration num: 80, Train loss: 0.49621421098709106\n", 287 | "09/18/2021 14:19:06 - INFO - __main__ - Iteration num: 80, Valid loss: 0.6097213625907898, Valid score: 0.0\n", 288 | "09/18/2021 14:20:17 - INFO - __main__ - Iteration num: 90, Train loss: 0.6173084378242493\n", 289 | "09/18/2021 14:20:17 - INFO - __main__ - Iteration num: 90, Valid loss: 0.601161003112793, Valid score: 0.0\n", 290 | "09/18/2021 14:21:25 - INFO - __main__ - Iteration num: 100, Train loss: 0.5134186148643494\n", 291 | "09/18/2021 14:21:25 - INFO - __main__ - Iteration num: 100, Valid loss: 0.5948840379714966, Valid score: 0.0\n", 292 | "09/18/2021 14:22:33 - INFO - __main__ - Iteration num: 110, Train loss: 0.5372046232223511\n", 293 | "09/18/2021 14:22:33 - INFO - __main__ - Iteration num: 110, Valid loss: 0.5822965502738953, Valid score: 0.0\n", 294 | "09/18/2021 14:23:42 - INFO - __main__ - Iteration num: 120, Train loss: 0.49894315004348755\n", 295 | "09/18/2021 14:23:42 - INFO - __main__ - Iteration num: 120, Valid loss: 0.5462360978126526, Valid score: 0.06558874629318973\n", 296 | "09/18/2021 14:24:50 - INFO - __main__ - Iteration num: 130, Train loss: 0.500604510307312\n", 297 | "09/18/2021 14:24:50 - INFO - __main__ - Iteration num: 130, Valid loss: 0.5303254723548889, Valid score: 0.3420819161583976\n", 298 | "09/18/2021 14:25:57 - INFO - __main__ - Iteration num: 140, Train loss: 0.48256492614746094\n", 299 | "09/18/2021 14:25:57 - INFO - __main__ - Iteration num: 140, Valid loss: 0.4842092990875244, Valid score: 0.4820573358802571\n", 300 | "09/18/2021 14:27:08 - INFO - __main__ - Iteration num: 150, Train loss: 0.573493480682373\n", 301 | "09/18/2021 14:27:08 - INFO - __main__ - Iteration num: 150, Valid loss: 0.48082056641578674, Valid score: 0.4518417226348695\n", 302 | "09/18/2021 14:28:17 - INFO - __main__ - Iteration num: 160, Train loss: 0.6603621244430542\n", 303 | "09/18/2021 14:28:17 - INFO - __main__ - Iteration num: 160, Valid loss: 0.5515894889831543, Valid score: 0.30352336473584907\n", 304 | "09/18/2021 14:29:24 - INFO - __main__ - Iteration num: 170, Train loss: 0.529253363609314\n", 305 | "09/18/2021 14:29:24 - INFO - __main__ - Iteration num: 170, Valid loss: 0.4677261710166931, Valid score: 0.4783289990633755\n", 306 | "09/18/2021 14:30:35 - INFO - __main__ - Iteration num: 180, Train loss: 0.5332573056221008\n", 307 | "09/18/2021 14:30:35 - INFO - __main__ - Iteration num: 180, Valid loss: 0.49773919582366943, Valid score: 0.4626168523165551\n", 308 | "09/18/2021 14:31:40 - INFO - __main__ - Iteration num: 190, Train loss: 0.5212928652763367\n", 309 | "09/18/2021 14:31:40 - INFO - __main__ - Iteration num: 190, Valid loss: 0.4595946669578552, Valid score: 0.5062636834188224\n", 310 | "09/18/2021 14:32:49 - INFO - __main__ - Iteration num: 200, Train loss: 0.44723019003868103\n", 311 | "09/18/2021 14:32:49 - INFO - __main__ - Iteration num: 200, Valid loss: 0.515070915222168, Valid score: 0.46936417904934724\n", 312 | "09/18/2021 14:33:58 - INFO - __main__ - Iteration num: 210, Train loss: 0.49733683466911316\n", 313 | "09/18/2021 14:33:58 - INFO - __main__ - Iteration num: 210, Valid loss: 0.4593220353126526, Valid score: 0.5078785333305506\n", 314 | "09/18/2021 14:35:08 - INFO - __main__ - Iteration num: 220, Train loss: 0.32779011130332947\n", 315 | "09/18/2021 14:35:08 - INFO - __main__ - Iteration num: 220, Valid loss: 0.46277913451194763, Valid score: 0.466399008874559\n", 316 | "09/18/2021 14:36:17 - INFO - __main__ - Iteration num: 230, Train loss: 0.5029074549674988\n", 317 | "09/18/2021 14:36:17 - INFO - __main__ - Iteration num: 230, Valid loss: 0.47985273599624634, Valid score: 0.4347815008814297\n", 318 | "09/18/2021 14:37:23 - INFO - __main__ - Iteration num: 240, Train loss: 0.6127015352249146\n", 319 | "09/18/2021 14:37:23 - INFO - __main__ - Iteration num: 240, Valid loss: 0.4389609694480896, Valid score: 0.5032312800977334\n", 320 | "09/18/2021 14:38:30 - INFO - __main__ - Iteration num: 250, Train loss: 0.4096861481666565\n", 321 | "09/18/2021 14:38:30 - INFO - __main__ - Iteration num: 250, Valid loss: 0.47394126653671265, Valid score: 0.46652373409575754\n", 322 | "09/18/2021 14:39:37 - INFO - __main__ - Iteration num: 260, Train loss: 0.4090423285961151\n", 323 | "09/18/2021 14:39:37 - INFO - __main__ - Iteration num: 260, Valid loss: 0.46983644366264343, Valid score: 0.46763521355858595\n", 324 | "09/18/2021 14:40:45 - INFO - __main__ - Iteration num: 270, Train loss: 0.46004629135131836\n", 325 | "09/18/2021 14:40:45 - INFO - __main__ - Iteration num: 270, Valid loss: 0.4244503378868103, Valid score: 0.5286324175580216\n", 326 | "09/18/2021 14:41:55 - INFO - __main__ - Iteration num: 280, Train loss: 0.2546089291572571\n", 327 | "09/18/2021 14:41:55 - INFO - __main__ - Iteration num: 280, Valid loss: 0.47313734889030457, Valid score: 0.5108235781406687\n", 328 | "09/18/2021 14:43:04 - INFO - __main__ - Iteration num: 290, Train loss: 0.33392155170440674\n", 329 | "09/18/2021 14:43:04 - INFO - __main__ - Iteration num: 290, Valid loss: 0.4313860237598419, Valid score: 0.5604772031378175\n", 330 | "09/18/2021 14:44:12 - INFO - __main__ - Iteration num: 300, Train loss: 0.4947934150695801\n", 331 | "09/18/2021 14:44:12 - INFO - __main__ - Iteration num: 300, Valid loss: 0.47867104411125183, Valid score: 0.5126228485857701\n", 332 | "09/18/2021 14:45:21 - INFO - __main__ - Iteration num: 310, Train loss: 0.33134156465530396\n", 333 | "09/18/2021 14:45:21 - INFO - __main__ - Iteration num: 310, Valid loss: 0.4790351390838623, Valid score: 0.4913288678758369\n", 334 | "09/18/2021 14:46:27 - INFO - __main__ - Iteration num: 320, Train loss: 0.21783889830112457\n", 335 | "09/18/2021 14:46:27 - INFO - __main__ - Iteration num: 320, Valid loss: 0.4453355371952057, Valid score: 0.510581671760424\n", 336 | "09/18/2021 14:47:38 - INFO - __main__ - Iteration num: 330, Train loss: 0.3152202367782593\n", 337 | "09/18/2021 14:47:38 - INFO - __main__ - Iteration num: 330, Valid loss: 0.4633635878562927, Valid score: 0.5132479296361289\n", 338 | "09/18/2021 14:48:46 - INFO - __main__ - Iteration num: 340, Train loss: 0.27159416675567627\n", 339 | "09/18/2021 14:48:46 - INFO - __main__ - Iteration num: 340, Valid loss: 0.4482996463775635, Valid score: 0.51909383850845\n", 340 | "09/18/2021 14:49:50 - INFO - __main__ - Iteration num: 350, Train loss: 0.4444682002067566\n", 341 | "09/18/2021 14:49:50 - INFO - __main__ - Iteration num: 350, Valid loss: 0.4257673919200897, Valid score: 0.5556088865196797\n", 342 | "09/18/2021 14:50:54 - INFO - __main__ - Iteration num: 360, Train loss: 0.27038827538490295\n", 343 | "09/18/2021 14:50:54 - INFO - __main__ - Iteration num: 360, Valid loss: 0.48734891414642334, Valid score: 0.4902718526664373\n", 344 | "09/18/2021 14:52:02 - INFO - __main__ - Iteration num: 370, Train loss: 0.28203272819519043\n", 345 | "09/18/2021 14:52:02 - INFO - __main__ - Iteration num: 370, Valid loss: 0.4324546456336975, Valid score: 0.537031102939668\n", 346 | "09/18/2021 14:53:12 - INFO - __main__ - Iteration num: 380, Train loss: 0.42943158745765686\n", 347 | "09/18/2021 14:53:12 - INFO - __main__ - Iteration num: 380, Valid loss: 0.4202009439468384, Valid score: 0.5567273065308361\n", 348 | "09/18/2021 14:54:20 - INFO - __main__ - Iteration num: 390, Train loss: 0.23938967287540436\n", 349 | "09/18/2021 14:54:20 - INFO - __main__ - Iteration num: 390, Valid loss: 0.4735446870326996, Valid score: 0.5021501283308087\n", 350 | "09/18/2021 14:55:29 - INFO - __main__ - Iteration num: 400, Train loss: 0.28121834993362427\n", 351 | "09/18/2021 14:55:29 - INFO - __main__ - Iteration num: 400, Valid loss: 0.43185245990753174, Valid score: 0.5474865115851942\n", 352 | "09/18/2021 14:56:39 - INFO - __main__ - Iteration num: 410, Train loss: 0.6848282814025879\n", 353 | "09/18/2021 14:56:39 - INFO - __main__ - Iteration num: 410, Valid loss: 0.49203166365623474, Valid score: 0.5179780196184617\n", 354 | "09/18/2021 14:57:46 - INFO - __main__ - Iteration num: 420, Train loss: 0.43221786618232727\n", 355 | "09/18/2021 14:57:46 - INFO - __main__ - Iteration num: 420, Valid loss: 0.4674992561340332, Valid score: 0.5234928415614652\n", 356 | "09/18/2021 14:58:53 - INFO - __main__ - Iteration num: 430, Train loss: 0.46589425206184387\n", 357 | "09/18/2021 14:58:53 - INFO - __main__ - Iteration num: 430, Valid loss: 0.4731331765651703, Valid score: 0.5077854223436783\n", 358 | "09/18/2021 15:00:00 - INFO - __main__ - Iteration num: 440, Train loss: 0.28378668427467346\n", 359 | "09/18/2021 15:00:00 - INFO - __main__ - Iteration num: 440, Valid loss: 0.3970543444156647, Valid score: 0.5943843817638454\n", 360 | "09/18/2021 15:01:15 - INFO - __main__ - Iteration num: 450, Train loss: 0.2512344717979431\n", 361 | "09/18/2021 15:01:15 - INFO - __main__ - Iteration num: 450, Valid loss: 0.4284090995788574, Valid score: 0.5391948418977317\n", 362 | "09/18/2021 15:02:23 - INFO - __main__ - Iteration num: 460, Train loss: 0.20361068844795227\n", 363 | "09/18/2021 15:02:23 - INFO - __main__ - Iteration num: 460, Valid loss: 0.421588271856308, Valid score: 0.5496311083440725\n", 364 | "09/18/2021 15:03:34 - INFO - __main__ - Iteration num: 470, Train loss: 0.3156574070453644\n", 365 | "09/18/2021 15:03:34 - INFO - __main__ - Iteration num: 470, Valid loss: 0.4350147545337677, Valid score: 0.5339501336198276\n", 366 | "09/18/2021 15:04:43 - INFO - __main__ - Iteration num: 480, Train loss: 0.32582053542137146\n", 367 | "09/18/2021 15:04:43 - INFO - __main__ - Iteration num: 480, Valid loss: 0.47398841381073, Valid score: 0.5161596159953151\n", 368 | "09/18/2021 15:05:57 - INFO - __main__ - Iteration num: 490, Train loss: 0.2616956830024719\n", 369 | "09/18/2021 15:05:57 - INFO - __main__ - Iteration num: 490, Valid loss: 0.42927974462509155, Valid score: 0.5661576952804931\n", 370 | "09/18/2021 15:07:10 - INFO - __main__ - Iteration num: 500, Train loss: 0.4593585729598999\n", 371 | "09/18/2021 15:07:10 - INFO - __main__ - Iteration num: 500, Valid loss: 0.44366273283958435, Valid score: 0.544509467622167\n", 372 | "09/18/2021 15:08:22 - INFO - __main__ - Iteration num: 510, Train loss: 0.24688488245010376\n", 373 | "09/18/2021 15:08:22 - INFO - __main__ - Iteration num: 510, Valid loss: 0.46253588795661926, Valid score: 0.5233331906511532\n", 374 | "09/18/2021 15:09:40 - INFO - __main__ - Iteration num: 520, Train loss: 0.457802951335907\n", 375 | "09/18/2021 15:09:40 - INFO - __main__ - Iteration num: 520, Valid loss: 0.451452374458313, Valid score: 0.5364214937932232\n", 376 | "09/18/2021 15:10:54 - INFO - __main__ - Iteration num: 530, Train loss: 0.5930119156837463\n", 377 | "09/18/2021 15:10:54 - INFO - __main__ - Iteration num: 530, Valid loss: 0.40512996912002563, Valid score: 0.5871330527374325\n", 378 | "09/18/2021 15:12:07 - INFO - __main__ - Iteration num: 540, Train loss: 0.29273754358291626\n", 379 | "09/18/2021 15:12:07 - INFO - __main__ - Iteration num: 540, Valid loss: 0.4500102400779724, Valid score: 0.5181984088997607\n", 380 | "09/18/2021 15:13:20 - INFO - __main__ - Iteration num: 550, Train loss: 0.13098815083503723\n", 381 | "09/18/2021 15:13:20 - INFO - __main__ - Iteration num: 550, Valid loss: 0.4583445191383362, Valid score: 0.5338774230813111\n", 382 | "09/18/2021 15:14:34 - INFO - __main__ - Iteration num: 560, Train loss: 0.15455476939678192\n", 383 | "09/18/2021 15:14:34 - INFO - __main__ - Iteration num: 560, Valid loss: 0.5114794969558716, Valid score: 0.523501779881147\n", 384 | "09/18/2021 15:15:49 - INFO - __main__ - Iteration num: 570, Train loss: 0.16482765972614288\n", 385 | "09/18/2021 15:15:49 - INFO - __main__ - Iteration num: 570, Valid loss: 0.4923681914806366, Valid score: 0.5649960305097885\n", 386 | "09/18/2021 15:16:55 - INFO - __main__ - Iteration num: 580, Train loss: 0.17552119493484497\n", 387 | "09/18/2021 15:16:55 - INFO - __main__ - Iteration num: 580, Valid loss: 0.4413779377937317, Valid score: 0.5992421641912468\n", 388 | "09/18/2021 15:18:06 - INFO - __main__ - Iteration num: 590, Train loss: 0.3282647728919983\n", 389 | "09/18/2021 15:18:06 - INFO - __main__ - Iteration num: 590, Valid loss: 0.46630725264549255, Valid score: 0.5899314497879129\n", 390 | "09/18/2021 15:19:22 - INFO - __main__ - Iteration num: 600, Train loss: 0.20983141660690308\n", 391 | "09/18/2021 15:19:22 - INFO - __main__ - Iteration num: 600, Valid loss: 0.507980465888977, Valid score: 0.5546576947515798\n", 392 | "09/18/2021 15:20:31 - INFO - __main__ - Iteration num: 610, Train loss: 0.1106065958738327\n", 393 | "09/18/2021 15:20:31 - INFO - __main__ - Iteration num: 610, Valid loss: 0.47413283586502075, Valid score: 0.5760780218044883\n", 394 | "09/18/2021 15:21:43 - INFO - __main__ - Iteration num: 620, Train loss: 0.1978704035282135\n", 395 | "09/18/2021 15:21:43 - INFO - __main__ - Iteration num: 620, Valid loss: 0.5025636553764343, Valid score: 0.544301235611677\n", 396 | "09/18/2021 15:22:54 - INFO - __main__ - Iteration num: 630, Train loss: 0.05193217843770981\n", 397 | "09/18/2021 15:22:54 - INFO - __main__ - Iteration num: 630, Valid loss: 0.5107348561286926, Valid score: 0.5364214937932232\n", 398 | "09/18/2021 15:24:10 - INFO - __main__ - Iteration num: 640, Train loss: 0.09921319782733917\n", 399 | "09/18/2021 15:24:10 - INFO - __main__ - Iteration num: 640, Valid loss: 0.533574104309082, Valid score: 0.5261452181661114\n", 400 | "09/18/2021 15:25:21 - INFO - __main__ - Iteration num: 650, Train loss: 0.20304258167743683\n", 401 | "09/18/2021 15:25:21 - INFO - __main__ - Iteration num: 650, Valid loss: 0.5007405281066895, Valid score: 0.5573424050983508\n", 402 | "09/18/2021 15:26:31 - INFO - __main__ - Iteration num: 660, Train loss: 0.27831006050109863\n", 403 | "09/18/2021 15:26:31 - INFO - __main__ - Iteration num: 660, Valid loss: 0.4885517358779907, Valid score: 0.5653604748370356\n", 404 | "09/18/2021 15:27:47 - INFO - __main__ - Iteration num: 670, Train loss: 0.39951032400131226\n", 405 | "09/18/2021 15:27:47 - INFO - __main__ - Iteration num: 670, Valid loss: 0.47766637802124023, Valid score: 0.5732046470010711\n", 406 | "09/18/2021 15:28:53 - INFO - __main__ - Iteration num: 680, Train loss: 0.24070191383361816\n", 407 | "09/18/2021 15:28:53 - INFO - __main__ - Iteration num: 680, Valid loss: 0.4810181260108948, Valid score: 0.5469587051515413\n", 408 | "09/18/2021 15:30:01 - INFO - __main__ - Iteration num: 690, Train loss: 0.3604811131954193\n", 409 | "09/18/2021 15:30:01 - INFO - __main__ - Iteration num: 690, Valid loss: 0.4702230989933014, Valid score: 0.554912808282685\n", 410 | "09/18/2021 15:31:11 - INFO - __main__ - Iteration num: 700, Train loss: 0.08144819736480713\n", 411 | "09/18/2021 15:31:11 - INFO - __main__ - Iteration num: 700, Valid loss: 0.47994065284729004, Valid score: 0.5575034099514093\n", 412 | "09/18/2021 15:32:18 - INFO - __main__ - Iteration num: 710, Train loss: 0.2222006767988205\n", 413 | "09/18/2021 15:32:18 - INFO - __main__ - Iteration num: 710, Valid loss: 0.49694788455963135, Valid score: 0.5598395777855655\n", 414 | "09/18/2021 15:33:23 - INFO - __main__ - Iteration num: 720, Train loss: 0.17715145647525787\n", 415 | "09/18/2021 15:33:23 - INFO - __main__ - Iteration num: 720, Valid loss: 0.500529944896698, Valid score: 0.5598239422810317\n", 416 | "09/18/2021 15:34:30 - INFO - __main__ - Iteration num: 730, Train loss: 0.05942519009113312\n", 417 | "09/18/2021 15:34:30 - INFO - __main__ - Iteration num: 730, Valid loss: 0.4944283664226532, Valid score: 0.5523181756655187\n", 418 | "09/18/2021 15:35:35 - INFO - __main__ - Iteration num: 740, Train loss: 0.16485676169395447\n", 419 | "09/18/2021 15:35:35 - INFO - __main__ - Iteration num: 740, Valid loss: 0.49027585983276367, Valid score: 0.5575034099514093\n", 420 | "09/18/2021 15:36:47 - INFO - __main__ - Iteration num: 750, Train loss: 0.13922181725502014\n", 421 | "09/18/2021 15:36:47 - INFO - __main__ - Iteration num: 750, Valid loss: 0.48830530047416687, Valid score: 0.5601977832642419\n", 422 | "09/18/2021 15:37:52 - INFO - __main__ - Iteration num: 760, Train loss: 0.0857522040605545\n", 423 | "09/18/2021 15:37:52 - INFO - __main__ - Iteration num: 760, Valid loss: 0.4938272535800934, Valid score: 0.547116568580723\n", 424 | "09/18/2021 15:38:59 - INFO - __main__ - Iteration num: 770, Train loss: 0.2457312047481537\n", 425 | "09/18/2021 15:38:59 - INFO - __main__ - Iteration num: 770, Valid loss: 0.49165835976600647, Valid score: 0.5550196600183235\n", 426 | "09/18/2021 15:40:14 - INFO - __main__ - Iteration num: 780, Train loss: 0.2179664671421051\n", 427 | "09/18/2021 15:40:14 - INFO - __main__ - Iteration num: 780, Valid loss: 0.4865102171897888, Valid score: 0.5654865997646326\n", 428 | "09/18/2021 15:41:27 - INFO - __main__ - Iteration num: 790, Train loss: 0.42092329263687134\n", 429 | "09/18/2021 15:41:27 - INFO - __main__ - Iteration num: 790, Valid loss: 0.48256343603134155, Valid score: 0.5680628967969402\n", 430 | "09/18/2021 15:42:37 - INFO - __main__ - Iteration num: 800, Train loss: 0.4516393840312958\n", 431 | "09/18/2021 15:42:37 - INFO - __main__ - Iteration num: 800, Valid loss: 0.48235321044921875, Valid score: 0.5680628967969402\n" 432 | ] 433 | } 434 | ], 435 | "source": [ 436 | "iteration_num = 0\n", 437 | "for epoch in range(config['num_epochs']):\n", 438 | " for batch in train_dataloader:\n", 439 | " batch_inputs = tokenizer(text=batch['text_excerpt'], padding=True, truncation=True, return_tensors='pt').to(device)\n", 440 | " batch_targets = batch['label'].to(device)\n", 441 | " train_outputs = train_step(model=model, inputs=batch_inputs, targets=batch_targets, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler)\n", 442 | " if iteration_num % config['validate_every_n_iteration'] == 0:\n", 443 | " valid_targets, valid_predictions = [], []\n", 444 | " for batch in valid_dataloader:\n", 445 | " batch_inputs = tokenizer(text=batch['text_excerpt'], padding=True, truncation=True, return_tensors='pt').to(device)\n", 446 | " batch_targets = batch['label'].to(device)\n", 447 | " valid_outputs = evaluate_step(model=model, inputs=batch_inputs, targets=batch_targets, loss_fn=loss_fn)\n", 448 | " valid_targets.extend(valid_outputs['targets'])\n", 449 | " valid_predictions.extend(valid_outputs['predictions'])\n", 450 | " valid_targets = torch.vstack(valid_targets)\n", 451 | " valid_predictions = torch.vstack(valid_predictions)\n", 452 | " valid_loss = loss_fn(valid_predictions, valid_targets)\n", 453 | " valid_targets = post_process_targets(valid_targets)\n", 454 | " valid_predictions = post_process_predictions(valid_predictions)\n", 455 | " valid_score = compute_matthews_corrcoef(targets=valid_targets, predictions=valid_predictions)\n", 456 | " logger.info(f\"Iteration num: {iteration_num}, Train loss: {train_outputs['loss']}\")\n", 457 | " logger.info(f\"Iteration num: {iteration_num}, Valid loss: {valid_loss}, Valid score: {valid_score}\")\n", 458 | " iteration_num += 1" 459 | ] 460 | } 461 | ], 462 | "metadata": { 463 | "kernelspec": { 464 | "display_name": "stabilizer", 465 | "language": "python", 466 | "name": "stabilizer" 467 | }, 468 | "language_info": { 469 | "codemirror_mode": { 470 | "name": "ipython", 471 | "version": 3 472 | }, 473 | "file_extension": ".py", 474 | "mimetype": "text/x-python", 475 | "name": "python", 476 | "nbconvert_exporter": "python", 477 | "pygments_lexer": "ipython3", 478 | "version": "3.7.10" 479 | } 480 | }, 481 | "nbformat": 4, 482 | "nbformat_minor": 5 483 | } 484 | --------------------------------------------------------------------------------