├── train ├── __init__.py ├── __pycache__ │ ├── data.cpython-36.pyc │ ├── optim.cpython-36.pyc │ ├── trainer.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── pretrain_data.cpython-36.pyc │ ├── pretrain_trainer.cpython-36.pyc │ ├── classification_data.cpython-36.pyc │ └── classification_trainer.cpython-36.pyc ├── .ipynb_checkpoints │ ├── train-checkpoint.py.gz │ ├── pretrain_trainer-checkpoint.py │ ├── classification_trainer-checkpoint.py │ ├── pretrain_data-checkpoint.py │ ├── optim-checkpoint.py │ └── classification_data-checkpoint.py ├── pretrain_trainer.py ├── classification_trainer.py ├── pretrain_data.py ├── optim.py └── classification_data.py ├── utils ├── __init__.py ├── utils.py ├── load_weights.py ├── file_utils.py └── tokenization.py ├── experiments ├── README.md ├── mrpc │ └── new │ │ └── events.out.tfevents.1556401093.jupyter-mcdanel └── pretrain │ ├── first-test │ └── events.out.tfevents.1556405292.jupyter-mcdanel │ └── test-finetuning-from-pretrained-weights │ └── events.out.tfevents.1556412794.jupyter-mcdanel ├── config ├── .ipynb_checkpoints │ ├── pretrain-checkpoint.json.gz │ ├── bert_configs-checkpoint.json.gz │ ├── finetune_mrpc-checkpoint.json.gz │ ├── pretrain-checkpoint.json │ ├── bert-large-uncased-checkpoint.json │ └── bert-base-uncased-checkpoint.json ├── finetune_mrpc.json ├── pretrain.json ├── bert-large-uncased.json └── bert-base-uncased.json ├── README.md ├── scripts ├── run-mrpc.sh ├── .ipynb_checkpoints │ ├── run-mrpc-checkpoint.sh │ └── pretrain-checkpoint.sh ├── pretrain.sh └── download-glue.py ├── models ├── __init__.py ├── common_layers.py ├── heads.py ├── transformer.py └── lightweight.py ├── .gitignore ├── pretrain.py └── classify.py /train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | A directory for output files/checkpoints/etc. -------------------------------------------------------------------------------- /train/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/optim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/optim.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /train/.ipynb_checkpoints/train-checkpoint.py.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/.ipynb_checkpoints/train-checkpoint.py.gz -------------------------------------------------------------------------------- /train/__pycache__/pretrain_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/pretrain_data.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/pretrain_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/pretrain_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/pretrain-checkpoint.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/config/.ipynb_checkpoints/pretrain-checkpoint.json.gz -------------------------------------------------------------------------------- /train/__pycache__/classification_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/classification_data.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/classification_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/train/__pycache__/classification_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/bert_configs-checkpoint.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/config/.ipynb_checkpoints/bert_configs-checkpoint.json.gz -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/finetune_mrpc-checkpoint.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/config/.ipynb_checkpoints/finetune_mrpc-checkpoint.json.gz -------------------------------------------------------------------------------- /experiments/mrpc/new/events.out.tfevents.1556401093.jupyter-mcdanel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/experiments/mrpc/new/events.out.tfevents.1556401093.jupyter-mcdanel -------------------------------------------------------------------------------- /config/finetune_mrpc.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 42, 3 | "batch_size": 32, 4 | "lr": 2e-5, 5 | "n_epochs": 3, 6 | "warmup": 0.1, 7 | "save_steps": 100, 8 | "total_steps": 345 9 | } -------------------------------------------------------------------------------- /config/pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 3431, 3 | "batch_size": 96, 4 | "lr": 1e-4, 5 | "n_epochs": 25, 6 | "warmup": 0.1, 7 | "save_steps": 10000, 8 | "total_steps": 1000000 9 | } -------------------------------------------------------------------------------- /experiments/pretrain/first-test/events.out.tfevents.1556405292.jupyter-mcdanel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/experiments/pretrain/first-test/events.out.tfevents.1556405292.jupyter-mcdanel -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/pretrain-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 3431, 3 | "batch_size": 96, 4 | "lr": 1e-4, 5 | "n_epochs": 25, 6 | "warmup": 0.1, 7 | "save_steps": 10000, 8 | "total_steps": 1000000 9 | } -------------------------------------------------------------------------------- /config/bert-large-uncased.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 1024, 3 | "dim_ff": 4096, 4 | "n_layers": 24, 5 | "n_heads": 16, 6 | "p_drop_attn": 0.1, 7 | "p_drop_hidden": 0.1, 8 | "max_len": 512, 9 | "vocab_size": 30522 10 | } 11 | -------------------------------------------------------------------------------- /config/bert-base-uncased.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 768, 3 | "dim_ff": 3072, 4 | "n_layers": 12, 5 | "p_drop_attn": 0.1, 6 | "n_heads": 12, 7 | "p_drop_hidden": 0.1, 8 | "max_len": 512, 9 | "n_segments": 2, 10 | "vocab_size": 30522 11 | } -------------------------------------------------------------------------------- /experiments/pretrain/test-finetuning-from-pretrained-weights/events.out.tfevents.1556412794.jupyter-mcdanel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/simple-bert/HEAD/experiments/pretrain/test-finetuning-from-pretrained-weights/events.out.tfevents.1556412794.jupyter-mcdanel -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/bert-large-uncased-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 1024, 3 | "dim_ff": 4096, 4 | "n_layers": 24, 5 | "n_heads": 16, 6 | "p_drop_attn": 0.1, 7 | "p_drop_hidden": 0.1, 8 | "max_len": 512, 9 | "vocab_size": 30522 10 | } 11 | -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/bert-base-uncased-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "dim": 768, 3 | "dim_ff": 3072, 4 | "n_layers": 12, 5 | "p_drop_attn": 0.1, 6 | "n_heads": 12, 7 | "p_drop_hidden": 0.1, 8 | "max_len": 512, 9 | "n_segments": 2, 10 | "vocab_size": 30522 11 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## BERT 2 | This repository is a simple, easy-to-use PyTorch implementation of BERT. It is based on Dong-Hyun Lee's [pytorchic-bert](#somelink), which is in turn based off of HuggingFace's [implementation](#somelink). 3 | 4 | Key features: 5 | * Load pre-trained weights from TensorFlow 6 | * Finetune BERT for text classification 7 | * Pretrain BERT from scratch on your own text data 8 | 9 | More details coming soon! -------------------------------------------------------------------------------- /scripts/run-mrpc.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT="new" 2 | DATA="data/GLUE/MRPC" 3 | TASK="MRPC" 4 | CONFIG="config/bert-base-uncased.json" 5 | WEIGHTS="pretrained/uncased_L-12_H-768_A-12/bert_model.ckpt" 6 | VOCAB="config/bert-uncased-vocab.txt" 7 | 8 | # Prepare experiment 9 | OUTPUT="experiments/${TASK,,}/$EXPERIMENT" 10 | echo "Removing $OUTPUT if it exists" 11 | if [ -d "$OUTPUT" ]; then rm -r $OUTPUT; fi 12 | mkdir -p $OUTPUT 13 | 14 | # Copy this script and the model config to the experiment directory 15 | cp $0 $OUTPUT 16 | cp $CONFIG $OUTPUT 17 | 18 | CUDA_VISIBLE_DEVICES=0 python classify.py \ 19 | --model bert \ 20 | --cfg $CONFIG \ 21 | --load_weights $WEIGHTS \ 22 | --exp_name $EXPERIMENT \ 23 | --task_name $TASK \ 24 | --data_dir $DATA \ 25 | --vocab $VOCAB \ 26 | --do_lower_case \ 27 | --seed 42 \ 28 | --val_every 1 \ 29 | --max_seq_length 128 \ 30 | --train_batch_size 40 \ 31 | --num_train_epochs 4.0 \ 32 | --learning_rate 2e-5 33 | -------------------------------------------------------------------------------- /scripts/.ipynb_checkpoints/run-mrpc-checkpoint.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT="new" 2 | DATA="data/GLUE/MRPC" 3 | TASK="MRPC" 4 | CONFIG="config/bert-base-uncased.json" 5 | WEIGHTS="pretrained/uncased_L-12_H-768_A-12/bert_model.ckpt" 6 | VOCAB="config/bert-uncased-vocab.txt" 7 | 8 | # Prepare experiment 9 | OUTPUT="experiments/${TASK,,}/$EXPERIMENT" 10 | echo "Removing $OUTPUT if it exists" 11 | if [ -d "$OUTPUT" ]; then rm -r $OUTPUT; fi 12 | mkdir -p $OUTPUT 13 | 14 | # Copy this script and the model config to the experiment directory 15 | cp $0 $OUTPUT 16 | cp $CONFIG $OUTPUT 17 | 18 | CUDA_VISIBLE_DEVICES=0 python classify.py \ 19 | --model bert \ 20 | --cfg $CONFIG \ 21 | --load_weights $WEIGHTS \ 22 | --exp_name $EXPERIMENT \ 23 | --task_name $TASK \ 24 | --data_dir $DATA \ 25 | --vocab $VOCAB \ 26 | --do_lower_case \ 27 | --seed 42 \ 28 | --val_every 1 \ 29 | --max_seq_length 128 \ 30 | --train_batch_size 40 \ 31 | --num_train_epochs 4.0 \ 32 | --learning_rate 2e-5 33 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT="test-finetuning-from-pretrained-weights" 2 | DATA="data/Wiki/wiki.train.processed" 3 | CONFIG="config/bert-base-uncased.json" 4 | TASK="pretrain" 5 | WEIGHTS="pretrained/uncased_L-12_H-768_A-12/bert_model.ckpt" 6 | VOCAB="config/bert-uncased-vocab.txt" 7 | 8 | # Prepare experiment 9 | OUTPUT="experiments/${TASK,,}/$EXPERIMENT" 10 | echo "Removing $OUTPUT if it exists" 11 | if [ -d "$OUTPUT" ]; then rm -r $OUTPUT; fi 12 | mkdir -p $OUTPUT 13 | 14 | # Copy this script and the model config to the experiment directory 15 | cp $0 $OUTPUT 16 | cp $CONFIG $OUTPUT 17 | 18 | CUDA_VISIBLE_DEVICES=0,1 python pretrain.py \ 19 | --model bert \ 20 | --text_file $DATA \ 21 | --cfg $CONFIG \ 22 | --exp_name $EXPERIMENT \ 23 | --vocab $VOCAB \ 24 | --do_lower_case \ 25 | --seed 42 \ 26 | --val_every 1 \ 27 | --max_seq_length 256 \ 28 | --train_batch_size 32 \ 29 | --total_iterations 100000 \ 30 | --learning_rate 5e-6 \ 31 | --load_weights $WEIGHTS 32 | -------------------------------------------------------------------------------- /scripts/.ipynb_checkpoints/pretrain-checkpoint.sh: -------------------------------------------------------------------------------- 1 | EXPERIMENT="test-finetuning-from-pretrained-weights" 2 | DATA="data/Wiki/wiki.train.processed" 3 | CONFIG="config/bert-base-uncased.json" 4 | TASK="pretrain" 5 | WEIGHTS="pretrained/uncased_L-12_H-768_A-12/bert_model.ckpt" 6 | VOCAB="config/bert-uncased-vocab.txt" 7 | 8 | # Prepare experiment 9 | OUTPUT="experiments/${TASK,,}/$EXPERIMENT" 10 | echo "Removing $OUTPUT if it exists" 11 | if [ -d "$OUTPUT" ]; then rm -r $OUTPUT; fi 12 | mkdir -p $OUTPUT 13 | 14 | # Copy this script and the model config to the experiment directory 15 | cp $0 $OUTPUT 16 | cp $CONFIG $OUTPUT 17 | 18 | CUDA_VISIBLE_DEVICES=0,1 python pretrain.py \ 19 | --model bert \ 20 | --text_file $DATA \ 21 | --cfg $CONFIG \ 22 | --exp_name $EXPERIMENT \ 23 | --vocab $VOCAB \ 24 | --do_lower_case \ 25 | --seed 42 \ 26 | --val_every 1 \ 27 | --max_seq_length 256 \ 28 | --train_batch_size 32 \ 29 | --total_iterations 100000 \ 30 | --learning_rate 5e-6 \ 31 | --load_weights $WEIGHTS 32 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper Functions 3 | """ 4 | 5 | import os 6 | import random 7 | import numpy as np 8 | import torch 9 | 10 | def set_seeds(seed, multi_gpu=True): 11 | '''Set all random seeds''' 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | if multi_gpu: 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | def split_last(x, shape): 19 | "split the last dimension to given shape" 20 | shape = list(shape) 21 | assert shape.count(-1) <= 1 22 | if -1 in shape: 23 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 24 | return x.view(*x.size()[:-1], *shape) 25 | 26 | def merge_last(x, n_dims): 27 | "merge the last n_dims to a dimension" 28 | s = x.size() 29 | assert n_dims > 1 and n_dims < len(s) 30 | return x.view(*s[:-n_dims], -1) 31 | 32 | def truncate_tokens_pair(tokens_a, tokens_b, max_len): 33 | '''Removes tokens until inputs have the same length''' 34 | while True: 35 | if len(tokens_a) + len(tokens_b) <= max_len: 36 | break 37 | if len(tokens_a) > len(tokens_b): 38 | tokens_a.pop() 39 | else: 40 | tokens_b.pop() 41 | 42 | def get_random_word(vocab_words): 43 | '''Unform random word from vocab''' 44 | i = random.randint(0, len(vocab_words)-1) 45 | return vocab_words[i] 46 | 47 | def get_tensorboard_logger(args): 48 | '''Gets a TensorBoard logger or creates a fallback''' 49 | print(f"Logging to {args.output_dir}") 50 | try: 51 | from tensorboardX import SummaryWriter 52 | logger = SummaryWriter(log_dir=args.output_dir) # this crashes my VM 53 | print(f"Connect with: \n\t tensorboard --logdir {args.output_dir} --port 6001") 54 | except: 55 | print('NOTE: TensorBoardX is not installed. Logging to console.') 56 | class NotSummaryWriter(object): pass; 57 | logger = NotSummaryWriter() 58 | nothing_function = lambda s, *args, **kw: None 59 | logger.add_text = logger.add_scalar = nothing_function 60 | logger.info = print # could also write to a file here but this is fine for now 61 | return logger 62 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | 4 | from . import transformer, heads 5 | 6 | def get_model_for_classification(args): 7 | ''' Load a model in full or half precision with pretrained weights. ''' 8 | 9 | # Load a BERT model 10 | if args.model == 'bert': 11 | cfg = transformer.TransformerConfig.from_json(args.cfg) 12 | body = transformer.Transformer(cfg) 13 | model = heads.TransformerForClassification(cfg, body, args.num_labels) 14 | 15 | # Load pretrained weights 16 | if args.load_weights: 17 | if '.pth' in args.load_weights: # PyTorch file 18 | model.load_state_dict(torch.load(args.load_weights)) 19 | elif '.ckpt' in args.load_weights: # TensorFlow file 20 | from utils.load_weights import load_weights_for_classification 21 | load_weights_for_classification(model, args.load_weights) 22 | 23 | # CUDA / half-precision / distributed training 24 | model = distribute_and_fp16(args, model) 25 | return model 26 | 27 | def get_model_for_pretrain(args): 28 | ''' Load a model in full or half precision with pretrained weights. ''' 29 | 30 | # Load a BERT model 31 | if args.model == 'bert': 32 | cfg = transformer.TransformerConfig.from_json(args.cfg) 33 | body = transformer.Transformer(cfg) 34 | model = heads.TransformerForPretrain(cfg, body) 35 | 36 | # Load pretrained weights 37 | if args.load_weights: 38 | if '.pth' in args.load_weights: # PyTorch file 39 | model.load_state_dict(torch.load(args.load_weights)) 40 | elif '.ckpt' in args.load_weights: # TensorFlow file 41 | from utils.load_weights import load_weights_for_pretrain 42 | load_weights_for_pretrain(model, args.load_weights) 43 | 44 | # CUDA / half-precision / distributed training 45 | model = distribute_and_fp16(args, model) 46 | return model 47 | 48 | def distribute_and_fp16(args, model): 49 | ''' Multi-GPU and half-precision ''' 50 | 51 | if args.fp16: 52 | model.half() 53 | model.to(args.device) 54 | if args.local_rank != -1: 55 | try: 56 | from apex.parallel import DistributedDataParallel as DDP 57 | except ImportError: 58 | raise ImportError("To use FP16, install apex from https://www.github.com/nvidia/apex") 59 | model = DDP(model) 60 | elif args.n_gpu > 1: 61 | model = torch.nn.DataParallel(model) 62 | return model -------------------------------------------------------------------------------- /models/common_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # GELU Activation: https://arxiv.org/abs/1606.08415 8 | gelu = lambda x : x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 9 | 10 | # LayerNorm 11 | try: 12 | from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm 13 | except ImportError: 14 | class LayerNorm(nn.Module): 15 | "Layer normalization in the TF style (epsilon inside the square root)." 16 | def __init__(self, cfg, variance_epsilon=1e-12): 17 | super().__init__() 18 | self.gamma = nn.Parameter(torch.ones(cfg.dim)) 19 | self.beta = nn.Parameter(torch.zeros(cfg.dim)) 20 | self.variance_epsilon = variance_epsilon 21 | 22 | def forward(self, x): 23 | u = x.mean(-1, keepdim=True) 24 | s = (x - u).pow(2).mean(-1, keepdim=True) 25 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 26 | return self.gamma * x + self.beta 27 | 28 | def Linear(in_features, out_features, bias=True): 29 | ''' Wrapper for nn.Linear ''' 30 | m = nn.Linear(in_features, out_features, bias) 31 | nn.init.xavier_uniform_(m.weight) 32 | if bias: 33 | nn.init.constant_(m.bias, 0.) 34 | return m 35 | 36 | def Embedding(num_embeddings, embedding_dim, padding_idx=None): 37 | ''' Wrapper for nn.Embedding ''' 38 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) 39 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) 40 | nn.init.constant_(m.weight[padding_idx], 0) 41 | return m 42 | 43 | class Embeddings(nn.Module): 44 | '''Embedding with optional position and segment type embeddings.''' 45 | def __init__(self, cfg, position_embeds=True, segment_embeds=True): 46 | super().__init__() 47 | self.tok_embed = Embedding(cfg.vocab_size, cfg.dim) # token embedding 48 | self.pos_embed = Embedding(cfg.max_len, cfg.dim) if position_embeds else None # position embedding 49 | self.seg_embed = Embedding(cfg.n_segments, cfg.dim) if segment_embeds else None # segment(token type) embedding 50 | self.norm = LayerNorm(cfg) 51 | self.drop = nn.Dropout(cfg.p_drop_hidden) 52 | 53 | def forward(self, x, seg): 54 | e = self.tok_embed(x) 55 | if self.pos_embed is not None: 56 | pos = torch.arange(x.size(1), dtype=torch.long, device=x.device) # x.size(1) = seq_len 57 | pos = pos.unsqueeze(0).expand_as(x) # (S,) -> (B, S) 58 | e = e + self.pos_embed(pos) 59 | if self.seg_embed is not None: 60 | e = e + self.seg_embed(seg) 61 | return self.drop(self.norm(e)) 62 | 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## CUSTOM 2 | data 3 | data/GLUE/* 4 | !data/GLUE/placeholder.txt 5 | experiments/*/*/* 6 | !experiments/placeholder.txt 7 | *.out 8 | pretrained 9 | 10 | # disabling this for now, as I am testing -- re-enable soon 11 | !experiments/*/*/*tfevents* 12 | 13 | run.sh 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # celery beat schedule file 107 | celerybeat-schedule 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | -------------------------------------------------------------------------------- /train/pretrain_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | 8 | class Trainer(): 9 | 10 | def __init__(self, logger=None): 11 | ''' The trainer simply holds the global training step and the logger. ''' 12 | self.logger = logger 13 | self.global_step = 0 14 | 15 | def train(self, args, model, dataloader, optimizer, epoch): 16 | '''Train for a single epoch on a training dataset''' 17 | model.train() 18 | total_loss = 0 19 | for step, batch in enumerate(tqdm(dataloader, desc=f"[Epoch {epoch+1:3d}] Batch ")): 20 | batch = tuple(t.to(args.device) for t in batch) 21 | input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_not_next = batch 22 | 23 | # Forward 24 | logits_lm, logits_sc = model(input_ids, segment_ids, input_mask, masked_pos) 25 | 26 | # Masked LM and sequence classification losses 27 | loss_lm = F.cross_entropy(logits_lm.transpose(1, 2), masked_ids, reduction='none') 28 | loss_lm = (loss_lm * masked_weights.float()).mean() 29 | loss_sc = F.cross_entropy(logits_sc, is_not_next) 30 | loss = loss_lm + loss_sc 31 | 32 | # Multi-gpu / gradient accumulation 33 | if args.n_gpu > 1: # note: use .mean() to average on multi-gpu 34 | loss = loss.mean() 35 | if args.gradient_accumulation_steps > 1: # accumulate gradient for small batch sizes 36 | loss = loss / args.gradient_accumulation_steps 37 | 38 | # Backward 39 | if args.fp16: 40 | optimizer.backward(loss) 41 | else: 42 | loss.backward() 43 | total_loss += loss.item() 44 | if (step + 1) % args.gradient_accumulation_steps == 0: 45 | if args.fp16: # modify l.r. with warmup (if args.fp16 is False, this is automatic) 46 | lr_this_step = args.learning_rate * \ 47 | warmup_linear(model.global_step/num_train_optimization_steps, args.warmup_proportion) 48 | for param_group in optimizer.param_groups: 49 | param_group['lr'] = lr_this_step 50 | optimizer.step() 51 | optimizer.zero_grad() 52 | self.global_step += 1 53 | 54 | if self.logger: 55 | # TODO: log learning rate 56 | self.logger.add_scalar('train/loss_total', loss.item(), self.global_step) 57 | self.logger.add_scalar('train/loss_lm', loss_lm.item(), self.global_step) 58 | self.logger.add_scalar('train/loss_sc', loss_sc.item(), self.global_step) 59 | 60 | if (self.global_step + 1) % args.checkpoint_every == 0: 61 | self.save(args, model) 62 | 63 | if self.logger: 64 | self.logger.info(f'Train loss: {total_loss/len(dataloader.dataset):.3f}') 65 | 66 | def evaluate(self): 67 | # Validation is not implemented -- pretrain for as long as possible 68 | raise NotImplementedError() 69 | 70 | def save(self, args, model, name=None): 71 | ''' Save a trained model and the associated configuration ''' 72 | model_name = f"model-{self.global_step}.pth" if name is None else name 73 | model_to_save = model.module if hasattr(model, 'module') else model # for nn.DataParallel 74 | model_file = os.path.join(args.output_dir, model_name) 75 | torch.save(model_to_save.state_dict(), model_file) 76 | return model_file 77 | 78 | -------------------------------------------------------------------------------- /train/.ipynb_checkpoints/pretrain_trainer-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | 8 | class Trainer(): 9 | 10 | def __init__(self, logger=None): 11 | ''' The trainer simply holds the global training step and the logger. ''' 12 | self.logger = logger 13 | self.global_step = 0 14 | 15 | def train(self, args, model, dataloader, optimizer, epoch): 16 | '''Train for a single epoch on a training dataset''' 17 | model.train() 18 | total_loss = 0 19 | for step, batch in enumerate(tqdm(dataloader, desc=f"[Epoch {epoch+1:3d}] Batch ")): 20 | batch = tuple(t.to(args.device) for t in batch) 21 | input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_not_next = batch 22 | 23 | # Forward 24 | logits_lm, logits_sc = model(input_ids, segment_ids, input_mask, masked_pos) 25 | 26 | # Masked LM and sequence classification losses 27 | loss_lm = F.cross_entropy(logits_lm.transpose(1, 2), masked_ids, reduction='none') 28 | loss_lm = (loss_lm * masked_weights.float()).mean() 29 | loss_sc = F.cross_entropy(logits_sc, is_not_next) 30 | loss = loss_lm + loss_sc 31 | 32 | # Multi-gpu / gradient accumulation 33 | if args.n_gpu > 1: # note: use .mean() to average on multi-gpu 34 | loss = loss.mean() 35 | if args.gradient_accumulation_steps > 1: # accumulate gradient for small batch sizes 36 | loss = loss / args.gradient_accumulation_steps 37 | 38 | # Backward 39 | if args.fp16: 40 | optimizer.backward(loss) 41 | else: 42 | loss.backward() 43 | total_loss += loss.item() 44 | if (step + 1) % args.gradient_accumulation_steps == 0: 45 | if args.fp16: # modify l.r. with warmup (if args.fp16 is False, this is automatic) 46 | lr_this_step = args.learning_rate * \ 47 | warmup_linear(model.global_step/num_train_optimization_steps, args.warmup_proportion) 48 | for param_group in optimizer.param_groups: 49 | param_group['lr'] = lr_this_step 50 | optimizer.step() 51 | optimizer.zero_grad() 52 | self.global_step += 1 53 | 54 | if self.logger: 55 | # TODO: log learning rate 56 | self.logger.add_scalar('train/loss_total', loss.item(), self.global_step) 57 | self.logger.add_scalar('train/loss_lm', loss_lm.item(), self.global_step) 58 | self.logger.add_scalar('train/loss_sc', loss_sc.item(), self.global_step) 59 | 60 | if (self.global_step + 1) % args.checkpoint_every == 0: 61 | self.save(args, model) 62 | 63 | if self.logger: 64 | self.logger.info(f'Train loss: {total_loss/len(dataloader.dataset):.3f}') 65 | 66 | def evaluate(self): 67 | # Validation is not implemented -- pretrain for as long as possible 68 | raise NotImplementedError() 69 | 70 | def save(self, args, model, name=None): 71 | ''' Save a trained model and the associated configuration ''' 72 | model_name = f"model-{self.global_step}.pth" if name is None else name 73 | model_to_save = model.module if hasattr(model, 'module') else model # for nn.DataParallel 74 | model_file = os.path.join(args.output_dir, model_name) 75 | torch.save(model_to_save.state_dict(), model_file) 76 | return model_file 77 | 78 | -------------------------------------------------------------------------------- /models/heads.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .common_layers import LayerNorm, gelu 7 | 8 | class TransformerForPretrain(nn.Module): 9 | """A model in the style of BERT for pretraining. 10 | Example (1): 11 | ``` 12 | transformer = models.Transformer(cfg) 13 | bert = models.TransformerForPretraining(cfg, transformer) 14 | ``` 15 | Example (2): 16 | ``` 17 | lwdc_transformer = models.LightweightTransformer(cfg) 18 | bert = models.TransformerForPretraining(cfg, lwdc_transformer) 19 | ``` 20 | """ 21 | def __init__(self, cfg, transformer): 22 | super().__init__() 23 | self.transformer = transformer 24 | 25 | # For sentence classification 26 | self.pooler = nn.Linear(cfg.dim, cfg.dim) 27 | self.pooler_activation = nn.Tanh() 28 | self.seq_relationship = nn.Linear(cfg.dim, 2) 29 | 30 | # For masked LM 31 | embed_weight = self.transformer.embed.tok_embed.weight 32 | n_vocab, n_dim = embed_weight.size() 33 | self.decoder_linear = nn.Linear(cfg.dim, cfg.dim) 34 | self.decoder_norm = LayerNorm(cfg) 35 | self.decoder_output = nn.Linear(n_dim, n_vocab, bias=False) 36 | self.decoder_output_bias = nn.Parameter(torch.zeros(n_vocab)) 37 | 38 | # Tie weights 39 | self.decoder_output.weight = self.transformer.embed.tok_embed.weight 40 | 41 | def forward(self, input_ids, segment_ids=None, input_mask=None, masked_pos=None): 42 | 43 | # Allow for null inputs 44 | segment_ids = torch.zeros_like(input_ids) if segment_ids is None else segment_ids 45 | input_mask = torch.ones_like(input_ids) if input_mask is None else input_mask 46 | 47 | # Transformer 48 | h = self.transformer(input_ids, segment_ids, input_mask) 49 | 50 | # For sentence classification 51 | pooled_h = self.pooler_activation(self.pooler(h[:, 0])) 52 | logits_clsf = self.seq_relationship(pooled_h) 53 | 54 | # For masked LM # NOTE: be careful about this masked_pos stuff 55 | masked_pos = masked_pos[:, :, None].expand(-1, -1, h.size(-1)) 56 | h_masked = h if masked_pos is None else torch.gather(h, 1, masked_pos) 57 | h_masked = self.decoder_norm(gelu(self.decoder_linear(h_masked))) 58 | logits_lm = self.decoder_output(h_masked) + self.decoder_output_bias 59 | return logits_lm, logits_clsf 60 | 61 | class TransformerForClassification(nn.Module): 62 | """A model in the style of BERT for classification. 63 | Example: 64 | ``` 65 | transformer = models.Transformer(cfg) 66 | bert_for_mrpc_finetuning = models.TransformerForPretraining(cfg, transformer, 7) 67 | ``` 68 | """ 69 | def __init__(self, cfg, transformer, num_classes): 70 | super().__init__() 71 | self.transformer = transformer 72 | 73 | # Pooling --> Dropout --> Linear 74 | self.pooler = nn.Linear(cfg.dim, cfg.dim) 75 | self.pooler_activation = nn.Tanh() 76 | self.dropout = nn.Dropout(cfg.p_drop_hidden) 77 | self.classifier = nn.Linear(cfg.dim, num_classes) 78 | 79 | def forward(self, input_ids, segment_ids=None, input_mask=None, masked_pos=None): 80 | 81 | # Allow for null inputs 82 | segment_ids = torch.zeros_like(input_ids) if segment_ids is None else segment_ids 83 | input_mask = torch.ones_like(input_ids) if input_mask is None else input_mask 84 | 85 | # Transformer --> Pooler --> Dropout --> Linear 86 | h = self.transformer(input_ids, segment_ids, input_mask) 87 | pooled_h = self.pooler_activation(self.pooler(h[:, 0])) 88 | pooled_h = self.dropout(pooled_h) 89 | logits = self.classifier(pooled_h) 90 | return logits 91 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain. 2 | # (Strongly inspired by original Google BERT code and Hugging Face's code) 3 | 4 | """ Transformer Model Classes & Config Class """ 5 | 6 | import math 7 | import json 8 | from typing import NamedTuple 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from utils.utils import split_last, merge_last 16 | from .common_layers import Embeddings, LayerNorm, gelu 17 | 18 | class TransformerConfig(NamedTuple): 19 | "Configuration for BERT model" 20 | vocab_size: int = None # Size of Vocabulary 21 | dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder 22 | n_layers: int = 12 # Numher of Hidden Layers 23 | n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers 24 | dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net 25 | p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers 26 | p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers 27 | max_len: int = 512 # Maximum Length for Positional Embeddings 28 | n_segments: int = 2 # Number of Sentence Segments 29 | 30 | @classmethod 31 | def from_json(cls, file): 32 | return cls(**json.load(open(file, "r"))) 33 | 34 | @classmethod 35 | def check(cfg): 36 | assert cfg.kernel_list is None 37 | assert cfg.conv_type is None 38 | assert cfg.dim % n_heads == 0 39 | 40 | class MultiHeadedSelfAttention(nn.Module): 41 | """ Multi-Headed Dot Product Attention """ 42 | def __init__(self, cfg): 43 | super().__init__() 44 | self.proj_q = nn.Linear(cfg.dim, cfg.dim) 45 | self.proj_k = nn.Linear(cfg.dim, cfg.dim) 46 | self.proj_v = nn.Linear(cfg.dim, cfg.dim) 47 | self.drop = nn.Dropout(cfg.p_drop_attn) 48 | self.scores = None # for visualization 49 | self.n_heads = cfg.n_heads 50 | 51 | def forward(self, x, mask): 52 | """ 53 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 54 | mask : (B(batch_size) x S(seq_len)) 55 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 56 | """ 57 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 58 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 59 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) 60 | for x in [q, k, v]) 61 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 62 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 63 | if mask is not None: 64 | mask = mask[:, None, None, :].float() 65 | scores -= 10000.0 * (1.0 - mask) 66 | scores = self.drop(F.softmax(scores, dim=-1)) 67 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 68 | h = (scores @ v).transpose(1, 2).contiguous() 69 | # -merge-> (B, S, D) 70 | h = merge_last(h, 2) 71 | self.scores = scores 72 | return h 73 | 74 | 75 | class PositionWiseFeedForward(nn.Module): 76 | """ FeedForward Neural Networks for each position """ 77 | def __init__(self, cfg): 78 | super().__init__() 79 | self.fc1 = nn.Linear(cfg.dim, cfg.dim_ff) 80 | self.fc2 = nn.Linear(cfg.dim_ff, cfg.dim) 81 | #self.activ = lambda x: activ_fn(cfg.activ_fn, x) 82 | 83 | def forward(self, x): 84 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 85 | return self.fc2(gelu(self.fc1(x))) 86 | 87 | 88 | class Block(nn.Module): 89 | """ Transformer Block """ 90 | def __init__(self, cfg): 91 | super().__init__() 92 | self.attn = MultiHeadedSelfAttention(cfg) 93 | self.proj = nn.Linear(cfg.dim, cfg.dim) 94 | self.norm1 = LayerNorm(cfg) 95 | self.pwff = PositionWiseFeedForward(cfg) 96 | self.norm2 = LayerNorm(cfg) 97 | self.drop = nn.Dropout(cfg.p_drop_hidden) 98 | 99 | def forward(self, x, mask): 100 | h = self.attn(x, mask) 101 | h = self.norm1(x + self.drop(self.proj(h))) 102 | h = self.norm2(h + self.drop(self.pwff(h))) 103 | return h 104 | 105 | 106 | class Transformer(nn.Module): 107 | """ Transformer with Self-Attentive Blocks""" 108 | def __init__(self, cfg): 109 | super().__init__() 110 | self.embed = Embeddings(cfg) 111 | self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) 112 | 113 | def forward(self, x, seg, mask): 114 | h = self.embed(x, seg) 115 | for block in self.blocks: 116 | h = block(h, mask) 117 | return h 118 | -------------------------------------------------------------------------------- /utils/load_weights.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load TensorFlow checkpoints into PyTorch model. 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | import torch 8 | 9 | def load_param(checkpoint_file, conversion_table): 10 | """ 11 | Load parameters according to conversion_table. 12 | Args: 13 | checkpoint_file (string): pretrained checkpoint model file in tensorflow 14 | conversion_table (dict): { pytorch tensor in a model : checkpoint variable name } 15 | """ 16 | for pyt_param, tf_param_name in conversion_table.items(): 17 | tf_param = tf.train.load_variable(checkpoint_file, tf_param_name) 18 | 19 | # for weight(kernel), we should do transpose 20 | if tf_param_name.endswith('kernel'): 21 | tf_param = np.transpose(tf_param) 22 | 23 | assert pyt_param.size() == tf_param.shape, \ 24 | 'Dim Mismatch: %s vs %s ; %s' % \ 25 | (tuple(pyt_param.size()), tf_param.shape, tf_param_name) 26 | 27 | # assign pytorch tensor from tensorflow param 28 | pyt_param.data = torch.from_numpy(tf_param) 29 | 30 | 31 | def load_transformer(model, checkpoint_file): 32 | """ 33 | Load transformer, ** not heads ** , into PyTorch model. 34 | """ 35 | 36 | # Embedding layer 37 | e, p = model.embed, 'bert/embeddings/' 38 | load_param(checkpoint_file, { 39 | e.tok_embed.weight: p+"word_embeddings", 40 | e.pos_embed.weight: p+"position_embeddings", 41 | e.seg_embed.weight: p+"token_type_embeddings", 42 | e.norm.gamma: p+"LayerNorm/gamma", 43 | e.norm.beta: p+"LayerNorm/beta" 44 | }) 45 | 46 | # Transformer blocks 47 | for i in range(len(model.blocks)): 48 | b, p = model.blocks[i], "bert/encoder/layer_%d/"%i 49 | load_param(checkpoint_file, { 50 | b.attn.proj_q.weight: p+"attention/self/query/kernel", 51 | b.attn.proj_q.bias: p+"attention/self/query/bias", 52 | b.attn.proj_k.weight: p+"attention/self/key/kernel", 53 | b.attn.proj_k.bias: p+"attention/self/key/bias", 54 | b.attn.proj_v.weight: p+"attention/self/value/kernel", 55 | b.attn.proj_v.bias: p+"attention/self/value/bias", 56 | b.proj.weight: p+"attention/output/dense/kernel", 57 | b.proj.bias: p+"attention/output/dense/bias", 58 | b.pwff.fc1.weight: p+"intermediate/dense/kernel", 59 | b.pwff.fc1.bias: p+"intermediate/dense/bias", 60 | b.pwff.fc2.weight: p+"output/dense/kernel", 61 | b.pwff.fc2.bias: p+"output/dense/bias", 62 | b.norm1.gamma: p+"attention/output/LayerNorm/gamma", 63 | b.norm1.beta: p+"attention/output/LayerNorm/beta", 64 | b.norm2.gamma: p+"output/LayerNorm/gamma", 65 | b.norm2.beta: p+"output/LayerNorm/beta", 66 | }) 67 | 68 | def load_weights_for_pretrain(model, checkpoint_file): 69 | ''' 70 | Load parameters of model for pretraining (i.e. masked LM and sequence classifier) 71 | from TensorFlow model checkpoint file onto PyTorch model. 72 | ''' 73 | 74 | # Load transformer body 75 | load_transformer(model.transformer, checkpoint_file) 76 | 77 | # Sequence classification (+ pooler) and masked language model (decoder) 78 | conversion_table = { 79 | model.pooler.weight: 'bert/pooler/dense/kernel', 80 | model.pooler.bias: 'bert/pooler/dense/bias', 81 | model.seq_relationship.weight: 'cls/seq_relationship/output_weights', 82 | model.seq_relationship.bias: 'cls/seq_relationship/output_bias', 83 | model.decoder_linear.weight: 'cls/predictions/transform/dense/kernel', 84 | model.decoder_linear.bias: 'cls/predictions/transform/dense/bias', 85 | model.decoder_norm.gamma: 'cls/predictions/transform/LayerNorm/gamma', 86 | model.decoder_norm.beta: 'cls/predictions/transform/LayerNorm/beta', 87 | model.decoder_output_bias: 'cls/predictions/output_bias', 88 | } 89 | load_param(checkpoint_file, conversion_table) 90 | 91 | def load_weights_for_classification(model, checkpoint_file): 92 | ''' 93 | Load parameters of model for classification (i.e. pooler) 94 | from TensorFlow model checkpoint file onto PyTorch model. 95 | ''' 96 | # Load transformer body 97 | load_transformer(model.transformer, checkpoint_file) 98 | 99 | # Sequence classification (+ pooler) and masked language model (decoder) 100 | conversion_table = { 101 | model.pooler.weight: 'bert/pooler/dense/kernel', 102 | model.pooler.bias: 'bert/pooler/dense/bias', 103 | } 104 | load_param(checkpoint_file, conversion_table) 105 | -------------------------------------------------------------------------------- /train/classification_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | 7 | from .classification_data import compute_metrics 8 | 9 | class Trainer(): 10 | 11 | def __init__(self, logger=None): 12 | ''' The trainer simply holds the global training step and the logger. ''' 13 | self.logger = logger 14 | self.global_step = 0 15 | 16 | def evaluate(self, args, model, dataloader, criterion): 17 | ''' Evaluate model on the dev/test set ''' 18 | model.eval() 19 | total_loss = 0 20 | all_logits = None 21 | all_labels = None 22 | for step, batch in enumerate(tqdm(dataloader, desc=f"Validation")): 23 | input_ids, segment_ids, input_mask, labels = tuple(t.to(args.device) for t in batch) 24 | 25 | # Forward 26 | with torch.no_grad(): 27 | logits = model(input_ids, segment_ids, input_mask) 28 | 29 | # Calculate loss 30 | labels = labels.view(-1) 31 | logits = logits.view(-1) if args.output_mode == 'regression' else logits.view(-1, args.num_labels) 32 | loss = criterion(logits, labels) 33 | 34 | # Statistics 35 | total_loss += loss.mean().item() 36 | logits = logits.detach().cpu().numpy() 37 | labels = labels.detach().cpu().numpy() 38 | all_logits = logits if all_logits is None else np.append(all_logits, logits, axis=0) 39 | all_labels = labels if all_labels is None else np.append(all_labels, labels, axis=0) 40 | 41 | # Calculate prediction metrics (i.e. accuracy) and log 42 | average_loss = total_loss / len(dataloader.dataset) 43 | predictions = np.squeeze(all_logits) if args.output_mode == 'regression' else np.argmax(all_logits, axis=1) 44 | result = compute_metrics(args.task_name, predictions, all_labels, 45 | logits=all_logits / all_logits.sum(axis=1, keepdims=True)) 46 | 47 | # Log 48 | if self.logger: 49 | self.logger.add_scalar('val/val_loss', average_loss, self.global_step) 50 | self.logger.info(f"Val loss: {average_loss:.3f}") 51 | for key in sorted(result.keys()): 52 | self.logger.add_scalar(f'val/{key}', result[key], self.global_step) 53 | self.logger.info(f"Val {key}: {result[key]:.3f}") 54 | return average_loss, result 55 | 56 | def train(self, args, model, dataloader, criterion, optimizer, epoch): 57 | '''Train for a single epoch on a training dataset''' 58 | model.train() 59 | total_loss = 0 60 | for step, batch in enumerate(tqdm(dataloader, desc=f"[Epoch {epoch+1:3d}] Iteration")): 61 | input_ids, segment_ids, input_mask, labels = tuple(t.to(args.device) for t in batch) 62 | 63 | # Forward 64 | logits = model(input_ids, segment_ids, input_mask) 65 | 66 | # Loss 67 | logits = logits.view(-1) if args.output_mode == 'regression' else logits.view(-1, args.num_labels) 68 | loss = criterion(logits, labels) 69 | if args.n_gpu > 1: # note: use .mean() to average on multi-gpu 70 | loss = loss.mean() 71 | if args.gradient_accumulation_steps > 1: # accumulate gradient for small batch sizes 72 | loss = loss / args.gradient_accumulation_steps 73 | 74 | # Backward 75 | if args.fp16: 76 | optimizer.backward(loss) 77 | else: 78 | loss.backward() 79 | total_loss += loss.item() 80 | if (step + 1) % args.gradient_accumulation_steps == 0: 81 | if args.fp16: 82 | # modify learning rate with special warm up BERT uses 83 | # if args.fp16 is False, BertAdam is used that handles this modification automatically 84 | lr_this_step = args.learning_rate * \ 85 | warmup_linear(model.global_step/num_train_optimization_steps, args.warmup_proportion) 86 | for param_group in optimizer.param_groups: 87 | param_group['lr'] = lr_this_step 88 | optimizer.step() 89 | optimizer.zero_grad() 90 | self.global_step += 1 91 | 92 | # TODO: log learning rate 93 | if self.logger: 94 | self.logger.add_scalar('train/lr', loss.item(), self.global_step) 95 | self.logger.add_scalar('train/loss', loss.item(), self.global_step) 96 | if self.logger: 97 | self.logger.info(f'Train loss: {total_loss/len(dataloader.dataset):.3f}') 98 | 99 | def save(self, args, model, name=None): 100 | ''' Save a trained model and the associated configuration ''' 101 | model_name = f"model-{self.global_step}.pth" if name is None else name 102 | model_to_save = model.module if hasattr(model, 'module') else model # for nn.DataParallel 103 | model_file = os.path.join(args.output_dir, model_name) 104 | torch.save(model_to_save.state_dict(), model_file) 105 | return model_file 106 | 107 | -------------------------------------------------------------------------------- /train/.ipynb_checkpoints/classification_trainer-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | 7 | from .classification_data import compute_metrics 8 | 9 | class Trainer(): 10 | 11 | def __init__(self, logger=None): 12 | ''' The trainer simply holds the global training step and the logger. ''' 13 | self.logger = logger 14 | self.global_step = 0 15 | 16 | def evaluate(self, args, model, dataloader, criterion): 17 | ''' Evaluate model on the dev/test set ''' 18 | model.eval() 19 | total_loss = 0 20 | all_logits = None 21 | all_labels = None 22 | for step, batch in enumerate(tqdm(dataloader, desc=f"Validation")): 23 | input_ids, segment_ids, input_mask, labels = tuple(t.to(args.device) for t in batch) 24 | 25 | # Forward 26 | with torch.no_grad(): 27 | logits = model(input_ids, segment_ids, input_mask) 28 | 29 | # Calculate loss 30 | labels = labels.view(-1) 31 | logits = logits.view(-1) if args.output_mode == 'regression' else logits.view(-1, args.num_labels) 32 | loss = criterion(logits, labels) 33 | 34 | # Statistics 35 | total_loss += loss.mean().item() 36 | logits = logits.detach().cpu().numpy() 37 | labels = labels.detach().cpu().numpy() 38 | all_logits = logits if all_logits is None else np.append(all_logits, logits, axis=0) 39 | all_labels = labels if all_labels is None else np.append(all_labels, labels, axis=0) 40 | 41 | # Calculate prediction metrics (i.e. accuracy) and log 42 | average_loss = total_loss / len(dataloader.dataset) 43 | predictions = np.squeeze(all_logits) if args.output_mode == 'regression' else np.argmax(all_logits, axis=1) 44 | result = compute_metrics(args.task_name, predictions, all_labels, 45 | logits=all_logits / all_logits.sum(axis=1, keepdims=True)) 46 | 47 | # Log 48 | if self.logger: 49 | self.logger.add_scalar('val/val_loss', average_loss, self.global_step) 50 | self.logger.info(f"Val loss: {average_loss:.3f}") 51 | for key in sorted(result.keys()): 52 | self.logger.add_scalar(f'val/{key}', result[key], self.global_step) 53 | self.logger.info(f"Val {key}: {result[key]:.3f}") 54 | return average_loss, result 55 | 56 | def train(self, args, model, dataloader, criterion, optimizer, epoch): 57 | '''Train for a single epoch on a training dataset''' 58 | model.train() 59 | total_loss = 0 60 | for step, batch in enumerate(tqdm(dataloader, desc=f"[Epoch {epoch+1:3d}] Iteration")): 61 | input_ids, segment_ids, input_mask, labels = tuple(t.to(args.device) for t in batch) 62 | 63 | # Forward 64 | logits = model(input_ids, segment_ids, input_mask) 65 | 66 | # Loss 67 | logits = logits.view(-1) if args.output_mode == 'regression' else logits.view(-1, args.num_labels) 68 | loss = criterion(logits, labels) 69 | if args.n_gpu > 1: # note: use .mean() to average on multi-gpu 70 | loss = loss.mean() 71 | if args.gradient_accumulation_steps > 1: # accumulate gradient for small batch sizes 72 | loss = loss / args.gradient_accumulation_steps 73 | 74 | # Backward 75 | if args.fp16: 76 | optimizer.backward(loss) 77 | else: 78 | loss.backward() 79 | total_loss += loss.item() 80 | if (step + 1) % args.gradient_accumulation_steps == 0: 81 | if args.fp16: 82 | # modify learning rate with special warm up BERT uses 83 | # if args.fp16 is False, BertAdam is used that handles this modification automatically 84 | lr_this_step = args.learning_rate * \ 85 | warmup_linear(model.global_step/num_train_optimization_steps, args.warmup_proportion) 86 | for param_group in optimizer.param_groups: 87 | param_group['lr'] = lr_this_step 88 | optimizer.step() 89 | optimizer.zero_grad() 90 | self.global_step += 1 91 | 92 | # TODO: log learning rate 93 | if self.logger: 94 | self.logger.add_scalar('train/lr', loss.item(), self.global_step) 95 | self.logger.add_scalar('train/loss', loss.item(), self.global_step) 96 | if self.logger: 97 | self.logger.info(f'Train loss: {total_loss/len(dataloader.dataset):.3f}') 98 | 99 | def save(self, args, model, name=None): 100 | ''' Save a trained model and the associated configuration ''' 101 | model_name = f"model-{self.global_step}.pth" if name is None else name 102 | model_to_save = model.module if hasattr(model, 'module') else model # for nn.DataParallel 103 | model_file = os.path.join(args.output_dir, model_name) 104 | torch.save(model_to_save.state_dict(), model_file) 105 | return model_file 106 | 107 | -------------------------------------------------------------------------------- /train/pretrain_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloaders for pretraining a BERT-style masked language model. 3 | """ 4 | 5 | from random import randint, shuffle 6 | from random import random as rand 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from utils.utils import get_random_word, truncate_tokens_pair 12 | 13 | def seek_random_offset(f, back_margin=2000): 14 | """ Seek random offset of file pointer """ 15 | f.seek(0, 2) 16 | max_offset = f.tell() - back_margin 17 | f.seek(randint(0, max_offset), 0) 18 | f.readline() # throw away an incomplete sentence 19 | 20 | class SentencePairDataLoader(): 21 | """ 22 | Load sentence pair (sequential or random order) from corpus. 23 | Input file format : 24 | 1. One sentence per line. These should ideally be actual sentences, 25 | not entire paragraphs or arbitrary spans of text. (Because we use 26 | the sentence boundaries for the "next sentence prediction" task). 27 | 2. Blank lines between documents. Document boundaries are needed 28 | so that the "next sentence prediction" task doesn't span between documents. 29 | """ 30 | def __init__(self, file, batch_size, tokenize, max_len, short_sampling_prob=0.1, pipeline=[]): 31 | super().__init__() 32 | self.f_pos = open(file, "r", encoding='utf-8', errors='ignore') # for a positive sample 33 | self.f_neg = open(file, "r", encoding='utf-8', errors='ignore') # for a negative (random) sample 34 | self.tokenize = tokenize # tokenize function 35 | self.max_len = max_len # maximum length of tokens 36 | self.short_sampling_prob = short_sampling_prob 37 | self.pipeline = pipeline 38 | self.batch_size = batch_size 39 | 40 | def read_tokens(self, f, length, discard_last_and_restart=True): 41 | """ Read tokens from file pointer with limited length """ 42 | tokens = [] 43 | while len(tokens) < length: 44 | line = f.readline() 45 | if not line: # end of file 46 | return None 47 | if not line.strip(): # blank line (delimiter of documents) 48 | if discard_last_and_restart: 49 | tokens = [] # throw all and restart 50 | continue 51 | else: 52 | return tokens # return last tokens in the document 53 | tokens.extend(self.tokenize(line.strip())) 54 | return tokens 55 | 56 | def __iter__(self): # iterator to load data 57 | while True: 58 | batch = [] 59 | for i in range(self.batch_size): 60 | # sampling length of each tokens_a and tokens_b 61 | # sometimes sample a short sentence to match between train and test sequences 62 | len_tokens = randint(1, int(self.max_len / 2)) \ 63 | if rand() < self.short_sampling_prob \ 64 | else int(self.max_len / 2) 65 | 66 | is_not_next = rand() < 0.5 # whether token_b is next to token_a or not 67 | 68 | tokens_a = self.read_tokens(self.f_pos, len_tokens, True) 69 | seek_random_offset(self.f_neg) 70 | f_next = self.f_neg if is_not_next else self.f_pos 71 | tokens_b = self.read_tokens(f_next, len_tokens, False) 72 | 73 | if tokens_a is None or tokens_b is None: # end of file 74 | self.f_pos.seek(0, 0) # reset file pointer 75 | return 76 | 77 | instance = (is_not_next, tokens_a, tokens_b) 78 | for proc in self.pipeline: 79 | instance = proc(instance) 80 | 81 | batch.append(instance) 82 | 83 | # To Tensor 84 | batch_tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*batch)] 85 | yield batch_tensors 86 | 87 | class PipelineForPretrain(): 88 | """ 89 | Pre-processing steps for pretraining transformer 90 | """ 91 | def __init__(self, max_pred, mask_prob, vocab_words, indexer, max_len=512): 92 | super().__init__() 93 | self.max_len = max_len 94 | self.max_pred = max_pred # max tokens of prediction 95 | self.mask_prob = mask_prob # masking probability 96 | self.vocab_words = vocab_words # vocabulary (sub)words 97 | self.indexer = indexer # function from token to token index 98 | self.max_len = max_len 99 | 100 | def __call__(self, instance): 101 | is_not_next, tokens_a, tokens_b = instance 102 | 103 | # -3 for special tokens [CLS], [SEP], [SEP] 104 | truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3) 105 | 106 | # Add Special Tokens 107 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] 108 | segment_ids = [0]*(len(tokens_a)+2) + [1]*(len(tokens_b)+1) 109 | input_mask = [1]*len(tokens) 110 | 111 | # For masked Language Models 112 | masked_tokens, masked_pos = [], [] 113 | # the number of prediction is sometimes less than max_pred when sequence is short 114 | n_pred = min(self.max_pred, max(1, int(round(len(tokens)*self.mask_prob)))) 115 | # candidate positions of masked tokens 116 | cand_pos = [i for i, token in enumerate(tokens) 117 | if token != '[CLS]' and token != '[SEP]'] 118 | shuffle(cand_pos) 119 | for pos in cand_pos[:n_pred]: 120 | masked_tokens.append(tokens[pos]) 121 | masked_pos.append(pos) 122 | if rand() < 0.8: # 80% 123 | tokens[pos] = '[MASK]' 124 | elif rand() < 0.5: # 10% 125 | tokens[pos] = get_random_word(self.vocab_words) 126 | # when n_pred < max_pred, we only calculate loss within n_pred 127 | masked_weights = [1]*len(masked_tokens) 128 | 129 | # Token Indexing 130 | input_ids = self.indexer(tokens) 131 | masked_ids = self.indexer(masked_tokens) 132 | 133 | # Zero Padding 134 | n_pad = self.max_len - len(input_ids) 135 | input_ids.extend([0]*n_pad) 136 | segment_ids.extend([0]*n_pad) 137 | input_mask.extend([0]*n_pad) 138 | 139 | # Zero Padding for masked target 140 | if self.max_pred > n_pred: 141 | n_pad = self.max_pred - n_pred 142 | masked_ids.extend([0]*n_pad) 143 | masked_pos.extend([0]*n_pad) 144 | masked_weights.extend([0]*n_pad) 145 | 146 | return (input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_not_next) -------------------------------------------------------------------------------- /train/.ipynb_checkpoints/pretrain_data-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloaders for pretraining a BERT-style masked language model. 3 | """ 4 | 5 | from random import randint, shuffle 6 | from random import random as rand 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from utils.utils import get_random_word, truncate_tokens_pair 12 | 13 | def seek_random_offset(f, back_margin=2000): 14 | """ Seek random offset of file pointer """ 15 | f.seek(0, 2) 16 | max_offset = f.tell() - back_margin 17 | f.seek(randint(0, max_offset), 0) 18 | f.readline() # throw away an incomplete sentence 19 | 20 | class SentencePairDataLoader(): 21 | """ 22 | Load sentence pair (sequential or random order) from corpus. 23 | Input file format : 24 | 1. One sentence per line. These should ideally be actual sentences, 25 | not entire paragraphs or arbitrary spans of text. (Because we use 26 | the sentence boundaries for the "next sentence prediction" task). 27 | 2. Blank lines between documents. Document boundaries are needed 28 | so that the "next sentence prediction" task doesn't span between documents. 29 | """ 30 | def __init__(self, file, batch_size, tokenize, max_len, short_sampling_prob=0.1, pipeline=[]): 31 | super().__init__() 32 | self.f_pos = open(file, "r", encoding='utf-8', errors='ignore') # for a positive sample 33 | self.f_neg = open(file, "r", encoding='utf-8', errors='ignore') # for a negative (random) sample 34 | self.tokenize = tokenize # tokenize function 35 | self.max_len = max_len # maximum length of tokens 36 | self.short_sampling_prob = short_sampling_prob 37 | self.pipeline = pipeline 38 | self.batch_size = batch_size 39 | 40 | def read_tokens(self, f, length, discard_last_and_restart=True): 41 | """ Read tokens from file pointer with limited length """ 42 | tokens = [] 43 | while len(tokens) < length: 44 | line = f.readline() 45 | if not line: # end of file 46 | return None 47 | if not line.strip(): # blank line (delimiter of documents) 48 | if discard_last_and_restart: 49 | tokens = [] # throw all and restart 50 | continue 51 | else: 52 | return tokens # return last tokens in the document 53 | tokens.extend(self.tokenize(line.strip())) 54 | return tokens 55 | 56 | def __iter__(self): # iterator to load data 57 | while True: 58 | batch = [] 59 | for i in range(self.batch_size): 60 | # sampling length of each tokens_a and tokens_b 61 | # sometimes sample a short sentence to match between train and test sequences 62 | len_tokens = randint(1, int(self.max_len / 2)) \ 63 | if rand() < self.short_sampling_prob \ 64 | else int(self.max_len / 2) 65 | 66 | is_not_next = rand() < 0.5 # whether token_b is next to token_a or not 67 | 68 | tokens_a = self.read_tokens(self.f_pos, len_tokens, True) 69 | seek_random_offset(self.f_neg) 70 | f_next = self.f_neg if is_not_next else self.f_pos 71 | tokens_b = self.read_tokens(f_next, len_tokens, False) 72 | 73 | if tokens_a is None or tokens_b is None: # end of file 74 | self.f_pos.seek(0, 0) # reset file pointer 75 | return 76 | 77 | instance = (is_not_next, tokens_a, tokens_b) 78 | for proc in self.pipeline: 79 | instance = proc(instance) 80 | 81 | batch.append(instance) 82 | 83 | # To Tensor 84 | batch_tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*batch)] 85 | yield batch_tensors 86 | 87 | class PipelineForPretrain(): 88 | """ 89 | Pre-processing steps for pretraining transformer 90 | """ 91 | def __init__(self, max_pred, mask_prob, vocab_words, indexer, max_len=512): 92 | super().__init__() 93 | self.max_len = max_len 94 | self.max_pred = max_pred # max tokens of prediction 95 | self.mask_prob = mask_prob # masking probability 96 | self.vocab_words = vocab_words # vocabulary (sub)words 97 | self.indexer = indexer # function from token to token index 98 | self.max_len = max_len 99 | 100 | def __call__(self, instance): 101 | is_not_next, tokens_a, tokens_b = instance 102 | 103 | # -3 for special tokens [CLS], [SEP], [SEP] 104 | truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3) 105 | 106 | # Add Special Tokens 107 | tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]'] 108 | segment_ids = [0]*(len(tokens_a)+2) + [1]*(len(tokens_b)+1) 109 | input_mask = [1]*len(tokens) 110 | 111 | # For masked Language Models 112 | masked_tokens, masked_pos = [], [] 113 | # the number of prediction is sometimes less than max_pred when sequence is short 114 | n_pred = min(self.max_pred, max(1, int(round(len(tokens)*self.mask_prob)))) 115 | # candidate positions of masked tokens 116 | cand_pos = [i for i, token in enumerate(tokens) 117 | if token != '[CLS]' and token != '[SEP]'] 118 | shuffle(cand_pos) 119 | for pos in cand_pos[:n_pred]: 120 | masked_tokens.append(tokens[pos]) 121 | masked_pos.append(pos) 122 | if rand() < 0.8: # 80% 123 | tokens[pos] = '[MASK]' 124 | elif rand() < 0.5: # 10% 125 | tokens[pos] = get_random_word(self.vocab_words) 126 | # when n_pred < max_pred, we only calculate loss within n_pred 127 | masked_weights = [1]*len(masked_tokens) 128 | 129 | # Token Indexing 130 | input_ids = self.indexer(tokens) 131 | masked_ids = self.indexer(masked_tokens) 132 | 133 | # Zero Padding 134 | n_pad = self.max_len - len(input_ids) 135 | input_ids.extend([0]*n_pad) 136 | segment_ids.extend([0]*n_pad) 137 | input_mask.extend([0]*n_pad) 138 | 139 | # Zero Padding for masked target 140 | if self.max_pred > n_pred: 141 | n_pad = self.max_pred - n_pred 142 | masked_ids.extend([0]*n_pad) 143 | masked_pos.extend([0]*n_pad) 144 | masked_weights.extend([0]*n_pad) 145 | 146 | return (input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_not_next) -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os, sys, argparse 2 | import time, datetime 3 | import random 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | 10 | from train.pretrain_data import PipelineForPretrain, SentencePairDataLoader 11 | from train.optim import get_optimizer 12 | from train.pretrain_trainer import Trainer 13 | from models import get_model_for_pretrain 14 | from utils import utils, tokenization 15 | 16 | # Required parameters 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--text_file", 19 | default=None, type=str, required=True, 20 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 21 | parser.add_argument("--model", 22 | default=None, type=str, required=True, 23 | help="Model name, for example: 'bert' ") 24 | parser.add_argument("--cfg", 25 | default=None, type=str, required=True, 26 | help="Model configuration file") 27 | parser.add_argument("--exp_name", 28 | default=None, type=str, required=True, 29 | help="Experiment output directory") 30 | 31 | # Other parameters 32 | parser.add_argument("--load_weights", 33 | type=str, default=None, 34 | help="A .ckpt or .pth file with a pretrained model.") 35 | parser.add_argument("--max_seq_length", 36 | type=int, default=512, 37 | help="Sequences longer than this will be truncated.") 38 | parser.add_argument("--val_every", 39 | type=int, default=-1, 40 | help="Validate on dev set every [X] iterations. Default is -1 (never).") 41 | parser.add_argument("--checkpoint_every", 42 | type=int, default=1000, 43 | help="Save model checkpoint every [X] iterations.") 44 | parser.add_argument("--do_lower_case", 45 | action='store_true', 46 | help="Set this flag if you are using an uncased model.") 47 | parser.add_argument("--vocab", 48 | type=str, default='config/bert-uncased-vocab.txt', 49 | help="File containing a BERT vocabulary.") 50 | parser.add_argument("--train_batch_size", 51 | type=int, default=32, 52 | help="Total batch size for training.") 53 | parser.add_argument("--val_batch_size", 54 | type=int, default=8, 55 | help="Validation/test batch size.") 56 | parser.add_argument("--learning_rate", 57 | type=float, default=2e-5, 58 | help="The initial learning rate for Adam.") 59 | parser.add_argument("--total_iterations", 60 | type=float, default=100000, # 1000000 61 | help="Total number of training iterations to perform.") 62 | parser.add_argument("--warmup_proportion", 63 | type=float, default=0.1, 64 | help="Linear training warmup proportion.") 65 | parser.add_argument("--local_rank", 66 | default=-1, type=int, 67 | help="For distributed training on gpus") 68 | parser.add_argument('--seed', 69 | default=42, type=int, 70 | help="Random seed") 71 | parser.add_argument('--gradient_accumulation_steps', 72 | default=1, type=int, 73 | help="Number of updates steps to accumulate before backprop.") 74 | parser.add_argument('--fp16', 75 | action='store_true', 76 | help="Whether to use 16-bit float precision instead of 32-bit") 77 | parser.add_argument('--loss_scale', 78 | type=float, default=0, 79 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 80 | "0 (default value): dynamic loss scaling.\n" 81 | "Positive power of 2: static loss scaling value.\n") 82 | parser.add_argument("--no_cuda", action='store_true', 83 | help="Disable GPUs and run on CPU.") 84 | parser.add_argument('--no_tensorboard', action='store_true', 85 | help="Disable tensorboard") 86 | 87 | # Parse and check args 88 | start_time = time.time() 89 | args = parser.parse_args() 90 | 91 | # Create output directory for saving models and logs 92 | args.output_dir = os.path.join('experiments', 'pretrain', args.exp_name) 93 | if not os.path.exists(args.output_dir): 94 | os.makedirs(args.output_dir) 95 | logger = utils.get_tensorboard_logger(args) 96 | 97 | # Select device 98 | if args.local_rank == -1 or args.no_cuda: 99 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 100 | args.n_gpu = torch.cuda.device_count() 101 | else: 102 | torch.cuda.set_device(args.local_rank) 103 | args.device = torch.device("cuda", args.local_rank) 104 | args.n_gpu = 1 105 | 106 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 107 | torch.distributed.init_process_group(backend='nccl') 108 | 109 | # Log GPU information 110 | logger.add_text('info', f"args: {args}") 111 | 112 | # Modify batch size if accumulating gradients 113 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 114 | 115 | # Reproducibility 116 | utils.set_seeds(args.seed, multi_gpu=args.n_gpu > 0) 117 | 118 | # Build dataloaders 119 | tokenizer = tokenization.FullTokenizer(args.vocab, do_lower_case=args.do_lower_case) 120 | tokenize = lambda x: tokenizer.tokenize(tokenizer.convert_to_unicode(x)) 121 | pipeline = [PipelineForPretrain(max_pred=20, # what is this? 122 | mask_prob=0.15, # actually this does nothing 123 | vocab_words=list(tokenizer.vocab.keys()), # 124 | indexer=tokenizer.convert_tokens_to_ids, 125 | max_len=args.max_seq_length)] 126 | dataloader = SentencePairDataLoader(args.text_file, 127 | batch_size=args.train_batch_size, 128 | tokenize=tokenize, 129 | max_len=args.max_seq_length, 130 | pipeline=pipeline) 131 | 132 | # Model, optimizer 133 | model = get_model_for_pretrain(args) 134 | optimizer = get_optimizer(args, model, t_total=1000000) 135 | 136 | # Train 137 | epoch = 0 138 | trainer = Trainer(logger) 139 | while trainer.global_step < args.total_iterations: 140 | 141 | # # Validation is not yet implemented -- run it for as long as possible 142 | # if args.val > 0 and trainer.global_step % args.val_every == 0: 143 | # trainer.evaluate(args, model, val_dataloader, criterion) # TODO: add val 144 | 145 | # Train for one epoch 146 | trainer.train(args, model, dataloader, optimizer, epoch) 147 | epoch += 1 148 | 149 | # Save trained model 150 | trainer.save(args, model) 151 | logger.info('Done pretraining!') 152 | -------------------------------------------------------------------------------- /models/lightweight.py: -------------------------------------------------------------------------------- 1 | import math 2 | import json 3 | from typing import NamedTuple 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from common_layers import Linear, LayerNorm, Embedding, gelu 11 | 12 | class LightweightConfig(NamedTuple): 13 | "Configuration for LWDC model" 14 | vocab_size: int = None # Size of Vocabulary 15 | dim: int = 512 # 768 # Dimension of Hidden Layer in Transformer Encoder 16 | n_layers: int = 6 # 12 # Numher of Hidden Layers 17 | n_heads: int = 8 # 12 # Numher of Heads in Multi-Headed Attention Layers 18 | dim_ff: int = 512*4 # 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net 19 | p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers 20 | p_drop_conv: float = 0.0 # 0.1 # Probability of Dropout of Attention Layers 21 | max_len: int = 128 # 512 # Maximum Length for Positional Embeddings 22 | n_segments: int = 2 # Number of Sentence Segments 23 | kernel_list: list = [3, 7, 15, 31, 31, 31] # convolutional kernels 24 | conv_type: str = 'lightweight' # either 'lightweight' or 'dynamic' 25 | glu_in_conv: bool = True # include generalized linear unit in the conv layer 26 | norm_before_conv: bool = True # layer norm before conv and after conv, not just after 27 | weight_softmax: bool = True # softmax the convolutional layer weights 28 | # tie_weights: bool = True # Share input and output weights # you have no choice 29 | 30 | @classmethod 31 | def from_json(cls, file): 32 | return cls(**json.load(open(file, "r"))) 33 | 34 | @classmethod 35 | def check(cfg): 36 | assert len(cfg.kernel_list) in [1, n_heads] 37 | assert cfg.conv_type in ['lightweight', 'dynamic'] 38 | assert cfg.dim % n_heads == 0 39 | 40 | class LightweightConv(nn.Module): 41 | '''Lightweight convolution from fairseq. 42 | Args: 43 | input_size: # of channels of the input and output 44 | kernel_size: convolution channels 45 | padding: padding 46 | num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) 47 | weight_softmax: normalize the weight with softmax before the convolution 48 | dropout: dropout probability 49 | Forward: 50 | Input: BxCxT, i.e. (batch_size, input_size, timesteps) 51 | Output: BxCxT, i.e. (batch_size, input_size, timesteps) 52 | Attributes: 53 | weight: learnable weights of shape `(num_heads, 1, kernel_size)` 54 | bias: learnable bias of shape `(input_size)` 55 | ''' 56 | def __init__(self, input_size, kernel_size=1, padding=0, n_heads=1, 57 | weight_softmax=True, bias=False, dropout=0.0): 58 | super().__init__() 59 | self.input_size = input_size 60 | self.kernel_size = kernel_size 61 | self.n_heads = n_heads 62 | self.padding = padding 63 | self.weight_softmax = weight_softmax 64 | self.weight = nn.Parameter(torch.Tensor(n_heads, 1, kernel_size)) 65 | self.bias = nn.Parameter(torch.Tensor(input_size)) if bias else None 66 | self.dropout = dropout 67 | self.reset_parameters() 68 | 69 | def forward(self, input): 70 | '''Takes input (B x C x T) to output (B x C x T)''' 71 | 72 | # Prepare weight (take softmax) 73 | B, C, T = input.size() 74 | H = self.n_heads 75 | weight = F.softmax(self.weight, dim=-1) if self.weight_softmax else self.weight 76 | weight = F.dropout(weight, self.weight_dropout, training=self.training) 77 | 78 | # Merge every C/H entries into the batch dimension (C = self.input_size) 79 | # B x C x T -> (B * C/H) x H x T 80 | # One can also expand the weight to C x 1 x K by a factor of C/H 81 | # and do not reshape the input instead, which is slow though 82 | input = input.view(-1, H, T) 83 | output = F.conv1d(input, weight, padding=self.padding, groups=H) 84 | output = output.view(B, C, T) 85 | if self.bias is not None: 86 | output = output + self.bias.view(1, -1, 1) 87 | return output 88 | 89 | class ConvBlock(nn.Module): 90 | """Lightweight or dynamic convolutional layer""" 91 | def __init__(self, cfg, kernel_size): 92 | self.norm_before_conv = cgf.norm_before_conv 93 | 94 | # Initial fully connected layer or GLU 95 | self.linear_1 = Linear(cfg.dim, cfg.dim * (2 if cfg.glu_in_conv else 1)) 96 | self.glu = nn.GLU() if cfg.glu_in_conv else None 97 | 98 | # Lightweight or dynamic convolution 99 | assert cfg.conv_type in ['lightweight', 'dynamic'] 100 | Conv = LightweightConv if cfg.conv_type == 'lightweight' else DynamicConv 101 | self.conv = Conv(cfg.dim, kernel_size=kernel_size, padding_l=kernel_size-1, # amount of padding 102 | weight_softmax=cfg.weight_softmax, n_heads=n_heads, dropout=cfg.p_drop_conv) 103 | 104 | # I do not think this second linear layer is necessary, but we will do it anyway 105 | self.linear_2 = nn.Linear(cfg.dim, cfg.dim) 106 | 107 | # Dropout and layer normalization 108 | self.dropout = nn.Dropout(cfg.p_drop_hidden) 109 | self.conv_layer_norm = LayerNorm(cfg.dim) 110 | 111 | # NOTE: This is where the encoder attention would go if there were any 112 | 113 | # Final linear layer: See Figure 2 in the LWDC paper 114 | self.fc1 = Linear(cfg.dim, cgf.dim_ff) 115 | self.fc2 = Linear(cgf.dim_ff, cfg.dim) 116 | self.final_layer_norm = LayerNorm(cfg.dim) 117 | 118 | def __forward__(self, cfg): 119 | '''See Figure 2(b) in the paper''' 120 | 121 | # Linear and GLU 122 | res = x 123 | if self.norm_before_conv: 124 | x = self.conv_layer_norm(x) 125 | x = self.dropout(self.linear_1(x)) 126 | x = x if self.glu is None else self.glu(x) 127 | 128 | # Conv 129 | x = self.conv(x) 130 | # x = self.linear_2(x) # I don't think this makes sense here 131 | x = self.dropout(x) # F.dropout(x, p=self.dropout, training=self.training) 132 | x = res + x 133 | x = self.conv_layer_norm(x) 134 | 135 | # Linear 136 | res = x 137 | if self.norm_before_conv: 138 | x = self.final_layer_norm(x) 139 | x = self.dropout(F.relu(self.fc1(x))) # use gelu? 140 | x = res + x 141 | x = self.final_layer_norm(x) 142 | return x 143 | 144 | class LightweightTransformer(nn.Module): 145 | """A LWDC model in the style of a transformer with Self-Attentive Blocks""" 146 | def __init__(self, cfg): 147 | super().__init__() 148 | self.embed = Embeddings(cfg, position_embeds=False, segment_embeds=True) 149 | kernel_list = cfg.kernel_list if len(cfg.kernel_list) > 1 else cfg.kernel_list * cfg.n_layers 150 | # ConvLayer = LightweightConvLayer if cfg.conv_type == 'lightweight' else DynamicConvLayer 151 | self.blocks = nn.ModuleList([ConvBlock(cfg, kernel_size=k) for k in kernel_list]) 152 | 153 | def forward(self, x, seg, mask): 154 | h = self.embed(x, seg) 155 | for block in self.blocks: 156 | h = block(h, mask) 157 | return h -------------------------------------------------------------------------------- /classify.py: -------------------------------------------------------------------------------- 1 | import os, sys, argparse 2 | import time, datetime 3 | import random 4 | import logging 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | 10 | from train import classification_data as data 11 | from train.optim import get_optimizer 12 | from train.classification_trainer import Trainer 13 | from models import get_model_for_classification 14 | from utils import utils, tokenization 15 | 16 | # Required parameters 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--data_dir", 19 | default=None, type=str, required=True, 20 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 21 | parser.add_argument("--model", 22 | default=None, type=str, required=True, 23 | help="Model name, for example: 'bert' ") 24 | parser.add_argument("--cfg", 25 | default=None, type=str, required=True, 26 | help="Model configuration file") 27 | parser.add_argument("--task_name", 28 | default=None, type=str, required=True, 29 | help="The name of the task to train.") 30 | parser.add_argument("--exp_name", 31 | default=None, type=str, required=True, 32 | help="The name of the experiment output directory where the model predictions and checkpoints will be written.") 33 | 34 | # Other parameters 35 | parser.add_argument("--load_weights", 36 | type=str, default=None, 37 | help="A .ckpt or .pth file with a pretrained model.") 38 | parser.add_argument("--max_seq_length", 39 | type=int, default=128, 40 | help="Sequences longer than this will be truncated.") 41 | parser.add_argument("--val_every", 42 | type=int, default=-1, 43 | help="Validate on dev set every [X] epochs while training. Default is -1 (never).") 44 | parser.add_argument("--do_lower_case", 45 | action='store_true', 46 | help="Set this flag if you are using an uncased model.") 47 | parser.add_argument("--vocab", 48 | type=str, default='config/bert-uncased-vocab.txt', 49 | help="File containing a BERT vocabulary.") 50 | parser.add_argument("--train_batch_size", 51 | type=int, default=32, 52 | help="Total batch size for training.") 53 | parser.add_argument("--val_batch_size", 54 | type=int, default=8, 55 | help="Validation/test batch size.") 56 | parser.add_argument("--learning_rate", 57 | type=float, default=2e-5, 58 | help="The initial learning rate for Adam.") 59 | parser.add_argument("--num_train_epochs", 60 | type=float, default=3.0, 61 | help="Total number of training epochs to perform.") 62 | parser.add_argument("--warmup_proportion", 63 | type=float, default=0.1, 64 | help="Linear training warmup proportion.") 65 | parser.add_argument("--local_rank", 66 | default=-1, type=int, 67 | help="For distributed training on gpus") 68 | parser.add_argument('--seed', 69 | default=42, type=int, 70 | help="Random seed") 71 | parser.add_argument('--gradient_accumulation_steps', 72 | default=1, type=int, 73 | help="Number of updates steps to accumulate before backprop.") 74 | parser.add_argument('--fp16', 75 | action='store_true', 76 | help="Whether to use 16-bit float precision instead of 32-bit") 77 | parser.add_argument('--loss_scale', 78 | type=float, default=0, 79 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" 80 | "0 (default value): dynamic loss scaling.\n" 81 | "Positive power of 2: static loss scaling value.\n") 82 | parser.add_argument("--no_cuda", action='store_true', 83 | help="Disable GPUs and run on CPU.") 84 | parser.add_argument('--no_tensorboard', action='store_true', 85 | help="Disable tensorboard") 86 | parser.add_argument('--do_test', action='store_true', 87 | help="[DeepMoji only for now] Evaluate on test after training") 88 | 89 | # Parse and check args 90 | start_time = time.time() 91 | args = parser.parse_args() 92 | 93 | # Create output directory for saving models and logs 94 | args.task_name = args.task_name.lower() 95 | args.output_dir = os.path.join('experiments', args.task_name, args.exp_name) 96 | if not os.path.exists(args.output_dir): 97 | os.makedirs(args.output_dir) 98 | logger = utils.get_tensorboard_logger(args) 99 | 100 | # Select device 101 | if args.local_rank == -1 or args.no_cuda: 102 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 103 | args.n_gpu = torch.cuda.device_count() 104 | else: 105 | torch.cuda.set_device(args.local_rank) 106 | args.device = torch.device("cuda", args.local_rank) 107 | args.n_gpu = 1 108 | 109 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 110 | torch.distributed.init_process_group(backend='nccl') 111 | 112 | # Log GPU information 113 | logger.add_text('info', f"args: {args}") 114 | 115 | # Modify batch size if accumulating gradients 116 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps 117 | 118 | # Reproducibility 119 | utils.set_seeds(args.seed, multi_gpu=args.n_gpu > 0) 120 | 121 | # Loss function 122 | args.output_mode = data.output_modes[args.task_name] 123 | if args.output_mode == 'classification': 124 | args.output_dtype = torch.long 125 | criterion = nn.CrossEntropyLoss() 126 | elif args.output_mode == "regression": 127 | args.output_dtype = torch.float 128 | criterion = nn.MSELoss() 129 | 130 | # Get the number of labels in our data in order to build the model 131 | args.label_list = data.processors[args.task_name]().get_labels() 132 | args.num_labels = len(args.label_list) 133 | 134 | # Build tokenizer 135 | tokenizer = tokenization.FullTokenizer(args.vocab, do_lower_case=args.do_lower_case) 136 | 137 | # Build model 138 | model = get_model_for_classification(args) 139 | 140 | # Build dataloaders 141 | if args.do_test: 142 | train_dataloader, val_dataloader, test_dataloader = data.prepare_dataloader(args, tokenizer, test=True) 143 | else: 144 | train_dataloader, val_dataloader = data.prepare_dataloader(args, tokenizer, test=False) 145 | logger.info(f"***** Loaded train [{len(train_dataloader)}] and val data [{len(val_dataloader)}]*****") 146 | 147 | # Build optimizer 148 | optimizer = get_optimizer(args, model, train_dataloader) 149 | 150 | # Trainer 151 | trainer = Trainer(logger) 152 | 153 | # Train 154 | for epoch in range(int(args.num_train_epochs)): 155 | 156 | # Validation 157 | if args.val_every > 0 and epoch % args.val_every == 0: 158 | trainer.evaluate(args, model, val_dataloader, criterion) 159 | 160 | # Train for one epoch 161 | trainer.train(args, model, train_dataloader, criterion, optimizer, epoch) 162 | 163 | # Save trained model 164 | trainer.save(args, model) 165 | 166 | # Finally, evaluate the model again 167 | loss, result = trainer.evaluate(args, model, val_dataloader, criterion) 168 | 169 | # If test 170 | if args.do_test: 171 | logger.info('******* Test evaluation *******') 172 | logger.info("... is not yet implemented") 173 | 174 | -------------------------------------------------------------------------------- /train/optim.py: -------------------------------------------------------------------------------- 1 | """ 2 | A slightly modified version of Hugging Face's BERTAdam class 3 | """ 4 | 5 | import math 6 | import torch 7 | from torch.optim import Optimizer 8 | from torch.nn.utils import clip_grad_norm_ 9 | 10 | def warmup_cosine(x, warmup=0.002): 11 | if x < warmup: 12 | return x/warmup 13 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 14 | 15 | def warmup_constant(x, warmup=0.002): 16 | if x < warmup: 17 | return x/warmup 18 | return 1.0 19 | 20 | def warmup_linear(x, warmup=0.002): 21 | if x < warmup: 22 | return x/warmup 23 | return 1.0 - x 24 | 25 | SCHEDULES = { 26 | 'warmup_cosine':warmup_cosine, 27 | 'warmup_constant':warmup_constant, 28 | 'warmup_linear':warmup_linear, 29 | } 30 | 31 | class BertAdam(Optimizer): 32 | """Implements BERT version of Adam algorithm with weight decay fix. 33 | Params: 34 | lr: learning rate 35 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 36 | t_total: total number of training steps for the learning 37 | rate schedule, -1 means constant learning rate. Default: -1 38 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 39 | b1: Adams b1. Default: 0.9 40 | b2: Adams b2. Default: 0.999 41 | e: Adams epsilon. Default: 1e-6 42 | weight_decay_rate: Weight decay. Default: 0.01 43 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 44 | """ 45 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 46 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 47 | max_grad_norm=1.0): 48 | assert lr > 0.0, "Learning rate: %f - should be > 0.0" % (lr) 49 | assert schedule in SCHEDULES, "Invalid schedule : %s" % (schedule) 50 | assert 0.0 <= warmup < 1.0 or warmup == -1.0, \ 51 | "Warmup %f - should be in 0.0 ~ 1.0 or -1 (no warm up)" % (warmup) 52 | assert 0.0 <= b1 < 1.0, "b1: %f - should be in 0.0 ~ 1.0" % (b1) 53 | assert 0.0 <= b2 < 1.0, "b2: %f - should be in 0.0 ~ 1.0" % (b2) 54 | assert e > 0.0, "epsilon: %f - should be > 0.0" % (e) 55 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 56 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 57 | max_grad_norm=max_grad_norm) 58 | super(BertAdam, self).__init__(params, defaults) 59 | 60 | def get_lr(self): 61 | """ get learning rate in training """ 62 | lr = [] 63 | for group in self.param_groups: 64 | for p in group['params']: 65 | state = self.state[p] 66 | if not state: 67 | return [0] 68 | if group['t_total'] != -1: 69 | schedule_fct = SCHEDULES[group['schedule']] 70 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 71 | else: 72 | lr_scheduled = group['lr'] 73 | lr.append(lr_scheduled) 74 | return lr 75 | 76 | def step(self, closure=None): 77 | """Performs a single optimization step. 78 | 79 | Arguments: 80 | closure (callable, optional): A closure that reevaluates the model 81 | and returns the loss. 82 | """ 83 | loss = None 84 | if closure is not None: 85 | loss = closure() 86 | 87 | for group in self.param_groups: 88 | for p in group['params']: 89 | if p.grad is None: 90 | continue 91 | grad = p.grad.data 92 | if grad.is_sparse: 93 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 94 | 95 | state = self.state[p] 96 | 97 | # State initialization 98 | if not state: 99 | state['step'] = 0 100 | # Exponential moving average of gradient values 101 | state['next_m'] = torch.zeros_like(p.data) 102 | # Exponential moving average of squared gradient values 103 | state['next_v'] = torch.zeros_like(p.data) 104 | 105 | next_m, next_v = state['next_m'], state['next_v'] 106 | beta1, beta2 = group['b1'], group['b2'] 107 | 108 | # Add grad clipping 109 | if group['max_grad_norm'] > 0: 110 | clip_grad_norm_(p, group['max_grad_norm']) 111 | 112 | # Decay the first and second moment running average coefficient 113 | # In-place operations to update the averages at the same time 114 | next_m.mul_(beta1).add_(1 - beta1, grad) 115 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 116 | update = next_m / (next_v.sqrt() + group['e']) 117 | 118 | # Just adding the square of the weights to the loss function is *not* 119 | # the correct way of using L2 regularization/weight decay with Adam, 120 | # since that will interact with the m and v parameters in strange ways. 121 | # 122 | # Instead we want to decay the weights in a manner that doesn't interact 123 | # with the m/v parameters. This is equivalent to adding the square 124 | # of the weights to the loss with plain (non-momentum) SGD. 125 | if group['weight_decay_rate'] > 0.0: 126 | update += group['weight_decay_rate'] * p.data 127 | 128 | if group['t_total'] != -1: 129 | schedule_fct = SCHEDULES[group['schedule']] 130 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 131 | else: 132 | lr_scheduled = group['lr'] 133 | 134 | update_with_lr = lr_scheduled * update 135 | p.data.add_(-update_with_lr) 136 | state['step'] += 1 137 | 138 | return loss 139 | 140 | def get_optimizer(args, model, train_dataloader=None, t_total=None): 141 | '''A helper function for getting a full-precision or half-precision optimizer''' 142 | assert train_dataloader or t_total 143 | 144 | # Get number of training steps for calculating optimizer warmup 145 | if not t_total: 146 | num_train_batches = len(train_dataloader.dataset) / args.train_batch_size 147 | num_train_optimization_steps = int(num_train_batches / args.gradient_accumulation_steps) * args.num_train_epochs 148 | if args.local_rank != -1: 149 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 150 | t_total = num_train_optimization_steps 151 | 152 | # Apply weight decay to all but a few parameters 153 | param_optimizer = list(model.named_parameters()) 154 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 155 | optimizer_grouped_parameters = [ 156 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 157 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 158 | ] 159 | 160 | # Half-precision 161 | if args.fp16: 162 | try: 163 | from apex.optimizers import FP16_Optimizer, FusedAdam 164 | except ImportError: 165 | raise ImportError("To use fp16, install apex from https://www.github.com/nvidia/apex") 166 | optimizer = FusedAdam(optimizer_grouped_parameters, 167 | lr=args.learning_rate, 168 | bias_correction=False, 169 | max_grad_norm=1.0) 170 | if args.loss_scale == 0: 171 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 172 | else: 173 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 174 | 175 | # Adam optimizer with Transformer warmup steps 176 | else: 177 | optimizer = BertAdam(optimizer_grouped_parameters, 178 | lr=args.learning_rate, 179 | warmup=args.warmup_proportion, 180 | t_total=t_total) 181 | return optimizer -------------------------------------------------------------------------------- /train/.ipynb_checkpoints/optim-checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | A slightly modified version of Hugging Face's BERTAdam class 3 | """ 4 | 5 | import math 6 | import torch 7 | from torch.optim import Optimizer 8 | from torch.nn.utils import clip_grad_norm_ 9 | 10 | def warmup_cosine(x, warmup=0.002): 11 | if x < warmup: 12 | return x/warmup 13 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 14 | 15 | def warmup_constant(x, warmup=0.002): 16 | if x < warmup: 17 | return x/warmup 18 | return 1.0 19 | 20 | def warmup_linear(x, warmup=0.002): 21 | if x < warmup: 22 | return x/warmup 23 | return 1.0 - x 24 | 25 | SCHEDULES = { 26 | 'warmup_cosine':warmup_cosine, 27 | 'warmup_constant':warmup_constant, 28 | 'warmup_linear':warmup_linear, 29 | } 30 | 31 | class BertAdam(Optimizer): 32 | """Implements BERT version of Adam algorithm with weight decay fix. 33 | Params: 34 | lr: learning rate 35 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 36 | t_total: total number of training steps for the learning 37 | rate schedule, -1 means constant learning rate. Default: -1 38 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 39 | b1: Adams b1. Default: 0.9 40 | b2: Adams b2. Default: 0.999 41 | e: Adams epsilon. Default: 1e-6 42 | weight_decay_rate: Weight decay. Default: 0.01 43 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 44 | """ 45 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 46 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 47 | max_grad_norm=1.0): 48 | assert lr > 0.0, "Learning rate: %f - should be > 0.0" % (lr) 49 | assert schedule in SCHEDULES, "Invalid schedule : %s" % (schedule) 50 | assert 0.0 <= warmup < 1.0 or warmup == -1.0, \ 51 | "Warmup %f - should be in 0.0 ~ 1.0 or -1 (no warm up)" % (warmup) 52 | assert 0.0 <= b1 < 1.0, "b1: %f - should be in 0.0 ~ 1.0" % (b1) 53 | assert 0.0 <= b2 < 1.0, "b2: %f - should be in 0.0 ~ 1.0" % (b2) 54 | assert e > 0.0, "epsilon: %f - should be > 0.0" % (e) 55 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 56 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 57 | max_grad_norm=max_grad_norm) 58 | super(BertAdam, self).__init__(params, defaults) 59 | 60 | def get_lr(self): 61 | """ get learning rate in training """ 62 | lr = [] 63 | for group in self.param_groups: 64 | for p in group['params']: 65 | state = self.state[p] 66 | if not state: 67 | return [0] 68 | if group['t_total'] != -1: 69 | schedule_fct = SCHEDULES[group['schedule']] 70 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 71 | else: 72 | lr_scheduled = group['lr'] 73 | lr.append(lr_scheduled) 74 | return lr 75 | 76 | def step(self, closure=None): 77 | """Performs a single optimization step. 78 | 79 | Arguments: 80 | closure (callable, optional): A closure that reevaluates the model 81 | and returns the loss. 82 | """ 83 | loss = None 84 | if closure is not None: 85 | loss = closure() 86 | 87 | for group in self.param_groups: 88 | for p in group['params']: 89 | if p.grad is None: 90 | continue 91 | grad = p.grad.data 92 | if grad.is_sparse: 93 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 94 | 95 | state = self.state[p] 96 | 97 | # State initialization 98 | if not state: 99 | state['step'] = 0 100 | # Exponential moving average of gradient values 101 | state['next_m'] = torch.zeros_like(p.data) 102 | # Exponential moving average of squared gradient values 103 | state['next_v'] = torch.zeros_like(p.data) 104 | 105 | next_m, next_v = state['next_m'], state['next_v'] 106 | beta1, beta2 = group['b1'], group['b2'] 107 | 108 | # Add grad clipping 109 | if group['max_grad_norm'] > 0: 110 | clip_grad_norm_(p, group['max_grad_norm']) 111 | 112 | # Decay the first and second moment running average coefficient 113 | # In-place operations to update the averages at the same time 114 | next_m.mul_(beta1).add_(1 - beta1, grad) 115 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 116 | update = next_m / (next_v.sqrt() + group['e']) 117 | 118 | # Just adding the square of the weights to the loss function is *not* 119 | # the correct way of using L2 regularization/weight decay with Adam, 120 | # since that will interact with the m and v parameters in strange ways. 121 | # 122 | # Instead we want to decay the weights in a manner that doesn't interact 123 | # with the m/v parameters. This is equivalent to adding the square 124 | # of the weights to the loss with plain (non-momentum) SGD. 125 | if group['weight_decay_rate'] > 0.0: 126 | update += group['weight_decay_rate'] * p.data 127 | 128 | if group['t_total'] != -1: 129 | schedule_fct = SCHEDULES[group['schedule']] 130 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 131 | else: 132 | lr_scheduled = group['lr'] 133 | 134 | update_with_lr = lr_scheduled * update 135 | p.data.add_(-update_with_lr) 136 | state['step'] += 1 137 | 138 | return loss 139 | 140 | def get_optimizer(args, model, train_dataloader=None, t_total=None): 141 | '''A helper function for getting a full-precision or half-precision optimizer''' 142 | assert train_dataloader or t_total 143 | 144 | # Get number of training steps for calculating optimizer warmup 145 | if not t_total: 146 | num_train_batches = len(train_dataloader.dataset) / args.train_batch_size 147 | num_train_optimization_steps = int(num_train_batches / args.gradient_accumulation_steps) * args.num_train_epochs 148 | if args.local_rank != -1: 149 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() 150 | t_total = num_train_optimization_steps 151 | 152 | # Apply weight decay to all but a few parameters 153 | param_optimizer = list(model.named_parameters()) 154 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 155 | optimizer_grouped_parameters = [ 156 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 157 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 158 | ] 159 | 160 | # Half-precision 161 | if args.fp16: 162 | try: 163 | from apex.optimizers import FP16_Optimizer, FusedAdam 164 | except ImportError: 165 | raise ImportError("To use fp16, install apex from https://www.github.com/nvidia/apex") 166 | optimizer = FusedAdam(optimizer_grouped_parameters, 167 | lr=args.learning_rate, 168 | bias_correction=False, 169 | max_grad_norm=1.0) 170 | if args.loss_scale == 0: 171 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) 172 | else: 173 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) 174 | 175 | # Adam optimizer with Transformer warmup steps 176 | else: 177 | optimizer = BertAdam(optimizer_grouped_parameters, 178 | lr=args.learning_rate, 179 | warmup=args.warmup_proportion, 180 | t_total=t_total) 181 | return optimizer -------------------------------------------------------------------------------- /scripts/download-glue.py: -------------------------------------------------------------------------------- 1 | ''' Script for downloading all GLUE data. 2 | Note: for legal reasons, we are unable to host MRPC. 3 | You can either use the version hosted by the SentEval team, which is already tokenized, 4 | or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually. 5 | For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example). 6 | You should then rename and place specific files in a folder (see below for an example). 7 | mkdir MRPC 8 | cabextract MSRParaphraseCorpus.msi -d MRPC 9 | cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt 10 | cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt 11 | rm MRPC/_* 12 | rm MSRParaphraseCorpus.msi 13 | 1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now. 14 | 2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray! 15 | ''' 16 | 17 | import os 18 | import sys 19 | import shutil 20 | import argparse 21 | import tempfile 22 | import urllib.request 23 | import zipfile 24 | 25 | TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"] 26 | TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4', 27 | "SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', 28 | "MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc', 29 | "QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5', 30 | "STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5', 31 | "MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce', 32 | "SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df', 33 | "QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601', 34 | "RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb', 35 | "WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf', 36 | "diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'} 37 | 38 | MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt' 39 | MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt' 40 | 41 | def download_and_extract(task, data_dir): 42 | print("Downloading and extracting %s..." % task) 43 | data_file = "%s.zip" % task 44 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 45 | with zipfile.ZipFile(data_file) as zip_ref: 46 | zip_ref.extractall(data_dir) 47 | os.remove(data_file) 48 | print("\tCompleted!") 49 | 50 | def format_mrpc(data_dir, path_to_data): 51 | print("Processing MRPC...") 52 | mrpc_dir = os.path.join(data_dir, "MRPC") 53 | if not os.path.isdir(mrpc_dir): 54 | os.mkdir(mrpc_dir) 55 | if path_to_data: 56 | mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt") 57 | mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt") 58 | else: 59 | print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN) 60 | mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt") 61 | mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt") 62 | urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file) 63 | urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file) 64 | assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file 65 | assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file 66 | urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv")) 67 | 68 | dev_ids = [] 69 | with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh: 70 | for row in ids_fh: 71 | dev_ids.append(row.strip().split('\t')) 72 | 73 | with open(mrpc_train_file, encoding="utf8") as data_fh, \ 74 | open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \ 75 | open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh: 76 | header = data_fh.readline() 77 | train_fh.write(header) 78 | dev_fh.write(header) 79 | for row in data_fh: 80 | label, id1, id2, s1, s2 = row.strip().split('\t') 81 | if [id1, id2] in dev_ids: 82 | dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 83 | else: 84 | train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2)) 85 | 86 | with open(mrpc_test_file, encoding="utf8") as data_fh, \ 87 | open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh: 88 | header = data_fh.readline() 89 | test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 90 | for idx, row in enumerate(data_fh): 91 | label, id1, id2, s1, s2 = row.strip().split('\t') 92 | test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2)) 93 | print("\tCompleted!") 94 | 95 | def download_diagnostic(data_dir): 96 | print("Downloading and extracting diagnostic...") 97 | if not os.path.isdir(os.path.join(data_dir, "diagnostic")): 98 | os.mkdir(os.path.join(data_dir, "diagnostic")) 99 | data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv") 100 | urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file) 101 | print("\tCompleted!") 102 | return 103 | 104 | def get_tasks(task_names): 105 | task_names = task_names.split(',') 106 | if "all" in task_names: 107 | tasks = TASKS 108 | else: 109 | tasks = [] 110 | for task_name in task_names: 111 | assert task_name in TASKS, "Task %s not found!" % task_name 112 | tasks.append(task_name) 113 | return tasks 114 | 115 | def main(arguments): 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data') 118 | parser.add_argument('--tasks', help='tasks to download data for as a comma separated string', 119 | type=str, default='all') 120 | parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt', 121 | type=str, default='') 122 | args = parser.parse_args(arguments) 123 | 124 | if not os.path.isdir(args.data_dir): 125 | os.mkdir(args.data_dir) 126 | tasks = get_tasks(args.tasks) 127 | 128 | for task in tasks: 129 | if task == 'MRPC': 130 | format_mrpc(args.data_dir, args.path_to_mrpc) 131 | elif task == 'diagnostic': 132 | download_diagnostic(args.data_dir) 133 | else: 134 | download_and_extract(task, args.data_dir) 135 | 136 | 137 | if __name__ == '__main__': 138 | sys.exit(main(sys.argv[1:])) 139 | -------------------------------------------------------------------------------- /utils/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from urllib.parse import urlparse 27 | except ImportError: 28 | from urlparse import urlparse 29 | 30 | try: 31 | from pathlib import Path 32 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 33 | Path.home() / '.pytorch_pretrained_bert')) 34 | except (AttributeError, ImportError): 35 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 36 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 37 | 38 | CONFIG_NAME = "config.json" 39 | WEIGHTS_NAME = "pytorch_model.bin" 40 | 41 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | def url_to_filename(url, etag=None): 45 | """ 46 | Convert `url` into a hashed filename in a repeatable way. 47 | If `etag` is specified, append its hash to the url's, delimited 48 | by a period. 49 | """ 50 | url_bytes = url.encode('utf-8') 51 | url_hash = sha256(url_bytes) 52 | filename = url_hash.hexdigest() 53 | 54 | if etag: 55 | etag_bytes = etag.encode('utf-8') 56 | etag_hash = sha256(etag_bytes) 57 | filename += '.' + etag_hash.hexdigest() 58 | 59 | return filename 60 | 61 | 62 | def filename_to_url(filename, cache_dir=None): 63 | """ 64 | Return the url and etag (which may be ``None``) stored for `filename`. 65 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 66 | """ 67 | if cache_dir is None: 68 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 69 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 70 | cache_dir = str(cache_dir) 71 | 72 | cache_path = os.path.join(cache_dir, filename) 73 | if not os.path.exists(cache_path): 74 | raise EnvironmentError("file {} not found".format(cache_path)) 75 | 76 | meta_path = cache_path + '.json' 77 | if not os.path.exists(meta_path): 78 | raise EnvironmentError("file {} not found".format(meta_path)) 79 | 80 | with open(meta_path, encoding="utf-8") as meta_file: 81 | metadata = json.load(meta_file) 82 | url = metadata['url'] 83 | etag = metadata['etag'] 84 | 85 | return url, etag 86 | 87 | 88 | def cached_path(url_or_filename, cache_dir=None): 89 | """ 90 | Given something that might be a URL (or might be a local path), 91 | determine which. If it's a URL, download the file and cache it, and 92 | return the path to the cached file. If it's already a local path, 93 | make sure the file exists and then return the path. 94 | """ 95 | if cache_dir is None: 96 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 97 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 98 | url_or_filename = str(url_or_filename) 99 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 100 | cache_dir = str(cache_dir) 101 | 102 | parsed = urlparse(url_or_filename) 103 | 104 | if parsed.scheme in ('http', 'https', 's3'): 105 | # URL, so get it from the cache (downloading if necessary) 106 | return get_from_cache(url_or_filename, cache_dir) 107 | elif os.path.exists(url_or_filename): 108 | # File, and it exists. 109 | return url_or_filename 110 | elif parsed.scheme == '': 111 | # File, but it doesn't exist. 112 | raise EnvironmentError("file {} not found".format(url_or_filename)) 113 | else: 114 | # Something unknown 115 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 116 | 117 | 118 | def split_s3_path(url): 119 | """Split a full s3 path into the bucket name and path.""" 120 | parsed = urlparse(url) 121 | if not parsed.netloc or not parsed.path: 122 | raise ValueError("bad s3 path {}".format(url)) 123 | bucket_name = parsed.netloc 124 | s3_path = parsed.path 125 | # Remove '/' at beginning of path. 126 | if s3_path.startswith("/"): 127 | s3_path = s3_path[1:] 128 | return bucket_name, s3_path 129 | 130 | 131 | def s3_request(func): 132 | """ 133 | Wrapper function for s3 requests in order to create more helpful error 134 | messages. 135 | """ 136 | 137 | @wraps(func) 138 | def wrapper(url, *args, **kwargs): 139 | try: 140 | return func(url, *args, **kwargs) 141 | except ClientError as exc: 142 | if int(exc.response["Error"]["Code"]) == 404: 143 | raise EnvironmentError("file {} not found".format(url)) 144 | else: 145 | raise 146 | 147 | return wrapper 148 | 149 | 150 | @s3_request 151 | def s3_etag(url): 152 | """Check ETag on S3 object.""" 153 | s3_resource = boto3.resource("s3") 154 | bucket_name, s3_path = split_s3_path(url) 155 | s3_object = s3_resource.Object(bucket_name, s3_path) 156 | return s3_object.e_tag 157 | 158 | 159 | @s3_request 160 | def s3_get(url, temp_file): 161 | """Pull a file directly from S3.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 165 | 166 | 167 | def http_get(url, temp_file): 168 | req = requests.get(url, stream=True) 169 | content_length = req.headers.get('Content-Length') 170 | total = int(content_length) if content_length is not None else None 171 | progress = tqdm(unit="B", total=total) 172 | for chunk in req.iter_content(chunk_size=1024): 173 | if chunk: # filter out keep-alive new chunks 174 | progress.update(len(chunk)) 175 | temp_file.write(chunk) 176 | progress.close() 177 | 178 | 179 | def get_from_cache(url, cache_dir=None): 180 | """ 181 | Given a URL, look for the corresponding dataset in the local cache. 182 | If it's not there, download it. Then return the path to the cached file. 183 | """ 184 | if cache_dir is None: 185 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 186 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 187 | cache_dir = str(cache_dir) 188 | 189 | if not os.path.exists(cache_dir): 190 | os.makedirs(cache_dir) 191 | 192 | # Get eTag to add to filename, if it exists. 193 | if url.startswith("s3://"): 194 | etag = s3_etag(url) 195 | else: 196 | try: 197 | response = requests.head(url, allow_redirects=True) 198 | if response.status_code != 200: 199 | etag = None 200 | else: 201 | etag = response.headers.get("ETag") 202 | except EnvironmentError: 203 | etag = None 204 | 205 | if sys.version_info[0] == 2 and etag is not None: 206 | etag = etag.decode('utf-8') 207 | filename = url_to_filename(url, etag) 208 | 209 | # get cache path to put the file 210 | cache_path = os.path.join(cache_dir, filename) 211 | 212 | # If we don't have a connection (etag is None) and can't identify the file 213 | # try to get the last downloaded one 214 | if not os.path.exists(cache_path) and etag is None: 215 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 216 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 217 | if matching_files: 218 | cache_path = os.path.join(cache_dir, matching_files[-1]) 219 | 220 | if not os.path.exists(cache_path): 221 | # Download to temporary file, then copy to cache dir once finished. 222 | # Otherwise you get corrupt cache entries if the download gets interrupted. 223 | with tempfile.NamedTemporaryFile() as temp_file: 224 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 225 | 226 | # GET file object 227 | if url.startswith("s3://"): 228 | s3_get(url, temp_file) 229 | else: 230 | http_get(url, temp_file) 231 | 232 | # we are copying the file before closing it, so flush to avoid truncation 233 | temp_file.flush() 234 | # shutil.copyfileobj() starts at the current position, so go to the start 235 | temp_file.seek(0) 236 | 237 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 238 | with open(cache_path, 'wb') as cache_file: 239 | shutil.copyfileobj(temp_file, cache_file) 240 | 241 | logger.info("creating metadata file for %s", cache_path) 242 | meta = {'url': url, 'etag': etag} 243 | meta_path = cache_path + '.json' 244 | with open(meta_path, 'w') as meta_file: 245 | output_string = json.dumps(meta) 246 | if sys.version_info[0] == 2 and isinstance(output_string, str): 247 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 248 | meta_file.write(output_string) 249 | 250 | logger.info("removing temp file %s", temp_file.name) 251 | 252 | return cache_path 253 | 254 | 255 | def read_set_from_file(filename): 256 | ''' 257 | Extract a de-duped collection (set) of text from a file. 258 | Expected file format is one item per line. 259 | ''' 260 | collection = set() 261 | with open(filename, 'r', encoding='utf-8') as file_: 262 | for line in file_: 263 | collection.add(line.rstrip()) 264 | return collection 265 | 266 | 267 | def get_file_extension(path, dot=True, lower=True): 268 | ext = os.path.splitext(path)[1] 269 | ext = ext if dot else ext[1:] 270 | return ext.lower() if lower else ext 271 | -------------------------------------------------------------------------------- /utils/tokenization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tokenization [from Google BERT] 3 | """ 4 | 5 | import collections 6 | import unicodedata 7 | import six 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class FullTokenizer(object): 13 | """Runs end-to-end tokenziation.""" 14 | 15 | def __init__(self, vocab_file, do_lower_case=True, max_len=512): 16 | self.vocab = load_vocab(vocab_file) 17 | self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) 18 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 19 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 20 | self.max_len = max_len 21 | 22 | def tokenize(self, text): 23 | split_tokens = [] 24 | for token in self.basic_tokenizer.tokenize(text): 25 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 26 | split_tokens.append(sub_token) 27 | return split_tokens 28 | 29 | def convert_tokens_to_ids(self, tokens): 30 | ids = [] 31 | for token in tokens: 32 | ids.append(self.vocab[token]) 33 | if len(ids) > self.max_len: 34 | logger.warning(f"[TOKENIZER WARNING] Token indices sequence length longer than {self.max_len}") 35 | return ids 36 | 37 | def convert_ids_to_tokens(self, ids): 38 | """Converts a sequence of ids in wordpiece tokens using the vocab.""" 39 | tokens = [] 40 | for i in ids: 41 | tokens.append(self.ids_to_tokens[i]) 42 | return tokens 43 | 44 | def convert_to_unicode(self, text): 45 | return convert_to_unicode(text) 46 | 47 | 48 | class BasicTokenizer(object): 49 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 50 | 51 | def __init__(self, do_lower_case=True): 52 | """Constructs a BasicTokenizer. 53 | 54 | Args: 55 | do_lower_case: Whether to lower case the input. 56 | """ 57 | self.do_lower_case = do_lower_case 58 | 59 | def tokenize(self, text): 60 | """Tokenizes a piece of text.""" 61 | text = convert_to_unicode(text) 62 | text = self._clean_text(text) 63 | orig_tokens = whitespace_tokenize(text) 64 | split_tokens = [] 65 | for token in orig_tokens: 66 | if self.do_lower_case: 67 | token = token.lower() 68 | token = self._run_strip_accents(token) 69 | split_tokens.extend(self._run_split_on_punc(token)) 70 | 71 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 72 | return output_tokens 73 | 74 | def _run_strip_accents(self, text): 75 | """Strips accents from a piece of text.""" 76 | text = unicodedata.normalize("NFD", text) 77 | output = [] 78 | for char in text: 79 | cat = unicodedata.category(char) 80 | if cat == "Mn": 81 | continue 82 | output.append(char) 83 | return "".join(output) 84 | 85 | def _run_split_on_punc(self, text): 86 | """Splits punctuation on a piece of text.""" 87 | chars = list(text) 88 | i = 0 89 | start_new_word = True 90 | output = [] 91 | while i < len(chars): 92 | char = chars[i] 93 | if _is_punctuation(char): 94 | output.append([char]) 95 | start_new_word = True 96 | else: 97 | if start_new_word: 98 | output.append([]) 99 | start_new_word = False 100 | output[-1].append(char) 101 | i += 1 102 | 103 | return ["".join(x) for x in output] 104 | 105 | def _clean_text(self, text): 106 | """Performs invalid character removal and whitespace cleanup on text.""" 107 | output = [] 108 | for char in text: 109 | cp = ord(char) 110 | if cp == 0 or cp == 0xfffd or _is_control(char): 111 | continue 112 | if _is_whitespace(char): 113 | output.append(" ") 114 | else: 115 | output.append(char) 116 | return "".join(output) 117 | 118 | 119 | class WordpieceTokenizer(object): 120 | """Runs WordPiece tokenization.""" 121 | 122 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 123 | self.vocab = vocab 124 | self.unk_token = unk_token 125 | self.max_input_chars_per_word = max_input_chars_per_word 126 | 127 | def tokenize(self, text): 128 | """Tokenizes a piece of text into its word pieces. 129 | 130 | This uses a greedy longest-match-first algorithm to perform tokenization 131 | using the given vocabulary. 132 | 133 | For example: 134 | input = "unaffable" 135 | output = ["un", "##aff", "##able"] 136 | 137 | Args: 138 | text: A single token or whitespace separated tokens. This should have 139 | already been passed through `BasicTokenizer. 140 | 141 | Returns: 142 | A list of wordpiece tokens. 143 | """ 144 | 145 | text = convert_to_unicode(text) 146 | 147 | output_tokens = [] 148 | for token in whitespace_tokenize(text): 149 | chars = list(token) 150 | if len(chars) > self.max_input_chars_per_word: 151 | output_tokens.append(self.unk_token) 152 | continue 153 | 154 | is_bad = False 155 | start = 0 156 | sub_tokens = [] 157 | while start < len(chars): 158 | end = len(chars) 159 | cur_substr = None 160 | while start < end: 161 | substr = "".join(chars[start:end]) 162 | if start > 0: 163 | substr = "##" + substr 164 | if substr in self.vocab: 165 | cur_substr = substr 166 | break 167 | end -= 1 168 | if cur_substr is None: 169 | is_bad = True 170 | break 171 | sub_tokens.append(cur_substr) 172 | start = end 173 | 174 | if is_bad: 175 | output_tokens.append(self.unk_token) 176 | else: 177 | output_tokens.extend(sub_tokens) 178 | return output_tokens 179 | 180 | 181 | def _is_whitespace(char): 182 | """Checks whether `chars` is a whitespace character.""" 183 | # \t, \n, and \r are technically contorl characters but we treat them 184 | # as whitespace since they are generally considered as such. 185 | if char == " " or char == "\t" or char == "\n" or char == "\r": 186 | return True 187 | cat = unicodedata.category(char) 188 | if cat == "Zs": 189 | return True 190 | return False 191 | 192 | 193 | def _is_control(char): 194 | """Checks whether `chars` is a control character.""" 195 | # These are technically control characters but we count them as whitespace 196 | # characters. 197 | if char == "\t" or char == "\n" or char == "\r": 198 | return False 199 | cat = unicodedata.category(char) 200 | if cat.startswith("C"): 201 | return True 202 | return False 203 | 204 | 205 | def _is_punctuation(char): 206 | """Checks whether `chars` is a punctuation character.""" 207 | cp = ord(char) 208 | # We treat all non-letter/number ASCII as punctuation. 209 | # Characters such as "^", "$", and "`" are not in the Unicode 210 | # Punctuation class but we treat them as punctuation anyways, for 211 | # consistency. 212 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 213 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 214 | return True 215 | cat = unicodedata.category(char) 216 | if cat.startswith("P"): 217 | return True 218 | return False 219 | 220 | def convert_to_unicode(text): 221 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 222 | if six.PY3: 223 | if isinstance(text, str): 224 | return text 225 | elif isinstance(text, bytes): 226 | return text.decode("utf-8", "ignore") 227 | else: 228 | raise ValueError("Unsupported string type: %s" % (type(text))) 229 | elif six.PY2: 230 | if isinstance(text, str): 231 | return text.decode("utf-8", "ignore") 232 | elif isinstance(text, unicode): 233 | return text 234 | else: 235 | raise ValueError("Unsupported string type: %s" % (type(text))) 236 | else: 237 | raise ValueError("Not running on Python2 or Python 3?") 238 | 239 | 240 | def printable_text(text): 241 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 242 | 243 | # These functions want `str` for both Python2 and Python3, but in one case 244 | # it's a Unicode string and in the other it's a byte string. 245 | if six.PY3: 246 | if isinstance(text, str): 247 | return text 248 | elif isinstance(text, bytes): 249 | return text.decode("utf-8", "ignore") 250 | else: 251 | raise ValueError("Unsupported string type: %s" % (type(text))) 252 | elif six.PY2: 253 | if isinstance(text, str): 254 | return text 255 | elif isinstance(text, unicode): 256 | return text.encode("utf-8") 257 | else: 258 | raise ValueError("Unsupported string type: %s" % (type(text))) 259 | else: 260 | raise ValueError("Not running on Python2 or Python 3?") 261 | 262 | 263 | def load_vocab(vocab_file): 264 | """Loads a vocabulary file into a dictionary.""" 265 | vocab = collections.OrderedDict() 266 | index = 0 267 | with open(vocab_file, "r") as reader: 268 | while True: 269 | token = convert_to_unicode(reader.readline()) 270 | if not token: 271 | break 272 | token = token.strip() 273 | vocab[token] = index 274 | index += 1 275 | return vocab 276 | 277 | def whitespace_tokenize(text): 278 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 279 | text = text.strip() 280 | if not text: 281 | return [] 282 | tokens = text.split() 283 | return tokens -------------------------------------------------------------------------------- /train/classification_data.py: -------------------------------------------------------------------------------- 1 | ## Adapted from https://github.com/huggingface/pytorch-pretrained-BERT 2 | 3 | import os, sys, argparse 4 | import csv, pickle 5 | import random 6 | import logging 7 | 8 | import numpy as np 9 | from scipy.stats import pearsonr, spearmanr 10 | from sklearn.metrics import matthews_corrcoef, f1_score, roc_auc_score 11 | from sklearn.preprocessing import label_binarize 12 | 13 | import torch 14 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | class InputExample(object): 18 | """A single training/test example for simple sequence classification.""" 19 | 20 | def __init__(self, guid, text_a, text_b=None, label=None): 21 | """Constructs a InputExample. 22 | Args: 23 | guid: Unique id for the example. 24 | text_a: string. The untokenized text of the first sequence. For single 25 | sequence tasks, only this sequence must be specified. 26 | text_b: (Optional) string. The untokenized text of the second sequence. 27 | Only must be specified for sequence pair tasks. 28 | label: (Optional) string. The label of the example. This should be 29 | specified for train and dev examples, but not for test examples. 30 | """ 31 | self.guid = guid 32 | self.text_a = text_a 33 | self.text_b = text_b 34 | self.label = label 35 | 36 | class InputFeatures(object): 37 | """A single set of features of data.""" 38 | 39 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 40 | self.input_ids = input_ids 41 | self.input_mask = input_mask 42 | self.segment_ids = segment_ids 43 | self.label_id = label_id 44 | 45 | class DataProcessor(object): 46 | """Interface for sequence classification.""" 47 | 48 | def get_train_examples(self, data_dir): 49 | """Gets a collection of `InputExample`s for the train set.""" 50 | raise NotImplementedError() 51 | 52 | def get_dev_examples(self, data_dir): 53 | """Gets a collection of `InputExample`s for the dev set.""" 54 | raise NotImplementedError() 55 | 56 | def get_labels(self): 57 | """Gets the list of labels for this data set.""" 58 | raise NotImplementedError() 59 | 60 | class GLUEDataProcessor(DataProcessor): 61 | """Base class for data converters for GLUE sequence classification data sets.""" 62 | 63 | @classmethod 64 | def _read_tsv(cls, input_file, quotechar=None): 65 | """Reads a tab separated value file.""" 66 | with open(input_file, "r") as f: 67 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 68 | lines = [] 69 | for line in reader: 70 | if sys.version_info[0] == 2: 71 | line = list(unicode(cell, 'utf-8') for cell in line) 72 | lines.append(line) 73 | return lines 74 | 75 | class DeepMojiDataProcessor(DataProcessor): 76 | '''This dataset is small enough that we keep it in memory''' 77 | 78 | def __init__(self, num_labels): 79 | super(DeepMojiDataProcessor, self).__init__() 80 | self.num_labels = num_labels 81 | 82 | def _create_examples(self, split, raw_data): 83 | examples = [] 84 | for i, j in enumerate(raw_data[f'{split}_ind']): 85 | guid = f'{split}-{i}' 86 | text = raw_data['texts'][j] 87 | label = np.argmax(raw_data['info'][j]['label']) 88 | example = InputExample(guid=guid, text_a=text, label=label) 89 | examples.append(example) 90 | return examples 91 | 92 | def get_train_examples(self, data_dir): 93 | with open(os.path.join(data_dir, 'raw.pickle'), 'rb') as f: 94 | raw_data = pickle.load(f, encoding="latin1") 95 | return self._create_examples('train', raw_data) 96 | 97 | def get_dev_examples(self, data_dir): 98 | with open(os.path.join(data_dir, 'raw.pickle'), 'rb') as f: 99 | raw_data = pickle.load(f, encoding="latin1") 100 | return self._create_examples('val', raw_data) 101 | 102 | def get_test_examples(self, data_dir): 103 | with open(os.path.join(data_dir, 'raw.pickle'), 'rb') as f: 104 | raw_data = pickle.load(f, encoding="latin1") 105 | return self._create_examples('test', raw_data) 106 | 107 | def get_labels(self): 108 | '''For example, for PsychExp, according to https://github.com/TetsumichiUmada/text2emoji, 109 | the labels are ["joy", "fear", "anger", "sadness", "disgust", "shame", "guilt"]''' 110 | return np.arange(self.num_labels) 111 | 112 | class PsychExpProcessor(DeepMojiDataProcessor): 113 | def __init__(self, pickle_file='data/PsychExp'): 114 | super(PsychExpProcessor, self).__init__(num_labels=7) 115 | 116 | class MrpcProcessor(GLUEDataProcessor): 117 | """Processor for the MRPC data set (GLUE version).""" 118 | 119 | def get_train_examples(self, data_dir): 120 | """See base class.""" 121 | return self._create_examples( 122 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 123 | 124 | def get_dev_examples(self, data_dir): 125 | """See base class.""" 126 | return self._create_examples( 127 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 128 | 129 | def get_labels(self): 130 | """See base class.""" 131 | return ["0", "1"] 132 | 133 | def _create_examples(self, lines, set_type): 134 | """Creates examples for the training and dev sets.""" 135 | examples = [] 136 | for (i, line) in enumerate(lines): 137 | if i == 0: 138 | continue 139 | guid = "%s-%s" % (set_type, i) 140 | text_a = line[3] 141 | text_b = line[4] 142 | label = line[0] 143 | examples.append( 144 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 145 | return examples 146 | 147 | 148 | class MnliProcessor(GLUEDataProcessor): 149 | """Processor for the MultiNLI data set (GLUE version).""" 150 | 151 | def get_train_examples(self, data_dir): 152 | """See base class.""" 153 | return self._create_examples( 154 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 155 | 156 | def get_dev_examples(self, data_dir): 157 | """See base class.""" 158 | return self._create_examples( 159 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 160 | "dev_matched") 161 | 162 | def get_labels(self): 163 | """See base class.""" 164 | return ["contradiction", "entailment", "neutral"] 165 | 166 | def _create_examples(self, lines, set_type): 167 | """Creates examples for the training and dev sets.""" 168 | examples = [] 169 | for (i, line) in enumerate(lines): 170 | if i == 0: 171 | continue 172 | guid = "%s-%s" % (set_type, line[0]) 173 | text_a = line[8] 174 | text_b = line[9] 175 | label = line[-1] 176 | examples.append( 177 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 178 | return examples 179 | 180 | 181 | class MnliMismatchedProcessor(MnliProcessor): 182 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 183 | 184 | def get_dev_examples(self, data_dir): 185 | """See base class.""" 186 | return self._create_examples( 187 | self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), 188 | "dev_matched") 189 | 190 | 191 | class ColaProcessor(GLUEDataProcessor): 192 | """Processor for the CoLA data set (GLUE version).""" 193 | 194 | def get_train_examples(self, data_dir): 195 | """See base class.""" 196 | return self._create_examples( 197 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 198 | 199 | def get_dev_examples(self, data_dir): 200 | """See base class.""" 201 | return self._create_examples( 202 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 203 | 204 | def get_labels(self): 205 | """See base class.""" 206 | return ["0", "1"] 207 | 208 | def _create_examples(self, lines, set_type): 209 | """Creates examples for the training and dev sets.""" 210 | examples = [] 211 | for (i, line) in enumerate(lines): 212 | guid = "%s-%s" % (set_type, i) 213 | text_a = line[3] 214 | label = line[1] 215 | examples.append( 216 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 217 | return examples 218 | 219 | 220 | class Sst2Processor(GLUEDataProcessor): 221 | """Processor for the SST-2 data set (GLUE version).""" 222 | 223 | def get_train_examples(self, data_dir): 224 | """See base class.""" 225 | return self._create_examples( 226 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 227 | 228 | def get_dev_examples(self, data_dir): 229 | """See base class.""" 230 | return self._create_examples( 231 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 232 | 233 | def get_labels(self): 234 | """See base class.""" 235 | return ["0", "1"] 236 | 237 | def _create_examples(self, lines, set_type): 238 | """Creates examples for the training and dev sets.""" 239 | examples = [] 240 | for (i, line) in enumerate(lines): 241 | if i == 0: 242 | continue 243 | guid = "%s-%s" % (set_type, i) 244 | text_a = line[0] 245 | label = line[1] 246 | examples.append( 247 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 248 | return examples 249 | 250 | 251 | class StsbProcessor(GLUEDataProcessor): 252 | """Processor for the STS-B data set (GLUE version).""" 253 | 254 | def get_train_examples(self, data_dir): 255 | """See base class.""" 256 | return self._create_examples( 257 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 258 | 259 | def get_dev_examples(self, data_dir): 260 | """See base class.""" 261 | return self._create_examples( 262 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 263 | 264 | def get_labels(self): 265 | """See base class.""" 266 | return [None] 267 | 268 | def _create_examples(self, lines, set_type): 269 | """Creates examples for the training and dev sets.""" 270 | examples = [] 271 | for (i, line) in enumerate(lines): 272 | if i == 0: 273 | continue 274 | guid = "%s-%s" % (set_type, line[0]) 275 | text_a = line[7] 276 | text_b = line[8] 277 | label = line[-1] 278 | examples.append( 279 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 280 | return examples 281 | 282 | 283 | class QqpProcessor(GLUEDataProcessor): 284 | """Processor for the STS-B data set (GLUE version).""" 285 | 286 | def get_train_examples(self, data_dir): 287 | """See base class.""" 288 | return self._create_examples( 289 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 290 | 291 | def get_dev_examples(self, data_dir): 292 | """See base class.""" 293 | return self._create_examples( 294 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 295 | 296 | def get_labels(self): 297 | """See base class.""" 298 | return ["0", "1"] 299 | 300 | def _create_examples(self, lines, set_type): 301 | """Creates examples for the training and dev sets.""" 302 | examples = [] 303 | for (i, line) in enumerate(lines): 304 | if i == 0: 305 | continue 306 | guid = "%s-%s" % (set_type, line[0]) 307 | try: 308 | text_a = line[3] 309 | text_b = line[4] 310 | label = line[5] 311 | except IndexError: 312 | continue 313 | examples.append( 314 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 315 | return examples 316 | 317 | 318 | class QnliProcessor(GLUEDataProcessor): 319 | """Processor for the STS-B data set (GLUE version).""" 320 | 321 | def get_train_examples(self, data_dir): 322 | """See base class.""" 323 | return self._create_examples( 324 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 325 | 326 | def get_dev_examples(self, data_dir): 327 | """See base class.""" 328 | return self._create_examples( 329 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 330 | "dev_matched") 331 | 332 | def get_labels(self): 333 | """See base class.""" 334 | return ["entailment", "not_entailment"] 335 | 336 | def _create_examples(self, lines, set_type): 337 | """Creates examples for the training and dev sets.""" 338 | examples = [] 339 | for (i, line) in enumerate(lines): 340 | if i == 0: 341 | continue 342 | guid = "%s-%s" % (set_type, line[0]) 343 | text_a = line[1] 344 | text_b = line[2] 345 | label = line[-1] 346 | examples.append( 347 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 348 | return examples 349 | 350 | 351 | class RteProcessor(GLUEDataProcessor): 352 | """Processor for the RTE data set (GLUE version).""" 353 | 354 | def get_train_examples(self, data_dir): 355 | """See base class.""" 356 | return self._create_examples( 357 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 358 | 359 | def get_dev_examples(self, data_dir): 360 | """See base class.""" 361 | return self._create_examples( 362 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 363 | 364 | def get_labels(self): 365 | """See base class.""" 366 | return ["entailment", "not_entailment"] 367 | 368 | def _create_examples(self, lines, set_type): 369 | """Creates examples for the training and dev sets.""" 370 | examples = [] 371 | for (i, line) in enumerate(lines): 372 | if i == 0: 373 | continue 374 | guid = "%s-%s" % (set_type, line[0]) 375 | text_a = line[1] 376 | text_b = line[2] 377 | label = line[-1] 378 | examples.append( 379 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 380 | return examples 381 | 382 | 383 | class WnliProcessor(GLUEDataProcessor): 384 | """Processor for the WNLI data set (GLUE version).""" 385 | 386 | def get_train_examples(self, data_dir): 387 | """See base class.""" 388 | return self._create_examples( 389 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 390 | 391 | def get_dev_examples(self, data_dir): 392 | """See base class.""" 393 | return self._create_examples( 394 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 395 | 396 | def get_labels(self): 397 | """See base class.""" 398 | return ["0", "1"] 399 | 400 | def _create_examples(self, lines, set_type): 401 | """Creates examples for the training and dev sets.""" 402 | examples = [] 403 | for (i, line) in enumerate(lines): 404 | if i == 0: 405 | continue 406 | guid = "%s-%s" % (set_type, line[0]) 407 | text_a = line[1] 408 | text_b = line[2] 409 | label = line[-1] 410 | examples.append( 411 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 412 | return examples 413 | 414 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_mode, logger=None): 415 | """Loads a data file into a list of `InputBatch`s.""" 416 | 417 | label_map = {label : i for i, label in enumerate(label_list)} 418 | 419 | features = [] 420 | for (ex_index, example) in enumerate(examples): 421 | # if logger is not None and ex_index % 10000 == 0: 422 | # logger.info("Writing example %d of %d" % (ex_index, len(examples))) 423 | 424 | tokens_a = tokenizer.tokenize(example.text_a) 425 | 426 | tokens_b = None 427 | if example.text_b: 428 | tokens_b = tokenizer.tokenize(example.text_b) 429 | # Modifies `tokens_a` and `tokens_b` in place so that the total 430 | # length is less than the specified length. 431 | # Account for [CLS], [SEP], [SEP] with "- 3" 432 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 433 | else: 434 | # Account for [CLS] and [SEP] with "- 2" 435 | if len(tokens_a) > max_seq_length - 2: 436 | tokens_a = tokens_a[:(max_seq_length - 2)] 437 | 438 | # The convention in BERT is: 439 | # (a) For sequence pairs: 440 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 441 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 442 | # (b) For single sequences: 443 | # tokens: [CLS] the dog is hairy . [SEP] 444 | # type_ids: 0 0 0 0 0 0 0 445 | # 446 | # Where "type_ids" are used to indicate whether this is the first 447 | # sequence or the second sequence. The embedding vectors for `type=0` and 448 | # `type=1` were learned during pre-training and are added to the wordpiece 449 | # embedding vector (and position vector). This is not *strictly* necessary 450 | # since the [SEP] token unambigiously separates the sequences, but it makes 451 | # it easier for the model to learn the concept of sequences. 452 | # 453 | # For classification tasks, the first vector (corresponding to [CLS]) is 454 | # used as as the "sentence vector". Note that this only makes sense because 455 | # the entire model is fine-tuned. 456 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 457 | segment_ids = [0] * len(tokens) 458 | 459 | if tokens_b: 460 | tokens += tokens_b + ["[SEP]"] 461 | segment_ids += [1] * (len(tokens_b) + 1) 462 | 463 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 464 | 465 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 466 | # tokens are attended to. 467 | input_mask = [1] * len(input_ids) 468 | 469 | # Zero-pad up to the sequence length. 470 | padding = [0] * (max_seq_length - len(input_ids)) 471 | input_ids += padding 472 | input_mask += padding 473 | segment_ids += padding 474 | 475 | assert len(input_ids) == max_seq_length 476 | assert len(input_mask) == max_seq_length 477 | assert len(segment_ids) == max_seq_length 478 | 479 | if output_mode == "classification": 480 | label_id = label_map[example.label] 481 | elif output_mode == "regression": 482 | label_id = float(example.label) 483 | else: 484 | raise KeyError(output_mode) 485 | 486 | features.append( 487 | InputFeatures(input_ids=input_ids, 488 | input_mask=input_mask, 489 | segment_ids=segment_ids, 490 | label_id=label_id)) 491 | return features 492 | 493 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 494 | """Truncates a sequence pair in place to the maximum length.""" 495 | 496 | # This is a simple heuristic which will always truncate the longer sequence 497 | # one token at a time. This makes more sense than truncating an equal percent 498 | # of tokens from each, since if one sequence is very short then each token 499 | # that's truncated likely contains more information than a longer sequence. 500 | while True: 501 | total_length = len(tokens_a) + len(tokens_b) 502 | if total_length <= max_length: 503 | break 504 | if len(tokens_a) > len(tokens_b): 505 | tokens_a.pop() 506 | else: 507 | tokens_b.pop() 508 | 509 | def simple_accuracy(preds, labels): 510 | return (preds == labels).mean() 511 | 512 | def acc_and_f1(preds, labels): 513 | acc = simple_accuracy(preds, labels) 514 | f1 = f1_score(y_true=labels, y_pred=preds) 515 | return { 516 | "acc": acc, 517 | "f1": f1, 518 | "acc_and_f1": (acc + f1) / 2, 519 | } 520 | 521 | def pearson_and_spearman(preds, labels): 522 | pearson_corr = pearsonr(preds, labels)[0] 523 | spearman_corr = spearmanr(preds, labels)[0] 524 | return { 525 | "pearson": pearson_corr, 526 | "spearmanr": spearman_corr, 527 | "corr": (pearson_corr + spearman_corr) / 2, 528 | } 529 | 530 | def compute_metrics(task_name, preds, labels, logits=None): 531 | assert len(preds) == len(labels) 532 | if task_name == "cola": 533 | return {"mcc": matthews_corrcoef(labels, preds)} 534 | elif task_name == "sst-2": 535 | return {"acc": simple_accuracy(preds, labels)} 536 | elif task_name == "mrpc": 537 | return acc_and_f1(preds, labels) 538 | elif task_name == "sts-b": 539 | return pearson_and_spearman(preds, labels) 540 | elif task_name == "qqp": 541 | return acc_and_f1(preds, labels) 542 | elif task_name == "mnli": 543 | return {"acc": simple_accuracy(preds, labels)} 544 | elif task_name == "mnli-mm": 545 | return {"acc": simple_accuracy(preds, labels)} 546 | elif task_name == "qnli": 547 | return {"acc": simple_accuracy(preds, labels)} 548 | elif task_name == "rte": 549 | return {"acc": simple_accuracy(preds, labels)} 550 | elif task_name == "wnli": 551 | return {"acc": simple_accuracy(preds, labels)} 552 | elif task_name == "psychexp": 553 | return { 554 | "f1": f1_score(y_true=labels, y_pred=preds, average='weighted'), 555 | "auc": roc_auc_score(y_true=label_binarize(labels, classes=PsychExpProcessor().get_labels()), 556 | y_score=logits, average='weighted'), 557 | "acc": simple_accuracy(preds, labels), 558 | } 559 | else: 560 | raise KeyError(task_name) 561 | 562 | def evaluate_test_f1_deepmoji(val_logits, val_labels, test_preds, test_labels, average='weighted'): 563 | # Why does the DeepMoji file say 'weighted_f1'?? 564 | """For the DeepMoji datasets only. For why we do this, see 565 | [this paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/) 566 | # Arguments: 567 | preds: Outputs of val/test set. 568 | labels: Outputs of val/test set. 569 | # Returns: 570 | f1_test: F1 score on the test set 571 | best_t: Best F1 threshold on validation set 572 | """ 573 | # Find best threshold on validation set 574 | f1_from_threshold = lambda threshold: f1_score(val_labels, (val_logits > threshold), average=average) # y_true, y_pred 575 | best_threshold = max(thresholds, key=f1_from_threshold) 576 | 577 | # Evaluate on test set 578 | f1_test = f1_score(test_labels, (test_preds > best_threshold), average=average) 579 | return f1_test, best_threshold 580 | 581 | processors = { 582 | "cola": ColaProcessor, 583 | "mnli": MnliProcessor, 584 | "mnli-mm": MnliMismatchedProcessor, 585 | "mrpc": MrpcProcessor, 586 | "sst-2": Sst2Processor, 587 | "sts-b": StsbProcessor, 588 | "qqp": QqpProcessor, 589 | "qnli": QnliProcessor, 590 | "rte": RteProcessor, 591 | "wnli": WnliProcessor, 592 | "psychexp": PsychExpProcessor, 593 | } 594 | 595 | output_modes = { 596 | "cola": "classification", 597 | "mnli": "classification", 598 | "mrpc": "classification", 599 | "sst-2": "classification", 600 | "sts-b": "regression", 601 | "qqp": "classification", 602 | "qnli": "classification", 603 | "rte": "classification", 604 | "wnli": "classification", 605 | "psychexp": "classification", 606 | } 607 | 608 | def prepare_dataloader(args, tokenizer, test=False): 609 | '''Return train, val (and optionally test) dataloader''' 610 | processor = processors[args.task_name]() 611 | train_dataloader = prepare_dataloader_split(args, processor, tokenizer, split='train') 612 | val_dataloader = prepare_dataloader_split(args, processor, tokenizer, split='val') 613 | if test: 614 | test_dataloader = prepare_dataloader_split(args, processor, tokenizer, split='test') 615 | return train_dataloader, val_dataloader, test_dataloader 616 | else: 617 | return train_dataloader, val_dataloader 618 | 619 | def prepare_dataloader_split(args, processor, tokenizer, split='train'): 620 | '''Load train/val/test/data''' 621 | if split == 'train': 622 | examples = processor.get_train_examples(args.data_dir) 623 | elif split == 'val': 624 | examples = processor.get_dev_examples(args.data_dir) 625 | elif split == 'test': 626 | examples = processor.get_test_examples(args.data_dir) 627 | features = convert_examples_to_features( 628 | examples, args.label_list, args.max_seq_length, tokenizer, args.output_mode) 629 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 630 | input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 631 | segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 632 | label_ids = torch.tensor([f.label_id for f in features], dtype=args.output_dtype) 633 | data = TensorDataset(input_ids, segment_ids, input_mask, label_ids) 634 | # NOTE: Above order is _different_ from original huggingface script, this makes more sense 635 | if split == 'train': 636 | sampler = RandomSampler(data) if args.local_rank == -1 else DistributedSampler(data) 637 | batch_size = args.train_batch_size 638 | elif split == 'val' or split == 'test': 639 | sampler = SequentialSampler(data) 640 | batch_size = args.val_batch_size 641 | dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) 642 | return dataloader -------------------------------------------------------------------------------- /train/.ipynb_checkpoints/classification_data-checkpoint.py: -------------------------------------------------------------------------------- 1 | ## Adapted from https://github.com/huggingface/pytorch-pretrained-BERT 2 | 3 | import os, sys, argparse 4 | import csv, pickle 5 | import random 6 | import logging 7 | 8 | import numpy as np 9 | from scipy.stats import pearsonr, spearmanr 10 | from sklearn.metrics import matthews_corrcoef, f1_score, roc_auc_score 11 | from sklearn.preprocessing import label_binarize 12 | 13 | import torch 14 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | class InputExample(object): 18 | """A single training/test example for simple sequence classification.""" 19 | 20 | def __init__(self, guid, text_a, text_b=None, label=None): 21 | """Constructs a InputExample. 22 | Args: 23 | guid: Unique id for the example. 24 | text_a: string. The untokenized text of the first sequence. For single 25 | sequence tasks, only this sequence must be specified. 26 | text_b: (Optional) string. The untokenized text of the second sequence. 27 | Only must be specified for sequence pair tasks. 28 | label: (Optional) string. The label of the example. This should be 29 | specified for train and dev examples, but not for test examples. 30 | """ 31 | self.guid = guid 32 | self.text_a = text_a 33 | self.text_b = text_b 34 | self.label = label 35 | 36 | class InputFeatures(object): 37 | """A single set of features of data.""" 38 | 39 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 40 | self.input_ids = input_ids 41 | self.input_mask = input_mask 42 | self.segment_ids = segment_ids 43 | self.label_id = label_id 44 | 45 | class DataProcessor(object): 46 | """Interface for sequence classification.""" 47 | 48 | def get_train_examples(self, data_dir): 49 | """Gets a collection of `InputExample`s for the train set.""" 50 | raise NotImplementedError() 51 | 52 | def get_dev_examples(self, data_dir): 53 | """Gets a collection of `InputExample`s for the dev set.""" 54 | raise NotImplementedError() 55 | 56 | def get_labels(self): 57 | """Gets the list of labels for this data set.""" 58 | raise NotImplementedError() 59 | 60 | class GLUEDataProcessor(DataProcessor): 61 | """Base class for data converters for GLUE sequence classification data sets.""" 62 | 63 | @classmethod 64 | def _read_tsv(cls, input_file, quotechar=None): 65 | """Reads a tab separated value file.""" 66 | with open(input_file, "r") as f: 67 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 68 | lines = [] 69 | for line in reader: 70 | if sys.version_info[0] == 2: 71 | line = list(unicode(cell, 'utf-8') for cell in line) 72 | lines.append(line) 73 | return lines 74 | 75 | class DeepMojiDataProcessor(DataProcessor): 76 | '''This dataset is small enough that we keep it in memory''' 77 | 78 | def __init__(self, num_labels): 79 | super(DeepMojiDataProcessor, self).__init__() 80 | self.num_labels = num_labels 81 | 82 | def _create_examples(self, split, raw_data): 83 | examples = [] 84 | for i, j in enumerate(raw_data[f'{split}_ind']): 85 | guid = f'{split}-{i}' 86 | text = raw_data['texts'][j] 87 | label = np.argmax(raw_data['info'][j]['label']) 88 | example = InputExample(guid=guid, text_a=text, label=label) 89 | examples.append(example) 90 | return examples 91 | 92 | def get_train_examples(self, data_dir): 93 | with open(os.path.join(data_dir, 'raw.pickle'), 'rb') as f: 94 | raw_data = pickle.load(f, encoding="latin1") 95 | return self._create_examples('train', raw_data) 96 | 97 | def get_dev_examples(self, data_dir): 98 | with open(os.path.join(data_dir, 'raw.pickle'), 'rb') as f: 99 | raw_data = pickle.load(f, encoding="latin1") 100 | return self._create_examples('val', raw_data) 101 | 102 | def get_test_examples(self, data_dir): 103 | with open(os.path.join(data_dir, 'raw.pickle'), 'rb') as f: 104 | raw_data = pickle.load(f, encoding="latin1") 105 | return self._create_examples('test', raw_data) 106 | 107 | def get_labels(self): 108 | '''For example, for PsychExp, according to https://github.com/TetsumichiUmada/text2emoji, 109 | the labels are ["joy", "fear", "anger", "sadness", "disgust", "shame", "guilt"]''' 110 | return np.arange(self.num_labels) 111 | 112 | class PsychExpProcessor(DeepMojiDataProcessor): 113 | def __init__(self, pickle_file='data/PsychExp'): 114 | super(PsychExpProcessor, self).__init__(num_labels=7) 115 | 116 | class MrpcProcessor(GLUEDataProcessor): 117 | """Processor for the MRPC data set (GLUE version).""" 118 | 119 | def get_train_examples(self, data_dir): 120 | """See base class.""" 121 | return self._create_examples( 122 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 123 | 124 | def get_dev_examples(self, data_dir): 125 | """See base class.""" 126 | return self._create_examples( 127 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 128 | 129 | def get_labels(self): 130 | """See base class.""" 131 | return ["0", "1"] 132 | 133 | def _create_examples(self, lines, set_type): 134 | """Creates examples for the training and dev sets.""" 135 | examples = [] 136 | for (i, line) in enumerate(lines): 137 | if i == 0: 138 | continue 139 | guid = "%s-%s" % (set_type, i) 140 | text_a = line[3] 141 | text_b = line[4] 142 | label = line[0] 143 | examples.append( 144 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 145 | return examples 146 | 147 | 148 | class MnliProcessor(GLUEDataProcessor): 149 | """Processor for the MultiNLI data set (GLUE version).""" 150 | 151 | def get_train_examples(self, data_dir): 152 | """See base class.""" 153 | return self._create_examples( 154 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 155 | 156 | def get_dev_examples(self, data_dir): 157 | """See base class.""" 158 | return self._create_examples( 159 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 160 | "dev_matched") 161 | 162 | def get_labels(self): 163 | """See base class.""" 164 | return ["contradiction", "entailment", "neutral"] 165 | 166 | def _create_examples(self, lines, set_type): 167 | """Creates examples for the training and dev sets.""" 168 | examples = [] 169 | for (i, line) in enumerate(lines): 170 | if i == 0: 171 | continue 172 | guid = "%s-%s" % (set_type, line[0]) 173 | text_a = line[8] 174 | text_b = line[9] 175 | label = line[-1] 176 | examples.append( 177 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 178 | return examples 179 | 180 | 181 | class MnliMismatchedProcessor(MnliProcessor): 182 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 183 | 184 | def get_dev_examples(self, data_dir): 185 | """See base class.""" 186 | return self._create_examples( 187 | self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), 188 | "dev_matched") 189 | 190 | 191 | class ColaProcessor(GLUEDataProcessor): 192 | """Processor for the CoLA data set (GLUE version).""" 193 | 194 | def get_train_examples(self, data_dir): 195 | """See base class.""" 196 | return self._create_examples( 197 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 198 | 199 | def get_dev_examples(self, data_dir): 200 | """See base class.""" 201 | return self._create_examples( 202 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 203 | 204 | def get_labels(self): 205 | """See base class.""" 206 | return ["0", "1"] 207 | 208 | def _create_examples(self, lines, set_type): 209 | """Creates examples for the training and dev sets.""" 210 | examples = [] 211 | for (i, line) in enumerate(lines): 212 | guid = "%s-%s" % (set_type, i) 213 | text_a = line[3] 214 | label = line[1] 215 | examples.append( 216 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 217 | return examples 218 | 219 | 220 | class Sst2Processor(GLUEDataProcessor): 221 | """Processor for the SST-2 data set (GLUE version).""" 222 | 223 | def get_train_examples(self, data_dir): 224 | """See base class.""" 225 | return self._create_examples( 226 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 227 | 228 | def get_dev_examples(self, data_dir): 229 | """See base class.""" 230 | return self._create_examples( 231 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 232 | 233 | def get_labels(self): 234 | """See base class.""" 235 | return ["0", "1"] 236 | 237 | def _create_examples(self, lines, set_type): 238 | """Creates examples for the training and dev sets.""" 239 | examples = [] 240 | for (i, line) in enumerate(lines): 241 | if i == 0: 242 | continue 243 | guid = "%s-%s" % (set_type, i) 244 | text_a = line[0] 245 | label = line[1] 246 | examples.append( 247 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 248 | return examples 249 | 250 | 251 | class StsbProcessor(GLUEDataProcessor): 252 | """Processor for the STS-B data set (GLUE version).""" 253 | 254 | def get_train_examples(self, data_dir): 255 | """See base class.""" 256 | return self._create_examples( 257 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 258 | 259 | def get_dev_examples(self, data_dir): 260 | """See base class.""" 261 | return self._create_examples( 262 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 263 | 264 | def get_labels(self): 265 | """See base class.""" 266 | return [None] 267 | 268 | def _create_examples(self, lines, set_type): 269 | """Creates examples for the training and dev sets.""" 270 | examples = [] 271 | for (i, line) in enumerate(lines): 272 | if i == 0: 273 | continue 274 | guid = "%s-%s" % (set_type, line[0]) 275 | text_a = line[7] 276 | text_b = line[8] 277 | label = line[-1] 278 | examples.append( 279 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 280 | return examples 281 | 282 | 283 | class QqpProcessor(GLUEDataProcessor): 284 | """Processor for the STS-B data set (GLUE version).""" 285 | 286 | def get_train_examples(self, data_dir): 287 | """See base class.""" 288 | return self._create_examples( 289 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 290 | 291 | def get_dev_examples(self, data_dir): 292 | """See base class.""" 293 | return self._create_examples( 294 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 295 | 296 | def get_labels(self): 297 | """See base class.""" 298 | return ["0", "1"] 299 | 300 | def _create_examples(self, lines, set_type): 301 | """Creates examples for the training and dev sets.""" 302 | examples = [] 303 | for (i, line) in enumerate(lines): 304 | if i == 0: 305 | continue 306 | guid = "%s-%s" % (set_type, line[0]) 307 | try: 308 | text_a = line[3] 309 | text_b = line[4] 310 | label = line[5] 311 | except IndexError: 312 | continue 313 | examples.append( 314 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 315 | return examples 316 | 317 | 318 | class QnliProcessor(GLUEDataProcessor): 319 | """Processor for the STS-B data set (GLUE version).""" 320 | 321 | def get_train_examples(self, data_dir): 322 | """See base class.""" 323 | return self._create_examples( 324 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 325 | 326 | def get_dev_examples(self, data_dir): 327 | """See base class.""" 328 | return self._create_examples( 329 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), 330 | "dev_matched") 331 | 332 | def get_labels(self): 333 | """See base class.""" 334 | return ["entailment", "not_entailment"] 335 | 336 | def _create_examples(self, lines, set_type): 337 | """Creates examples for the training and dev sets.""" 338 | examples = [] 339 | for (i, line) in enumerate(lines): 340 | if i == 0: 341 | continue 342 | guid = "%s-%s" % (set_type, line[0]) 343 | text_a = line[1] 344 | text_b = line[2] 345 | label = line[-1] 346 | examples.append( 347 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 348 | return examples 349 | 350 | 351 | class RteProcessor(GLUEDataProcessor): 352 | """Processor for the RTE data set (GLUE version).""" 353 | 354 | def get_train_examples(self, data_dir): 355 | """See base class.""" 356 | return self._create_examples( 357 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 358 | 359 | def get_dev_examples(self, data_dir): 360 | """See base class.""" 361 | return self._create_examples( 362 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 363 | 364 | def get_labels(self): 365 | """See base class.""" 366 | return ["entailment", "not_entailment"] 367 | 368 | def _create_examples(self, lines, set_type): 369 | """Creates examples for the training and dev sets.""" 370 | examples = [] 371 | for (i, line) in enumerate(lines): 372 | if i == 0: 373 | continue 374 | guid = "%s-%s" % (set_type, line[0]) 375 | text_a = line[1] 376 | text_b = line[2] 377 | label = line[-1] 378 | examples.append( 379 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 380 | return examples 381 | 382 | 383 | class WnliProcessor(GLUEDataProcessor): 384 | """Processor for the WNLI data set (GLUE version).""" 385 | 386 | def get_train_examples(self, data_dir): 387 | """See base class.""" 388 | return self._create_examples( 389 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 390 | 391 | def get_dev_examples(self, data_dir): 392 | """See base class.""" 393 | return self._create_examples( 394 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 395 | 396 | def get_labels(self): 397 | """See base class.""" 398 | return ["0", "1"] 399 | 400 | def _create_examples(self, lines, set_type): 401 | """Creates examples for the training and dev sets.""" 402 | examples = [] 403 | for (i, line) in enumerate(lines): 404 | if i == 0: 405 | continue 406 | guid = "%s-%s" % (set_type, line[0]) 407 | text_a = line[1] 408 | text_b = line[2] 409 | label = line[-1] 410 | examples.append( 411 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 412 | return examples 413 | 414 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_mode, logger=None): 415 | """Loads a data file into a list of `InputBatch`s.""" 416 | 417 | label_map = {label : i for i, label in enumerate(label_list)} 418 | 419 | features = [] 420 | for (ex_index, example) in enumerate(examples): 421 | # if logger is not None and ex_index % 10000 == 0: 422 | # logger.info("Writing example %d of %d" % (ex_index, len(examples))) 423 | 424 | tokens_a = tokenizer.tokenize(example.text_a) 425 | 426 | tokens_b = None 427 | if example.text_b: 428 | tokens_b = tokenizer.tokenize(example.text_b) 429 | # Modifies `tokens_a` and `tokens_b` in place so that the total 430 | # length is less than the specified length. 431 | # Account for [CLS], [SEP], [SEP] with "- 3" 432 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 433 | else: 434 | # Account for [CLS] and [SEP] with "- 2" 435 | if len(tokens_a) > max_seq_length - 2: 436 | tokens_a = tokens_a[:(max_seq_length - 2)] 437 | 438 | # The convention in BERT is: 439 | # (a) For sequence pairs: 440 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 441 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 442 | # (b) For single sequences: 443 | # tokens: [CLS] the dog is hairy . [SEP] 444 | # type_ids: 0 0 0 0 0 0 0 445 | # 446 | # Where "type_ids" are used to indicate whether this is the first 447 | # sequence or the second sequence. The embedding vectors for `type=0` and 448 | # `type=1` were learned during pre-training and are added to the wordpiece 449 | # embedding vector (and position vector). This is not *strictly* necessary 450 | # since the [SEP] token unambigiously separates the sequences, but it makes 451 | # it easier for the model to learn the concept of sequences. 452 | # 453 | # For classification tasks, the first vector (corresponding to [CLS]) is 454 | # used as as the "sentence vector". Note that this only makes sense because 455 | # the entire model is fine-tuned. 456 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 457 | segment_ids = [0] * len(tokens) 458 | 459 | if tokens_b: 460 | tokens += tokens_b + ["[SEP]"] 461 | segment_ids += [1] * (len(tokens_b) + 1) 462 | 463 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 464 | 465 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 466 | # tokens are attended to. 467 | input_mask = [1] * len(input_ids) 468 | 469 | # Zero-pad up to the sequence length. 470 | padding = [0] * (max_seq_length - len(input_ids)) 471 | input_ids += padding 472 | input_mask += padding 473 | segment_ids += padding 474 | 475 | assert len(input_ids) == max_seq_length 476 | assert len(input_mask) == max_seq_length 477 | assert len(segment_ids) == max_seq_length 478 | 479 | if output_mode == "classification": 480 | label_id = label_map[example.label] 481 | elif output_mode == "regression": 482 | label_id = float(example.label) 483 | else: 484 | raise KeyError(output_mode) 485 | 486 | # if logger is not None and ex_index < 5: 487 | # logger.info("*** Example ***") 488 | # logger.info("guid: %s" % (example.guid)) 489 | # logger.info("tokens: %s" % " ".join( 490 | # [str(x) for x in tokens])) 491 | # logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 492 | # logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 493 | # logger.info( 494 | # "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 495 | # logger.info("label: %s (id = %d)" % (example.label, label_id)) 496 | 497 | features.append( 498 | InputFeatures(input_ids=input_ids, 499 | input_mask=input_mask, 500 | segment_ids=segment_ids, 501 | label_id=label_id)) 502 | return features 503 | 504 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 505 | """Truncates a sequence pair in place to the maximum length.""" 506 | 507 | # This is a simple heuristic which will always truncate the longer sequence 508 | # one token at a time. This makes more sense than truncating an equal percent 509 | # of tokens from each, since if one sequence is very short then each token 510 | # that's truncated likely contains more information than a longer sequence. 511 | while True: 512 | total_length = len(tokens_a) + len(tokens_b) 513 | if total_length <= max_length: 514 | break 515 | if len(tokens_a) > len(tokens_b): 516 | tokens_a.pop() 517 | else: 518 | tokens_b.pop() 519 | 520 | def simple_accuracy(preds, labels): 521 | return (preds == labels).mean() 522 | 523 | def acc_and_f1(preds, labels): 524 | acc = simple_accuracy(preds, labels) 525 | f1 = f1_score(y_true=labels, y_pred=preds) 526 | return { 527 | "acc": acc, 528 | "f1": f1, 529 | "acc_and_f1": (acc + f1) / 2, 530 | } 531 | 532 | def pearson_and_spearman(preds, labels): 533 | pearson_corr = pearsonr(preds, labels)[0] 534 | spearman_corr = spearmanr(preds, labels)[0] 535 | return { 536 | "pearson": pearson_corr, 537 | "spearmanr": spearman_corr, 538 | "corr": (pearson_corr + spearman_corr) / 2, 539 | } 540 | 541 | def compute_metrics(task_name, preds, labels, logits=None): 542 | assert len(preds) == len(labels) 543 | if task_name == "cola": 544 | return {"mcc": matthews_corrcoef(labels, preds)} 545 | elif task_name == "sst-2": 546 | return {"acc": simple_accuracy(preds, labels)} 547 | elif task_name == "mrpc": 548 | return acc_and_f1(preds, labels) 549 | elif task_name == "sts-b": 550 | return pearson_and_spearman(preds, labels) 551 | elif task_name == "qqp": 552 | return acc_and_f1(preds, labels) 553 | elif task_name == "mnli": 554 | return {"acc": simple_accuracy(preds, labels)} 555 | elif task_name == "mnli-mm": 556 | return {"acc": simple_accuracy(preds, labels)} 557 | elif task_name == "qnli": 558 | return {"acc": simple_accuracy(preds, labels)} 559 | elif task_name == "rte": 560 | return {"acc": simple_accuracy(preds, labels)} 561 | elif task_name == "wnli": 562 | return {"acc": simple_accuracy(preds, labels)} 563 | elif task_name == "psychexp": 564 | return { 565 | "f1": f1_score(y_true=labels, y_pred=preds, average='weighted'), 566 | "auc": roc_auc_score(y_true=label_binarize(labels, classes=PsychExpProcessor().get_labels()), 567 | y_score=logits, average='weighted'), 568 | "acc": simple_accuracy(preds, labels), 569 | } 570 | else: 571 | raise KeyError(task_name) 572 | 573 | def evaluate_test_f1_deepmoji(val_logits, val_labels, test_preds, test_labels, average='weighted'): 574 | # Why does the DeepMoji file say 'weighted_f1'?? 575 | """For the DeepMoji datasets only. For why we do this, see 576 | [this paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/) 577 | # Arguments: 578 | preds: Outputs of val/test set. 579 | labels: Outputs of val/test set. 580 | # Returns: 581 | f1_test: F1 score on the test set 582 | best_t: Best F1 threshold on validation set 583 | """ 584 | # Find best threshold on validation set 585 | f1_from_threshold = lambda threshold: f1_score(val_labels, (val_logits > threshold), average=average) # y_true, y_pred 586 | best_threshold = max(thresholds, key=f1_from_threshold) 587 | 588 | # Evaluate on test set 589 | f1_test = f1_score(test_labels, (test_preds > best_threshold), average=average) 590 | return f1_test, best_threshold 591 | 592 | processors = { 593 | "cola": ColaProcessor, 594 | "mnli": MnliProcessor, 595 | "mnli-mm": MnliMismatchedProcessor, 596 | "mrpc": MrpcProcessor, 597 | "sst-2": Sst2Processor, 598 | "sts-b": StsbProcessor, 599 | "qqp": QqpProcessor, 600 | "qnli": QnliProcessor, 601 | "rte": RteProcessor, 602 | "wnli": WnliProcessor, 603 | "psychexp": PsychExpProcessor, 604 | } 605 | 606 | output_modes = { 607 | "cola": "classification", 608 | "mnli": "classification", 609 | "mrpc": "classification", 610 | "sst-2": "classification", 611 | "sts-b": "regression", 612 | "qqp": "classification", 613 | "qnli": "classification", 614 | "rte": "classification", 615 | "wnli": "classification", 616 | "psychexp": "classification", 617 | } 618 | 619 | def prepare_dataloader(args, tokenizer, test=False): 620 | '''Return train, val (and optionally test) dataloader''' 621 | processor = processors[args.task_name]() 622 | train_dataloader = prepare_dataloader_split(args, processor, tokenizer, split='train') 623 | val_dataloader = prepare_dataloader_split(args, processor, tokenizer, split='val') 624 | if test: 625 | test_dataloader = prepare_dataloader_split(args, processor, tokenizer, split='test') 626 | return train_dataloader, val_dataloader, test_dataloader 627 | else: 628 | return train_dataloader, val_dataloader 629 | 630 | def prepare_dataloader_split(args, processor, tokenizer, split='train'): 631 | '''Load train/val/test/data''' 632 | if split == 'train': 633 | examples = processor.get_train_examples(args.data_dir) 634 | elif split == 'val': 635 | examples = processor.get_dev_examples(args.data_dir) 636 | elif split == 'test': 637 | examples = processor.get_test_examples(args.data_dir) 638 | features = convert_examples_to_features( 639 | examples, args.label_list, args.max_seq_length, tokenizer, args.output_mode) 640 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 641 | input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 642 | segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 643 | label_ids = torch.tensor([f.label_id for f in features], dtype=args.output_dtype) 644 | data = TensorDataset(input_ids, segment_ids, input_mask, label_ids) 645 | # NOTE: Above order is _different_ from original huggingface script, this makes more sense 646 | if split == 'train': 647 | sampler = RandomSampler(data) if args.local_rank == -1 else DistributedSampler(data) 648 | batch_size = args.train_batch_size 649 | elif split == 'val' or split == 'test': 650 | sampler = SequentialSampler(data) 651 | batch_size = args.val_batch_size 652 | dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) 653 | return dataloader --------------------------------------------------------------------------------