├── modals ├── __init__.py ├── losses.py ├── data_util.py ├── augmentation_transforms.py ├── augmentation_transforms_cpu.py ├── setup.py ├── custom_ops.py ├── policy.py └── trainer.py ├── train.py ├── LICENSE ├── scripts ├── search.sh └── train.sh ├── requirements.txt ├── networks └── blstm.py ├── README.md ├── .gitignore ├── search.py ├── schedule ├── policy_sst2.txt └── policy_trec.txt └── utility.py /modals/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from modals.trainer import TextModelTrainer 2 | from modals.setup import create_parser, create_hparams 3 | 4 | 5 | def main(FLAGS, hparams): 6 | start_epoch = 0 7 | trainer = TextModelTrainer(hparams, FLAGS.name) 8 | 9 | if FLAGS.restore is not None: 10 | start_epoch, _ = trainer.load_model(FLAGS.restore) 11 | 12 | for e in range(start_epoch+1, hparams['num_epochs']+1): 13 | trainer.run_model(e) 14 | 15 | if e % 20 == 0: 16 | # print(hparams) 17 | trainer.save_checkpoint(hparams['checkpoint_dir'], e) 18 | trainer._test(e, 'test') 19 | 20 | trainer.save_checkpoint(hparams['checkpoint_dir'], e) 21 | trainer._test(hparams['num_epochs'], 'test') 22 | 23 | 24 | if __name__ == "__main__": 25 | FLAGS = create_parser('train') 26 | hparams = create_hparams('train', FLAGS) 27 | main(FLAGS, hparams) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 jamestszhim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/search.sh: -------------------------------------------------------------------------------- 1 | if [[ $1 = "trec" ]]; then 2 | CUDA_VISIBLE_DEVICES=3 \ 3 | python search.py \ 4 | --model_name blstm \ 5 | --dataset trec \ 6 | --valid_size 0 \ 7 | --epochs 60 \ 8 | --gpu 0.15 --cpu 2 \ 9 | --num_samples 16 --perturbation_interval 3 \ 10 | --ray_name ray_experiment_trec \ 11 | --distance_metric loss \ 12 | --metric_learning \ 13 | --metric_loss random \ 14 | --metric_weight 0.01 \ 15 | --metric_margin 0.5 \ 16 | --enforce_prior \ 17 | --prior_weight 1 18 | elif [[ $1 = "sst2" ]]; then 19 | CUDA_VISIBLE_DEVICES=3 \ 20 | python search.py \ 21 | --model_name blstm \ 22 | --dataset sst2 \ 23 | --valid_size 0 \ 24 | --epochs 60 \ 25 | --gpu 0.15 --cpu 2 \ 26 | --num_samples 16 --perturbation_interval 3 \ 27 | --ray_name ray_experiment_sst2 \ 28 | --distance_metric loss \ 29 | --metric_learning \ 30 | --metric_loss random \ 31 | --metric_weight 0.03 \ 32 | --metric_margin 2 \ 33 | --enforce_prior \ 34 | --prior_weight 1 35 | fi 36 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | if [[ $1 = "trec" ]]; then 2 | CUDA_VISIBLE_DEVICES=3 \ 3 | python -u train.py \ 4 | --model_name blstm \ 5 | --dataset trec \ 6 | --valid_size 0 \ 7 | --subtrain_ratio 0.1 \ 8 | --policy_epochs 100 \ 9 | --epochs 6 \ 10 | --name trec_model \ 11 | --use_modals \ 12 | --temperature 1 \ 13 | --distance_metric loss \ 14 | --policy_path ./schedule/policy_trec.txt \ 15 | --enforce_prior \ 16 | --prior_weight 1 \ 17 | --metric_learning \ 18 | --metric_loss random \ 19 | --metric_weight 0.01 \ 20 | --metric_margin 0.5 21 | elif [[ $1 = "sst2" ]]; then 22 | CUDA_VISIBLE_DEVICES=3 \ 23 | python -u train.py \ 24 | --model_name blstm \ 25 | --dataset sst2 \ 26 | --valid_size 0 \ 27 | --subtrain_ratio 0.1 \ 28 | --policy_epochs 100 \ 29 | --epochs 60 \ 30 | --name trec_model \ 31 | --use_modals \ 32 | --temperature 1 \ 33 | --distance_metric loss \ 34 | --policy_path ./schedule/policy_sst2.txt \ 35 | --enforce_prior \ 36 | --prior_weight 1 \ 37 | --metric_learning \ 38 | --metric_loss random \ 39 | --metric_weight 0.03 \ 40 | --metric_margin 0.5 41 | fi 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.7.3 2 | aiohttp-cors==0.7.0 3 | aioredis==1.3.1 4 | async-timeout==3.0.1 5 | attrs==20.3.0 6 | blessings==1.7 7 | cachetools==4.2.0 8 | certifi==2020.12.5 9 | chardet==3.0.4 10 | click==7.1.2 11 | colorama==0.4.4 12 | colorful==0.5.4 13 | contextvars==2.4 14 | cycler==0.10.0 15 | dataclasses==0.8 16 | dill==0.3.3 17 | filelock==3.0.12 18 | google-api-core==1.25.0 19 | google-auth==1.24.0 20 | googleapis-common-protos==1.52.0 21 | gpustat==0.6.0 22 | grpcio==1.34.1 23 | hiredis==1.1.0 24 | idna==2.10 25 | idna-ssl==1.1.0 26 | immutables==0.14 27 | importlib-metadata==3.4.0 28 | joblib==1.0.0 29 | jsonschema==3.2.0 30 | kiwisolver==1.3.1 31 | matplotlib==3.3.3 32 | msgpack==1.0.2 33 | multidict==5.1.0 34 | numpy==1.19.5 35 | nvidia-ml-py3==7.352.0 36 | opencensus==0.7.12 37 | opencensus-context==0.1.2 38 | pandas==1.1.5 39 | Pillow==8.1.0 40 | prometheus-client==0.9.0 41 | protobuf==3.14.0 42 | psutil==5.8.0 43 | py-spy==0.3.4 44 | pyasn1==0.4.8 45 | pyasn1-modules==0.2.8 46 | pyparsing==2.4.7 47 | pyrsistent==0.17.3 48 | python-dateutil==2.8.1 49 | pytz==2020.5 50 | PyYAML==5.4 51 | ray==1.1.0 52 | redis==3.5.3 53 | requests==2.25.1 54 | rsa==4.7 55 | scikit-learn==0.24.1 56 | scipy==1.5.4 57 | six==1.15.0 58 | sklearn==0.0 59 | tabulate==0.8.7 60 | tensorboardX==2.1 61 | threadpoolctl==2.1.0 62 | torch==1.7.1 63 | torchtext==0.8.1 64 | tqdm==4.56.0 65 | typing-extensions==3.7.4.3 66 | urllib3==1.26.2 67 | yarl==1.6.3 68 | zipp==3.4.0 69 | -------------------------------------------------------------------------------- /modals/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def discriminator_loss(d_real, d_fake, eps): 7 | return -torch.mean(torch.log(d_real+eps)+torch.log(1-d_fake+eps)) 8 | 9 | 10 | def adverserial_loss(d_fake, eps): 11 | return -torch.mean(torch.log(d_fake+eps)) 12 | 13 | 14 | class OnlineTripletLoss(nn.Module): 15 | """ 16 | Online Triplets loss 17 | Takes a batch of embeddings and corresponding labels. 18 | Triplets are generated using triplet_selector object that take embeddings and targets and return indices of 19 | triplets 20 | """ 21 | 22 | def __init__(self, margin, triplet_selector): 23 | super(OnlineTripletLoss, self).__init__() 24 | self.margin = margin 25 | self.triplet_selector = triplet_selector 26 | 27 | def forward(self, embeddings, target): 28 | 29 | triplets = self.triplet_selector.get_triplets(embeddings, target) 30 | 31 | if embeddings.is_cuda: 32 | triplets = triplets.cuda() 33 | 34 | ap_distances = (embeddings[triplets[:, 0]] - 35 | embeddings[triplets[:, 1]]).pow(2).sum(1) # .pow(.5) 36 | an_distances = (embeddings[triplets[:, 0]] - 37 | embeddings[triplets[:, 2]]).pow(2).sum(1) # .pow(.5) 38 | # print(f'+ve: {ap_distances.mean()}\t-ve: {an_distances.mean()}') 39 | losses = F.relu(ap_distances - an_distances + self.margin) 40 | # losses = torch.max(an_distances - ap_distances+ self.margin, 0)[0] 41 | 42 | return losses.mean(), len(triplets) 43 | -------------------------------------------------------------------------------- /networks/blstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BiLSTM(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | # params: "n_" means dimension 10 | # number of unique words in vocabulary 11 | self.n_vocab = config['n_vocab'] 12 | self.n_layers = config['n_layers'] # number of layers 13 | # number of hidden nodes 14 | self.rnn_hidden = config['n_hidden']//( 15 | 2*config['n_layers']) if config['b_dir'] else config['n_hidden'] 16 | 17 | self.embedding = self._embedding( 18 | config['n_vocab'], config['n_embed'], config['pad_idx'], config['emb']) 19 | self.rnn = self._cell(config['n_embed'], self.rnn_hidden, 20 | config['n_layers'], config['rnn_drop'], config['b_dir']) 21 | self.dropout = nn.Dropout(config['fc_drop']) 22 | self.fc = nn.Linear(config['n_hidden'], config['n_output']) 23 | 24 | def _cell(self, n_embed, n_hidden, n_layers, drop_p, b_dir): 25 | cell = nn.LSTM(n_embed, n_hidden, n_layers, bidirectional=b_dir) 26 | return cell 27 | 28 | def _embedding(self, n_vocab, n_embed, pad_idx, emb): 29 | embedding = nn.Embedding( 30 | n_vocab, n_embed, padding_idx=pad_idx).from_pretrained(emb, freeze=True) 31 | return embedding 32 | 33 | def extract_features(self, texts, seq_lens): 34 | embedded = self.dropout(self.embedding(texts)) # sq_len X bs X n_EMB 35 | packed_embedded = nn.utils.rnn.pack_padded_sequence( 36 | embedded, seq_lens) # seq_len:128 [0]: lenght of each sentence 37 | rnn_out, (hidden, cell) = self.rnn( 38 | packed_embedded) # 1 X bs X n_hidden 39 | features = hidden.permute(1, 0, 2).reshape(len(seq_lens), -1) 40 | return features 41 | 42 | def classify(self, features): 43 | fc_out = self.fc(features) # 1 x bs x d_out 44 | # softmax_out = F.softmax(fc_out, dim=-1) 45 | return fc_out 46 | 47 | def forward(self, x, seq_lens): 48 | x = self.extract_features(x, seq_lens) 49 | return self.classify(x) 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | __Update (20 Jan 2020): MODALS on text data is avialable__ 2 | 3 | ### MODALS 4 | MODALS: Modality-agnostic Automated Data Augmentation in the Latent Space 5 | 6 | ### Table of Contents 7 | 8 | 1. [Introduction](#introduction) 9 | 2. [Getting Started](#getting-started) 10 | 3. [Run Search](#run-modals-search) 11 | 4. [Run Training](#run-modals-training) 12 | 5. [Citation](#citation) 13 | 14 | ### Introduction 15 | 16 | MODALS is a framework to apply automated data augmentation to augment data for any modality in a generic way. It exploits automated data augmentation 17 | to fine-tune four universal data transformation operations in the latent space to adapt the transform to data of different modalities. 18 | 19 | This repository contains code for the work "MODALS: Modality-agnostic Automated Data Augmentation in the Latent Space" (https://openreview.net/pdf?id=XjYgR6gbCEc) implemented using the PyTorch library. It includes searching and training of the SST2 and TREC6 datasets. 20 | 21 | ### Getting Started 22 | Code supports Python 3. 23 | 24 | #### Install requirements 25 | 26 | ```shell 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | #### Setting up directory path 31 | In `modals/setup.py`, specify the dataset path for `DATA_DIR` and the path to the directory that contains the glove embeddings for `EMB_DIR`. 32 | 33 | ### Run MODALS search 34 | Script to search for the augmentation policy for SST2 and TREC6 datasets is located in `scripts/search.sh`. Pass the dataset name as the arguement to call the script. 35 | 36 | For example, to search for the augmentation policy for SST2 dataset: 37 | 38 | ```shell 39 | bash scripts/search.sh sst2 40 | ``` 41 | 42 | The training log and candidate policies of the search will be output to the `./ray_experiments` directory. 43 | 44 | ### Run MODALS training 45 | Two searched policy is included in the `./schedule` directory. The script to apply the searched policy for training SST2 and TREC6 is located in `scripts/train.sh`. Pass the dataset name as the arguement to call the script. 46 | 47 | ```shell 48 | bash scripts/train.sh sst2 49 | ``` 50 | 51 | ### Citation 52 | If you use MODALS in your research, please cite: 53 | 54 | ``` 55 | @inproceedings{cheung2021modals, 56 | title = {{\{}MODALS{\}}: Modality-agnostic Automated Data Augmentation in the Latent Space}, 57 | author = {Tsz-Him Cheung and Dit-Yan Yeung}, 58 | booktitle = {International Conference on Learning Representations}, 59 | year = {2021}, 60 | url = {https://openreview.net/forum?id=XjYgR6gbCEc} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # data 10 | data/ 11 | vector_cache/ 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | *.DS_Store 34 | __pycache__/ 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import ray 5 | import ray.tune as tune 6 | from modals.setup import create_hparams, create_parser 7 | from modals.trainer import TextModelTrainer 8 | from ray.tune.schedulers import PopulationBasedTraining 9 | 10 | 11 | class RayModel(tune.Trainable): 12 | def _setup(self, *args): 13 | self.trainer = TextModelTrainer(self.config) 14 | 15 | def _train(self): 16 | print(f'Starting Ray Iteration: {self._iteration}') 17 | train_acc, valid_acc = self.trainer.run_model(self._iteration) 18 | test_acc, test_loss = self.trainer._test(self._iteration, mode='test') 19 | return {'train_acc': train_acc, 'valid_acc': valid_acc, 'test_acc': test_acc} 20 | 21 | def _save(self, checkpoint_dir): 22 | print(checkpoint_dir) 23 | path = self.trainer.save_model(checkpoint_dir, self._iteration) 24 | print(path) 25 | return path 26 | 27 | def _restore(self, checkpoint_path): 28 | self.trainer.load_model(checkpoint_path) 29 | 30 | def reset_config(self, new_config): 31 | self.config = new_config 32 | self.trainer.reset_config(self.config) 33 | return True 34 | 35 | 36 | def search(): 37 | FLAGS = create_parser('search') 38 | hparams = create_hparams('search', FLAGS) 39 | 40 | # if FLAGS.restore: 41 | # train_spec["restore"] = FLAGS.restore 42 | 43 | def explore(config): 44 | """Custom explore function. 45 | 46 | Args: 47 | config: dictionary containing ray config params. 48 | 49 | Returns: 50 | Copy of config with modified augmentation policy. 51 | """ 52 | new_params = [] 53 | for i, param in enumerate(config["hp_policy"]): 54 | if random.random() < 0.2: 55 | new_params.append(random.randint(0, 10)) 56 | else: 57 | amt = np.random.choice( 58 | [0, 1, 2, 3], p=[0.25, 0.25, 0.25, 0.25]) 59 | amt = int(amt) 60 | if random.random() < 0.5: 61 | new_params.append(max(0, param - amt)) 62 | else: 63 | new_params.append(min(10, param + amt)) 64 | config["hp_policy"] = new_params 65 | return config 66 | 67 | ray.init() 68 | 69 | pbt = PopulationBasedTraining( 70 | time_attr="training_iteration", 71 | metric="valid_acc", 72 | mode='max', 73 | perturbation_interval=FLAGS.perturbation_interval, 74 | custom_explore_fn=explore, 75 | log_config=True) 76 | 77 | tune.run( 78 | RayModel, 79 | name=hparams['ray_name'], 80 | scheduler=pbt, 81 | reuse_actors=True, 82 | verbose=True, 83 | checkpoint_score_attr="valid_acc", 84 | checkpoint_freq=FLAGS.checkpoint_freq, 85 | resources_per_trial={"gpu": FLAGS.gpu, "cpu": FLAGS.cpu}, 86 | stop={"training_iteration": hparams['num_epochs']}, 87 | config=hparams, 88 | local_dir=FLAGS.ray_dir, 89 | num_samples=FLAGS.num_samples 90 | ) 91 | 92 | 93 | if __name__ == "__main__": 94 | search() 95 | -------------------------------------------------------------------------------- /modals/data_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import dill 7 | import torch 8 | import torchtext.data as data 9 | import torchtext.datasets as datasets 10 | from torch.utils.data import Sampler 11 | from torchtext.vocab import GloVe 12 | from modals.setup import EMB_DIR 13 | 14 | 15 | def save_txt_dataset(dataset, path): 16 | if not isinstance(path, Path): 17 | path = Path(path) 18 | path.mkdir(parents=True, exist_ok=True) 19 | torch.save(dataset.examples, path/"examples.pkl", pickle_module=dill) 20 | torch.save(dataset.fields, path/"fields.pkl", pickle_module=dill) 21 | 22 | 23 | def load_txt_dataset(path, fields): 24 | if not isinstance(path, Path): 25 | path = Path(path) 26 | examples = torch.load(path/"examples.pkl", pickle_module=dill) 27 | # fields = torch.load(path/"fields.pkl", pickle_module=dill) 28 | return data.Dataset(examples, fields) 29 | 30 | 31 | class SubsetSampler(Sampler): 32 | r"""Samples elements from a given list of indices, without replacement. 33 | Arguments: 34 | indices (sequence): a sequence of indices 35 | """ 36 | 37 | def __init__(self, indices): 38 | self.indices = indices 39 | 40 | def __iter__(self): 41 | return (i for i in self.indices) 42 | 43 | def __len__(self): 44 | return len(self.indices) 45 | 46 | 47 | def binarize(dataset): 48 | binary_examples = [] 49 | for example in dataset.examples: 50 | if example.label != 'neutral': 51 | binary_examples.append(example) 52 | dataset.examples = binary_examples 53 | return dataset 54 | 55 | 56 | def get_text_dataloaders(dataset_name, valid_size, batch_size, subtrain_ratio=1.0, dataroot='.data'): 57 | 58 | TEXT = data.Field(lower=True, include_lengths=True, batch_first=False) 59 | LABEL = data.Field(sequential=False) 60 | fields = {'text': TEXT, 'label': LABEL} 61 | 62 | if dataset_name == 'sst2': 63 | train, valid, test = datasets.SST.splits(TEXT, LABEL, root=dataroot) 64 | train, valid, test = binarize(train), binarize(valid), binarize(test) 65 | if subtrain_ratio < 1.0: 66 | train, hold_train = train.split( 67 | split_ratio=subtrain_ratio, stratified=True) 68 | classes = ['negative', 'positive'] 69 | elif dataset_name == 'trec': 70 | random.seed(0) 71 | train, test = datasets.TREC.splits( 72 | TEXT, LABEL, fine_grained=False, root=dataroot) 73 | if valid_size > 0: 74 | train, valid = train.split( 75 | stratified=True, random_state=random.getstate()) # default 0.7 76 | else: 77 | valid = None 78 | if subtrain_ratio < 1.0: 79 | train, hold_train = train.split( 80 | split_ratio=subtrain_ratio, stratified=True, random_state=random.getstate()) 81 | classes = ['DESC', 'ENTY', 'ABBR', 'HUM', 'NUM', 'LOC'] 82 | else: 83 | ValueError(f'Invalid dataset name={dataset_name}') 84 | 85 | TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300, cache=EMB_DIR)) 86 | LABEL.build_vocab(train) 87 | 88 | train_loader, valid_loader, test_loader = data.BucketIterator.splits( 89 | (train, valid, test), batch_size=batch_size, sort=True, sort_key=lambda x: len(x.text), 90 | sort_within_batch=True) 91 | 92 | print('### Dataset ###') 93 | print(f'=>{dataset_name}') 94 | print(f' |Train size:\t{len(train)}') 95 | if valid is not None: 96 | print(f' |Valid size:\t{len(valid)}') 97 | print(f' |Test size:\t{len(test)}') 98 | print(f' |Vocab size:\t{len(TEXT.vocab)}') 99 | 100 | return train_loader, valid_loader, test_loader, classes, TEXT.vocab 101 | -------------------------------------------------------------------------------- /schedule/policy_sst2.txt: -------------------------------------------------------------------------------- 1 | ["7", "1", 1, 2, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [1, 2, 8, 0, 1, 0, 6, 1, 2, 0, 0, 0, 3, 0, 2, 2], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 2 | ["2", "7", 25, 24, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [1, 2, 8, 0, 1, 0, 6, 1, 2, 0, 0, 0, 3, 0, 2, 2], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 6, 3, 1, 0, 7, 0, 2, 2, 2, 0, 1, 3, 0, 0], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 3 | ["14", "2", 31, 33, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 6, 3, 1, 0, 7, 0, 2, 2, 2, 0, 1, 3, 0, 0], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 6, 6, 0, 1, 5, 3, 2, 8, 2, 0, 0, 3, 10, 0], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 4 | ["10", "14", 41, 40, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 6, 6, 0, 1, 5, 3, 2, 8, 2, 0, 0, 3, 10, 0], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [9, 1, 7, 4, 0, 0, 3, 2, 5, 5, 2, 0, 0, 4, 10, 3], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 5 | ["4", "10", 47, 47, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [9, 1, 7, 4, 0, 0, 3, 2, 5, 5, 2, 0, 0, 4, 10, 3], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "sst2", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": false, "metric_learning": false, "pba": false, "autoaugment": false, "use_lpba": true, "policy_path": null, "hp_policy": [9, 1, 10, 2, 0, 0, 6, 0, 5, 4, 0, 1, 3, 4, 9, 1], "ray_name": "sst2/v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 6 | -------------------------------------------------------------------------------- /modals/augmentation_transforms.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from modals.custom_ops import cosine 9 | 10 | PARAMETER_MAX = 10 11 | 12 | 13 | def float_parameter(level, maxval): 14 | """Helper function to scale `val` between 0 and maxval . 15 | 16 | Args: 17 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 18 | maxval: Maximum value that the operation can have. This will be scaled 19 | to level/PARAMETER_MAX. 20 | 21 | Returns: 22 | A float that results from scaling `maxval` according to `level`. 23 | """ 24 | x = float(level) * maxval / PARAMETER_MAX 25 | 26 | return torch.as_tensor(x) 27 | 28 | 29 | def apply_policy_from_pool(policy, img, img_pool, verbose=0): 30 | """Apply the `policy` to the sentence. 31 | 32 | Args: 33 | policy: A list of tuples with the form (name, probability, magnitude) where 34 | `name` is the name of the augmentation operation to apply, `probability` 35 | is the probability of applying the operation and `magnitude` is what strength 36 | the operation to apply. 37 | img: Numpy image that will have `policy` applied to it. 38 | verbose: 0: no log 39 | 1: text log 40 | 2: visualization log 41 | 42 | Returns: 43 | The result of applying `policy` to `sentence`. 44 | """ 45 | label_img_pool = img_pool 46 | display = '=> ' 47 | count = np.random.choice([0, 1, 2, 3], p=[0.2, 0.7, 0.1, 0.0]) 48 | support_idxs = [] 49 | ximg = img 50 | if count != 0: 51 | policy = copy.copy(policy) 52 | random.shuffle(policy) 53 | for xform in policy: 54 | assert len(xform) == 3 55 | name, probability, magnitude = xform 56 | assert 0. <= probability <= 1. 57 | assert 0 <= magnitude <= PARAMETER_MAX 58 | xform_fn = NAME_TO_TRANSFORM[name].transformer( 59 | probability, magnitude) 60 | (ximg, support_idx), res = xform_fn(ximg, img_pool) # 1st: (img, support) 61 | if verbose > 0 and res: 62 | display += f"Op: {name}, Magnitude: {magnitude}, Prob: {probability} " 63 | if verbose > 1: 64 | support_idxs.append(support_idx) 65 | count -= res 66 | assert count >= 0 67 | if count == 0: 68 | break 69 | if verbose: 70 | print(display) 71 | return ximg, support_idxs 72 | else: 73 | return img, [] 74 | 75 | 76 | class TransformFunction(object): 77 | """Wraps the Transform function for pretty printing options.""" 78 | 79 | def __init__(self, func, name): 80 | self.f = func 81 | self.name = name 82 | 83 | def __repr__(self): 84 | return '<' + self.name + '>' 85 | 86 | def __call__(self, img, label_img_pool): 87 | return self.f(img, label_img_pool) 88 | 89 | 90 | class TransformT(object): 91 | """Each instance of this class represents a specific transform.""" 92 | 93 | def __init__(self, name, xform_fn): 94 | self.name = name 95 | self.xform = xform_fn 96 | 97 | def transformer(self, probability, magnitude): 98 | 99 | def return_function(img, label_img_pool): 100 | res = False 101 | s = [] 102 | if random.random() < probability: 103 | img, s = self.xform(img, label_img_pool, magnitude) 104 | res = True 105 | return (img, s), res 106 | 107 | name = self.name + '({:.1f},{})'.format(probability, magnitude) 108 | return TransformFunction(return_function, name) 109 | 110 | def do_transform(self, img, label_img_pool, magnitude): 111 | f = self.transformer(PARAMETER_MAX, magnitude) 112 | return f(img, label_img_pool) 113 | 114 | 115 | def _interpolate(img, class_info, magnitude): 116 | ''' this function interpolates target imgage with a pool of other images 117 | using a magnitude 118 | img: a 1D numpy arrays 119 | img_pool: a 2D numpy array''' 120 | m = float_parameter(magnitude, 1) 121 | x = img 122 | p = class_info['weights'] 123 | if len(p)<1: 124 | return img, [] 125 | k = max(1, int(len(class_info['pool']) * 0.05)) 126 | idxs = np.random.choice(len(class_info['pool']), k, p=p) #choose points near to the boundary 127 | distances = cosine(class_info['pool'][idxs]-class_info['mean'], x.detach().cpu().view(-1)-class_info['mean']) #but not too far from the seed 128 | idx = idxs[np.argmax(distances)] 129 | y = class_info['pool'][idx] 130 | x_hat = (y.cuda()-x)*m + x 131 | return x_hat, [idx] 132 | 133 | 134 | interpolate = TransformT('Interpolate', _interpolate) 135 | 136 | 137 | def _extrapolate(img, class_info, magnitude): 138 | ''' this function extrapolate target imgage with a pool of other images 139 | using a magnitude 140 | img: a 1D numpy arrays 141 | img_pool: a 2D numpy array''' 142 | 143 | m = float_parameter(magnitude, 1) 144 | x = img 145 | mu = class_info['mean'] 146 | x_hat = (x-mu.cuda())*m + x 147 | return x_hat, [] 148 | 149 | 150 | extrapolate = TransformT('Extrapolate', _extrapolate) 151 | 152 | 153 | def _linearpolate(img, class_info, magnitude): 154 | ''' this function linear move target imgage with a pool of other images 155 | using a magnitude 156 | img: a 1D numpy arrays 157 | img_pool: a 2D numpy array''' 158 | 159 | m = float_parameter(magnitude, 1) 160 | x = img 161 | if len(class_info['pool']) < 2: 162 | return x, [0,0] 163 | idx1, idx2 = random.sample(range(len(class_info['pool'])), 2) 164 | y1, y2 = class_info['pool'][idx1], class_info['pool'][idx2] 165 | x_hat = (y1.cuda()-y2.cuda())*m + x 166 | return x_hat, [idx1, idx2] 167 | 168 | 169 | linear_polate = TransformT('LinearPolate', _linearpolate) 170 | 171 | 172 | def _resample(img, class_info, magnitude): 173 | x = img 174 | m = float_parameter(magnitude, 1) 175 | noise = torch.randn(img.size()).cuda() 176 | x_hat = x+noise*class_info['sd'].cuda()*m 177 | return x_hat, [] 178 | 179 | 180 | resample = TransformT('Resample', _resample) 181 | 182 | HP_TRANSFORMS = [ 183 | interpolate, 184 | extrapolate, 185 | linear_polate, 186 | resample 187 | ] 188 | 189 | NAME_TO_TRANSFORM = collections.OrderedDict((t.name, t) for t in HP_TRANSFORMS) 190 | HP_TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys() 191 | NUM_HP_TRANSFORM = len(HP_TRANSFORM_NAMES) 192 | -------------------------------------------------------------------------------- /modals/augmentation_transforms_cpu.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import random 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | from modals.custom_ops import cosine 9 | 10 | PARAMETER_MAX = 10 11 | 12 | 13 | def float_parameter(level, maxval): 14 | """Helper function to scale `val` between 0 and maxval . 15 | 16 | Args: 17 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 18 | maxval: Maximum value that the operation can have. This will be scaled 19 | to level/PARAMETER_MAX. 20 | 21 | Returns: 22 | A float that results from scaling `maxval` according to `level`. 23 | """ 24 | x = float(level) * maxval / PARAMETER_MAX 25 | 26 | return torch.as_tensor(x) 27 | 28 | 29 | def apply_policy_from_pool(policy, img, img_pool, verbose=0): 30 | """Apply the `policy` to the sentence. 31 | 32 | Args: 33 | policy: A list of tuples with the form (name, probability, magnitude) where 34 | `name` is the name of the augmentation operation to apply, `probability` 35 | is the probability of applying the operation and `magnitude` is what strength 36 | the operation to apply. 37 | img: Numpy image that will have `policy` applied to it. 38 | verbose: 0: no log 39 | 1: text log 40 | 2: visualization log 41 | 42 | Returns: 43 | The result of applying `policy` to `sentence`. 44 | """ 45 | label_img_pool = img_pool 46 | display = '=> ' 47 | count = np.random.choice([0, 1, 2, 3], p=[0.2, 0.7, 0.1, 0.0]) 48 | support_idxs = [] 49 | ximg = img 50 | if count != 0: 51 | policy = copy.copy(policy) 52 | random.shuffle(policy) 53 | for xform in policy: 54 | assert len(xform) == 3 55 | name, probability, magnitude = xform 56 | assert 0. <= probability <= 1. 57 | assert 0 <= magnitude <= PARAMETER_MAX 58 | xform_fn = NAME_TO_TRANSFORM[name].transformer( 59 | probability, magnitude) 60 | (ximg, support_idx), res = xform_fn(ximg, img_pool) # 1st: (img, support) 61 | if verbose > 0 and res: 62 | display += f"Op: {name}, Magnitude: {magnitude}, Prob: {probability} " 63 | if verbose > 1: 64 | support_idxs.append(support_idx) 65 | count -= res 66 | assert count >= 0 67 | if count == 0: 68 | break 69 | if verbose: 70 | print(display) 71 | return ximg, support_idxs 72 | else: 73 | return img, [] 74 | 75 | 76 | class TransformFunction(object): 77 | """Wraps the Transform function for pretty printing options.""" 78 | 79 | def __init__(self, func, name): 80 | self.f = func 81 | self.name = name 82 | 83 | def __repr__(self): 84 | return '<' + self.name + '>' 85 | 86 | def __call__(self, img, label_img_pool): 87 | return self.f(img, label_img_pool) 88 | 89 | 90 | class TransformT(object): 91 | """Each instance of this class represents a specific transform.""" 92 | 93 | def __init__(self, name, xform_fn): 94 | self.name = name 95 | self.xform = xform_fn 96 | 97 | def transformer(self, probability, magnitude): 98 | 99 | def return_function(img, label_img_pool): 100 | res = False 101 | s = [] 102 | if random.random() < probability: 103 | img, s = self.xform(img, label_img_pool, magnitude) 104 | res = True 105 | return (img, s), res 106 | 107 | name = self.name + '({:.1f},{})'.format(probability, magnitude) 108 | return TransformFunction(return_function, name) 109 | 110 | def do_transform(self, img, label_img_pool, magnitude): 111 | f = self.transformer(PARAMETER_MAX, magnitude) 112 | return f(img, label_img_pool) 113 | 114 | 115 | def _interpolate(img, class_info, magnitude): 116 | ''' this function interpolates target imgage with a pool of other images 117 | using a magnitude 118 | img: a 1D numpy arrays 119 | img_pool: a 2D numpy array''' 120 | m = float_parameter(magnitude, 1) 121 | x = img 122 | p = class_info['weights'] 123 | k = max(1, int(len(class_info['pool']) * 0.05)) 124 | idxs = np.random.choice(len(class_info['pool']), k, p=p) #choose points near to the boundary 125 | # print(type(class_info['mean'])) 126 | # print(class_info['pool'][idxs]-class_info['mean']) 127 | distances = cosine(class_info['pool'][idxs]-class_info['mean'], x.detach().cpu()-class_info['mean']) #but not too far from the seed 128 | idx = idxs[np.argmax(distances)] 129 | y = class_info['pool'][idx] 130 | x_hat = (y-x)*m + x 131 | return x_hat, [idx] 132 | 133 | 134 | interpolate = TransformT('Interpolate', _interpolate) 135 | 136 | 137 | def _extrapolate(img, class_info, magnitude): 138 | ''' this function extrapolate target imgage with a pool of other images 139 | using a magnitude 140 | img: a 1D numpy arrays 141 | img_pool: a 2D numpy array''' 142 | 143 | m = float_parameter(magnitude, 1) 144 | x = img 145 | mu = class_info['mean'] 146 | x_hat = (x-mu)*m + x 147 | return x_hat, [] 148 | 149 | 150 | extrapolate = TransformT('Extrapolate', _extrapolate) 151 | 152 | 153 | def _linearpolate(img, class_info, magnitude): 154 | ''' this function linear move target imgage with a pool of other images 155 | using a magnitude 156 | img: a 1D numpy arrays 157 | img_pool: a 2D numpy array''' 158 | 159 | m = float_parameter(magnitude, 1) 160 | x = img 161 | if len(class_info['pool']) < 2: 162 | return x, [0,0] 163 | idx1, idx2 = random.sample(range(len(class_info['pool'])), 2) 164 | y1, y2 = class_info['pool'][idx1], class_info['pool'][idx2] 165 | x_hat = (y1-y2)*m + x 166 | return x_hat, [idx1, idx2] 167 | 168 | 169 | linear_polate = TransformT('LinearPolate', _linearpolate) 170 | 171 | 172 | def _resample(img, class_info, magnitude): 173 | x = img 174 | m = float_parameter(magnitude, 1) 175 | noise = torch.randn(img.size()) 176 | x_hat = x+noise*class_info['sd']*m 177 | return x_hat, [] 178 | 179 | 180 | resample = TransformT('Resample', _resample) 181 | 182 | HP_TRANSFORMS = [ 183 | interpolate, 184 | extrapolate, 185 | linear_polate, 186 | resample 187 | ] 188 | 189 | NAME_TO_TRANSFORM = collections.OrderedDict((t.name, t) for t in HP_TRANSFORMS) 190 | HP_TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys() 191 | NUM_HP_TRANSFORM = len(HP_TRANSFORM_NAMES) 192 | -------------------------------------------------------------------------------- /schedule/policy_trec.txt: -------------------------------------------------------------------------------- 1 | ["7", "4", 7, 8, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [0, 10, 0, 0, 0, 0, 0, 2, 2, 1, 0, 7, 0, 3, 1, 1], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 2 | ["7", "6", 13, 14, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 0, 3, 0, 0, 10, 0, 6, 2, 2, 4, 1, 0, 0, 0], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 3 | ["15", "7", 33, 34, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [0, 0, 0, 3, 0, 0, 10, 0, 6, 2, 2, 4, 1, 0, 0, 0], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [7, 0, 2, 4, 3, 0, 9, 4, 6, 4, 0, 7, 0, 8, 9, 0], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 4 | ["6", "15", 39, 38, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [7, 0, 2, 4, 3, 0, 9, 4, 6, 4, 0, 7, 0, 8, 9, 0], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [6, 3, 0, 8, 1, 0, 10, 4, 9, 10, 0, 10, 5, 8, 1, 5], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 5 | ["2", "6", 53, 52, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [6, 3, 0, 8, 1, 0, 10, 4, 9, 10, 0, 10, 5, 8, 1, 5], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}, {"valid_size": 500, "dataset_name": "trec", "dataset_dir": "", "checkpoint_dir": "", "batch_size": 100, "gradient_clipping_by_global_norm": 5.0, "mixup": false, "lr": 0.1, "wd": 0.0005, "momentum": 0.9, "milestones": "60,120,160", "gamma": 0.2, "gpu_device": "cuda:0", "mode": "search", "temperature": 1.0, "distance_metric": "loss", "enforce_prior": true, "metric_learning": true, "pba": false, "autoaugment": false, "prior_weight": 1.0, "metric_loss": "random", "metric_margin": 4.0, "metric_weight": 0.03, "use_lpba": true, "policy_path": null, "hp_policy": [7, 3, 6, 7, 1, 2, 10, 6, 9, 10, 2, 4, 5, 8, 1, 3], "ray_name": "v500e60_t1_dloss_@b0.03_r_m4.0_0429", "model_name": "blstm", "num_epochs": 60}] 6 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import math 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import time 7 | import pandas as pd 8 | import numpy as np 9 | 10 | def unpickle(file): 11 | import pickle 12 | with open(file, 'rb+') as fo: 13 | dict = pickle.load(fo, encoding='utf-8') 14 | return dict 15 | 16 | def get_image_stats(dataset): 17 | 18 | MEANS = { 19 | 'cifar10': (0.49139968, 0.48215841, 0.44653091), 20 | 'reduced_cifar10': (0.49056774, 0.48116026, 0.44726052), 21 | 'cifar100': (0.50707516, 0.48654887, 0.44091784), 22 | 'reduced_svhn': (0.45163885, 0.4557915, 0.48093327), 23 | 'ablation_svhn': (0.20385217, 0.20957996, 0.20804394), 24 | 'svhn': (0.43090966, 0.4302428, 0.44634357) 25 | } 26 | STDS = { 27 | 'cifar10': (0.24703223, 0.24348513, 0.26158784), 28 | 'reduced_cifar10': (0.24710728, 0.24451308, 0.26235099), 29 | 'cifar100': (0.26733429, 0.25643846, 0.27615047), 30 | 'reduced_svhn': (0.20385217, 0.20957996, 0.20804394), 31 | 'ablation_svhn': (0.20385217, 0.20957996, 0.20804394), 32 | 'svhn': (0.19652855, 0.19832038, 0.19942076) 33 | } 34 | 35 | return MEANS[dataset], STDS[dataset] 36 | 37 | 38 | def imshow(img, dataset, normalize=False): 39 | img = img.clone() 40 | if normalize: 41 | m, s = get_image_stats(dataset) 42 | for t, m, s in zip(img, m, s): 43 | t.mul_(s).add_(m) 44 | 45 | npimg = img.detach().cpu().numpy() 46 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 47 | plt.show() 48 | 49 | def cosine_lr(learning_rate, cur_step, total_step): 50 | """Cosine Learning rate. 51 | 52 | Args: 53 | learning_rate: Initial learning rate. 54 | epoch: Current epoch we are one. This is one based. 55 | iteration: Current batch in this epoch. 56 | batches_per_epoch: Batches per epoch. 57 | total_epochs: Total epochs you are training for. 58 | 59 | Returns: 60 | The learning rate to be used for this current batch. 61 | """ 62 | # t_total = total_epochs * batches_per_epoch 63 | # t_cur = float(epoch * batches_per_epoch + iteration) 64 | return 0.5 * learning_rate * (1 + np.cos(np.pi * cur_step / total_step)) 65 | 66 | 67 | def get_lr(learning_rate, iteration=None, total_iteration=None): 68 | """Returns the learning rate during training based on the current epoch.""" 69 | assert iteration is not None 70 | lr = cosine_lr(learning_rate, iteration, total_iteration) 71 | return lr 72 | 73 | 74 | def get_hms(seconds): 75 | m, s = divmod(seconds, 60) 76 | h, m = divmod(m, 60) 77 | 78 | return h, m, s 79 | 80 | 81 | def save_checkpoint(model, name, model_dir, epoch, loss_dict): 82 | path = os.path.join(model_dir, name) 83 | 84 | # save the checkpoint. 85 | if not os.path.exists(model_dir): 86 | os.makedirs(model_dir) 87 | torch.save({'state': model.state_dict(), 88 | 'epoch': epoch, 'loss': loss_dict}, path) 89 | 90 | # notify that we successfully saved the checkpoint. 91 | print('=> saved the model {name} to {path}'.format( 92 | name=name, path=path 93 | )) 94 | 95 | 96 | def load_checkpoint(model, name, model_dir): 97 | path = os.path.join(model_dir, name) 98 | 99 | # load the checkpoint. 100 | checkpoint = torch.load(path) 101 | print('=> loaded checkpoint of {name} from {path}'.format( 102 | name=name, path=(path) 103 | )) 104 | 105 | # load parameters and return the checkpoint's epoch and precision. 106 | model.load_state_dict(checkpoint['state']) 107 | epoch = checkpoint['epoch'] 108 | loss = checkpoint['loss'] 109 | return epoch, loss 110 | 111 | 112 | def export_feature(dataloader, net, save_path, device): 113 | 114 | img_features = [] 115 | img_labels = [] 116 | 117 | for idx, (images, labels) in enumerate(dataloader): 118 | images, labels = images.to(device), labels.to(device) 119 | x = net.extract_features(images) 120 | features = x.view(x.size(0), -1) 121 | img_features.extend(features.detach().cpu().numpy()) 122 | img_labels.extend(labels.cpu().numpy()) 123 | 124 | img_features = np.array(img_features) 125 | img_labels = np.array(img_labels) 126 | print(img_features.shape) 127 | print(img_labels.shape) 128 | np.save(f'feat_{save_path}', img_features) 129 | np.save(f'label_{save_path}', img_features) 130 | 131 | def to_tensor(): 132 | def _to_tensor(image): 133 | if len(image.shape) == 3: 134 | return torch.from_numpy( 135 | image.transpose(2, 0, 1).astype(np.float32)) 136 | else: 137 | return torch.from_numpy(image[None, :, :].astype(np.float32)) 138 | 139 | return _to_tensor 140 | 141 | def cutout(mask_size, p, cutout_inside, mask_color=(0, 0, 0)): 142 | mask_size_half = mask_size // 2 143 | offset = 1 if mask_size % 2 == 0 else 0 144 | 145 | def _cutout(image): 146 | image = np.asarray(image).copy() 147 | 148 | if np.random.random() > p: 149 | return image 150 | 151 | h, w = image.shape[:2] 152 | 153 | if cutout_inside: 154 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half 155 | cymin, cymax = mask_size_half, h + offset - mask_size_half 156 | else: 157 | cxmin, cxmax = 0, w + offset 158 | cymin, cymax = 0, h + offset 159 | 160 | cx = np.random.randint(cxmin, cxmax) 161 | cy = np.random.randint(cymin, cymax) 162 | xmin = cx - mask_size_half 163 | ymin = cy - mask_size_half 164 | xmax = xmin + mask_size 165 | ymax = ymin + mask_size 166 | xmin = max(0, xmin) 167 | ymin = max(0, ymin) 168 | xmax = min(w, xmax) 169 | ymax = min(h, ymax) 170 | image[ymin:ymax, xmin:xmax] = mask_color 171 | return image 172 | 173 | return _cutout 174 | 175 | def normalize(mean, std): 176 | mean = np.array(mean) 177 | std = np.array(std) 178 | 179 | def _normalize(image): 180 | image = np.asarray(image).astype(np.float32) / 255. 181 | image = (image - mean) / std 182 | return image 183 | 184 | return _normalize 185 | 186 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 187 | '''Returns mixed inputs, pairs of targets, and lambda''' 188 | if alpha > 0: 189 | lam = np.random.beta(alpha, alpha) 190 | else: 191 | lam = 1 192 | 193 | batch_size = x.size()[0] 194 | if use_cuda: 195 | index = torch.randperm(batch_size).cuda() 196 | else: 197 | index = torch.randperm(batch_size) 198 | 199 | mixed_x = lam * x + (1 - lam) * x[index, :] 200 | y_a, y_b = y, y[index] 201 | return mixed_x, y_a, y_b, lam 202 | 203 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 204 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 205 | -------------------------------------------------------------------------------- /modals/setup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | from modals.augmentation_transforms import NUM_HP_TRANSFORM 5 | 6 | RAY_DIR = './ray_results' 7 | DATA_DIR = '' 8 | EMB_DIR = '' 9 | CP_DIR = './checkpoints' 10 | 11 | 12 | def create_parser(mode): 13 | """Create arg parser for flags.""" 14 | parser = argparse.ArgumentParser() 15 | 16 | ## Datasetting 17 | parser.add_argument('--data_dir', default=DATA_DIR, help='Directory where dataset is located.') 18 | parser.add_argument('--dataset', default='trec', choices=('malware', 'har', 'arem', 'hand', 'emnist', 'reduced_emnist', 'cifar10', 'cifar100', 'reduced_cifar10', 'svhn', 'reduced_svhn', 'sst2', 'trec', 'arcene', 'arrhythmia', 'breast','credit', 'iris', 'abalone', 'htru2', 'phishing')) 19 | parser.add_argument('--valid_size', type=int, default=500, help='Number of validation examples.') 20 | parser.add_argument('--subtrain_ratio', type=float, default=1.0, help='Ratio of sub training set') 21 | 22 | ## Model and training setting 23 | parser.add_argument('--model_name',default='wrn',choices=('elstm', 'tslstm', 'wrn-28x10', 'wrn-40x2', 'resnet', 'blstm', 'mlp', 'densenet')) 24 | parser.add_argument('--epochs', type=int, default=1, help='Number of epochs') 25 | parser.add_argument('--lr', type=float, default=0.1, help='learning rate') 26 | parser.add_argument('--wd', type=float, default=0.0005, help='weight decay') 27 | parser.add_argument('--bs', type=int, default=100, help='batch size') 28 | parser.add_argument('--gpu_device', type=str, default='cuda:0') 29 | 30 | parser.add_argument('--checkpoint_freq', type=int, default=50, help='Checkpoint frequency.') 31 | parser.add_argument('--checkpoint_dir', type=str, default=CP_DIR, help='checkpoint directory.') 32 | parser.add_argument('--restore', type=str, default=None, help='If specified, tries to restore from given path.') 33 | 34 | ## Custom Modifications 35 | parser.add_argument('--temperature', type=float, default=1, help='temperature') 36 | parser.add_argument('--enforce_prior', action='store_true', help='otherwise use no policy') 37 | parser.add_argument('--prior_weight', type=float, default=1, help='weight of prior loss') 38 | parser.add_argument('--distance_metric', type=str, default='l2', help='metric used to weight the supporting samples', choices=('l2', 'loss', 'same', 'cosine')) 39 | parser.add_argument('--policy_path', type=str, default=None, help='text file storing a policy') 40 | parser.add_argument('--metric_learning', action='store_true', help='use metric learning') 41 | parser.add_argument('--metric_loss', type=str, default='random_triplets', help='type of triplet loss', choices=('semihard', 'random', 'hardest')) 42 | parser.add_argument('--metric_margin', type=float, default=1.0, help='metric margin') 43 | parser.add_argument('--metric_weight', type=float, default=0.5, help='weight of metric loss') 44 | 45 | parser.add_argument('--mixup', action='store_true', help='mixup benchmark') 46 | parser.add_argument('--alpha', type=float, default=1.0, help='mixup parameter') 47 | parser.add_argument('--manifold_mixup', action='store_true', help='manifold mixup benchmark') 48 | 49 | if mode == 'train': 50 | parser.add_argument('--use_modals', action='store_true', help='otherwise use no policy') 51 | parser.add_argument('--hp_policy', type=str, default=None, help='either a comma separated list of values') 52 | parser.add_argument('--policy_epochs', type=int, default=200, help='number of epochs/iterations policy trained for') 53 | parser.add_argument('--name', type=str, default='autoaug') 54 | 55 | elif mode == 'search': 56 | ## Ray setting 57 | parser.add_argument('--ray_dir', type=str, default=RAY_DIR, help='Ray directory.') 58 | parser.add_argument('--num_samples', type=int, default=3, help='Number of Ray samples') 59 | parser.add_argument('--cpu', type=float, default=4, help='Allocated by Ray') 60 | parser.add_argument('--gpu', type=float, default=0.12, help='Allocated by Ray') 61 | parser.add_argument('--perturbation_interval', type=int, default=3) 62 | parser.add_argument('--ray_name', type=str, default='ray_experiment') 63 | 64 | else: 65 | raise ValueError('unknown state') 66 | 67 | return parser.parse_args() 68 | 69 | 70 | def create_hparams(mode, FLAGS): 71 | """Creates hyperparameters to pass into Ray config. 72 | 73 | Different options depending on search or eval mode. 74 | 75 | Args: 76 | mode: a string, 'train' or 'test' or 'search'. 77 | FLAGS: parsed command line flags. 78 | 79 | Returns: dict 80 | """ 81 | hparams = { 82 | 'valid_size': FLAGS.valid_size, 83 | 'dataset_name': FLAGS.dataset, 84 | 'dataset_dir': FLAGS.data_dir, 85 | 'checkpoint_dir': FLAGS.checkpoint_dir, 86 | 'batch_size': FLAGS.bs, 87 | 'gradient_clipping_by_global_norm': 5.0, 88 | 'mixup': FLAGS.mixup, 89 | 'lr': FLAGS.lr, 90 | 'wd': FLAGS.wd, 91 | 'momentum': 0.9, 92 | 'gpu_device': FLAGS.gpu_device, 93 | 'mode': mode, 94 | 'temperature': FLAGS.temperature if FLAGS.temperature<=1 else FLAGS.temperature/10, 95 | 'distance_metric': FLAGS.distance_metric, 96 | 'enforce_prior': FLAGS.enforce_prior, 97 | 'metric_learning': FLAGS.metric_learning, 98 | 'subtrain_ratio': FLAGS.subtrain_ratio, ## for text data controlling ratio of training data 99 | 'manifold_mixup': FLAGS.manifold_mixup, 100 | } 101 | 102 | if FLAGS.enforce_prior: 103 | hparams['prior_weight'] = FLAGS.prior_weight if FLAGS.prior_weight<=1 else FLAGS.prior_weight/10 104 | 105 | if FLAGS.metric_learning: 106 | hparams['metric_loss'] = FLAGS.metric_loss 107 | hparams['metric_margin'] = FLAGS.metric_margin 108 | hparams['metric_weight'] = FLAGS.metric_weight 109 | 110 | if mode == 'train': 111 | hparams['use_modals'] = FLAGS.use_modals 112 | hparams['policy_path'] = None 113 | hparams['hp_policy'] = None 114 | if FLAGS.use_modals: 115 | if FLAGS.hp_policy == 'random': 116 | # random policy 117 | parsed_policy = [random.randrange(0, 11) for i in range(NUM_HP_TRANSFORM * 4)] 118 | hparams['hp_policy'] = parsed_policy 119 | elif FLAGS.hp_policy == 'average': 120 | # random policy 121 | parsed_policy = [5]* (NUM_HP_TRANSFORM * 4) 122 | hparams['hp_policy'] = parsed_policy 123 | elif FLAGS.policy_path is not None: 124 | # supplied a schedule 125 | hparams['policy_path'] = FLAGS.policy_path 126 | else: 127 | # parse input into a fixed augmentation policy 128 | parsed_policy = FLAGS.hp_policy.split(',') 129 | parsed_policy = [int(p) for p in parsed_policy] 130 | hparams['hp_policy'] = parsed_policy 131 | 132 | elif mode == 'search': 133 | hparams['use_modals'] = True 134 | hparams['policy_path'] = None 135 | # default start value of 0 136 | hparams['hp_policy'] = [0 for _ in range(4 * NUM_HP_TRANSFORM)] 137 | hparams['ray_name'] = FLAGS.ray_name 138 | 139 | else: 140 | raise ValueError('unknown mode') 141 | 142 | # Child model 143 | hparams['model_name'] = FLAGS.model_name 144 | 145 | # epochs is put here for later setting for specific models 146 | hparams['num_epochs'] = FLAGS.epochs 147 | 148 | return hparams 149 | -------------------------------------------------------------------------------- /modals/custom_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import norm 3 | from itertools import combinations 4 | import torch 5 | 6 | 7 | def cosine(X, mu): 8 | Xm = np.dot(X, mu) 9 | norm_p = norm(X, axis=1) * norm(mu) 10 | distances = 1 - Xm/(norm_p+1e-5) 11 | return distances 12 | 13 | 14 | def pdist(vectors): 15 | distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum( 16 | dim=1).view(-1, 1) 17 | return distance_matrix 18 | 19 | 20 | class PairSelector: 21 | """ 22 | Implementation should return indices of positive pairs and negative pairs that will be passed to compute 23 | Contrastive Loss 24 | return positive_pairs, negative_pairs 25 | """ 26 | 27 | def __init__(self): 28 | pass 29 | 30 | def get_pairs(self, embeddings, labels): 31 | raise NotImplementedError 32 | 33 | 34 | class AllPositivePairSelector(PairSelector): 35 | """ 36 | Discards embeddings and generates all possible pairs given labels. 37 | If balance is True, negative pairs are a random sample to match the number of positive samples 38 | """ 39 | 40 | def __init__(self, balance=True): 41 | super(AllPositivePairSelector, self).__init__() 42 | self.balance = balance 43 | 44 | def get_pairs(self, embeddings, labels): 45 | labels = labels.cpu().data.numpy() 46 | all_pairs = np.array(list(combinations(range(len(labels)), 2))) 47 | all_pairs = torch.LongTensor(all_pairs) 48 | positive_pairs = all_pairs[( 49 | labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()] 50 | negative_pairs = all_pairs[( 51 | labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()] 52 | if self.balance: 53 | negative_pairs = negative_pairs[torch.randperm( 54 | len(negative_pairs))[:len(positive_pairs)]] 55 | 56 | return positive_pairs, negative_pairs 57 | 58 | 59 | class HardNegativePairSelector(PairSelector): 60 | """ 61 | Creates all possible positive pairs. For negative pairs, pairs with smallest distance are taken into consideration, 62 | matching the number of positive pairs. 63 | """ 64 | 65 | def __init__(self, cpu=True): 66 | super(HardNegativePairSelector, self).__init__() 67 | self.cpu = cpu 68 | 69 | def get_pairs(self, embeddings, labels): 70 | if self.cpu: 71 | embeddings = embeddings.cpu() 72 | distance_matrix = pdist(embeddings) 73 | 74 | labels = labels.cpu().data.numpy() 75 | all_pairs = np.array(list(combinations(range(len(labels)), 2))) 76 | all_pairs = torch.LongTensor(all_pairs) 77 | positive_pairs = all_pairs[( 78 | labels[all_pairs[:, 0]] == labels[all_pairs[:, 1]]).nonzero()] 79 | negative_pairs = all_pairs[( 80 | labels[all_pairs[:, 0]] != labels[all_pairs[:, 1]]).nonzero()] 81 | 82 | negative_distances = distance_matrix[negative_pairs[:, 83 | 0], negative_pairs[:, 1]] 84 | negative_distances = negative_distances.cpu().data.numpy() 85 | top_negatives = np.argpartition(negative_distances, len(positive_pairs))[ 86 | :len(positive_pairs)] 87 | top_negative_pairs = negative_pairs[torch.LongTensor(top_negatives)] 88 | 89 | return positive_pairs, top_negative_pairs 90 | 91 | 92 | class TripletSelector: 93 | """ 94 | Implementation should return indices of anchors, positive and negative samples 95 | return np array of shape [N_triplets x 3] 96 | """ 97 | 98 | def __init__(self): 99 | pass 100 | 101 | def get_triplets(self, embeddings, labels): 102 | raise NotImplementedError 103 | 104 | 105 | class AllTripletSelector(TripletSelector): 106 | """ 107 | Returns all possible triplets 108 | May be impractical in most cases 109 | """ 110 | 111 | def __init__(self): 112 | super(AllTripletSelector, self).__init__() 113 | 114 | def get_triplets(self, embeddings, labels): 115 | labels = labels.cpu().data.numpy() 116 | triplets = [] 117 | for label in set(labels): 118 | label_mask = (labels == label) 119 | label_indices = np.where(label_mask)[0] 120 | if len(label_indices) < 2: 121 | continue 122 | negative_indices = np.where(np.logical_not(label_mask))[0] 123 | # All anchor-positive pairs 124 | anchor_positives = list(combinations(label_indices, 2)) 125 | 126 | # Add all negatives for all positive pairs 127 | temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives 128 | for neg_ind in negative_indices] 129 | triplets += temp_triplets 130 | return torch.LongTensor(np.array(triplets)) 131 | 132 | 133 | def hardest_negative(loss_values): 134 | hard_negative = np.argmax(loss_values) 135 | return hard_negative if loss_values[hard_negative] > 0 else None 136 | 137 | 138 | def random_hard_negative(loss_values): 139 | hard_negatives = np.where(loss_values > 0)[0] 140 | return np.random.choice(hard_negatives) if len(hard_negatives) > 0 else None 141 | 142 | 143 | def semihard_negative(loss_values, margin): 144 | semihard_negatives = np.where(np.logical_and( 145 | loss_values < margin, loss_values > 0))[0] 146 | return np.random.choice(semihard_negatives) if len(semihard_negatives) > 0 else None 147 | 148 | 149 | class FunctionNegativeTripletSelector(TripletSelector): 150 | """ 151 | For each positive pair, takes the hardest negative sample (with the greatest triplet loss value) to create a triplet 152 | Margin should match the margin used in triplet loss. 153 | negative_selection_fn should take array of loss_values for a given anchor-positive pair and all negative samples 154 | and return a negative index for that pair 155 | """ 156 | 157 | def __init__(self, margin, negative_selection_fn, cpu=True): 158 | super(FunctionNegativeTripletSelector, self).__init__() 159 | self.cpu = cpu 160 | self.margin = margin 161 | self.negative_selection_fn = negative_selection_fn 162 | 163 | def get_triplets(self, embeddings, labels): 164 | if self.cpu: 165 | embeddings = embeddings.cpu() 166 | distance_matrix = pdist(embeddings) 167 | distance_matrix = distance_matrix.cpu() 168 | 169 | labels = labels.cpu().data.numpy() 170 | triplets = [] 171 | if len(list(set(labels))) == 1: 172 | return torch.LongTensor(triplets) 173 | 174 | for label in set(labels): 175 | label_mask = (labels == label) 176 | label_indices = np.where(label_mask)[0] 177 | if len(label_indices) < 2: 178 | continue 179 | negative_indices = np.where(np.logical_not(label_mask))[0] 180 | # All anchor-positive pairs 181 | anchor_positives = list(combinations(label_indices, 2)) 182 | anchor_positives = np.array(anchor_positives) 183 | 184 | ap_distances = distance_matrix[anchor_positives[:,0], anchor_positives[:, 1]] 185 | for anchor_positive, ap_distance in zip(anchor_positives, ap_distances): 186 | loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array( 187 | [anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin 188 | loss_values = loss_values.data.cpu().numpy() 189 | hard_negative = self.negative_selection_fn(loss_values) 190 | if hard_negative is not None: 191 | hard_negative = negative_indices[hard_negative] 192 | triplets.append( 193 | [anchor_positive[0], anchor_positive[1], hard_negative]) 194 | 195 | if len(triplets) == 0: 196 | if len(anchor_positive) < 2 and len(negative_indices) < 1: 197 | triplets.append( 198 | [anchor_positive, anchor_positive, negative_indices]) 199 | elif len(anchor_positive) >= 2 and len(negative_indices) < 1: 200 | triplets.append( 201 | [anchor_positive[0], anchor_positive[1], negative_indices]) 202 | else: 203 | triplets.append( 204 | [anchor_positive[0], anchor_positive[1], negative_indices[0]]) 205 | 206 | triplets = np.array(triplets) 207 | # print(triplets) 208 | 209 | return torch.LongTensor(triplets) 210 | 211 | 212 | def HardestNegativeTripletSelector(margin, cpu=False): 213 | return FunctionNegativeTripletSelector(margin=margin, negative_selection_fn=hardest_negative, cpu=cpu) 214 | 215 | 216 | def RandomNegativeTripletSelector(margin, cpu=False): 217 | return FunctionNegativeTripletSelector(margin=margin, negative_selection_fn=random_hard_negative, cpu=cpu) 218 | 219 | 220 | def SemihardNegativeTripletSelector(margin, cpu=False): 221 | return FunctionNegativeTripletSelector(margin=margin, negative_selection_fn=lambda x: semihard_negative(x, margin), cpu=cpu) 222 | -------------------------------------------------------------------------------- /modals/policy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from modals.custom_ops import cosine 8 | from sklearn.metrics.pairwise import euclidean_distances 9 | 10 | 11 | def parse_policy_log(policy_file): 12 | ### When Magic happends ### 13 | pbt_policy_file = open(policy_file, 'r').readlines() 14 | 15 | perturb_events = [] 16 | for perturb_event in pbt_policy_file: 17 | event = json.loads(perturb_event) 18 | perturb_events.append(event) 19 | 20 | initial_policy = perturb_events[0][4]['hp_policy'] 21 | policy_num_epochs = perturb_events[0][4]['num_epochs'] 22 | 23 | ''' 24 | epoch 0 15 20 35 100 25 | policy p0 p1 p2 p3 26 | => 27 | epoch 15-0 20-15 35-20 100-35 28 | policy p0 p1 p2 p3 29 | ''' 30 | 31 | # the epoch policy changed 32 | perturb_epoch = [0] 33 | 34 | # the policy changed 35 | perturb_policy = [initial_policy] # initial policy i.e. [0] policy 36 | 37 | for event in perturb_events: 38 | perturb_epoch.append(event[3]) 39 | perturb_policy.append(event[5]['hp_policy']) 40 | 41 | perturb_epoch.append(policy_num_epochs) 42 | 43 | # how many times running_policy[i] is ran 44 | n_repeats = [0] * (len(perturb_epoch)-1) 45 | 46 | for i in range(len(n_repeats)): 47 | n_repeats[i] = perturb_epoch[i+1] - perturb_epoch[i] 48 | 49 | assert len(perturb_policy) == len(n_repeats) 50 | assert sum(n_repeats) == policy_num_epochs 51 | 52 | return (n_repeats, perturb_policy) 53 | 54 | 55 | class RawPolicy(object): 56 | """Each instance of this class represents a specific transform.""" 57 | 58 | def __init__(self, mode, num_epochs, hp_policy=None, policy_path=None): 59 | """ 60 | search: pba, must be single 61 | train/test: using a pba schedule to train/test a child model, 62 | can be a schedule or a single. However, the piority is given 63 | to policy_path. 64 | """ 65 | assert mode in ['search', 'train', 'visualize'] 66 | if mode == 'search' or mode == 'visualize': 67 | assert hp_policy is not None 68 | self.type = 'single' 69 | self.emb = hp_policy 70 | else: 71 | if policy_path is not None: 72 | # Parse policy form pbt_policy_{i}.txt 73 | n_repeats, raw_policies = parse_policy_log(policy_path) 74 | if num_epochs != sum(n_repeats): 75 | print('Interpolating policy') 76 | ratio = num_epochs / sum(n_repeats) 77 | n_repeats = [math.floor(n * ratio) for n in n_repeats] 78 | n_pad = num_epochs - sum(n_repeats) 79 | n_repeats[-1] += n_pad 80 | assert num_epochs == sum(n_repeats) 81 | 82 | # Unroll a policy 83 | self.schedule = np.repeat( 84 | raw_policies, n_repeats, axis=0) 85 | self.type = 'schedule' 86 | 87 | elif hp_policy is not None: 88 | if isinstance(hp_policy[0], list): 89 | # provided schdule must match epochs 90 | assert len(hp_policy) == num_epochs 91 | self.type = 'schedule' 92 | self.schedule = hp_policy 93 | else: 94 | self.type = 'single' 95 | self.emb = hp_policy 96 | else: 97 | raise ValueError('You must provide hp_policy or policy path') 98 | 99 | 100 | class PolicyManager(object): 101 | """Manage policy.""" 102 | 103 | def __init__(self, aug_trans, raw_policy, num_classes, device): 104 | self.num_xform = aug_trans.NUM_HP_TRANSFORM 105 | self.xform_names = aug_trans.HP_TRANSFORM_NAMES 106 | self.apply_policy_from_pool = aug_trans.apply_policy_from_pool 107 | self.policy = None 108 | self.update_policy(raw_policy) 109 | self.num_classes = num_classes 110 | self.device = device 111 | self.criterion = nn.CrossEntropyLoss(reduction='none') 112 | 113 | def apply_pba(self, images): 114 | return 115 | 116 | def update_policy(self, raw_policy): 117 | self.raw_policy = raw_policy 118 | self.policy = self.parse_policy(raw_policy) 119 | if raw_policy.type == 'single': 120 | print(f'Updated hp policy: {self.policy}') 121 | 122 | def parse_policy(self, raw_policy): 123 | if raw_policy.type == 'single': 124 | return self._parse_one_policy(raw_policy.emb) 125 | elif raw_policy.type == 'schedule': 126 | policy = [] 127 | for one_emb in raw_policy.schedule: 128 | policy.append(self._parse_one_policy(one_emb)) 129 | return policy 130 | 131 | def _parse_one_policy(self, emb): 132 | assert len( 133 | emb) == 2*2*self.num_xform, f'raw policy was: {len(emb)}, supposed to be: {2*2*self.num_xform}' 134 | one_policy = [] 135 | for i, xform in enumerate(list(self.xform_names)*2): 136 | one_policy.append((xform, emb[2 * i] / 10., emb[2 * i + 1])) 137 | return one_policy 138 | 139 | def apply_policy(self, imgs, labels, epoch, batch_idx, verbose=1): 140 | inner_verbose = True if verbose > 1 else False 141 | cur_epoch = max(epoch-1, 0) 142 | running_policy = self.policy if self.raw_policy.type == 'single' else self.policy[ 143 | cur_epoch] 144 | x_imgs = imgs.new_empty(imgs.shape) 145 | 146 | for i, label in enumerate(labels): 147 | class_info = {'pool': self.feat_pool[self.idx_by_class[label]], 148 | 'weights': self.img_weights_by_class[label], 149 | 'mean': self.mean_by_class[label], 150 | 'sd': self.sd_by_class[label]} 151 | x_img, _ = self.apply_policy_from_pool(running_policy, imgs[i], 152 | class_info, inner_verbose) 153 | x_imgs[i] = x_img 154 | 155 | if batch_idx == 0 and verbose > 0: 156 | print(f'Applying policy {running_policy}') 157 | 158 | return x_imgs 159 | 160 | def print_data_pool_report(self, cc, cmu): 161 | print("### Data Pool Report ###") 162 | print(f'Cluster Closeness: {np.around(cc, 4)}') 163 | print("Cluster mean distance:") 164 | print(euclidean_distances(cmu)) 165 | print() 166 | 167 | def reset_text_data_pool(self, encoder, dataloader, temperature, weight_metric, dataset, verbose=False, return_report=False): 168 | # compute the sd and mean of each lass 169 | # save the features into a pool 170 | feat_pool = [] 171 | label_pool = [] 172 | loss_pool = [] 173 | feat_dim = 0 174 | encoder.eval() 175 | 176 | with torch.no_grad(): 177 | for batch in dataloader: 178 | inputs, seq_lens, labels = batch.text[0].to( 179 | self.device), batch.text[1].to(self.device), batch.label.to(self.device) 180 | 181 | if dataset == 'sst2' or dataset == 'trec': 182 | labels -= 1 # because I binarized the data 183 | 184 | features = encoder.extract_features(inputs, seq_lens) 185 | outputs = encoder.classify(features) 186 | loss = self.criterion(outputs, labels.to(self.device)) 187 | for i in range(len(labels)): 188 | feat_pool.append(features[i].cpu()) 189 | label_pool.append(labels[i].cpu()) 190 | loss_pool.append(loss[i].cpu()) 191 | 192 | feat_dim = feat_pool[0].shape[0] 193 | self.feat_pool = torch.stack( 194 | feat_pool).reshape(-1, feat_dim).double() # list of all images 195 | label_pool = torch.stack(label_pool).reshape(-1) 196 | loss_pool = torch.stack(loss_pool).reshape(-1) 197 | 198 | self.idx_by_class = [] # [0]: [1,4,2,...] <-index that belongs to class 0 199 | self.mean_by_class = [] # [0]: mean of class 0 200 | self.sd_by_class = [] # [0]: sd of class 0 201 | self.img_weights_by_class = [] # [0]: weights of the images in class 0 202 | 203 | cluster_closeness = [] 204 | class_means = [] 205 | 206 | for i in range(self.num_classes): 207 | img_idxs = torch.where(label_pool == i)[0] 208 | self.idx_by_class.append(img_idxs) 209 | class_imgs = self.feat_pool[img_idxs] 210 | class_loss = loss_pool[img_idxs] 211 | 212 | class_mean = class_imgs.mean(0) 213 | self.mean_by_class.append(class_mean) 214 | img_distances = np.linalg.norm( 215 | class_imgs-class_mean, ord=2, axis=1) # distance from mean 216 | 217 | class_means.append(class_mean.numpy()) 218 | # in-class pari-wise distance 219 | icpd = euclidean_distances(class_imgs) 220 | aicpd = np.sum(icpd)/(len(icpd)*(len(icpd)-1)) 221 | cluster_closeness.append(aicpd) 222 | 223 | if weight_metric == 'l2': 224 | # weight by l2 distance 225 | img_weights = img_distances 226 | img_weights -= img_weights.max(0) # numerical stability 227 | elif weight_metric == 'cosine': 228 | # weight by consine distance 229 | img_weights = cosine(class_imgs, class_mean) 230 | elif weight_metric == 'loss': 231 | # weight by loss 232 | img_weights = class_loss 233 | img_weights -= torch.max(img_weights) 234 | elif weight_metric == 'same': 235 | # uniform 236 | img_weights = np.ones(len(class_imgs))/len(class_imgs) 237 | 238 | img_weights = np.exp(img_weights/temperature) 239 | img_weights = img_weights/img_weights.sum() # normalize 240 | 241 | self.img_weights_by_class.append(np.nan_to_num(img_weights)) 242 | self.sd_by_class.append(class_imgs.std(0)) 243 | 244 | if verbose: 245 | self.print_data_pool_report(cluster_closeness, class_means) 246 | 247 | if return_report: 248 | cpd = euclidean_distances(class_means) 249 | acpd = np.sum(cpd)/(len(cpd)*(len(cpd)-1)) 250 | return np.mean(cluster_closeness), acpd 251 | 252 | def print_policy(self): 253 | print(self.policy) 254 | -------------------------------------------------------------------------------- /modals/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from networks.blstm import BiLSTM 9 | from torch.autograd import Variable 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.optim import lr_scheduler 12 | from utility import mixup_criterion, mixup_data 13 | 14 | from modals.data_util import get_text_dataloaders 15 | from modals.policy import PolicyManager, RawPolicy 16 | 17 | if torch.cuda.is_available(): 18 | import modals.augmentation_transforms as aug_trans 19 | else: 20 | import modals.augmentation_transforms_cpu as aug_trans 21 | 22 | from modals.custom_ops import (HardestNegativeTripletSelector, 23 | RandomNegativeTripletSelector, 24 | SemihardNegativeTripletSelector) 25 | from modals.losses import (OnlineTripletLoss, adverserial_loss, 26 | discriminator_loss) 27 | 28 | 29 | def count_parameters(model): 30 | temp = sum(p.numel() for p in model.parameters() if p.requires_grad) 31 | print(f' |Trainable parameters: {temp}') 32 | 33 | 34 | def build_model(model_name, vocab, n_class, z_size=2): 35 | net = None 36 | if model_name == 'blstm': 37 | config = {'n_vocab': len(vocab), 38 | 'n_embed': 300, 39 | 'emb': vocab.vectors, 40 | 'n_hidden': 256, 41 | 'n_output': n_class, 42 | 'n_layers': 2, 43 | 'pad_idx': vocab.stoi[''], 44 | 'b_dir': True, 45 | 'rnn_drop': 0.2, 46 | 'fc_drop': 0.5} 47 | net = BiLSTM(config) 48 | z_size = 256 49 | else: 50 | ValueError(f'Invalid model name={model_name}') 51 | 52 | print('\n### Model ###') 53 | print(f'=> {model_name}') 54 | count_parameters(net) 55 | 56 | return net, z_size, model_name 57 | 58 | 59 | class Discriminator(nn.Module): 60 | def __init__(self, z_size): 61 | super(Discriminator, self).__init__() 62 | self.z_size = z_size 63 | self.fc1 = nn.Linear(z_size, 512) 64 | self.fc2 = nn.Linear(512, 128) 65 | self.fc3 = nn.Linear(128, 1) 66 | 67 | def forward(self, x): 68 | x = F.relu(self.fc1(x)) 69 | x = F.relu(self.fc2(x)) 70 | x = torch.sigmoid(self.fc3(x)) 71 | return x 72 | 73 | 74 | class TextModelTrainer(object): 75 | 76 | def __init__(self, hparams, name=''): 77 | self.hparams = hparams 78 | print(hparams) 79 | 80 | self.name = name 81 | 82 | random.seed(0) 83 | self.train_loader, self.valid_loader, self.test_loader, self.classes, self.vocab = get_text_dataloaders( 84 | hparams['dataset_name'], valid_size=hparams['valid_size'], batch_size=hparams['batch_size'], 85 | subtrain_ratio=hparams['subtrain_ratio'], dataroot=hparams['dataset_dir']) 86 | random.seed() 87 | 88 | self.device = torch.device( 89 | hparams['gpu_device'] if torch.cuda.is_available() else 'cpu') 90 | print() 91 | print('### Device ###') 92 | print(self.device) 93 | self.net, self.z_size, self.file_name = build_model( 94 | hparams['model_name'], self.vocab, len(self.classes)) 95 | self.net = self.net.to(self.device) 96 | 97 | self.criterion = nn.CrossEntropyLoss() 98 | if hparams['mode'] in ['train', 'search']: 99 | self.optimizer = optim.Adam(self.net.parameters(), 0.001) 100 | self.loss_dict = {'train': [], 'valid': []} 101 | 102 | if hparams['use_modals']: 103 | print("\n=> ### Policy ###") 104 | # print(f' |hp_policy: {hparams['hp_policy']}') 105 | # print(f' |policy_path: {hparams['policy_path']}') 106 | raw_policy = RawPolicy(mode=hparams['mode'], num_epochs=hparams['num_epochs'], 107 | hp_policy=hparams['hp_policy'], policy_path=hparams['policy_path']) 108 | transformations = aug_trans 109 | self.pm = PolicyManager( 110 | transformations, raw_policy, len(self.classes), self.device) 111 | 112 | print("\n### Loss ###") 113 | print('Classification Loss') 114 | 115 | if hparams['mixup']: 116 | print('Mixup') 117 | 118 | if hparams['enforce_prior']: 119 | print('Adversarial Loss') 120 | self.EPS = 1e-15 121 | self.D = Discriminator(self.z_size) 122 | self.D = self.D.to(self.device) 123 | self.D_optimizer = optim.SGD(self.D.parameters(), lr=0.01, 124 | momentum=hparams['momentum'], weight_decay=hparams['wd']) 125 | # self.G_optimizer = optim.Adam(self.net.parameters(), lr=0.001) 126 | 127 | if hparams['metric_learning']: 128 | margin = hparams['metric_margin'] 129 | metric_loss = hparams["metric_loss"] 130 | metric_weight = hparams["metric_weight"] 131 | print( 132 | f"Metric Loss (margin: {margin} loss: {metric_loss} weight: {metric_weight})") 133 | 134 | self.M_optimizer = optim.SGD( 135 | self.net.parameters(), momentum=0.9, lr=1e-3, weight_decay=1e-8) 136 | self.metric_weight = hparams['metric_weight'] 137 | 138 | if metric_loss == 'random': 139 | self.metric_loss = OnlineTripletLoss( 140 | margin, RandomNegativeTripletSelector(margin)) 141 | elif metric_loss == 'hardest': 142 | self.metric_loss = OnlineTripletLoss( 143 | margin, HardestNegativeTripletSelector(margin)) 144 | elif metric_loss == 'semihard': 145 | self.metric_loss = OnlineTripletLoss( 146 | margin, SemihardNegativeTripletSelector(margin)) 147 | 148 | def reset_model(self, z_size=256): 149 | # tunable z_size only use for visualization 150 | # if blstm is used, it is automatically 256 151 | self.net, self.z_size, self.file_name = build_model( 152 | self.hparams['model_name'], self.vocab, len(self.classes), z_size) 153 | self.net = self.net.to(self.device) 154 | self.optimizer = optim.Adam(self.net.parameters(), 0.001) 155 | self.loss_dict = {'train': [], 'valid': []} 156 | 157 | def reset_discriminator(self, z_size=256): 158 | self.D = Discriminator(z_size) 159 | self.D = self.D.to(self.device) 160 | self.D_optimizer = optim.SGD(self.D.parameters(), lr=0.01, 161 | momentum=self.hparams['momentum'], weight_decay=self.hparams['wd']) 162 | 163 | def update_policy(self, policy): 164 | raw_policy = RawPolicy(mode='train', num_epochs=1, 165 | hp_policy=policy, policy_path=None) 166 | self.pm.update_policy(raw_policy) 167 | 168 | def _train(self, cur_epoch): 169 | self.net.train() 170 | self.net.training = True 171 | self.scheduler = lr_scheduler.CosineAnnealingLR( 172 | self.optimizer, len(self.train_loader)) # cosine learning rate 173 | train_losses = 0.0 174 | clf_losses = 0.0 175 | metric_losses = 0.0 176 | d_losses = 0.0 177 | g_losses = 0.0 178 | correct = 0 179 | total = 0 180 | n_batch = len(self.train_loader) 181 | 182 | print(f'\n=> Training Epoch #{cur_epoch}') 183 | for batch_idx, batch in enumerate(self.train_loader): 184 | 185 | inputs, seq_lens, labels = batch.text[0].to( 186 | self.device), batch.text[1].to(self.device), batch.label.to(self.device) 187 | 188 | # if self.hparams['dataset_name'] == 'sst2': 189 | labels -= 1 # because I binarized the data 190 | 191 | seed_features = self.net.extract_features(inputs, seq_lens) 192 | features = seed_features 193 | 194 | if self.hparams['manifold_mixup']: 195 | features, targets_a, targets_b, lam = mixup_data(features, labels, 196 | 0.2, use_cuda=True) 197 | features, targets_a, targets_b = map(Variable, (features, 198 | targets_a, targets_b)) 199 | # apply pba transformation 200 | if self.hparams['use_modals']: 201 | features = self.pm.apply_policy( 202 | features, labels, cur_epoch, batch_idx, verbose=1).to(self.device) 203 | 204 | outputs = self.net.classify(features) # Forward Propagation 205 | 206 | if self.hparams['mixup']: 207 | inputs, targets_a, targets_b, lam = mixup_data(outputs, labels, 208 | self.hparams['alpha'], use_cuda=True) 209 | inputs, targets_a, targets_b = map(Variable, (outputs, 210 | targets_a, targets_b)) 211 | # freeze D 212 | if self.hparams['enforce_prior']: 213 | for p in self.D.parameters(): 214 | p.requires_grad = False 215 | 216 | # classification loss 217 | if self.hparams['mixup'] or self.hparams['manifold_mixup']: 218 | c_loss = mixup_criterion( 219 | self.criterion, outputs, targets_a, targets_b, lam) 220 | else: 221 | c_loss = self.criterion(outputs, labels) # Loss 222 | clf_losses += c_loss.item() 223 | 224 | # total loss 225 | loss = c_loss 226 | if self.hparams['metric_learning']: 227 | m_loss = self.metric_loss(seed_features, labels)[0] 228 | metric_losses += m_loss.item() 229 | loss = self.metric_weight * m_loss + \ 230 | (1-self.metric_weight) * c_loss 231 | 232 | train_losses += loss.item() 233 | 234 | if self.hparams['enforce_prior']: 235 | # Regularizer update 236 | # freeze D 237 | for p in self.D.parameters(): 238 | p.requires_grad = False 239 | self.net.train() 240 | d_fake = self.D(features) 241 | g_loss = self.hparams['prior_weight'] * \ 242 | adverserial_loss(d_fake, self.EPS) 243 | g_losses += g_loss.item() 244 | loss += g_loss 245 | 246 | self.optimizer.zero_grad() 247 | loss.backward() # Backward Propagation 248 | clip_grad_norm_(self.net.parameters(), 5.0) 249 | self.optimizer.step() # Optimizer update 250 | 251 | if self.hparams['enforce_prior']: 252 | # Discriminator update 253 | for p in self.D.parameters(): 254 | p.requires_grad = True 255 | 256 | features = self.net.extract_features(inputs, seq_lens) 257 | d_real = self.D(torch.randn(features.size()).to(self.device)) 258 | d_fake = self.D(F.softmax(features, dim=0)) 259 | d_loss = discriminator_loss(d_real, d_fake, self.EPS) 260 | self.D_optimizer.zero_grad() 261 | d_loss.backward() 262 | self.D_optimizer.step() 263 | d_losses += d_loss.item() 264 | 265 | # Accuracy 266 | _, predicted = torch.max(outputs.data, 1) 267 | total += labels.size(0) 268 | if self.hparams['mixup']: 269 | correct += (lam * predicted.eq(targets_a.data).cpu().sum().float() 270 | + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float()) 271 | else: 272 | correct += (predicted == labels).sum().item() 273 | 274 | # step 275 | step = (cur_epoch-1)*(len(self.train_loader)) + batch_idx 276 | total_steps = self.hparams['num_epochs']*len(self.train_loader) 277 | 278 | # logs 279 | display = f'| Epoch [{cur_epoch}/{self.hparams["num_epochs"]}]\tIter[{step}/{total_steps}]\tLoss: {train_losses/n_batch:.4f}\tAcc@1: {correct/total:.4f}\tclf_loss: {clf_losses/n_batch:.4f}' 280 | if self.hparams['enforce_prior']: 281 | display += f'\td_loss: {d_losses/n_batch:.4f}\tg_loss: {g_losses/n_batch:.4f}' 282 | if self.hparams['metric_learning']: 283 | display += f'\tmetric_loss: {metric_losses/n_batch:.4f}' 284 | print(display) 285 | 286 | return correct/total, train_losses/total 287 | 288 | def _test(self, cur_epoch, mode): 289 | self.net.eval() 290 | self.net.training = False 291 | correct = 0 292 | total = 0 293 | test_loss = 0.0 294 | data_loader = self.valid_loader if mode == 'valid' else self.test_loader 295 | 296 | with torch.no_grad(): 297 | for batch_idx, batch in enumerate(data_loader): 298 | inputs, seq_lens, labels = batch.text[0].to( 299 | self.device), batch.text[1].to(self.device), batch.label.to(self.device) 300 | 301 | # if self.hparams['dataset_name'] == 'sst2': 302 | labels -= 1 # because I binarized the data 303 | 304 | outputs = self.net(inputs, seq_lens) 305 | loss = self.criterion(outputs, labels) # Loss 306 | test_loss += loss.item() 307 | 308 | _, predicted = torch.max(outputs.data, 1) 309 | total += labels.size(0) 310 | correct += (predicted == labels).sum().item() 311 | 312 | torch.cuda.empty_cache() 313 | 314 | print( 315 | f'| ({mode}) Epoch #{cur_epoch}\t Loss: {test_loss/total:.4f}\t Acc@1: {correct/total:.4f}') 316 | 317 | return correct/total, test_loss/total 318 | 319 | def run_model(self, epoch): 320 | if self.hparams['use_modals']: 321 | self.pm.reset_text_data_pool( 322 | self.net, self.train_loader, self.hparams['temperature'], self.hparams['distance_metric'], self.hparams['dataset_name']) 323 | 324 | train_acc, tl = self._train(epoch) 325 | self.loss_dict['train'].append(tl) 326 | 327 | if self.hparams['valid_size'] > 0: 328 | val_acc, vl = self._test(epoch, mode='valid') 329 | self.loss_dict['valid'].append(vl) 330 | else: 331 | val_acc = 0.0 332 | 333 | return train_acc, val_acc 334 | 335 | # for benchmark 336 | def save_checkpoint(self, ckpt_dir, epoch): 337 | path = os.path.join( 338 | ckpt_dir, self.hparams['dataset_name'], f'{self.name}_{self.file_name}') 339 | if not os.path.exists(ckpt_dir): 340 | os.makedirs(ckpt_dir) 341 | 342 | torch.save({'state': self.net.state_dict(), 343 | 'epoch': epoch, 344 | 'loss': self.loss_dict, 345 | 'optimizer': self.optimizer.state_dict(), 346 | 'scheduler': self.scheduler.state_dict()}, path) 347 | 348 | print(f'=> saved the model {self.file_name} to {path}') 349 | return path 350 | 351 | # for ray 352 | def save_model(self, ckpt_dir, epoch): 353 | # save the checkpoint. 354 | print(self.file_name) 355 | print(ckpt_dir) 356 | path = os.path.join(ckpt_dir, self.file_name) 357 | if not os.path.exists(ckpt_dir): 358 | os.makedirs(ckpt_dir) 359 | 360 | torch.save({'state': self.net.state_dict(), 361 | 'epoch': epoch, 362 | 'loss': self.loss_dict, 363 | 'optimizer': self.optimizer.state_dict(), 364 | 'scheduler': self.scheduler.state_dict()}, path) 365 | 366 | print(f'=> saved the model {self.file_name} to {path}') 367 | return path 368 | 369 | def load_model(self, ckpt): 370 | # load the checkpoint. 371 | # path = os.path.join(ckpt_dir, self.model_name) 372 | # map_location='cuda:0') 373 | checkpoint = torch.load(ckpt, map_location=torch.device('cpu')) 374 | self.net.load_state_dict(checkpoint['state']) 375 | self.loss_dict = checkpoint['loss'] 376 | if self.hparams['mode'] != 'test': 377 | self.optimizer.load_state_dict(checkpoint['optimizer']) 378 | self.scheduler.load_state_dict(checkpoint['scheduler']) 379 | print(f'=> loaded checkpoint of {self.file_name} from {ckpt}') 380 | return checkpoint['epoch'], checkpoint['loss'] 381 | 382 | def reset_config(self, new_hparams): 383 | self.hparams = new_hparams 384 | new_policy = RawPolicy(mode=self.hparams['mode'], num_epochs=self.hparams['num_epochs'], 385 | hp_policy=self.hparams['hp_policy']) 386 | self.pm.update_policy(new_policy) 387 | return 388 | 389 | --------------------------------------------------------------------------------