├── .gitignore ├── LICENSE ├── README.md ├── configs ├── alert_config.json └── base_experiment_config.json ├── doc └── dispatcher_readme.md ├── rationale_net ├── __init__.py ├── datasets │ ├── __init__.py │ ├── abstract_dataset.py │ ├── factory.py │ ├── full_beer_dataset.py │ └── news_group_dataset.py ├── learn │ ├── __init__.py │ └── train.py ├── models │ ├── __init__.py │ ├── cnn.py │ ├── empty.py │ ├── encoder.py │ ├── generator.py │ └── tagger.py └── utils │ ├── __init__.py │ ├── embedding.py │ ├── generic.py │ ├── learn.py │ ├── metrics.py │ ├── model.py │ └── parsing.py ├── requirements.txt ├── requirements3.txt ├── scripts ├── dispatcher.py ├── main.py └── preprocess │ └── preprocess_snli.py └── tutorial └── TextCNN.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Data dirs 2 | data/ 3 | # Byte-compiled / optimized / DLL files 4 | venv3 5 | logs 6 | secrets 7 | raw_data 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | *.swp 12 | *.swo 13 | # C extensions 14 | *.so 15 | *.DS_Store 16 | #data 17 | beer_review 18 | snapshot 19 | #logs 20 | *.txt 21 | *.png 22 | *.results 23 | *.csv 24 | # Distribution / packaging 25 | .Python 26 | env/ 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # dotenv 98 | .env 99 | 100 | # virtualenv 101 | .venv 102 | venv/ 103 | ENV/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Adam Yala 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # text_nn 2 | 3 | Text Classification models. Used as a submodule for other projects. 4 | Supports extractive rationale extraction like in Tao Lei's Rationalizing neural prediction. These departs from Tao's 5 | original framework in the following way: 6 | 7 | - I implement Generator training using the Gumbel Softmax instead of using REINFORCE 8 | - I only implement the indepdent selector. 9 | 10 | ## Requirments 11 | This repository assumes glove embeddings. 12 | Download Glove embeddings at: https://nlp.stanford.edu/projects/glove/ 13 | And place `glove.6B/glove.6B.300d.txt` in `data/embeddings/glove.6B/glove.6B.300d.txt`. 14 | 15 | This code supports the the NewsGroup dataset and the BeerReview dataset. The for access to the BeerReview and the corresponding embeddings, please contact Tao Lei. I've included the NewsGroup dataset, conveiently provided by SKLearn so you can run code out of the box. 16 | 17 | 18 | ## Usage: 19 | Example run: 20 | ``` 21 | CUDA_VISIBLE_DEVICES=0 python -u scripts/main.py --batch_size 64 --cuda --dataset news_group --embedding 22 | glove --dropout 0.05 --weight_decay 5e-06 --num_layers 1 --model_form cnn --hidden_dim 100 --epochs 50 --init_lr 0.0001 --num_workers 23 | 0 --objective cross_entropy --patience 5 --save_dir snapshot --train --test --results_path logs/demo_run.results --gumbel_decay 1e-5 --get_rationales 24 | --selection_lambda .001 --continuity_lambda 0 25 | ``` 26 | Use `--get_rationales` to enable extractive rationales. 27 | 28 | The results and extracted rationales will be saved in `results_path` 29 | And be accessed as 30 | 31 | ``` 32 | results = pickle.load(open(results_path,'rb')) 33 | rationales = results['test_stats']['rationales'] 34 | ``` 35 | 36 | To run grid search, see `docs/dispatcher`. 37 | 38 | Note, the rationale model is very sensitive to hyperparameters and the example run has not been tuned. 39 | 40 | 41 | ## Base Models Supported: 42 | - TextCNN (Yoon 2014) 43 | 44 | ## Extending: 45 | ### How to add a new dataset: 46 | - Fork the repo 47 | - Add a pytorch Dataset object to `rationale_net/datasets` and register it to the dataset factory. See the news_group and beer_review datasets for an example. 48 | - Add the corresponding import to `rationale_net/datasets/__init__.py` 49 | ### How to add a new model base? 50 | - Supported in research version of this repo, but it's involved. If there is interest, please contact me. 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/alert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "path_to_twilio_secret": "secrets/twilio_info.json", 3 | "alert_nums": ["your-phone-number"], 4 | "suppress_alerts":false 5 | } 6 | -------------------------------------------------------------------------------- /configs/base_experiment_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "search_space": { 3 | "batch_size": [128], 4 | "cuda": [true], 5 | "dataset": ["pathology"], 6 | "embedding": ["pathology"], 7 | "aspect": ["DCIS", "IDC", "ADH"], 8 | "dropout": [0.25], 9 | "num_layers": [1], 10 | "epochs": [25], 11 | "init_lr": [1e-05], 12 | "num_workers": [8], 13 | "objective": ["cross_entropy"], 14 | "patience": [10], 15 | "plot_losses": [false], 16 | "save_dir": ["snapshot"], 17 | "train": [true], 18 | "test": [false], 19 | "get_rationales": [true], 20 | "use_gumbel": [true], 21 | "selection_lambda": [.01], 22 | "continuity_lambda":[.01], 23 | "gumbel_decay":[.00001], 24 | "gumbel_temprature":[1] 25 | }, 26 | "available_gpus":[0,1,2] 27 | } 28 | -------------------------------------------------------------------------------- /doc/dispatcher_readme.md: -------------------------------------------------------------------------------- 1 | # Dispatcher Configs 2 | - First, copy `base_experiment_config.json` and name it something meaninful. 3 | - Change flags to your experimental setup. For example, you may want to change the dataset. 4 | - Each arg in search_space is a list. To search over parameter settings, define the values you want to try in the respective lists and the dispatcher will run over all combinations. 5 | - Copy `alert_config` into your own and mondify it for your needs (such as adding your phone number, or using a list of phone numbers). The base config will have the path to @yala's twilio auth token. 6 | - You can change the GPU's you'll use by adding/ removing gpu id's in `available_gpus` 7 | - Run `python scripts/dispatcher.py --config_path=[your_config] --alert_config=[your_alert_config] --result_path=[path_in_shared_storage_result_dir]` and the grid search will text you when it's done! 8 | 9 | Please check `num_jobs` logged at the beginning of the program makes sense to make sure you didn't make a mistake in your config file. 10 | 11 | Please make sure you use your own alert_config so your job doesn't send @yala a million text messages. -------------------------------------------------------------------------------- /rationale_net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yala/text_nn/476d1336f5be7178bc13b70a569a1a0b964b8244/rationale_net/__init__.py -------------------------------------------------------------------------------- /rationale_net/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import rationale_net.datasets.full_beer_dataset 2 | import rationale_net.datasets.news_group_dataset 3 | -------------------------------------------------------------------------------- /rationale_net/datasets/abstract_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod, abstractproperty 2 | import torch.utils.data as data 3 | 4 | TRAIN_ONLY_ERR_MSG = "{} only supported for train dataset! Instead saw {}" 5 | 6 | class AbstractDataset(data.Dataset): 7 | __metaclass__ = ABCMeta 8 | 9 | def __len__(self): 10 | return len(self.dataset) 11 | 12 | def __getitem__(self,index): 13 | sample = self.dataset[index] 14 | return sample 15 | -------------------------------------------------------------------------------- /rationale_net/datasets/factory.py: -------------------------------------------------------------------------------- 1 | NO_DATASET_ERR = "Dataset {} not in DATASET_REGISTRY! Available datasets are {}" 2 | 3 | DATASET_REGISTRY = {} 4 | 5 | 6 | def RegisterDataset(dataset_name): 7 | """Registers a dataset.""" 8 | 9 | def decorator(f): 10 | DATASET_REGISTRY[dataset_name] = f 11 | return f 12 | 13 | return decorator 14 | 15 | 16 | # Depending on arg, build dataset 17 | def get_dataset(args, word_to_indx, truncate_train=False): 18 | if args.dataset not in DATASET_REGISTRY: 19 | raise Exception( 20 | NO_DATASET_ERR.format(args.dataset, DATASET_REGISTRY.keys())) 21 | 22 | if args.dataset in DATASET_REGISTRY: 23 | train = DATASET_REGISTRY[args.dataset](args, word_to_indx, 'train') 24 | dev = DATASET_REGISTRY[args.dataset](args, word_to_indx, 'dev') 25 | test = DATASET_REGISTRY[args.dataset](args, word_to_indx, 'test') 26 | 27 | return train, dev, test 28 | -------------------------------------------------------------------------------- /rationale_net/datasets/full_beer_dataset.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import tqdm 3 | from rationale_net.utils.embedding import get_indices_tensor 4 | from rationale_net.datasets.factory import RegisterDataset 5 | from rationale_net.datasets.abstract_dataset import AbstractDataset 6 | 7 | 8 | SMALL_TRAIN_SIZE = 800 9 | 10 | @RegisterDataset('full_beer') 11 | class FullBeerDataset(AbstractDataset): 12 | 13 | def __init__(self, args, word_to_indx, mode, max_length=250, stem='raw_data/beer_review/reviews.aspect'): 14 | aspect = args.aspect 15 | self.args= args 16 | self.name = mode 17 | self.objective = args.objective 18 | self.dataset = [] 19 | self.word_to_indx = word_to_indx 20 | self.max_length = max_length 21 | self.aspects_to_num = {'appearance':0, 'aroma':1, 'palate':2,'taste':3, 'overall':4} 22 | self.class_map = {0: 0, 1:0, 2:0, 3:0, 4:1, 5:1, 6:1, 7:1, 8:2, 9:2, 10:2} 23 | self.name_to_key = {'train':'train', 'dev':'heldout', 'test':'heldout'} 24 | self.class_balance = {} 25 | with gzip.open(stem+str(self.aspects_to_num[aspect])+'.'+self.name_to_key[self.name]+'.txt.gz') as gfile: 26 | lines = gfile.readlines() 27 | lines = list(zip( range(len(lines)), lines) ) 28 | if args.debug_mode: 29 | lines = lines[:SMALL_TRAIN_SIZE] 30 | elif self.name == 'dev': 31 | lines = lines[:5000] 32 | elif self.name == 'test': 33 | lines = lines[5000:10000] 34 | elif self.name == 'train': 35 | lines = lines[0:20000] 36 | 37 | for indx, line in tqdm.tqdm(enumerate(lines)): 38 | uid, line_content = line 39 | sample = self.processLine(line_content, self.aspects_to_num[aspect], indx) 40 | 41 | if not sample['y'] in self.class_balance: 42 | self.class_balance[ sample['y'] ] = 0 43 | self.class_balance[ sample['y'] ] += 1 44 | sample['uid'] = uid 45 | self.dataset.append(sample) 46 | gfile.close() 47 | print ("Class balance", self.class_balance) 48 | 49 | if args.class_balance: 50 | raise NotImplementedError("Beer review dataset doesn't support balanced sampling!") 51 | 52 | ## Convert one line from beer dataset to {Text, Tensor, Labels} 53 | def processLine(self, line, aspect_num, i): 54 | if isinstance(line, bytes): 55 | line = line.decode() 56 | labels = [ float(v) for v in line.split()[:5] ] 57 | if self.objective == 'mse': 58 | label = float(labels[aspect_num]) 59 | self.args.num_class = 1 60 | else: 61 | label = int(self.class_map[ int(labels[aspect_num] *10) ]) 62 | self.args.num_class = 3 63 | text_list = line.split('\t')[-1].split()[:self.max_length] 64 | text = " ".join(text_list) 65 | x = get_indices_tensor(text_list, self.word_to_indx, self.max_length) 66 | sample = {'text':text,'x':x, 'y':label, 'i':i} 67 | return sample 68 | -------------------------------------------------------------------------------- /rationale_net/datasets/news_group_dataset.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import re 3 | import tqdm 4 | from rationale_net.utils.embedding import get_indices_tensor 5 | from rationale_net.datasets.factory import RegisterDataset 6 | from rationale_net.datasets.abstract_dataset import AbstractDataset 7 | from sklearn.datasets import fetch_20newsgroups 8 | import random 9 | random.seed(0) 10 | 11 | 12 | SMALL_TRAIN_SIZE = 800 13 | CATEGORIES = ['alt.atheism', 14 | 'comp.graphics', 15 | 'comp.os.ms-windows.misc', 16 | 'comp.sys.ibm.pc.hardware', 17 | 'comp.sys.mac.hardware', 18 | 'comp.windows.x', 19 | 'misc.forsale', 20 | 'rec.autos', 21 | 'rec.motorcycles', 22 | 'rec.sport.baseball', 23 | 'rec.sport.hockey', 24 | 'sci.crypt', 25 | 'sci.electronics', 26 | 'sci.med', 27 | 'sci.space', 28 | 'soc.religion.christian', 29 | 'talk.politics.guns', 30 | 'talk.politics.mideast', 31 | 'talk.politics.misc', 32 | 'talk.religion.misc'] 33 | 34 | def preprocess_data(data): 35 | processed_data = [] 36 | for indx, sample in enumerate(data['data']): 37 | text, label = sample, data['target'][indx] 38 | label_name = data['target_names'][label] 39 | text = re.sub('\W+', ' ', text).lower().strip() 40 | processed_data.append( (text, label, label_name) ) 41 | return processed_data 42 | 43 | 44 | @RegisterDataset('news_group') 45 | class NewsGroupDataset(AbstractDataset): 46 | 47 | def __init__(self, args, word_to_indx, name, max_length=80): 48 | self.args = args 49 | self.args.num_class = 20 50 | self.name = name 51 | self.dataset = [] 52 | self.word_to_indx = word_to_indx 53 | self.max_length = max_length 54 | self.class_balance = {} 55 | 56 | if name in ['train', 'dev']: 57 | data = preprocess_data(fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'), categories=CATEGORIES)) 58 | random.shuffle(data) 59 | num_train = int(len(data)*.8) 60 | if name == 'train': 61 | data = data[:num_train] 62 | else: 63 | data = data[num_train:] 64 | else: 65 | data = preprocess_data(fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'), categories=CATEGORIES)) 66 | 67 | for indx, _sample in tqdm.tqdm(enumerate(data)): 68 | sample = self.processLine(_sample) 69 | 70 | if not sample['y'] in self.class_balance: 71 | self.class_balance[ sample['y'] ] = 0 72 | self.class_balance[ sample['y'] ] += 1 73 | self.dataset.append(sample) 74 | 75 | print ("Class balance", self.class_balance) 76 | 77 | if args.class_balance: 78 | raise NotImplementedError("NewsGroup dataset doesn't support balanced sampling") 79 | if args.objective == 'mse': 80 | raise NotImplementedError("News Group does not support Regression objective") 81 | 82 | ## Convert one line from beer dataset to {Text, Tensor, Labels} 83 | def processLine(self, row): 84 | text, label, label_name = row 85 | text = " ".join(text.split()[:self.max_length]) 86 | x = get_indices_tensor(text.split(), self.word_to_indx, self.max_length) 87 | sample = {'text':text,'x':x, 'y':label, 'y_name': label_name} 88 | return sample 89 | -------------------------------------------------------------------------------- /rationale_net/learn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yala/text_nn/476d1336f5be7178bc13b70a569a1a0b964b8244/rationale_net/learn/__init__.py -------------------------------------------------------------------------------- /rationale_net/learn/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.autograd as autograd 5 | import torch.nn.functional as F 6 | import rationale_net.utils.generic as generic 7 | import rationale_net.utils.metrics as metrics 8 | import tqdm 9 | import numpy as np 10 | import pdb 11 | import sklearn.metrics 12 | import rationale_net.utils.learn as learn 13 | 14 | 15 | def train_model(train_data, dev_data, model, gen, args): 16 | ''' 17 | Train model and tune on dev set. If model doesn't improve dev performance within args.patience 18 | epochs, then halve the learning rate, restore the model to best and continue training. 19 | 20 | At the end of training, the function will restore the model to best dev version. 21 | 22 | returns epoch_stats: a dictionary of epoch level metrics for train and test 23 | returns model : best model from this call to train 24 | ''' 25 | 26 | if args.cuda: 27 | model = model.cuda() 28 | gen = gen.cuda() 29 | 30 | args.lr = args.init_lr 31 | optimizer = learn.get_optimizer([model, gen], args) 32 | 33 | num_epoch_sans_improvement = 0 34 | epoch_stats = metrics.init_metrics_dictionary(modes=['train', 'dev']) 35 | step = 0 36 | tuning_key = "dev_{}".format(args.tuning_metric) 37 | best_epoch_func = min if tuning_key == 'loss' else max 38 | 39 | train_loader = learn.get_train_loader(train_data, args) 40 | dev_loader = learn.get_dev_loader(dev_data, args) 41 | 42 | 43 | 44 | 45 | for epoch in range(1, args.epochs + 1): 46 | 47 | print("-------------\nEpoch {}:\n".format(epoch)) 48 | for mode, dataset, loader in [('Train', train_data, train_loader), ('Dev', dev_data, dev_loader)]: 49 | train_model = mode == 'Train' 50 | print('{}'.format(mode)) 51 | key_prefix = mode.lower() 52 | epoch_details, step, _, _, _, _ = run_epoch( 53 | data_loader=loader, 54 | train_model=train_model, 55 | model=model, 56 | gen=gen, 57 | optimizer=optimizer, 58 | step=step, 59 | args=args) 60 | 61 | epoch_stats, log_statement = metrics.collate_epoch_stat(epoch_stats, epoch_details, key_prefix, args) 62 | 63 | # Log performance 64 | print(log_statement) 65 | 66 | 67 | # Save model if beats best dev 68 | best_func = min if args.tuning_metric == 'loss' else max 69 | if best_func(epoch_stats[tuning_key]) == epoch_stats[tuning_key][-1]: 70 | num_epoch_sans_improvement = 0 71 | if not os.path.isdir(args.save_dir): 72 | os.makedirs(args.save_dir) 73 | # Subtract one because epoch is 1-indexed and arr is 0-indexed 74 | epoch_stats['best_epoch'] = epoch - 1 75 | torch.save(model, args.model_path) 76 | torch.save(gen, learn.get_gen_path(args.model_path)) 77 | else: 78 | num_epoch_sans_improvement += 1 79 | 80 | if not train_model: 81 | print('---- Best Dev {} is {:.4f} at epoch {}'.format( 82 | args.tuning_metric, 83 | epoch_stats[tuning_key][epoch_stats['best_epoch']], 84 | epoch_stats['best_epoch'] + 1)) 85 | 86 | if num_epoch_sans_improvement >= args.patience: 87 | print("Reducing learning rate") 88 | num_epoch_sans_improvement = 0 89 | model.cpu() 90 | gen.cpu() 91 | model = torch.load(args.model_path) 92 | gen = torch.load(learn.get_gen_path(args.model_path)) 93 | 94 | if args.cuda: 95 | model = model.cuda() 96 | gen = gen.cuda() 97 | args.lr *= .5 98 | optimizer = learn.get_optimizer([model, gen], args) 99 | 100 | # Restore model to best dev performance 101 | if os.path.exists(args.model_path): 102 | model.cpu() 103 | model = torch.load(args.model_path) 104 | gen.cpu() 105 | gen = torch.load(learn.get_gen_path(args.model_path)) 106 | 107 | return epoch_stats, model, gen 108 | 109 | 110 | def test_model(test_data, model, gen, args): 111 | ''' 112 | Run model on test data, and return loss, accuracy. 113 | ''' 114 | if args.cuda: 115 | model = model.cuda() 116 | gen = gen.cuda() 117 | 118 | test_loader = torch.utils.data.DataLoader( 119 | test_data, 120 | batch_size=args.batch_size, 121 | shuffle=False, 122 | num_workers=args.num_workers, 123 | drop_last=False) 124 | 125 | test_stats = metrics.init_metrics_dictionary(modes=['test']) 126 | 127 | mode = 'Test' 128 | train_model = False 129 | key_prefix = mode.lower() 130 | print("-------------\nTest") 131 | epoch_details, _, losses, preds, golds, rationales = run_epoch( 132 | data_loader=test_loader, 133 | train_model=train_model, 134 | model=model, 135 | gen=gen, 136 | optimizer=None, 137 | step=None, 138 | args=args) 139 | 140 | test_stats, log_statement = metrics.collate_epoch_stat(test_stats, epoch_details, 'test', args) 141 | test_stats['losses'] = losses 142 | test_stats['preds'] = preds 143 | test_stats['golds'] = golds 144 | test_stats['rationales'] = rationales 145 | 146 | print(log_statement) 147 | 148 | return test_stats 149 | 150 | def run_epoch(data_loader, train_model, model, gen, optimizer, step, args): 151 | ''' 152 | Train model for one pass of train data, and return loss, acccuracy 153 | ''' 154 | eval_model = not train_model 155 | data_iter = data_loader.__iter__() 156 | 157 | losses = [] 158 | obj_losses = [] 159 | k_selection_losses = [] 160 | k_continuity_losses = [] 161 | preds = [] 162 | golds = [] 163 | losses = [] 164 | texts = [] 165 | rationales = [] 166 | 167 | if train_model: 168 | model.train() 169 | gen.train() 170 | else: 171 | gen.eval() 172 | model.eval() 173 | 174 | num_batches_per_epoch = len(data_iter) 175 | if train_model: 176 | num_batches_per_epoch = min(len(data_iter), 10000) 177 | 178 | for _ in tqdm.tqdm(range(num_batches_per_epoch)): 179 | batch = data_iter.next() 180 | if train_model: 181 | step += 1 182 | if step % 100 == 0 or args.debug_mode: 183 | args.gumbel_temprature = max( np.exp((step+1) *-1* args.gumbel_decay), .05) 184 | 185 | x_indx = learn.get_x_indx(batch, args, eval_model) 186 | text = batch['text'] 187 | y = autograd.Variable(batch['y'], volatile=eval_model) 188 | 189 | if args.cuda: 190 | x_indx, y = x_indx.cuda(), y.cuda() 191 | 192 | if train_model: 193 | optimizer.zero_grad() 194 | 195 | if args.get_rationales: 196 | mask, z = gen(x_indx) 197 | else: 198 | mask = None 199 | 200 | logit, _ = model(x_indx, mask=mask) 201 | 202 | if args.use_as_tagger: 203 | logit = logit.view(-1, 2) 204 | y = y.view(-1) 205 | 206 | loss = get_loss(logit, y, args) 207 | obj_loss = loss 208 | 209 | if args.get_rationales: 210 | selection_cost, continuity_cost = gen.loss(mask, x_indx) 211 | 212 | loss += args.selection_lambda * selection_cost 213 | loss += args.continuity_lambda * continuity_cost 214 | 215 | if train_model: 216 | loss.backward() 217 | optimizer.step() 218 | 219 | if args.get_rationales: 220 | k_selection_losses.append( generic.tensor_to_numpy(selection_cost)) 221 | k_continuity_losses.append( generic.tensor_to_numpy(continuity_cost)) 222 | 223 | obj_losses.append(generic.tensor_to_numpy(obj_loss)) 224 | losses.append( generic.tensor_to_numpy(loss) ) 225 | batch_softmax = F.softmax(logit, dim=-1).cpu() 226 | preds.extend(torch.max(batch_softmax, 1)[1].view(y.size()).data.numpy()) 227 | 228 | texts.extend(text) 229 | rationales.extend(learn.get_rationales(mask, text)) 230 | 231 | if args.use_as_tagger: 232 | golds.extend(batch['y'].view(-1).numpy()) 233 | else: 234 | golds.extend(batch['y'].numpy()) 235 | 236 | 237 | 238 | epoch_metrics = metrics.get_metrics(preds, golds, args) 239 | 240 | epoch_stat = { 241 | 'loss' : np.mean(losses), 242 | 'obj_loss': np.mean(obj_losses) 243 | } 244 | 245 | for metric_k in epoch_metrics.keys(): 246 | epoch_stat[metric_k] = epoch_metrics[metric_k] 247 | 248 | if args.get_rationales: 249 | epoch_stat['k_selection_loss'] = np.mean(k_selection_losses) 250 | epoch_stat['k_continuity_loss'] = np.mean(k_continuity_losses) 251 | 252 | return epoch_stat, step, losses, preds, golds, rationales 253 | 254 | 255 | def get_loss(logit,y, args): 256 | if args.objective == 'cross_entropy': 257 | if args.use_as_tagger: 258 | loss = F.cross_entropy(logit, y, reduce=False) 259 | neg_loss = torch.sum(loss * (y == 0).float()) / torch.sum(y == 0).float() 260 | pos_loss = torch.sum(loss * (y == 1).float()) / torch.sum(y == 1).float() 261 | loss = args.tag_lambda * neg_loss + (1 - args.tag_lambda) * pos_loss 262 | else: 263 | loss = F.cross_entropy(logit, y) 264 | elif args.objective == 'mse': 265 | loss = F.mse_loss(logit, y.float()) 266 | else: 267 | raise Exception( 268 | "Objective {} not supported!".format(args.objective)) 269 | return loss 270 | -------------------------------------------------------------------------------- /rationale_net/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yala/text_nn/476d1336f5be7178bc13b70a569a1a0b964b8244/rationale_net/models/__init__.py -------------------------------------------------------------------------------- /rationale_net/models/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | import torch.nn.functional as F 5 | import pdb 6 | 7 | class CNN(nn.Module): 8 | 9 | def __init__(self, args, max_pool_over_time=False): 10 | super(CNN, self).__init__() 11 | 12 | self.args = args 13 | self.layers = [] 14 | for layer in range(args.num_layers): 15 | convs = [] 16 | for filt in args.filters: 17 | in_channels = args.embedding_dim if layer == 0 else args.filter_num * len( args.filters) 18 | kernel_size = filt 19 | new_conv = nn.Conv1d(in_channels=in_channels, out_channels=args.filter_num, kernel_size=kernel_size) 20 | self.add_module( 'layer_'+str(layer)+'_conv_'+str(filt), new_conv) 21 | convs.append(new_conv) 22 | 23 | self.layers.append(convs) 24 | 25 | self.max_pool = max_pool_over_time 26 | 27 | 28 | 29 | def _conv(self, x): 30 | layer_activ = x 31 | for layer in self.layers: 32 | next_activ = [] 33 | for conv in layer: 34 | left_pad = conv.kernel_size[0] - 1 35 | pad_tensor_size = [d for d in layer_activ.size()] 36 | pad_tensor_size[2] = left_pad 37 | left_pad_tensor =autograd.Variable( torch.zeros( pad_tensor_size ) ) 38 | if self.args.cuda: 39 | left_pad_tensor = left_pad_tensor.cuda() 40 | padded_activ = torch.cat( (left_pad_tensor, layer_activ), dim=2) 41 | next_activ.append( conv(padded_activ) ) 42 | 43 | # concat across channels 44 | layer_activ = F.relu( torch.cat(next_activ, 1) ) 45 | 46 | return layer_activ 47 | 48 | 49 | def _pool(self, relu): 50 | pool = F.max_pool1d(relu, relu.size(2)).squeeze(-1) 51 | return pool 52 | 53 | 54 | def forward(self, x): 55 | activ = self._conv(x) 56 | if self.max_pool: 57 | activ = self._pool(activ) 58 | return activ 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /rationale_net/models/empty.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pdb 4 | 5 | class Empty(torch.nn.Module): 6 | def __init__(self): 7 | super(Empty, self).__init__() 8 | 9 | def forward(self, x): 10 | return x 11 | -------------------------------------------------------------------------------- /rationale_net/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | import torch.nn.functional as F 5 | import rationale_net.models.cnn as cnn 6 | import pdb 7 | 8 | class Encoder(nn.Module): 9 | 10 | def __init__(self, embeddings, args): 11 | super(Encoder, self).__init__() 12 | ### Encoder 13 | self.args = args 14 | vocab_size, hidden_dim = embeddings.shape 15 | self.embedding_dim = hidden_dim 16 | self.embedding_layer = nn.Embedding( vocab_size, hidden_dim) 17 | self.embedding_layer.weight.data = torch.from_numpy( embeddings ) 18 | self.embedding_layer.weight.requires_grad = True 19 | self.embedding_fc = nn.Linear( hidden_dim, hidden_dim ) 20 | self.embedding_bn = nn.BatchNorm1d( hidden_dim) 21 | 22 | if args.model_form == 'cnn': 23 | self.cnn = cnn.CNN(args, max_pool_over_time=(not args.use_as_tagger)) 24 | self.fc = nn.Linear( len(args.filters)*args.filter_num, args.hidden_dim) 25 | else: 26 | raise NotImplementedError("Model form {} not yet supported for encoder!".format(args.model_form)) 27 | 28 | self.dropout = nn.Dropout(args.dropout) 29 | self.hidden = nn.Linear(args.hidden_dim, args.num_class) 30 | 31 | def forward(self, x_indx, mask=None): 32 | ''' 33 | x_indx: batch of word indices 34 | mask: Mask to apply over embeddings for tao ratioanles 35 | ''' 36 | x = self.embedding_layer(x_indx.squeeze(1)) 37 | if self.args.cuda: 38 | x = x.cuda() 39 | if not mask is None: 40 | x = x * mask.unsqueeze(-1) 41 | x = F.relu( self.embedding_fc(x)) 42 | x = self.dropout(x) 43 | 44 | if self.args.model_form == 'cnn': 45 | x = torch.transpose(x, 1, 2) # Switch X to (Batch, Embed, Length) 46 | hidden = self.cnn(x) 47 | hidden = F.relu( self.fc(hidden) ) 48 | else: 49 | raise Exception("Model form {} not yet supported for encoder!".format(args.model_form)) 50 | 51 | hidden = self.dropout(hidden) 52 | logit = self.hidden(hidden) 53 | return logit, hidden 54 | -------------------------------------------------------------------------------- /rationale_net/models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | import torch.nn.functional as F 5 | import rationale_net.models.cnn as cnn 6 | import rationale_net.utils.learn as learn 7 | import pdb 8 | 9 | ''' 10 | The generator selects a rationale z from a document x that should be sufficient 11 | for the encoder to make it's prediction. 12 | 13 | Several froms of Generator are supported. Namely CNN with arbitary number of layers, and @taolei's FastKNN 14 | ''' 15 | class Generator(nn.Module): 16 | 17 | def __init__(self, embeddings, args): 18 | super(Generator, self).__init__() 19 | vocab_size, hidden_dim = embeddings.shape 20 | self.embedding_layer = nn.Embedding( vocab_size, hidden_dim) 21 | self.embedding_layer.weight.data = torch.from_numpy( embeddings ) 22 | self.embedding_layer.weight.requires_grad = False 23 | self.args = args 24 | if args.model_form == 'cnn': 25 | self.cnn = cnn.CNN(args, max_pool_over_time = False) 26 | 27 | self.z_dim = 2 28 | 29 | self.hidden = nn.Linear((len(args.filters)* args.filter_num), self.z_dim) 30 | self.dropout = nn.Dropout(args.dropout) 31 | 32 | 33 | 34 | def __z_forward(self, activ): 35 | ''' 36 | Returns prob of each token being selected 37 | ''' 38 | activ = activ.transpose(1,2) 39 | logits = self.hidden(activ) 40 | probs = learn.gumbel_softmax(logits, self.args.gumbel_temprature, self.args.cuda) 41 | z = probs[:,:,1] 42 | return z 43 | 44 | 45 | def forward(self, x_indx): 46 | ''' 47 | Given input x_indx of dim (batch, length), return z (batch, length) such that z 48 | can act as element-wise mask on x 49 | ''' 50 | if self.args.model_form == 'cnn': 51 | x = self.embedding_layer(x_indx.squeeze(1)) 52 | if self.args.cuda: 53 | x = x.cuda() 54 | x = torch.transpose(x, 1, 2) # Switch X to (Batch, Embed, Length) 55 | activ = self.cnn(x) 56 | else: 57 | raise NotImplementedError("Model form {} not yet supported for generator!".format(args.model_form)) 58 | 59 | z = self.__z_forward(F.relu(activ)) 60 | mask = self.sample(z) 61 | return mask, z 62 | 63 | 64 | def sample(self, z): 65 | ''' 66 | Get mask from probablites at each token. Use gumbel 67 | softmax at train time, hard mask at test time 68 | ''' 69 | mask = z 70 | if self.training: 71 | mask = z 72 | else: 73 | ## pointwise set <.5 to 0 >=.5 to 1 74 | mask = learn.get_hard_mask(z) 75 | return mask 76 | 77 | 78 | def loss(self, mask, x_indx): 79 | ''' 80 | Compute the generator specific costs, i.e selection cost, continuity cost, and global vocab cost 81 | ''' 82 | selection_cost = torch.mean( torch.sum(mask, dim=1) ) 83 | l_padded_mask = torch.cat( [mask[:,0].unsqueeze(1), mask] , dim=1) 84 | r_padded_mask = torch.cat( [mask, mask[:,-1].unsqueeze(1)] , dim=1) 85 | continuity_cost = torch.mean( torch.sum( torch.abs( l_padded_mask - r_padded_mask ) , dim=1) ) 86 | return selection_cost, continuity_cost 87 | 88 | -------------------------------------------------------------------------------- /rationale_net/models/tagger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.autograd as autograd 4 | import torch.nn.functional as F 5 | import rationale_net.models.cnn as cnn 6 | import pdb 7 | 8 | ''' 9 | Implements a CNN with arbitary number of layers for tagging (predicts 0/1 for each token in text if token matches label), no max pool over time. 10 | ''' 11 | class Tagger(nn.Module): 12 | 13 | def __init__(self, embeddings, args): 14 | super(Tagger, self).__init__() 15 | vocab_size, hidden_dim = embeddings.shape 16 | self.embedding_layer = nn.Embedding(vocab_size, hidden_dim) 17 | self.embedding_layer.weight.data = torch.from_numpy(embeddings) 18 | self.embedding_layer.weight.requires_grad = False 19 | self.args = args 20 | if args.model_form == 'cnn': 21 | self.cnn = cnn.CNN(args, max_pool_over_time=False) 22 | 23 | self.hidden = nn.Linear((len(args.filters)*args.filter_num), args.num_tags) 24 | self.dropout = nn.Dropout(args.dropout) 25 | 26 | 27 | def forward(self, x_indx, mask): 28 | '''Given input x_indx of dim (batch_size, 1, max_length), return z (batch, length) such that z 29 | can act as element-wise mask on x''' 30 | if self.args.model_form == 'cnn': 31 | ## embedding layer takes in dim (batch_size, max_length), outputs x of dim (batch_size, max_length, hidden_dim) 32 | x = self.embedding_layer(x_indx.squeeze(1)) 33 | 34 | if self.args.cuda: 35 | x = x.cuda() 36 | ## switch x to dim (batch_size, hidden_dim, max_length) 37 | x = torch.transpose(x, 1, 2) 38 | ## activ of dim (batch_size, len(filters)*filter_num, max_length) 39 | activ = self.cnn(x) 40 | else: 41 | raise NotImplementedError("Model form {} not yet supported for generator!".format(args.model_form)) 42 | 43 | ## hidden layer takes activ transposed to dim (batch_size, max_length, len(filters)*filter_num) and outputs logit of dim (batch_size, max_length, num_tags) 44 | logit = self.hidden(torch.transpose(activ, 1, 2)) 45 | return logit, self.hidden 46 | -------------------------------------------------------------------------------- /rationale_net/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yala/text_nn/476d1336f5be7178bc13b70a569a1a0b964b8244/rationale_net/utils/__init__.py -------------------------------------------------------------------------------- /rationale_net/utils/embedding.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import numpy as np 3 | import torch 4 | import pickle 5 | import pdb 6 | 7 | 8 | NO_EMBEDDING_ERR = "Embedding {} not in EMBEDDING_REGISTRY! Available embeddings are {}" 9 | 10 | EMBEDDING_REGISTRY = {} 11 | 12 | 13 | def RegisterEmbedding(name): 14 | """Registers a dataset.""" 15 | 16 | def decorator(f): 17 | EMBEDDING_REGISTRY[name] = f 18 | return f 19 | return decorator 20 | 21 | 22 | # Depending on arg, return embeddings 23 | def get_embedding_tensor(args): 24 | if args.embedding not in EMBEDDING_REGISTRY: 25 | raise Exception( 26 | NO_EMBEDDING_ERR.format(args.embedding, EMBEDDING_REGISTRY.keys())) 27 | 28 | if args.embedding in EMBEDDING_REGISTRY: 29 | embeddings, word_to_indx = EMBEDDING_REGISTRY[args.embedding](args) 30 | 31 | args.embedding_dim = embeddings.shape[1] 32 | 33 | return embeddings, word_to_indx 34 | 35 | 36 | @RegisterEmbedding('beer') 37 | def getBeerEmbedding(args): 38 | embedding_path='raw_data/beer_review/review+wiki.filtered.200.txt.gz' 39 | lines = [] 40 | with gzip.open(embedding_path) as file: 41 | lines = file.readlines() 42 | file.close() 43 | embedding_tensor = [] 44 | word_to_indx = {} 45 | for indx, l in enumerate(lines): 46 | word, emb = l.split()[0], l.split()[1:] 47 | vector = [float(x) for x in emb ] 48 | if indx == 0: 49 | embedding_tensor.append( np.zeros( len(vector) ) ) 50 | embedding_tensor.append(vector) 51 | word_to_indx[word] = indx+1 52 | embedding_tensor = np.array(embedding_tensor, dtype=np.float32) 53 | return embedding_tensor, word_to_indx 54 | 55 | @RegisterEmbedding('glove') 56 | def getGloveEmbedding(args): 57 | embedding_path='data/embeddings/glove.6B/glove.6B.300d.txt' 58 | lines = [] 59 | with open(embedding_path) as file: 60 | lines = file.readlines() 61 | file.close() 62 | embedding_tensor = [] 63 | word_to_indx = {} 64 | for indx, l in enumerate(lines): 65 | word, emb = l.split()[0], l.split()[1:] 66 | if not len(emb) == 300: 67 | continue 68 | vector = [float(x) for x in emb ] 69 | if indx == 0: 70 | embedding_tensor.append( np.zeros( len(vector) ) ) 71 | embedding_tensor.append(vector) 72 | word_to_indx[word] = indx+1 73 | embedding_tensor = np.array(embedding_tensor, dtype=np.float32) 74 | return embedding_tensor, word_to_indx 75 | 76 | 77 | def get_indices_tensor(text_arr, word_to_indx, max_length): 78 | ''' 79 | -text_arr: array of word tokens 80 | -word_to_indx: mapping of word -> index 81 | -max length of return tokens 82 | 83 | returns tensor of same size as text with each words corresponding 84 | index 85 | ''' 86 | nil_indx = 0 87 | text_indx = [ word_to_indx[x] if x in word_to_indx else nil_indx for x in text_arr][:max_length] 88 | if len(text_indx) < max_length: 89 | text_indx.extend( [nil_indx for _ in range(max_length - len(text_indx))]) 90 | 91 | x = torch.LongTensor([text_indx]) 92 | 93 | return x 94 | -------------------------------------------------------------------------------- /rationale_net/utils/generic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pdb 3 | import argparse 4 | 5 | def tensor_to_numpy(tensor): 6 | return tensor.data[0] 7 | 8 | class Namespace: 9 | def __init__(self, **kwargs): 10 | self.__dict__.update(kwargs) 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='Rationale-Net Classifier') 15 | #setup 16 | parser.add_argument('--train', action='store_true', default=False, help='Whether or not to train model') 17 | parser.add_argument('--test', action='store_true', default=False, help='Whether or not to run model on test set') 18 | # device 19 | parser.add_argument('--cuda', action='store_true', default=False, help='enable the gpu' ) 20 | parser.add_argument('--num_gpus', type=int, default=1, help='Num GPUs to use. More than one gpu turns on multi_gpu training with nn.DataParallel.') 21 | parser.add_argument('--debug_mode', action='store_true', default=False, help='debug mode' ) 22 | parser.add_argument('--class_balance', action='store_true', default=False, help='use balanced samlping for train loaded' ) 23 | # learning 24 | parser.add_argument('--objective', default='cross_entropy', help='choose which loss objective to use') 25 | parser.add_argument('--aspect', default='overall', help='which aspect to train/eval on') 26 | parser.add_argument('--init_lr', type=float, default=0.001, help='initial learning rate [default: 0.001]') 27 | parser.add_argument('--epochs', type=int, default=256, help='number of epochs for train [default: 256]') 28 | parser.add_argument('--batch_size', type=int, default=128, help='batch size for training [default: 64]') 29 | parser.add_argument('--patience', type=int, default=10, help='Num epochs of no dev progress before half learning rate [default: 10]') 30 | parser.add_argument('--tuning_metric', type=str, default='loss', help='Metric to judge dev set results. Possible options loss, accuracy, precision, recall or f1, where precision/recall/f1 are all microaveraged. [default: loss]') 31 | #paths 32 | parser.add_argument('--save_dir', type=str, default='snapshot', help='where to save the snapshot') 33 | parser.add_argument('--results_path', type=str, default='', help='where to dump model config and epoch stats. If get_rationales is set to true, rationales for the test set will also be stored here.') 34 | parser.add_argument('--snapshot', type=str, default=None, help='filename of model snapshot to load[default: None]') 35 | # data loading 36 | parser.add_argument('--num_workers' , type=int, default=4, help='num workers for data loader') 37 | # model 38 | parser.add_argument('--model_form', type=str, default='cnn', help="Form of model, i.e cnn, rnn, etc.") 39 | parser.add_argument('--hidden_dim', type=int, default=100, help="Dim of hidden layer") 40 | parser.add_argument('--num_layers', type=int, default=1, help="Num layers of model_form to use") 41 | parser.add_argument('--dropout', type=float, default=0.1, help='the probability for dropout [default: 0.5]') 42 | parser.add_argument('--weight_decay', type=float, default=1e-3, help='L2 norm penalty [default: 1e-3]') 43 | parser.add_argument('--filter_num', type=int, default=100, help='number of each kind of kernel') 44 | parser.add_argument('--filters', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution') 45 | # data 46 | parser.add_argument('--dataset', default='news_group', help='choose which dataset to run on. [default: news_group]') 47 | parser.add_argument('--embedding', default='glove', help='choose what embeddings to use. To use them, please download them to "embeddings/glove.6B.300d.txt and set this argument to "glove" [default: random] ') 48 | 49 | # gumbel 50 | parser.add_argument('--gumbel_temprature', type=float, default=1, help="Start temprature for gumbel softmax. This is annealed via exponential decay") 51 | parser.add_argument('--gumbel_decay', type=float, default=1e-5, help="Start temprature for gumbel softmax. This is annealed via linear decay") 52 | # rationale 53 | parser.add_argument('--get_rationales', action='store_true', default=False, help="output attributions for dataset. Note, will only be stored for test set in results file, as indicated by results_path") 54 | parser.add_argument('--selection_lambda', type=float, default=.01, help="y1 in Gen cost L + y1||z|| + y2|zt - zt-1| + y3|{z}|") 55 | parser.add_argument('--continuity_lambda', type=float, default=.01, help="y2 in Gen cost L + y1||z|| + y2|zt - zt-1|+ y3|{z}|") 56 | parser.add_argument('--num_class', type=int, default=2, help="num classes") 57 | 58 | # tagging task support. Note, does not support rationales x tagging 59 | parser.add_argument('--use_as_tagger', action='store_true', default=False, help="Use model for a taggign task, i.e with labels per word in the document. Note only supports binary tagging") 60 | parser.add_argument('--tag_lambda', type=float, default=.5, help="Lambda to weight the null entity class") 61 | 62 | args = parser.parse_args() 63 | 64 | # update args and print 65 | args.filters = [int(k) for k in args.filters.split(',')] 66 | if args.objective == 'mse': 67 | args.num_class = 1 68 | 69 | print("\nParameters:") 70 | for attr, value in sorted(args.__dict__.items()): 71 | print("\t{}={}".format(attr.upper(), value)) 72 | 73 | return args 74 | 75 | 76 | -------------------------------------------------------------------------------- /rationale_net/utils/learn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch.utils.data as data 6 | import pdb 7 | 8 | def get_train_loader(train_data, args): 9 | if args.class_balance: 10 | sampler = data.sampler.WeightedRandomSampler( 11 | weights=train_data.weights, 12 | num_samples=len(train_data), 13 | replacement=True) 14 | train_loader = data.DataLoader( 15 | train_data, 16 | num_workers=args.num_workers, 17 | sampler=sampler, 18 | batch_size=args.batch_size) 19 | else: 20 | train_loader = data.DataLoader( 21 | train_data, 22 | batch_size=args.batch_size, 23 | shuffle=True, 24 | num_workers=args.num_workers, 25 | drop_last=False) 26 | 27 | return train_loader 28 | 29 | def get_rationales(mask, text): 30 | if mask is None: 31 | return text 32 | masked_text = [] 33 | for i, t in enumerate(text): 34 | sample_mask = list(mask.data[i]) 35 | original_words = t.split() 36 | words = [ w if m > .5 else "_" for w,m in zip(original_words, sample_mask) ] 37 | masked_sample = " ".join(words) 38 | masked_text.append(masked_sample) 39 | return masked_text 40 | 41 | 42 | 43 | def get_dev_loader(dev_data, args): 44 | dev_loader = data.DataLoader( 45 | dev_data, 46 | batch_size=args.batch_size, 47 | shuffle=False, 48 | num_workers=args.num_workers, 49 | drop_last=False) 50 | return dev_loader 51 | 52 | def get_optimizer(models, args): 53 | ''' 54 | -models: List of models (such as Generator, classif, memory, etc) 55 | -args: experiment level config 56 | 57 | returns: torch optimizer over models 58 | ''' 59 | params = [] 60 | for model in models: 61 | params.extend([param for param in model.parameters() if param.requires_grad]) 62 | return torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 63 | 64 | 65 | def get_x_indx(batch, args, eval_model): 66 | x_indx = autograd.Variable(batch['x'], volatile=eval_model) 67 | return x_indx 68 | 69 | 70 | 71 | def get_hard_mask(z, return_ind=False): 72 | ''' 73 | -z: torch Tensor where each element probablity of element 74 | being selected 75 | -args: experiment level config 76 | 77 | returns: A torch variable that is binary mask of z >= .5 78 | ''' 79 | max_z, ind = torch.max(z, dim=-1) 80 | if return_ind: 81 | del z 82 | return ind 83 | masked = torch.ge(z, max_z.unsqueeze(-1)).float() 84 | del z 85 | return masked 86 | 87 | def get_gen_path(model_path): 88 | ''' 89 | -model_path: path of encoder model 90 | 91 | returns: path of generator 92 | ''' 93 | return '{}.gen'.format(model_path) 94 | 95 | def one_hot(label, num_class): 96 | vec = torch.zeros( (1, num_class) ) 97 | vec[0][label] = 1 98 | return vec 99 | 100 | 101 | def gumbel_softmax(input, temperature, cuda): 102 | noise = torch.rand(input.size()) 103 | noise.add_(1e-9).log_().neg_() 104 | noise.add_(1e-9).log_().neg_() 105 | noise = autograd.Variable(noise) 106 | if cuda: 107 | noise = noise.cuda() 108 | x = (input + noise) / temperature 109 | x = F.softmax(x.view(-1, x.size()[-1]), dim=-1) 110 | return x.view_as(input) 111 | -------------------------------------------------------------------------------- /rationale_net/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import sklearn.metrics 2 | 3 | def collate_epoch_stat(stat_dict, epoch_details, mode, args): 4 | ''' 5 | Update stat_dict with details from epoch_details and create 6 | log statement 7 | 8 | - stat_dict: a dictionary of statistics lists to update 9 | - epoch_details: list of statistics for a given epoch 10 | - mode: train, dev or test 11 | - args: model run configuration 12 | 13 | returns: 14 | -stat_dict: updated stat_dict with epoch details 15 | -log_statement: log statement sumarizing new epoch 16 | 17 | ''' 18 | log_statement_details = '' 19 | for metric in epoch_details: 20 | loss = epoch_details[metric] 21 | stat_dict['{}_{}'.format(mode, metric)].append(loss) 22 | 23 | log_statement_details += ' -{}: {}'.format(metric, loss) 24 | 25 | log_statement = '\n {} - {}\n--'.format( 26 | args.objective, log_statement_details ) 27 | 28 | return stat_dict, log_statement 29 | 30 | def get_metrics(preds, golds, args): 31 | metrics = {} 32 | 33 | if args.objective in ['cross_entropy', 'margin']: 34 | metrics['accuracy'] = sklearn.metrics.accuracy_score(y_true=golds, y_pred=preds) 35 | metrics['confusion_matrix'] = sklearn.metrics.confusion_matrix(y_true=golds,y_pred=preds) 36 | metrics['precision'] = sklearn.metrics.precision_score(y_true=golds, y_pred=preds, average="weighted") 37 | metrics['recall'] = sklearn.metrics.recall_score(y_true=golds,y_pred=preds, average="weighted") 38 | metrics['f1'] = sklearn.metrics.f1_score(y_true=golds,y_pred=preds, average="weighted") 39 | 40 | metrics['mse'] = "NA" 41 | 42 | elif args.objective == 'mse': 43 | metrics['mse'] = sklearn.metrics.mean_squared_error(y_true=golds, y_pred=preds) 44 | metrics['confusion_matrix'] = "NA" 45 | metrics['accuracy'] = "NA" 46 | metrics['precision'] = "NA" 47 | metrics['recall'] = "NA" 48 | metrics['f1'] = 'NA' 49 | 50 | return metrics 51 | 52 | 53 | 54 | 55 | 56 | def init_metrics_dictionary(modes): 57 | ''' 58 | Create dictionary with empty array for each metric in each mode 59 | ''' 60 | epoch_stats = {} 61 | metrics = [ 62 | 'loss', 'obj_loss', 'k_selection_loss', 63 | 'k_continuity_loss', 'accuracy', 'precision', 'recall', 'f1', 'confusion_matrix', 'mse'] 64 | for metric in metrics: 65 | for mode in modes: 66 | key = "{}_{}".format(mode, metric) 67 | epoch_stats[key] = [] 68 | 69 | return epoch_stats 70 | -------------------------------------------------------------------------------- /rationale_net/utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rationale_net.models.encoder as encoder 3 | import rationale_net.models.generator as generator 4 | import rationale_net.models.tagger as tagger 5 | import rationale_net.models.empty as empty 6 | import rationale_net.utils.learn as learn 7 | import os 8 | import pdb 9 | 10 | def get_model(args, embeddings, train_data): 11 | if args.snapshot is None: 12 | if args.use_as_tagger == True: 13 | gen = empty.Empty() 14 | model = tagger.Tagger(embeddings, args) 15 | else: 16 | gen = generator.Generator(embeddings, args) 17 | model = encoder.Encoder(embeddings, args) 18 | else : 19 | print('\nLoading model from [%s]...' % args.snapshot) 20 | try: 21 | gen_path = learn.get_gen_path(args.snapshot) 22 | if os.path.exists(gen_path): 23 | gen = torch.load(gen_path) 24 | model = torch.load(args.snapshot) 25 | except : 26 | print("Sorry, This snapshot doesn't exist."); exit() 27 | 28 | if args.num_gpus > 1: 29 | model = nn.DataParallel(model, 30 | device_ids=range(args.num_gpus)) 31 | 32 | if not gen is None: 33 | gen = nn.DataParallel(gen, 34 | device_ids=range(args.num_gpus)) 35 | return gen, model 36 | -------------------------------------------------------------------------------- /rationale_net/utils/parsing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import hashlib 4 | 5 | 6 | POSS_VAL_NOT_LIST = "Flag {} has an invalid list of values: {}. Length of list must be >=1" 7 | 8 | 9 | def md5(key): 10 | ''' 11 | returns a hashed with md5 string of the key 12 | ''' 13 | return hashlib.md5(key.encode()).hexdigest() 14 | 15 | def parse_dispatcher_config(config): 16 | ''' 17 | Parses an experiment config, and creates jobs. For flags that are expected to be a single item, 18 | but the config contains a list, this will return one job for each item in the list. 19 | :config - experiment_config 20 | 21 | returns: jobs - a list of flag strings, each of which encapsulates one job. 22 | *Example: --train --cuda --dropout=0.1 ... 23 | returns: experiment_axies - axies that the grid search is searching over 24 | ''' 25 | jobs = [""] 26 | experiment_axies = [] 27 | search_space = config['search_space'] 28 | 29 | # Go through the tree of possible jobs and enumerate into a list of jobs 30 | for ind, flag in enumerate(search_space): 31 | possible_values = search_space[flag] 32 | if len(possible_values) > 1: 33 | experiment_axies.append(flag) 34 | 35 | children = [] 36 | if len(possible_values) == 0 or type(possible_values) is not list: 37 | raise Exception(POSS_VAL_NOT_LIST.format(flag, possible_values)) 38 | for value in possible_values: 39 | for parent_job in jobs: 40 | if type(value) is bool: 41 | if value: 42 | new_job_str = "{} --{}".format(parent_job, flag) 43 | else: 44 | new_job_str = parent_job 45 | elif type(value) is list: 46 | val_list_str = " ".join([str(v) for v in value]) 47 | new_job_str = "{} --{} {}".format(parent_job, flag, 48 | val_list_str) 49 | else: 50 | new_job_str = "{} --{} {}".format(parent_job, flag, value) 51 | children.append(new_job_str) 52 | jobs = children 53 | 54 | return jobs, experiment_axies 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl 2 | torchvision 3 | sklearn 4 | scipy 5 | pydot 6 | matplotlib 7 | tqdm 8 | twilio 9 | -------------------------------------------------------------------------------- /requirements3.txt: -------------------------------------------------------------------------------- 1 | http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl 2 | torchvision 3 | sklearn 4 | scipy 5 | pydot 6 | matplotlib 7 | tqdm 8 | twilio 9 | -------------------------------------------------------------------------------- /scripts/dispatcher.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import os 4 | import multiprocessing 5 | import pickle 6 | import csv 7 | from twilio.rest import Client 8 | import json 9 | import sys 10 | from os.path import dirname, realpath 11 | 12 | sys.path.append(dirname(dirname(realpath(__file__)))) 13 | 14 | import rationale_net.utils.parsing as parsing 15 | 16 | EXPERIMENT_CRASH_MSG = "ALERT! job:[{}] has crashed! Check logfile at:[{}]" 17 | CONFIG_NOT_FOUND_MSG = "ALERT! {} config {} file does not exist!" 18 | RESULTS_PATH_APPEAR_ERR = 'results_path should not appear in config. It will be determined automatically per job' 19 | SUCESSFUL_SEARCH_STR = "SUCCESS! Grid search results dumped to {}. Best dev loss: {}, dev accuracy: {:.3f}" 20 | 21 | RESULT_KEY_STEMS = ['{}_loss', '{}_obj_loss', '{}_k_selection_loss', 22 | '{}_k_continuity_loss','{}_metric'] 23 | 24 | LOG_KEYS = ['results_path', 'model_path', 'log_path'] 25 | SORT_KEY = 'dev_loss' 26 | 27 | parser = argparse.ArgumentParser(description='OncoNet Grid Search Dispatcher. For use information, see `doc/README.md`') 28 | parser.add_argument("--experiment_config_path", required=True, type=str, help="Path of experiment config") 29 | parser.add_argument("--alert_config_path", type=str, default='configs/alert_config.json', help="Path of alert config") 30 | parser.add_argument('--log_dir', type=str, default="logs", help="path to store logs and detailed job level result files") 31 | parser.add_argument('--result_path', type=str, default="results/grid_search.csv", help="path to store grid_search table. This is preferably on shared storage") 32 | parser.add_argument('--rerun_experiments', action='store_true', default=False, help='whether to rerun experiments with the same result file location') 33 | 34 | 35 | def send_text_msg(msg, alert_config, twilio_config): 36 | ''' 37 | Send a text message using twilio acct specified twilio conf to numbers 38 | specified in alert_conf. 39 | If suppress_alerts is turned on, do nothing 40 | :msg: - body of text message 41 | :alert_config: - dictionary with a list fo numbers to send message to 42 | :twilio-config: - dictionary with twilio SID, TOKEN, and phone number 43 | ''' 44 | if alert_config['suppress_alerts']: 45 | return 46 | client = Client(twilio_config['ACCOUNT_SID'], twilio_config['AUTH_TOKEN']) 47 | for number in [alert_config['alert_nums']]: 48 | client.messages.create( 49 | to=number, from_=twilio_config['twilio_num'], body=msg) 50 | 51 | 52 | def launch_experiment(gpu, flag_string, alert_conf, twilio_conf): 53 | ''' 54 | Launch an experiment and direct logs and results to a unique filepath. 55 | Alert of something goes wrong. 56 | :gpu: gpu to run this machine on. 57 | :flag_string: flags to use for this model run. Will be fed into 58 | scripts/main.py 59 | ''' 60 | if not os.path.isdir(args.log_dir): 61 | os.makedirs(args.log_dir) 62 | 63 | log_name = parsing.md5(flag_string) 64 | log_stem = os.path.join(args.log_dir, log_name) 65 | log_path = '{}.txt'.format(log_stem) 66 | results_path = "{}.results".format(log_stem) 67 | 68 | experiment_string = "CUDA_VISIBLE_DEVICES={} python -u scripts/main.py {} --results_path {}".format( 69 | gpu, flag_string, results_path) 70 | 71 | # forward logs to logfile 72 | shell_cmd = "{} > {} 2>&1".format(experiment_string, log_path) 73 | print("Lauched exp: {}".format(shell_cmd)) 74 | if not os.path.exists(results_path) or args.rerun_experiments: 75 | subprocess.call(shell_cmd, shell=True) 76 | 77 | if not os.path.exists(results_path): 78 | # running this process failed, alert me 79 | job_fail_msg = EXPERIMENT_CRASH_MSG.format(experiment_string, log_path) 80 | send_text_msg(job_fail_msg, alert_conf, twilio_conf) 81 | 82 | return results_path, log_path 83 | 84 | 85 | def worker(gpu, job_queue, done_queue, alert_config, twilio_config): 86 | ''' 87 | Worker thread for each gpu. Consumes all jobs and pushes results to done_queue. 88 | :gpu - gpu this worker can access. 89 | :job_queue - queue of available jobs. 90 | :done_queue - queue where to push results. 91 | ''' 92 | while not job_queue.empty(): 93 | params = job_queue.get() 94 | if params is None: 95 | return 96 | done_queue.put( 97 | launch_experiment(gpu, params, alert_config, twilio_config)) 98 | 99 | 100 | if __name__ == "__main__": 101 | 102 | args = parser.parse_args() 103 | if not os.path.exists(args.experiment_config_path): 104 | print(CONFIG_NOT_FOUND_MSG.format("experiment", args.experiment_config_path)) 105 | sys.exit(1) 106 | experiment_config = json.load(open(args.experiment_config_path, 'r')) 107 | 108 | if 'results_path' in experiment_config['search_space']: 109 | print (RESULTS_PATH_APPEAR_ERR) 110 | sys.exit(1) 111 | 112 | if not os.path.exists(args.alert_config_path): 113 | print(CONFIG_NOT_FOUND_MSG.format("alert", args.alert_config_path)) 114 | sys.exit(1) 115 | alert_config = json.load(open(args.alert_config_path, 'r')) 116 | 117 | twilio_conf_path = alert_config['path_to_twilio_secret'] 118 | if not os.path.exists(twilio_conf_path): 119 | print(CONFIG_NOT_FOUND_MSG.format("twilio", twilio_conf_path)) 120 | 121 | twilio_config = None 122 | if not alert_config['suppress_alerts']: 123 | twilio_config = json.load(open(twilio_conf_path, 'r')) 124 | 125 | job_list, experiment_axies = parsing.parse_dispatcher_config(experiment_config) 126 | job_queue = multiprocessing.Queue() 127 | done_queue = multiprocessing.Queue() 128 | 129 | for job in job_list: 130 | job_queue.put(job) 131 | print("Launching Dispatcher with {} jobs!".format(len(job_list))) 132 | print() 133 | for gpu in experiment_config['available_gpus']: 134 | print("Start gpu worker {}".format(gpu)) 135 | multiprocessing.Process(target=worker, args=(gpu, job_queue, done_queue, alert_config, twilio_config)).start() 136 | print() 137 | 138 | summary = [] 139 | result_keys = [] 140 | for mode in ['train','dev','test']: 141 | result_keys.extend( [k.format(mode) for k in RESULT_KEY_STEMS ]) 142 | for _ in range(len(job_list)): 143 | result_path, log_path = done_queue.get() 144 | assert result_path is not None 145 | try: 146 | result_dict = pickle.load(open(result_path, 'rb')) 147 | except Exception as e: 148 | print("Experiment failed! Logs are located at: {}".format(log_path)) 149 | continue 150 | 151 | result_dict['log_path'] = log_path 152 | # Get results from best epoch and move to top level of results dict 153 | best_epoch_indx = result_dict['epoch_stats']['best_epoch'] 154 | present_result_keys = [] 155 | for k in result_keys: 156 | if (k in result_dict['test_stats'] and len(result_dict['test_stats'][k])>0) or (k in result_dict['epoch_stats'] and len(result_dict['epoch_stats'][k])>0): 157 | present_result_keys.append(k) 158 | if 'test' in k: 159 | result_dict[k] = result_dict['test_stats'][k][0] 160 | else: 161 | result_dict[k] = result_dict['epoch_stats'][k][best_epoch_indx] 162 | 163 | 164 | summary_columns = experiment_axies + present_result_keys + LOG_KEYS 165 | # Only export keys we want to see in sheet to csv 166 | summary_dict = {} 167 | for key in summary_columns: 168 | summary_dict[key] = result_dict[key] 169 | summary.append(summary_dict) 170 | summary = sorted(summary, key=lambda k: k[SORT_KEY]) 171 | 172 | dump_result_string = SUCESSFUL_SEARCH_STR.format( 173 | args.result_path, summary[0]['dev_loss'], summary[0]['dev_metric'] 174 | ) 175 | # Write summary to csv 176 | with open(args.result_path, 'w') as out_file: 177 | writer = csv.DictWriter(out_file, fieldnames=summary_columns) 178 | writer.writeheader() 179 | for experiment in summary: 180 | writer.writerow(experiment) 181 | 182 | print(dump_result_string) 183 | send_text_msg(dump_result_string, alert_config, twilio_config) 184 | -------------------------------------------------------------------------------- /scripts/main.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, realpath 2 | import sys 3 | sys.path.append(dirname(dirname(realpath(__file__)))) 4 | import argparse 5 | 6 | import rationale_net.datasets.factory as dataset_factory 7 | import rationale_net.utils.embedding as embedding 8 | import rationale_net.utils.model as model_factory 9 | import rationale_net.utils.generic as generic 10 | import rationale_net.learn.train as train 11 | import os 12 | import torch 13 | import datetime 14 | import pickle 15 | import pdb 16 | 17 | 18 | if __name__ == '__main__': 19 | # update args and print 20 | args = generic.parse_args() 21 | 22 | embeddings, word_to_indx = embedding.get_embedding_tensor(args) 23 | 24 | train_data, dev_data, test_data = dataset_factory.get_dataset(args, word_to_indx) 25 | 26 | results_path_stem = args.results_path.split('/')[-1].split('.')[0] 27 | args.model_path = '{}.pt'.format(os.path.join(args.save_dir, results_path_stem)) 28 | 29 | # model 30 | gen, model = model_factory.get_model(args, embeddings, train_data) 31 | 32 | print() 33 | # train 34 | if args.train : 35 | epoch_stats, model, gen = train.train_model(train_data, dev_data, model, gen, args) 36 | args.epoch_stats = epoch_stats 37 | save_path = args.results_path 38 | print("Save train/dev results to", save_path) 39 | args_dict = vars(args) 40 | pickle.dump(args_dict, open(save_path,'wb') ) 41 | 42 | 43 | # test 44 | if args.test : 45 | test_stats = train.test_model(test_data, model, gen, args) 46 | args.test_stats = test_stats 47 | args.train_data = train_data 48 | args.test_data = test_data 49 | 50 | save_path = args.results_path 51 | print("Save test results to", save_path) 52 | args_dict = vars(args) 53 | pickle.dump(args_dict, open(save_path,'wb') ) 54 | -------------------------------------------------------------------------------- /scripts/preprocess/preprocess_snli.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import re 4 | 5 | if __name__ == "__main__": 6 | regex = re.compile('[^a-zA-Z]') 7 | for mode in ['train', 'dev', 'test']: 8 | 9 | with open('raw_data/snli_1.0/snli_1.0_{}.jsonl'.format(mode), 'r') as f: 10 | lines = f.readlines() 11 | 12 | raw_data = [ json.loads(line) for line in lines] 13 | 14 | data = [] 15 | for ind, row in enumerate(raw_data): 16 | if row['gold_label'] == '-': 17 | continue 18 | concat_text = '{} \t {}'.format( 19 | row['sentence1'].lower(), row['sentence2'].lower()) 20 | data.append({ 21 | 'text1': regex.sub(' ', row['sentence1']), 22 | 'text2': regex.sub(' ', row['sentence2']), 23 | 'text' : regex.sub(' ',concat_text), 24 | 'label': row['gold_label'], 25 | 'uid': ind 26 | }) 27 | pickle.dump(data[:300], open('raw_data/snli_1.0/{}.debug.p'.format(mode),'w')) 28 | pickle.dump(data, open('raw_data/snli_1.0/{}.p'.format(mode),'w')) 29 | -------------------------------------------------------------------------------- /tutorial/TextCNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Text Classification with CNN (PyTorch)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## 1.0 Introduction\n", 15 | "\n", 16 | "Text classification is one of the most common Natural Language Processing tasks. It consists in encoding a text into a tensor that can be then used by a machine learning method to predict a label, which may represent any category.\n", 17 | "\n", 18 | "Text classification has been used for a wide range of purposes, such as text categorization, spam detection, information extraction, sentiment analysis, and so on.\n", 19 | "\n", 20 | "In this short tutorial, we will develop a simple Convolutional Neural Network (CNN) for text classification in PyTorch. All code will be commented to increase readability.\n", 21 | "\n", 22 | "We will train and test our model on the 20 Newsgroup Dataset following the steps below:\n", 23 | " \n", 24 | "1. Load the word embeddings (e.g. Word2Vec or Glove)\n", 25 | "2. Load the dataset (i.e. 20 Newsgroup Dataset)\n", 26 | "3. Define the hyperparameters (e.g. arguments)\n", 27 | "4. Define the model (i.e. CNN)\n", 28 | "5. Train and Validate the model (Epochs, Metrics, etc.)\n", 29 | "6. Test the model (Metrics)\n", 30 | "\n", 31 | "The reader can expand and re-adapt our system to work on different datasets and for different purposes.\n", 32 | "\n", 33 | "\n", 34 | "### 1.1 References and Acknowledgements\n", 35 | "\n", 36 | "Most of the code described below is adapted from:\n", 37 | "- TextCNN (Yoon, 2014): https://arxiv.org/abs/1408.5882\n", 38 | "- Rationale Net (Tao et al., 2016): https://arxiv.org/abs/1606.04155\n", 39 | "- Extraction from Breast Pathology Reports (Yala et al., 2016): https://www.biorxiv.org/content/early/2016/10/10/079913\n", 40 | "\n", 41 | "We recommend the reader to go through these papers for having a clear understanding of the model and the theory behind it." 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## 2.0 Task\n", 49 | "\n", 50 | "The 20 Newsgroups dataset (http://qwone.com/~jason/20Newsgroups/) is a collection of approximately 20,000 newsgroup documents, partitioned across 20 different categories (med, space, atheism, etc.). This dataset has been previously adopted for both clustering and classification of documents.\n", 51 | "\n", 52 | "In this tutorial, we will load the dataset through the Scikit-Learn interface and we will use it for text classification. See section 4.0." 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## 3.0 Word Embeddings\n", 60 | "\n", 61 | "\"Word embeddings\" refers to a set of language modeling and feature learning techniques that allows computer to learn word representations in dense vectors of real numbers (as opposed to sparse co-occurrence vectors).\n", 62 | "\n", 63 | "The two most common word embedding types are Word2Vec (either Continuous Bag-of-words or Skip-Gram) and Glove, even though a number of other algorithms have been proposed through the years. In this tutorial we do not intend to describe how such algorithms work, but we suggest the reader to look at least for some basic information about them.\n", 64 | "\n", 65 | "What we can briefly mention here is that word embeddings rely on the *Distributional Hypothesis* (Harris, 1954), according to which words that occur in similar contexts tend to be similar. If we count or predict the contexts in which words occur, we can learn vectorial representations that are expected to represent similarity by mean of distance in the generated vectorial semantic space. Word vectors representing similar meaning will be closer than word vectors representing different ones.\n", 66 | "\n", 67 | "\n", 68 | "### 3.1 Load the Word Embeddings\n", 69 | "\n", 70 | "The first step to make our algorithm work is to provide it with word vectorial representations.\n", 71 | "\n", 72 | "One way to do so would be to collect the vocabulary used in our target dataset and learn the vectorial representations for each word from a large corpus.\n", 73 | "\n", 74 | "Another, and more practical, way consists in loading the pre-trained word embeddings, which can be easily downloaded from the Web. This is possible because we can expect that the majority of words used in our newsgroup dataset is common and frequent enough to exist in the pre-trained word embeddings. Such assumption would have been wrong if we had to deal with medical or pharmaceutical domain, as the vocabulary would have contained very rare words.\n", 75 | "\n", 76 | "In the code below, we will use Glove embeddings (https://nlp.stanford.edu/projects/glove/), but the reader can eventually use different kind of embeddings." 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 1, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# Load the Embeddings\n", 86 | "import numpy as np\n", 87 | "\n", 88 | "# Set the path where you have downloaded the embeddings\n", 89 | "emb_path = \"/scratch1/esantus/text_nn/data/embeddings/glove.6B/glove.6B.300d.txt\"\n", 90 | "\n", 91 | "# Set the embedding size\n", 92 | "emb_dims = 300\n", 93 | "\n", 94 | "\n", 95 | "def load_embeddings(emb_path, emb_dims):\n", 96 | " '''\n", 97 | " Load the embeddings from a text file\n", 98 | " \n", 99 | " :param emb_path: Path of the text file\n", 100 | " :param emb_dims: Embedding dimensions\n", 101 | " \n", 102 | " :return emb_tensor: tensor containing all word embeedings\n", 103 | " :return word_to_indx: dictionary with word:index\n", 104 | " '''\n", 105 | "\n", 106 | " # Load the file\n", 107 | " lines = open(emb_path).readlines()\n", 108 | " \n", 109 | " # Creating the list and adding the PADDING embedding\n", 110 | " emb_tensor = [np.zeros(emb_dims)]\n", 111 | " word_to_indx = {'PADDING_WORD':0}\n", 112 | " \n", 113 | " # For each line, save the embedding and the word:index\n", 114 | " for indx, l in enumerate(lines):\n", 115 | " word, emb = l.split()[0], l.split()[1:]\n", 116 | " \n", 117 | " if not len(emb) == emb_dims:\n", 118 | " continue\n", 119 | " \n", 120 | " # Update the embedding list and the word:index dictionary\n", 121 | " emb_tensor.append([float(x) for x in emb])\n", 122 | " word_to_indx[word] = indx+1\n", 123 | " \n", 124 | " # Turning the list into a numpy object\n", 125 | " emb_tensor = np.array(emb_tensor, dtype=np.float32)\n", 126 | " return emb_tensor, word_to_indx" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "The function load_embeddings takes in input the embedding path (pointing to a text file with one word and vector per line) and the embedding dimensions.\n", 134 | "\n", 135 | "It loads the embeddings into emb_tensor, adding a zero-padding embedding in position zero. For each word in the emb_tensor, the word index is recorded in the word_to_indx dictionary.\n", 136 | "\n", 137 | "Below we call the function and print the dimensions of both the vector tensor and dictionary." 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 2, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "Words: 400001\n", 150 | "Vectors (+ zero-padding): (400001, 300)\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "# Calling load_embeddings and printing the size of the returned objects\n", 156 | "emb_tensor, word_to_indx = load_embeddings(emb_path, emb_dims)\n", 157 | "\n", 158 | "print('Words: {}\\nVectors (+ zero-padding): {}'.format(len(word_to_indx.keys()), emb_tensor.shape))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "## 4.0 Load the Dataset\n", 166 | "\n", 167 | "In section 2.0 we have shortly introduced the task. In this section we show how to load the dataset using the Scikit-Learn API.\n", 168 | "\n", 169 | "The dataset needs to be processed in a way that it can be then used by our Convolutional Neural Network for classification: our classes below are used exactly for this goal." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 3, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "# Load the Dataset\n", 179 | "\n", 180 | "from sklearn.datasets import fetch_20newsgroups\n", 181 | "from abc import ABCMeta, abstractmethod, abstractproperty\n", 182 | "import torch.utils.data as data\n", 183 | "import torch\n", 184 | "\n", 185 | "import re\n", 186 | "import random\n", 187 | "import tqdm\n", 188 | "\n", 189 | "\n", 190 | "# Classes in the dataset\n", 191 | "classes = ['alt.atheism',\n", 192 | " 'comp.graphics',\n", 193 | " 'comp.os.ms-windows.misc',\n", 194 | " 'comp.sys.ibm.pc.hardware',\n", 195 | " 'comp.sys.mac.hardware',\n", 196 | " 'comp.windows.x',\n", 197 | " 'misc.forsale',\n", 198 | " 'rec.autos',\n", 199 | " 'rec.motorcycles',\n", 200 | " 'rec.sport.baseball',\n", 201 | " 'rec.sport.hockey',\n", 202 | " 'sci.crypt',\n", 203 | " 'sci.electronics',\n", 204 | " 'sci.med',\n", 205 | " 'sci.space',\n", 206 | " 'soc.religion.christian',\n", 207 | " 'talk.politics.guns',\n", 208 | " 'talk.politics.mideast',\n", 209 | " 'talk.politics.misc',\n", 210 | " 'talk.religion.misc']\n", 211 | "\n", 212 | "\n", 213 | "class AbstractDataset(data.Dataset):\n", 214 | " '''\n", 215 | " Abstract class that adds general method to the Newsgroup dataset\n", 216 | " '''\n", 217 | " \n", 218 | " __metaclass__ = ABCMeta\n", 219 | "\n", 220 | " def __len__(self):\n", 221 | " return len(self.dataset)\n", 222 | "\n", 223 | " def __getitem__(self, index):\n", 224 | " sample = self.dataset[index]\n", 225 | " return sample\n", 226 | "\n", 227 | "\n", 228 | "class Newsgroup(AbstractDataset):\n", 229 | " '''\n", 230 | " Newsgroup dataset loader\n", 231 | " '''\n", 232 | " \n", 233 | " def __init__(self, set_type, classes, word_to_indx, class_balance_true=True, max_length=80):\n", 234 | " '''\n", 235 | " Load the dataset from SK-Learn\n", 236 | "\n", 237 | " :param set_type: string containing either 'train', 'dev' or 'test'\n", 238 | " :param classes: list of strings containing the classes\n", 239 | " :param word_to_indx: dictionary of word:index\n", 240 | " :param max_length: integer with max word to consider\n", 241 | " :return: nothing\n", 242 | " '''\n", 243 | "\n", 244 | " # Deterministic randomization\n", 245 | " random.seed(0)\n", 246 | " \n", 247 | " n_classes = len(classes)\n", 248 | " class_balance = {}\n", 249 | " self.dataset = []\n", 250 | "\n", 251 | " # If train or dev...\n", 252 | " if set_type in ['train', 'dev']:\n", 253 | " data = self.preprocess(fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'),\n", 254 | " categories=classes))\n", 255 | " \n", 256 | " # Randomly split train in 80-20%\n", 257 | " random.shuffle(data)\n", 258 | " num_train = int(len(data)*.8)\n", 259 | " if set_type == 'train':\n", 260 | " data = data[:num_train]\n", 261 | " else:\n", 262 | " data = data[num_train:]\n", 263 | " \n", 264 | " # If test... \n", 265 | " else:\n", 266 | " data = self.preprocess(fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'),\n", 267 | " categories=classes))\n", 268 | "\n", 269 | " # For every unprocessed_sample in the created set, process it\n", 270 | " for indx, unprocessed_sample in tqdm.tqdm(enumerate(data)):\n", 271 | " sample = self.process_line(unprocessed_sample, word_to_indx, max_length)\n", 272 | " \n", 273 | " # If the sample is not empty, save it and add its y to the class_balance dictionary\n", 274 | " if sample['text'] != '':\n", 275 | " if not sample['y'] in class_balance:\n", 276 | " class_balance[sample['y']] = 0\n", 277 | " class_balance[sample['y']] += 1\n", 278 | " self.dataset.append(sample)\n", 279 | "\n", 280 | " \n", 281 | " def preprocess(self, data):\n", 282 | " '''\n", 283 | " Return a list of (text, label and label_name)\n", 284 | "\n", 285 | " :param data: 20 newsgroup dataset as imported by SK-Learn\n", 286 | " \n", 287 | " :return processed_data: list of text, label and label_name\n", 288 | " '''\n", 289 | " processed_data = []\n", 290 | " for indx, sample in enumerate(data['data']):\n", 291 | " text, label = sample, data['target'][indx]\n", 292 | " label_name = data['target_names'][label]\n", 293 | " text = re.sub('\\W+', ' ', text).lower().strip()\n", 294 | " processed_data.append((text, label, label_name))\n", 295 | " return processed_data\n", 296 | "\n", 297 | " \n", 298 | " def get_indices_tensor(self, text_arr, word_to_indx, max_length):\n", 299 | " '''\n", 300 | " Return a tensor of max_length with the word indices\n", 301 | " \n", 302 | " :param text_arr: text array\n", 303 | " :param word_to_indx: dictionary word:index\n", 304 | " :param max_length: maximum length of returned tensors\n", 305 | " \n", 306 | " :return x: tensor containing the indices\n", 307 | " '''\n", 308 | " \n", 309 | " pad_indx = 0\n", 310 | " text_indx = [word_to_indx[x] if x in word_to_indx else pad_indx for x in text_arr][:max_length]\n", 311 | " \n", 312 | " # Padding\n", 313 | " if len(text_indx) < max_length:\n", 314 | " text_indx.extend([pad_indx for _ in range(max_length - len(text_indx))])\n", 315 | "\n", 316 | " x = torch.LongTensor([text_indx])\n", 317 | "\n", 318 | " return x\n", 319 | "\n", 320 | "\n", 321 | " def process_line(self, row, word_to_indx, max_length, case_insensitive=True):\n", 322 | " '''\n", 323 | " Return every line as a dictionary with text, x, y, y_name\n", 324 | "\n", 325 | " :param row: document (or comment)\n", 326 | " :param word_to_indx: dictionary of word:index\n", 327 | " :param max_length: integer with max word to consider\n", 328 | " \n", 329 | " :return sample: dictionary of text, x, y, y_name\n", 330 | " '''\n", 331 | " \n", 332 | " text, label, label_name = row\n", 333 | " \n", 334 | " if case_insensitive:\n", 335 | " text = \" \".join(text.split()[:max_length]).lower()\n", 336 | " else:\n", 337 | " text = \" \".join(text.split()[:max_length])\n", 338 | " \n", 339 | " x = self.get_indices_tensor(text.split(), word_to_indx, max_length)\n", 340 | " \n", 341 | " sample = {'text':text,'x':x, 'y':label, 'y_name': label_name}\n", 342 | " return sample" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": {}, 348 | "source": [ 349 | "The class AbstractDataset adds general method to the Newsgroup dataset.\n", 350 | "\n", 351 | "The class Newsgroup loads the dataset and process it, turning every line of it in a dictionary with the following keys:\n", 352 | "\n", 353 | "- text: the text of the comment\n", 354 | "- x: tensor containing the indices of the words in text\n", 355 | "- y: label (an integer)\n", 356 | "- y_name: name of the label\n", 357 | "\n", 358 | "Below we load the dataset in the train, dev and test sets and we print one sample." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 4, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "name": "stderr", 368 | "output_type": "stream", 369 | "text": [ 370 | "9051it [00:00, 10572.29it/s]\n", 371 | "2263it [00:00, 7198.71it/s]\n", 372 | "7532it [00:02, 3131.95it/s]" 373 | ] 374 | }, 375 | { 376 | "name": "stdout", 377 | "output_type": "stream", 378 | "text": [ 379 | "{'y': 19, 'text': u'dr england s story deleted it was a nice read the first time through it isn t so much a matter of interpretation of bible texts that sets mormonism apart from orthodoxy as it is a matter of fabrication about 20 years ago _national lampoon_ had some comic strips in them that were drawn by neal adams they were called son o god comics it was a parody of the jesus in the bible in the comic there were a', 'y_name': 'talk.religion.misc', 'x': tensor([[ 6457, 564, 1535, 524, 16202, 21, 16, 8, 3083, 1466,\n", 380 | " 1, 59, 80, 132, 21, 75361, 2160, 101, 182, 8,\n", 381 | " 1121, 4, 6513, 4, 5490, 8239, 13, 2304, 51747, 2726,\n", 382 | " 26, 22887, 20, 21, 15, 8, 1121, 4, 22078, 60,\n", 383 | " 325, 83, 364, 0, 0, 41, 78, 4250, 11393, 7,\n", 384 | " 102, 13, 36, 2572, 22, 13824, 4127, 40, 36, 176,\n", 385 | " 631, 4869, 1534, 6109, 21, 16, 8, 13302, 4, 1,\n", 386 | " 3994, 7, 1, 5490, 7, 1, 4250, 64, 36, 8]])}\n" 387 | ] 388 | }, 389 | { 390 | "name": "stderr", 391 | "output_type": "stream", 392 | "text": [ 393 | "\n" 394 | ] 395 | } 396 | ], 397 | "source": [ 398 | "# Loading the dataset\n", 399 | "train = Newsgroup('train', classes, word_to_indx, class_balance_true=True, max_length=80)\n", 400 | "dev = Newsgroup('dev', classes, word_to_indx, class_balance_true=True, max_length=80)\n", 401 | "test = Newsgroup('test', classes, word_to_indx, class_balance_true=True, max_length=80)\n", 402 | "\n", 403 | "# Printing 3 datapoints\n", 404 | "for datapoint in train[:1]:\n", 405 | " print(datapoint)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "## 5.0 Define the Hyperparameters\n", 413 | "\n", 414 | "Whenever we train a neural network, a large set of hyperparameters need to be defined and tuned. Such parameters are generally tuned looking at the performance on the development set.\n", 415 | "\n", 416 | "In this section we define the default arguments. We will see in our experiments that such parameters are already good enough to obtain high accuracy on the Newsgroup dataset." 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 5, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "# Set the parameters\n", 426 | "\n", 427 | "args = {'train':True, 'test':False, 'cuda':False, 'class_balance':False,\n", 428 | " 'init_lr':0.001, 'epochs':4, 'batch_size':128, 'patience':10,\n", 429 | " 'save_dir':'snapshot', 'model_path':'model.pt', 'results_path':'snapshot/results.txt', 'model':'TextCNN',\n", 430 | " 'hidden_dims':100, 'num_layers':1, 'dropout':0.1, 'weight_decay':1e-3,\n", 431 | " 'filter_num':100, 'filters':[3, 4, 5], 'num_class':20, 'emb_dims':300,\n", 432 | " 'tuning_metric':'loss', 'num_workers':4, 'objective':'cross_entropy'}\n", 433 | "\n", 434 | "#'gumbel_temprature':1, 'gumbel_decay':1e-5,'tag_lambda':.5" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "## 6.0 Defining the Model\n", 442 | "\n", 443 | "In this section, we see how to create a Convolutional Neural Network for text classification. We do not intend here to discuss the theory behind CNNs, as the reader can easily find sources online for it (a nice tutorial can be found here: http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/). We would instead propose the commented code below.\n", 444 | "\n", 445 | "Our implementation is organized in two classes:\n", 446 | "- one is the Encoder, which loads the embeddings, calls the model and returns the logits for the output classes;\n", 447 | "- the other is the model, implemented as a TextCNN, which takes in input a three dimensional tensor (batch times word_number times emb_dimensions) and returns the activation." 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 6, 453 | "metadata": { 454 | "collapsed": true 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "# Defining the Encoder and the Model classes\n", 459 | "\n", 460 | "import pdb\n", 461 | "import torch\n", 462 | "import torch.nn as nn\n", 463 | "import torch.autograd as autograd\n", 464 | "import torch.nn.functional as F\n", 465 | "\n", 466 | "\n", 467 | "# Encoder\n", 468 | "class Encoder(nn.Module):\n", 469 | " '''\n", 470 | " Load the embeddings and encode them\n", 471 | " '''\n", 472 | "\n", 473 | " def __init__(self, embeddings, args):\n", 474 | " '''\n", 475 | " Load embeddings and call the TextCNN model\n", 476 | " \n", 477 | " :param embeddings: tensor with word embeddings\n", 478 | " :param model: default is 'TextCNN'\n", 479 | " \n", 480 | " :return: nothing\n", 481 | " '''\n", 482 | " super(Encoder, self).__init__()\n", 483 | " \n", 484 | " # Saving the parameters\n", 485 | " self.model = args['model']\n", 486 | " self.num_class = args['num_class']\n", 487 | " self.hidden_dims = args['hidden_dims']\n", 488 | " self.num_layers = args['num_layers']\n", 489 | " self.filters = args['filters']\n", 490 | " self.filter_num = args['filter_num']\n", 491 | " self.cuda = args['cuda']\n", 492 | " self.dropout = args['dropout']\n", 493 | " \n", 494 | " # Loading the word embeddings in the Neural Network\n", 495 | " vocab_size, hidden_dim = embeddings.shape\n", 496 | " self.emb_dims = hidden_dim\n", 497 | " self.emb_layer = nn.Embedding(vocab_size, hidden_dim)\n", 498 | " self.emb_layer.weight.data = torch.from_numpy(embeddings)\n", 499 | " self.emb_layer.weight.requires_grad = True\n", 500 | " self.emb_fc = nn.Linear(hidden_dim, hidden_dim)\n", 501 | " self.emb_bn = nn.BatchNorm1d(hidden_dim)\n", 502 | " \n", 503 | " # Calling the model, followed by a fully connected hidden layer\n", 504 | " if self.model == 'TextCNN':\n", 505 | " self.cnn = TextCNN(args, max_pool_over_time=True)\n", 506 | " # The hidden fully connected layer size is given by the number of filters\n", 507 | " # times the filter size, by the number of hidden dimensions\n", 508 | " self.fc = nn.Linear(len(self.filters) * self.filter_num, hidden_dim)\n", 509 | " else:\n", 510 | " raise NotImplementedError(\"Model {} not yet supported for encoder!\".format(model))\n", 511 | "\n", 512 | " # Dropout and final layer\n", 513 | " self.dropout = nn.Dropout(self.dropout)\n", 514 | " self.hidden = nn.Linear(hidden_dim, self.num_class)\n", 515 | " \n", 516 | " \n", 517 | " def forward(self, x_indx):\n", 518 | " '''\n", 519 | " Forward step\n", 520 | " \n", 521 | " :param x_indx: batch of word indices\n", 522 | " \n", 523 | " :return logit: predictions\n", 524 | " :return: hidden layer\n", 525 | " '''\n", 526 | " \n", 527 | " x = self.emb_layer(x_indx.squeeze(1))\n", 528 | " if self.cuda:\n", 529 | " x = x.cuda()\n", 530 | " \n", 531 | " # Non linear projection with dropout\n", 532 | " x = F.relu(self.emb_fc(x))\n", 533 | " x = self.dropout(x)\n", 534 | " # TextNN, fully connected and non linearity\n", 535 | " if self.model == 'TextCNN':\n", 536 | " x = torch.transpose(x, 1, 2) # Transpose x dimensions into (Batch, Emb, Length)\n", 537 | " hidden = self.cnn(x)\n", 538 | " hidden = F.relu(self.fc(hidden))\n", 539 | " else:\n", 540 | " raise Exception(\"Model {} not yet supported for encoder!\".format(self.model))\n", 541 | "\n", 542 | " # Dropout and final layer\n", 543 | " hidden = self.dropout(hidden)\n", 544 | " logit = self.hidden(hidden)\n", 545 | " return logit, hidden\n", 546 | "\n", 547 | "\n", 548 | "# Model\n", 549 | "class TextCNN(nn.Module):\n", 550 | " '''\n", 551 | " CNN for Text Classification\n", 552 | " '''\n", 553 | "\n", 554 | " def __init__(self, args, max_pool_over_time=False):\n", 555 | " '''\n", 556 | " Convolutional Neural Network\n", 557 | " \n", 558 | " :param num_layers: number of layers\n", 559 | " :param filters: filters shape\n", 560 | " :param filter_num: number of filters\n", 561 | " :param emb_dims: embedding dimensions\n", 562 | " :param max_pool_over_time: boolean\n", 563 | " \n", 564 | " :return: nothing\n", 565 | " '''\n", 566 | " super(TextCNN, self).__init__()\n", 567 | "\n", 568 | " # Saving the parameters\n", 569 | " self.num_layers = args['num_layers']\n", 570 | " self.filters = args['filters']\n", 571 | " self.filter_num = args['filter_num']\n", 572 | " self.emb_dims = args['emb_dims']\n", 573 | " self.cuda = args['cuda']\n", 574 | " self.max_pool = max_pool_over_time\n", 575 | " \n", 576 | " self.layers = []\n", 577 | " \n", 578 | " # For every layer...\n", 579 | " for l in range(self.num_layers):\n", 580 | " convs = []\n", 581 | " \n", 582 | " # For every filter...\n", 583 | " for f in self.filters:\n", 584 | " # Defining the sizes\n", 585 | " in_channels = self.emb_dims if l == 0 else self.filter_num * len(self.filters)\n", 586 | " kernel_size = f\n", 587 | " \n", 588 | " # Adding the convolutions in the list\n", 589 | " conv = nn.Conv1d(in_channels=in_channels, out_channels=self.filter_num, kernel_size=kernel_size)\n", 590 | " self.add_module('layer_' + str(l) + '_conv_' + str(f), conv)\n", 591 | " convs.append(conv)\n", 592 | " \n", 593 | " self.layers.append(convs)\n", 594 | "\n", 595 | "\n", 596 | " def _conv(self, x):\n", 597 | " '''\n", 598 | " Left padding and returning the activation\n", 599 | " \n", 600 | " :param x: input tensor (batch, emb, length)\n", 601 | " :return layer_activ: activation\n", 602 | " '''\n", 603 | " \n", 604 | " layer_activ = x\n", 605 | " \n", 606 | " for layer in self.layers:\n", 607 | " next_activ = []\n", 608 | " \n", 609 | " for conv in layer:\n", 610 | " # Setting the padding dimensions: it is like adding\n", 611 | " # kernel_size - 1 empty embeddings\n", 612 | " left_pad = conv.kernel_size[0] - 1\n", 613 | " pad_tensor_size = [d for d in layer_activ.size()]\n", 614 | " pad_tensor_size[2] = left_pad\n", 615 | " left_pad_tensor = autograd.Variable(torch.zeros(pad_tensor_size))\n", 616 | " \n", 617 | " if self.cuda:\n", 618 | " left_pad_tensor = left_pad_tensor.cuda()\n", 619 | " \n", 620 | " # Concatenating the padding to the tensor\n", 621 | " padded_activ = torch.cat((left_pad_tensor, layer_activ), dim=2)\n", 622 | " \n", 623 | " # onvolution activation\n", 624 | " next_activ.append(conv(padded_activ))\n", 625 | "\n", 626 | " # Concatenating accross channels\n", 627 | " layer_activ = F.relu(torch.cat(next_activ, 1))\n", 628 | " #pdb.set_trace()\n", 629 | " return layer_activ\n", 630 | "\n", 631 | "\n", 632 | " def _pool(self, relu):\n", 633 | " '''\n", 634 | " Max Pool Over Time\n", 635 | " '''\n", 636 | " \n", 637 | " pool = F.max_pool1d(relu, relu.size(2)).squeeze(-1)\n", 638 | " return pool\n", 639 | "\n", 640 | "\n", 641 | " def forward(self, x):\n", 642 | " '''\n", 643 | " Forward steps over the x\n", 644 | " \n", 645 | " :param x: input (batch, emb, length)\n", 646 | "\n", 647 | " :return activ: activation\n", 648 | " '''\n", 649 | " \n", 650 | " activ = self._conv(x)\n", 651 | " \n", 652 | " # Pooling over time?\n", 653 | " if self.max_pool:\n", 654 | " activ = self._pool(activ)\n", 655 | " \n", 656 | " return activ" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 7, 662 | "metadata": {}, 663 | "outputs": [ 664 | { 665 | "name": "stdout", 666 | "output_type": "stream", 667 | "text": [ 668 | "Output logits for the first (randomly sorted) element of the dataset:\n", 669 | "\n", 670 | "\n", 671 | "tensor([[ 0.0795, 0.1743, 0.0644, 0.0376, 0.0002, 0.0198, 0.0035, 0.0052,\n", 672 | " -0.0534, -0.0350, 0.0040, 0.0520, 0.1227, -0.0657, 0.0526, -0.0068,\n", 673 | " 0.0724, -0.0000, 0.0232, 0.0251]], grad_fn=)\n" 674 | ] 675 | } 676 | ], 677 | "source": [ 678 | "# Creating the encoder and TextCNN, and printing an output from a random input\n", 679 | "\n", 680 | "encoder = Encoder(emb_tensor, args) \n", 681 | "\n", 682 | "print(\"Output logits for the first (randomly sorted) element of the dataset:\\n\\n\")\n", 683 | "print(encoder(train[0]['x'])[0])" 684 | ] 685 | }, 686 | { 687 | "cell_type": "markdown", 688 | "metadata": { 689 | "collapsed": true 690 | }, 691 | "source": [ 692 | "## 7.0 Train the Model\n", 693 | "\n", 694 | "After loading the word embeddings and the dataset, we defined the model and the encoder. At this point, it remains to train the system and finally to evaluate it.\n", 695 | "\n", 696 | "The training code is relatively complicated, so we split it into utilities and core functions. Every function is properly commented, and we hope the reader can easily understand their goal." 697 | ] 698 | }, 699 | { 700 | "cell_type": "markdown", 701 | "metadata": {}, 702 | "source": [ 703 | "### 7.1 Utilities\n", 704 | "\n", 705 | "All the functions listed below are of support for the core training functions implemented in the next section." 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": 8, 711 | "metadata": { 712 | "collapsed": true 713 | }, 714 | "outputs": [], 715 | "source": [ 716 | "# Train the model\n", 717 | "import sklearn.metrics\n", 718 | "import sys, os\n", 719 | "\n", 720 | "def get_optimizer(models, args):\n", 721 | " '''\n", 722 | " Save the parameters of every model in models and pass them to\n", 723 | " Adam optimizer.\n", 724 | " \n", 725 | " :param models: list of models (such as TextCNN, etc.)\n", 726 | " :param args: arguments\n", 727 | " \n", 728 | " :return: torch optimizer over models\n", 729 | " '''\n", 730 | " params = []\n", 731 | " for model in models:\n", 732 | " params.extend([param for param in model.parameters() if param.requires_grad])\n", 733 | " return torch.optim.Adam(params, lr=args['lr'], weight_decay=args['weight_decay'])\n", 734 | "\n", 735 | "\n", 736 | "def init_metrics_dictionary(modes):\n", 737 | " '''\n", 738 | " Create dictionary with empty array for each metric in each mode\n", 739 | " \n", 740 | " :param modes: list with either train, dev or test\n", 741 | " \n", 742 | " :return epoch_stats: statistics for a given epoch\n", 743 | " '''\n", 744 | " epoch_stats = {}\n", 745 | " metrics = ['loss', 'obj_loss', 'k_selection_loss', 'k_continuity_loss',\n", 746 | " 'accuracy', 'precision', 'recall', 'f1', 'confusion_matrix', 'mse']\n", 747 | " for metric in metrics:\n", 748 | " for mode in modes:\n", 749 | " key = \"{}_{}\".format(mode, metric)\n", 750 | " epoch_stats[key] = []\n", 751 | " return epoch_stats\n", 752 | "\n", 753 | "\n", 754 | "def get_train_loader(train_data, args):\n", 755 | " '''\n", 756 | " Iterative train loader with sampler and replacer if class_balance\n", 757 | " is true, normal otherwise.\n", 758 | " \n", 759 | " :param train_data: training data\n", 760 | " :param args: arguments\n", 761 | " \n", 762 | " :return train_loader: iterable training set\n", 763 | " '''\n", 764 | " \n", 765 | " if args['class_balance']:\n", 766 | " # If the class_balance is true: sample and replace\n", 767 | " sampler = data.sampler.WeightedRandomSampler(\n", 768 | " weights=train_data.weights,\n", 769 | " num_samples=len(train_data),\n", 770 | " replacement=True)\n", 771 | " train_loader = data.DataLoader(\n", 772 | " train_data,\n", 773 | " num_workers=args['num_workers'],\n", 774 | " sampler=sampler,\n", 775 | " batch_size=args['batch_size'])\n", 776 | " else:\n", 777 | " # If the class_balance is false, do not sample\n", 778 | " train_loader = data.DataLoader(\n", 779 | " train_data,\n", 780 | " batch_size=args['batch_size'],\n", 781 | " shuffle=True,\n", 782 | " num_workers=args['num_workers'],\n", 783 | " drop_last=False)\n", 784 | " return train_loader\n", 785 | "\n", 786 | "\n", 787 | "def get_dev_loader(dev_data, args):\n", 788 | " '''\n", 789 | " Iterative dev loader\n", 790 | " \n", 791 | " :param dev_data: dev set\n", 792 | " :param args: arguments\n", 793 | " \n", 794 | " :return dev_loader: iterative dev set\n", 795 | " '''\n", 796 | " \n", 797 | " dev_loader = data.DataLoader(\n", 798 | " dev_data,\n", 799 | " batch_size=args['batch_size'],\n", 800 | " shuffle=False,\n", 801 | " num_workers=args['num_workers'],\n", 802 | " drop_last=False)\n", 803 | " return dev_loader\n", 804 | "\n", 805 | "\n", 806 | "def get_x_indx(batch, eval_model):\n", 807 | " '''\n", 808 | " Given a batch, return all the x\n", 809 | " \n", 810 | " :param batch: batch of dictionaries\n", 811 | " :param eval_model: true or false, for volatile\n", 812 | " \n", 813 | " :return x_indx: tensor of batch*x\n", 814 | " '''\n", 815 | " \n", 816 | " x_indx = autograd.Variable(batch['x'], volatile=eval_model)\n", 817 | " return x_indx\n", 818 | "\n", 819 | "\n", 820 | "def get_loss(logit, y, args):\n", 821 | " '''\n", 822 | " Return the cross entropy or mse loss\n", 823 | " \n", 824 | " :param logit: predictions\n", 825 | " :param y: gold standard\n", 826 | " :param args: arguments\n", 827 | " \n", 828 | " :return loss: loss\n", 829 | " '''\n", 830 | " \n", 831 | " if args['objective'] == 'cross_entropy':\n", 832 | " loss = F.cross_entropy(logit, y)\n", 833 | " elif args['objective'] == 'mse':\n", 834 | " loss = F.mse_loss(logit, y.float())\n", 835 | " else:\n", 836 | " raise Exception(\"Objective {} not supported!\".format(args['objective']))\n", 837 | " return loss\n", 838 | "\n", 839 | "\n", 840 | "def tensor_to_numpy(tensor):\n", 841 | " '''\n", 842 | " Return a numpy matrix from a tensor\n", 843 | "\n", 844 | " :param tensor: tensor\n", 845 | " \n", 846 | " :return numpy_matrix: numpy matrix\n", 847 | " '''\n", 848 | " return tensor.data[0]\n", 849 | "\n", 850 | "\n", 851 | "def get_metrics(preds, golds, args):\n", 852 | " '''\n", 853 | " Return the metrics given predictions and golds\n", 854 | " \n", 855 | " :param preds: list of predictions\n", 856 | " :param golds: list of golds\n", 857 | " :param args: arguments\n", 858 | " \n", 859 | " :return metrics: metrics dictionary\n", 860 | " '''\n", 861 | " metrics = {}\n", 862 | "\n", 863 | " if args['objective'] in ['cross_entropy', 'margin']:\n", 864 | " metrics['accuracy'] = sklearn.metrics.accuracy_score(y_true=golds, y_pred=preds)\n", 865 | " metrics['confusion_matrix'] = sklearn.metrics.confusion_matrix(y_true=golds,y_pred=preds)\n", 866 | " metrics['precision'] = sklearn.metrics.precision_score(y_true=golds, y_pred=preds, average=\"weighted\")\n", 867 | " metrics['recall'] = sklearn.metrics.recall_score(y_true=golds,y_pred=preds, average=\"weighted\")\n", 868 | " metrics['f1'] = sklearn.metrics.f1_score(y_true=golds,y_pred=preds, average=\"weighted\")\n", 869 | " metrics['mse'] = \"NA\"\n", 870 | " elif args['objective'] == 'mse':\n", 871 | " metrics['mse'] = sklearn.metrics.mean_squared_error(y_true=golds, y_pred=preds)\n", 872 | " metrics['confusion_matrix'] = \"NA\"\n", 873 | " metrics['accuracy'] = \"NA\"\n", 874 | " metrics['precision'] = \"NA\"\n", 875 | " metrics['recall'] = \"NA\"\n", 876 | " metrics['f1'] = 'NA'\n", 877 | " return metrics\n", 878 | "\n", 879 | "\n", 880 | "def collate_epoch_stat(stat_dict, epoch_details, mode, args):\n", 881 | " '''\n", 882 | " Update stat_dict with details from epoch_details and create\n", 883 | " log statement\n", 884 | "\n", 885 | " :param stat_dict: a dictionary of statistics lists to update\n", 886 | " :param epoch_details: list of statistics for a given epoch\n", 887 | " :param mode: train, dev or test\n", 888 | " :param args: model run configuration\n", 889 | "\n", 890 | " :return stat_dict: updated stat_dict with epoch details\n", 891 | " :return log_statement: log statement sumarizing new epoch\n", 892 | "\n", 893 | " '''\n", 894 | " log_statement_details = ''\n", 895 | " for metric in epoch_details:\n", 896 | " loss = epoch_details[metric]\n", 897 | " stat_dict['{}_{}'.format(mode, metric)].append(loss)\n", 898 | "\n", 899 | " log_statement_details += ' -{}: {}'.format(metric, loss)\n", 900 | "\n", 901 | " log_statement = '\\n {} - {}\\n--'.format(args['objective'], log_statement_details )\n", 902 | "\n", 903 | " return stat_dict, log_statement" 904 | ] 905 | }, 906 | { 907 | "cell_type": "markdown", 908 | "metadata": {}, 909 | "source": [ 910 | "### 7.2 Core Functions\n", 911 | "\n", 912 | "Below we present the core functions for the training." 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": 9, 918 | "metadata": { 919 | "collapsed": true 920 | }, 921 | "outputs": [], 922 | "source": [ 923 | "# Run each epoch\n", 924 | "def run_epoch(data_loader, train_model, model, optimizer, step, args):\n", 925 | " '''\n", 926 | " Train model for one pass of train data, and return loss, acccuracy\n", 927 | " \n", 928 | " :param data_loader: iterable dataset\n", 929 | " :param train_model: true if training, false otherwise\n", 930 | " :param model: text classifier, such as TextCNN\n", 931 | " :param optimizer: Adam\n", 932 | " :param args: arguments\n", 933 | " \n", 934 | " :return epoch_stat:\n", 935 | " :return step: number of steps\n", 936 | " :return losses: list of losses\n", 937 | " :return preds: list of predictions\n", 938 | " :return golds: list of gold standards\n", 939 | " '''\n", 940 | " \n", 941 | " eval_model = not train_model\n", 942 | " data_iter = data_loader.__iter__()\n", 943 | "\n", 944 | " losses = []\n", 945 | " obj_losses = []\n", 946 | " \n", 947 | " preds = []\n", 948 | " golds = []\n", 949 | " texts = []\n", 950 | "\n", 951 | " if train_model:\n", 952 | " model.train()\n", 953 | " else:\n", 954 | " model.eval()\n", 955 | "\n", 956 | " num_batches_per_epoch = len(data_iter)\n", 957 | " if train_model:\n", 958 | " num_batches_per_epoch = min(len(data_iter), 10000)\n", 959 | "\n", 960 | " for _ in tqdm.tqdm(range(num_batches_per_epoch)):\n", 961 | " # Get the batch\n", 962 | " batch = data_iter.next()\n", 963 | " \n", 964 | " if train_model:\n", 965 | " step += 1\n", 966 | " #if step % 100 == 0:\n", 967 | " # args['gumbel_temprature'] = max(np.exp((step+1) * -1 * args['gumbel_decay']), .05)\n", 968 | "\n", 969 | " # Load X and Y\n", 970 | " x_indx = get_x_indx(batch, eval_model)\n", 971 | " text = batch['text']\n", 972 | " y = autograd.Variable(batch['y'], volatile=eval_model)\n", 973 | "\n", 974 | " if args['cuda']:\n", 975 | " x_indx, y = x_indx.cuda(), y.cuda()\n", 976 | "\n", 977 | " if train_model:\n", 978 | " optimizer.zero_grad()\n", 979 | "\n", 980 | " logit, _ = model(x_indx)\n", 981 | "\n", 982 | " # Calculate the loss\n", 983 | " loss = get_loss(logit, y, args)\n", 984 | " obj_loss = loss\n", 985 | "\n", 986 | " # Backward step\n", 987 | " if train_model:\n", 988 | " loss.backward()\n", 989 | " optimizer.step()\n", 990 | "\n", 991 | " # Saving loss\n", 992 | " obj_losses.append(tensor_to_numpy(obj_loss))\n", 993 | " losses.append(tensor_to_numpy(loss))\n", 994 | " \n", 995 | " # Softmax, preds, text and gold\n", 996 | " batch_softmax = F.softmax(logit, dim=-1).cpu()\n", 997 | " preds.extend(torch.max(batch_softmax, 1)[1].view(y.size()).data.numpy())\n", 998 | " texts.extend(text)\n", 999 | " golds.extend(batch['y'].numpy())\n", 1000 | "\n", 1001 | " # Get metrics\n", 1002 | " epoch_metrics = get_metrics(preds, golds, args)\n", 1003 | " epoch_stat = {'loss' : np.mean(losses), 'obj_loss': np.mean(obj_losses)}\n", 1004 | "\n", 1005 | " for metric_k in epoch_metrics.keys():\n", 1006 | " epoch_stat[metric_k] = epoch_metrics[metric_k]\n", 1007 | "\n", 1008 | " return epoch_stat, step, losses, preds, golds\n", 1009 | "\n", 1010 | "\n", 1011 | "def train_model(train_data, dev_data, model, args):\n", 1012 | " '''\n", 1013 | " Train model on the training and tune it on the dev set.\n", 1014 | " \n", 1015 | " If model does not improve dev performance within patience\n", 1016 | " epochs, best model is restored and the learning rate halved\n", 1017 | " to continue training.\n", 1018 | "\n", 1019 | " At the end of training, the function will restore the best model\n", 1020 | " on the dev set.\n", 1021 | "\n", 1022 | " :param train_data: preprocessed data\n", 1023 | " :param dev_data: preprocessed data\n", 1024 | " :param models: models to be used for text classification\n", 1025 | " :param args: hyperparameters\n", 1026 | " \n", 1027 | " :return epoch_stats: a dictionary of metrics for train and dev\n", 1028 | " :return model: best model\n", 1029 | " '''\n", 1030 | " \n", 1031 | " snapshot = '{}'.format(os.path.join(args['save_dir'], args['model_path']))\n", 1032 | "\n", 1033 | " if args['cuda']:\n", 1034 | " model = model.cuda()\n", 1035 | "\n", 1036 | " args['lr'] = args['init_lr']\n", 1037 | " optimizer = get_optimizer([model], args)\n", 1038 | "\n", 1039 | " num_epoch_sans_improvement = 0\n", 1040 | " epoch_stats = init_metrics_dictionary(modes=['train', 'dev'])\n", 1041 | " step = 0\n", 1042 | " tuning_key = \"dev_{}\".format(args['tuning_metric'])\n", 1043 | " best_epoch_func = min if tuning_key == 'loss' else max\n", 1044 | "\n", 1045 | " train_loader = get_train_loader(train_data, args)\n", 1046 | " dev_loader = get_dev_loader(dev_data, args)\n", 1047 | "\n", 1048 | " # For every epoch...\n", 1049 | " for epoch in range(1, args['epochs'] + 1):\n", 1050 | " print(\"-------------\\nEpoch {}:\\n\".format(epoch))\n", 1051 | " \n", 1052 | " # Load the training and dev sets...\n", 1053 | " for mode, dataset, loader in [('Train', train_data, train_loader),\n", 1054 | " ('Dev', dev_data, dev_loader)]:\n", 1055 | " \n", 1056 | " train_model = mode == 'Train'\n", 1057 | " print('{}'.format(mode))\n", 1058 | " key_prefix = mode.lower()\n", 1059 | " epoch_details, step, _, _, _ = run_epoch(data_loader=loader, train_model=train_model, model=model,\n", 1060 | " optimizer=optimizer, step=step, args=args)\n", 1061 | " \n", 1062 | " epoch_stats, log_statement = collate_epoch_stat(epoch_stats, epoch_details, key_prefix, args)\n", 1063 | " \n", 1064 | " # Log performance\n", 1065 | " print(log_statement)\n", 1066 | "\n", 1067 | " # Save model if beats best dev\n", 1068 | " best_func = min if args['tuning_metric'] == 'loss' else max\n", 1069 | " if best_func(epoch_stats[tuning_key]) == epoch_stats[tuning_key][-1]:\n", 1070 | " num_epoch_sans_improvement = 0\n", 1071 | " if not os.path.isdir(args['save_dir']):\n", 1072 | " os.makedirs(args['save_dir'])\n", 1073 | " # Subtract one because epoch is 1-indexed and arr is 0-indexed\n", 1074 | " epoch_stats['best_epoch'] = epoch - 1\n", 1075 | " torch.save(model, snapshot)\n", 1076 | " else:\n", 1077 | " num_epoch_sans_improvement += 1\n", 1078 | "\n", 1079 | " if not train_model:\n", 1080 | " print('---- Best Dev {} is {:.4f} at epoch {}'.format(\n", 1081 | " args['tuning_metric'], epoch_stats[tuning_key][epoch_stats['best_epoch']],\n", 1082 | " epoch_stats['best_epoch'] + 1))\n", 1083 | "\n", 1084 | " # If the number of epochs without improvements is high, reduce the learning rate\n", 1085 | " if num_epoch_sans_improvement >= args['patience']:\n", 1086 | " print(\"Reducing learning rate\")\n", 1087 | " num_epoch_sans_improvement = 0\n", 1088 | " model.cpu()\n", 1089 | " model = torch.load(snapshot)\n", 1090 | "\n", 1091 | " if args['cuda']:\n", 1092 | " model = model.cuda()\n", 1093 | " args['lr'] *= .5\n", 1094 | " optimizer = get_optimizer([model], args)\n", 1095 | "\n", 1096 | " # Restore model to best dev performance\n", 1097 | " if os.path.exists(args['model_path']):\n", 1098 | " model.cpu()\n", 1099 | " model = torch.load(snapshot)\n", 1100 | "\n", 1101 | " return epoch_stats, model" 1102 | ] 1103 | }, 1104 | { 1105 | "cell_type": "markdown", 1106 | "metadata": {}, 1107 | "source": [ 1108 | "Let's start to train the model." 1109 | ] 1110 | }, 1111 | { 1112 | "cell_type": "code", 1113 | "execution_count": 10, 1114 | "metadata": {}, 1115 | "outputs": [ 1116 | { 1117 | "name": "stdout", 1118 | "output_type": "stream", 1119 | "text": [ 1120 | "-------------\n", 1121 | "Epoch 1:\n", 1122 | "\n", 1123 | "Train\n" 1124 | ] 1125 | }, 1126 | { 1127 | "name": "stderr", 1128 | "output_type": "stream", 1129 | "text": [ 1130 | " 0%| | 0/69 [00:00