├── README.md ├── requirements.txt └── src ├── bart_aug ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── denoised_dataset.cpython-37.pyc │ └── masking_task.cpython-37.pyc ├── denoised_dataset.py └── masking_task.py ├── bert_aug ├── __init__.py ├── __pycache__ │ ├── bert_model.cpython-37.pyc │ ├── bert_model.cpython-38.pyc │ ├── data_processors.cpython-37.pyc │ └── data_processors.cpython-38.pyc ├── backtranslation.py ├── bert_classifier.py ├── bert_model.py ├── cbert.py ├── cgpt2.py ├── cmodbert.py ├── cmodbertp.py ├── data_processors.py └── eda.py ├── scripts ├── bart_snips_lower.sh ├── bart_stsa_lower.sh ├── bart_trec_lower.sh ├── bert_snips_lower.sh ├── bert_stsa_lower.sh └── bert_trec_lower.sh └── utils ├── __init__.py ├── bpe_encoder.py ├── convert_num_to_text_labels.py ├── create_fsl_dataset.py ├── download_and_prepare_datasets.sh └── gpt2_bpe ├── dict.txt ├── encoder.json └── vocab.bpe /README.md: -------------------------------------------------------------------------------- 1 | Source code for ACL 2022 paper: Text Smoothing: Enhance Various Data Augmentation Methods on Text Classification Tasks 2 | 3 | Our work mainly based on [Data Augmentation using Pre-trained Transformer Models](https://github.com/amazon-research/transformers-data-augmentation) 4 | 5 | Code contains implementation of the following data augmentation methods 6 | - TextSmoothing 7 | - EDA + TextSmoothing 8 | - Backtranslation + TextSmoothing 9 | - CBERT + TextSmoothing 10 | - BERT Prepend + TextSmoothing 11 | - GPT-2 Prepend + TextSmoothing 12 | - BART Prepend + TextSmoothing 13 | 14 | ## DataSets 15 | 16 | In paper, we use three datasets from following resources 17 | - STSA-2 : [https://github.com/1024er/cbert_aug/tree/crayon/datasets/stsa.binary](https://github.com/1024er/cbert_aug/tree/crayon/datasets/stsa.binary) 18 | - TREC : [https://github.com/1024er/cbert_aug/tree/crayon/datasets/TREC](https://github.com/1024er/cbert_aug/tree/crayon/datasets/TREC) 19 | - SNIPS : [https://github.com/MiuLab/SlotGated-SLU/tree/master/data/snips](https://github.com/MiuLab/SlotGated-SLU/tree/master/data/snips) 20 | 21 | ### Low-data regime experiment setup 22 | Run `src/utils/download_and_prepare_datasets.sh` file to prepare all datsets. 23 | `download_and_prepare_datasets.sh` performs following steps 24 | 1. Download data from github 25 | 2. Replace numeric labels with text for STSA-2 and TREC dataset 26 | 3. For a given dataset, creates 15 random splits of train and dev data. 27 | 28 | ## Dependencies 29 | 30 | To run this code, you need following dependencies 31 | - Pytorch 1.5 32 | - fairseq 0.9 33 | - transformers 2.9 34 | 35 | ## How to run 36 | To run data augmentation experiment for a given dataset, run bash script in `scripts` folder. 37 | For example, to run data augmentation on `snips` dataset, 38 | - run `scripts/bart_snips_lower.sh` for BART experiment 39 | - run `scripts/bert_snips_lower.sh` for rest of the data augmentation methods 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #pytorch==1.6 2 | fairseq==0.9 3 | transformers==2.9 -------------------------------------------------------------------------------- /src/bart_aug/__init__.py: -------------------------------------------------------------------------------- 1 | from . import denoised_dataset 2 | from . import masking_task 3 | -------------------------------------------------------------------------------- /src/bart_aug/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bart_aug/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/bart_aug/__pycache__/denoised_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bart_aug/__pycache__/denoised_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /src/bart_aug/__pycache__/masking_task.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bart_aug/__pycache__/masking_task.cpython-37.pyc -------------------------------------------------------------------------------- /src/bart_aug/denoised_dataset.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | # Original Copyright Facebook, Inc. and its affiliates. Licensed under the MIT License as part of 4 | # fairseq package. 5 | 6 | import numpy as np 7 | import torch 8 | import math 9 | 10 | from fairseq.data import data_utils, FairseqDataset 11 | 12 | 13 | def collate( 14 | samples, 15 | pad_idx, 16 | eos_idx, 17 | vocab, 18 | left_pad_source=False, 19 | left_pad_target=False, 20 | input_feeding=True, 21 | ): 22 | assert input_feeding 23 | if len(samples) == 0: 24 | return {} 25 | 26 | def merge(key, left_pad, move_eos_to_beginning=False): 27 | return data_utils.collate_tokens( 28 | [s[key] for s in samples], 29 | pad_idx, eos_idx, left_pad, move_eos_to_beginning, 30 | ) 31 | 32 | id = torch.LongTensor([s['id'] for s in samples]) 33 | src_tokens = merge('source', left_pad=left_pad_source) 34 | # sort by descending source length 35 | src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) 36 | src_lengths, sort_order = src_lengths.sort(descending=True) 37 | id = id.index_select(0, sort_order) 38 | src_tokens = src_tokens.index_select(0, sort_order) 39 | 40 | prev_output_tokens = None 41 | target = None 42 | if samples[0].get('target', None) is not None: 43 | target = merge('target', left_pad=left_pad_target) 44 | target = target.index_select(0, sort_order) 45 | ntokens = sum(len(s['target']) for s in samples) 46 | 47 | if input_feeding: 48 | # we create a shifted version of targets for feeding the 49 | # previous output token(s) into the next decoder step 50 | prev_output_tokens = merge( 51 | 'target', 52 | left_pad=left_pad_target, 53 | move_eos_to_beginning=True, 54 | ) 55 | prev_output_tokens = prev_output_tokens.index_select(0, sort_order) 56 | else: 57 | ntokens = sum(len(s['source']) for s in samples) 58 | 59 | batch = { 60 | 'id': id, 61 | 'ntokens': ntokens, 62 | 'net_input': { 63 | 'src_tokens': src_tokens, 64 | 'src_lengths': src_lengths, 65 | }, 66 | 'target': target, 67 | 'nsentences': samples[0]['source'].size(0), 68 | } 69 | if prev_output_tokens is not None: 70 | batch['net_input']['prev_output_tokens'] = prev_output_tokens 71 | return batch 72 | 73 | 74 | class BARTDenoisingDataset(FairseqDataset): 75 | """ 76 | A wrapper around TokenBlockDataset for BART dataset. 77 | 78 | Args: 79 | dataset (TokenBlockDataset): dataset to wrap 80 | sizes (List[int]): sentence lengths 81 | vocab (~fairseq.data.Dictionary): vocabulary 82 | mask_idx (int): dictionary index used for masked token 83 | mask_whole_words: only mask whole words. This should be a byte mask 84 | over vocab indices, indicating whether it is the beginning of a 85 | word. We will extend any mask to encompass the whole word. 86 | shuffle (bool, optional): shuffle the elements before batching. 87 | Default: ``True`` 88 | seed: Seed for random number generator for reproducibility. 89 | args: argparse arguments. 90 | """ 91 | 92 | def __init__( 93 | self, 94 | dataset, 95 | sizes, 96 | vocab, 97 | mask_idx, 98 | mask_whole_words, 99 | shuffle, 100 | seed, 101 | args 102 | ): 103 | self.dataset = dataset 104 | 105 | self.sizes = sizes 106 | 107 | self.vocab = vocab 108 | self.shuffle = shuffle 109 | self.seed = seed 110 | self.mask_idx = mask_idx 111 | self.mask_whole_word = mask_whole_words 112 | self.mask_ratio = args.mask 113 | self.random_ratio = args.mask_random 114 | self.insert_ratio = args.insert 115 | self.tokens_to_keep = args.tokens_to_keep 116 | 117 | if args.bpe != 'gpt2': 118 | self.full_stop_index = self.vocab.index(".") 119 | else: 120 | assert args.bpe == 'gpt2' 121 | self.full_stop_index = self.vocab.index('13') 122 | 123 | self.tab_sep_index = self.vocab.index('\t') 124 | self.replace_length = args.replace_length 125 | if not self.replace_length in [-1, 0, 1]: 126 | raise (f'invalid arg: replace_length={self.replace_length}') 127 | if not args.mask_length in ['subword', 'word', 'span', 'span-poisson']: 128 | raise (f'invalid arg: mask-length={args.mask_length}') 129 | if args.mask_length == 'subword' and not args.replace_length in [0, 1]: 130 | raise (f'if using subwords, use replace-length=1 or 0') 131 | 132 | self.is_span_mask = (args.mask_length == 'span') 133 | self.mask_span_distribution = None 134 | if args.mask_length == 'span-poisson': 135 | _lambda = args.poisson_lambda 136 | 137 | lambda_to_the_k = 1 138 | e_to_the_minus_lambda = math.exp(-_lambda) 139 | k_factorial = 1 140 | ps = [] 141 | for k in range(0, 128): 142 | ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) 143 | lambda_to_the_k *= _lambda 144 | k_factorial *= (k + 1) 145 | if ps[-1] < 0.0000001: 146 | break 147 | ps = torch.FloatTensor(ps) 148 | self.mask_span_distribution = torch.distributions.Categorical(ps) 149 | 150 | self.epoch = 0 151 | torch.manual_seed(self.seed) 152 | 153 | def set_epoch(self, epoch, **unused): 154 | self.epoch = epoch 155 | 156 | def __getitem__(self, index): 157 | with data_utils.numpy_seed(self.seed, self.epoch, index): 158 | tokens = self.dataset[index] 159 | assert tokens[-1] == self.vocab.eos() 160 | source, target = tokens, tokens.clone() 161 | 162 | if self.mask_ratio > 0: 163 | if self.is_span_mask: 164 | source = self.add_multiple_words_mask(source, self.mask_ratio) 165 | else: 166 | source = self.add_whole_word_mask(source, self.mask_ratio) 167 | 168 | assert (source >= 0).all() 169 | assert (source[1:-1] >= 1).all() 170 | assert (source <= len(self.vocab)).all() 171 | assert source[0] == self.vocab.bos() 172 | assert source[-1] == self.vocab.eos() 173 | return { 174 | 'id': index, 175 | 'source': source, 176 | 'target': target, 177 | } 178 | 179 | def __len__(self): 180 | return len(self.dataset) 181 | 182 | def word_starts(self, source): 183 | if self.mask_whole_word is not None: 184 | is_word_start = self.mask_whole_word.gather(0, source) 185 | else: 186 | is_word_start = torch.ones(source.size()) 187 | is_word_start[0] = 0 188 | is_word_start[-1] = 0 189 | 190 | is_word_start[1] = 0 # exclude the first word. Label word 191 | # for i in range(1, self.tokens_to_keep+1): 192 | # is_word_start[i] = 0 193 | 194 | return is_word_start 195 | 196 | def add_whole_word_mask(self, source, p): 197 | is_word_start = self.word_starts(source) 198 | num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) 199 | num_inserts = 0 200 | if num_to_mask == 0: 201 | return source 202 | 203 | if self.mask_span_distribution is not None: 204 | lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) 205 | 206 | # Make sure we have enough to mask 207 | cum_length = torch.cumsum(lengths, 0) 208 | while cum_length[-1] < num_to_mask: 209 | lengths = torch.cat([lengths, self.mask_span_distribution.sample(sample_shape=(num_to_mask,))], dim=0) 210 | cum_length = torch.cumsum(lengths, 0) 211 | 212 | # Trim to masking budget 213 | i = 0 214 | while cum_length[i] < num_to_mask: 215 | i += 1 216 | lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) 217 | num_to_mask = i + 1 218 | lengths = lengths[:num_to_mask] 219 | 220 | # Handle 0-length mask (inserts) separately 221 | lengths = lengths[lengths > 0] 222 | num_inserts = num_to_mask - lengths.size(0) 223 | num_to_mask -= num_inserts 224 | if num_to_mask == 0: 225 | return self.add_insertion_noise(source, num_inserts / source.size(0)) 226 | 227 | assert (lengths > 0).all() 228 | else: 229 | lengths = torch.ones((num_to_mask,)).long() 230 | 231 | assert is_word_start[-1] == 0 232 | word_starts = is_word_start.nonzero() 233 | indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1) 234 | mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio 235 | 236 | source_length = source.size(0) 237 | assert source_length - 1 not in indices 238 | to_keep = torch.ones(source_length, dtype=torch.bool) 239 | is_word_start[-1] = 255 # acts as a long length, so spans don't go over the end of doc 240 | if self.replace_length == 0: 241 | to_keep[indices] = 0 242 | else: 243 | # keep index, but replace it with [MASK] 244 | source[indices] = self.mask_idx 245 | source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) 246 | 247 | if self.mask_span_distribution is not None: 248 | assert len(lengths.size()) == 1 249 | assert lengths.size() == indices.size() 250 | lengths -= 1 251 | while indices.size(0) > 0: 252 | assert lengths.size() == indices.size() 253 | lengths -= is_word_start[indices + 1].long() 254 | uncompleted = lengths >= 0 255 | indices = indices[uncompleted] + 1 256 | mask_random = mask_random[uncompleted] 257 | lengths = lengths[uncompleted] 258 | if self.replace_length != -1: 259 | # delete token 260 | to_keep[indices] = 0 261 | else: 262 | # keep index, but replace it with [MASK] 263 | source[indices] = self.mask_idx 264 | source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) 265 | else: 266 | # A bit faster when all lengths are 1 267 | while indices.size(0) > 0: 268 | uncompleted = is_word_start[indices + 1] == 0 269 | indices = indices[uncompleted] + 1 270 | mask_random = mask_random[uncompleted] 271 | if self.replace_length != -1: 272 | # delete token 273 | to_keep[indices] = 0 274 | else: 275 | # keep index, but replace it with [MASK] 276 | source[indices] = self.mask_idx 277 | source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) 278 | 279 | assert source_length - 1 not in indices 280 | 281 | source = source[to_keep] 282 | 283 | if num_inserts > 0: 284 | source = self.add_insertion_noise(source, num_inserts / source.size(0)) 285 | 286 | return source 287 | 288 | def add_multiple_words_mask(self, source, p): 289 | is_word_start = self.word_starts(source) 290 | num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) 291 | if num_to_mask == 0: 292 | return source 293 | 294 | assert is_word_start[-1] == 0 295 | word_starts = is_word_start.nonzero() 296 | start_index = word_starts.size(0)-num_to_mask 297 | if start_index < 1: 298 | print(source, is_word_start) 299 | return source 300 | 301 | mask_word_start_id = np.random.randint(start_index) 302 | 303 | source_length = source.size(0) 304 | to_keep = torch.ones(source_length, dtype=torch.bool) 305 | is_word_start[-1] = 255 # acts as a long length, so spans don't go over the end of doc 306 | 307 | # keep first index, but replace it with [MASK], and delete remaining index 308 | source[word_starts[mask_word_start_id]] = self.mask_idx 309 | #assert mask_word_start_id+num_to_mask < word_starts.size(0) 310 | #assert (word_starts[mask_word_start_id].item()+num_to_mask) < source_length 311 | try: 312 | for ind in range(word_starts[mask_word_start_id]+1, word_starts[mask_word_start_id+num_to_mask]): 313 | to_keep[ind] = 0 314 | except IndexError: 315 | print("Index error", source, is_word_start) 316 | pass 317 | 318 | source = source[to_keep] 319 | return source 320 | 321 | def collater(self, samples): 322 | """Merge a list of samples to form a mini-batch. 323 | Args: 324 | samples (List[dict]): samples to collate 325 | Returns: 326 | dict: a mini-batch of data 327 | """ 328 | return collate(samples, self.vocab.pad(), self.vocab.eos(), self.vocab) 329 | 330 | def num_tokens(self, index): 331 | """Return the number of tokens in a sample. This value is used to 332 | enforce ``--max-tokens`` during batching.""" 333 | return self.sizes[index] 334 | 335 | def size(self, index): 336 | """Return an example's size as a float or tuple. This value is used when 337 | filtering a dataset with ``--max-positions``.""" 338 | return self.sizes[index] 339 | 340 | def ordered_indices(self): 341 | """Return an ordered list of indices. Batches will be constructed based 342 | on this order.""" 343 | if self.shuffle: 344 | indices = np.random.permutation(len(self)) 345 | else: 346 | indices = np.arange(len(self)) 347 | return indices[np.argsort(self.sizes[indices], kind='mergesort')] 348 | 349 | def prefetch(self, indices): 350 | self.src.prefetch(indices) 351 | self.tgt.prefetch(indices) 352 | 353 | @property 354 | def supports_prefetch(self): 355 | return ( 356 | hasattr(self.src, 'supports_prefetch') 357 | and self.src.supports_prefetch 358 | and hasattr(self.tgt, 'supports_prefetch') 359 | and self.tgt.supports_prefetch 360 | ) 361 | -------------------------------------------------------------------------------- /src/bart_aug/masking_task.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | # Original Copyright Facebook, Inc. and its affiliates. Licensed under the MIT License as part of 4 | # fairseq package. 5 | 6 | 7 | import os 8 | 9 | from fairseq.data import ( 10 | data_utils, 11 | Dictionary, 12 | AppendTokenDataset, 13 | PrependTokenDataset, 14 | StripTokenDataset, 15 | TokenBlockDataset, 16 | ) 17 | from .denoised_dataset import BARTDenoisingDataset 18 | from fairseq.data.encoders.utils import get_whole_word_mask 19 | from fairseq.tasks import FairseqTask, register_task 20 | 21 | 22 | @register_task('mask_s2s') 23 | class DenoisingTaskS2S(FairseqTask): 24 | """ 25 | Denoising task for applying sequence to sequence denoising. 26 | """ 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | """Add task-specific arguments to the parser.""" 31 | parser.add_argument('data', help='path to data directory') 32 | parser.add_argument('--tokens-per-sample', default=512, type=int, 33 | help='max number of total tokens over all segments' 34 | ' per sample for dataset') 35 | parser.add_argument('--raw-text', default=False, action='store_true', 36 | help='load raw text dataset') 37 | parser.add_argument( 38 | '--sample-break-mode', default="eos", type=str, 39 | help='mode for breaking sentence', 40 | ) 41 | parser.add_argument( 42 | '--mask', default=0.3, type=float, 43 | help='fraction of words/subwords that will be masked', 44 | ) 45 | parser.add_argument( 46 | '--mask-random', default=0.0, type=float, 47 | help='instead of using [MASK], use random token this often' 48 | ) 49 | parser.add_argument( 50 | '--insert', default=0.0, type=float, 51 | help='insert this percentage of additional random tokens', 52 | ) 53 | parser.add_argument( 54 | '--mask-length', default="word", type=str, 55 | choices=['subword', 'word', 'span', 'span-poisson'], 56 | help='mask length to choose' 57 | ) 58 | parser.add_argument( 59 | '--replace-length', default=-1, type=int, 60 | help='when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)' 61 | ) 62 | parser.add_argument( 63 | '--tokens-to-keep', default=2, type=int, 64 | help="Don't mask first tokens" 65 | ) 66 | 67 | # following 2 arguments are required for the GPT2 BPE encoding 68 | parser.add_argument( 69 | '--gpt2_encoder_json', default=None, type=str, 70 | help='GPT2 encoder path' 71 | ) 72 | parser.add_argument( 73 | '--gpt2_vocab_bpe', default=None, type=str, 74 | help='GPT2 vocab path' 75 | ) 76 | 77 | def __init__(self, args, dictionary): 78 | super().__init__(args) 79 | self.dictionary = dictionary 80 | self.seed = args.seed 81 | 82 | # add mask token 83 | self.mask_idx = self.dictionary.add_symbol('') 84 | 85 | @classmethod 86 | def setup_task(cls, args, **kwargs): 87 | """Setup the task. 88 | """ 89 | dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) 90 | print('| dictionary: {} types'.format(len(dictionary))) 91 | if not hasattr(args, 'shuffle_instance'): 92 | args.shuffle_instance = False 93 | return cls(args, dictionary) 94 | 95 | def load_dataset(self, split, epoch=0, combine=False, data_selector=None): 96 | """Load a given dataset split. 97 | 98 | Args: 99 | split (str): name of the split (e.g., train, valid, test) 100 | """ 101 | 102 | paths = self.args.data.split(':') 103 | assert len(paths) > 0 104 | data_path = paths[epoch % len(paths)] 105 | split_path = os.path.join(data_path, split) 106 | 107 | dataset = data_utils.load_indexed_dataset( 108 | split_path, 109 | self.dictionary, 110 | self.args.dataset_impl, 111 | combine=combine, 112 | ) 113 | if dataset is None: 114 | raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) 115 | 116 | dataset = StripTokenDataset(dataset, self.dictionary.eos()) 117 | 118 | # create continuous blocks of tokens 119 | dataset = TokenBlockDataset( 120 | dataset, 121 | dataset.sizes, 122 | self.args.tokens_per_sample - 2, # one less for and one for 123 | pad=self.dictionary.pad(), 124 | eos=self.dictionary.eos(), 125 | break_mode="eos", 126 | document_sep_len=0 127 | ) 128 | 129 | # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) 130 | dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) 131 | dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) 132 | 133 | mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ 134 | if self.args.mask_length != 'subword' else None 135 | 136 | self.datasets[split] = BARTDenoisingDataset( 137 | dataset, dataset.sizes, self.dictionary, self.mask_idx, 138 | mask_whole_words, shuffle=self.args.shuffle_instance, 139 | seed=self.seed, args=self.args 140 | ) 141 | print( 142 | "| Split: {0}, Loaded {1} samples of denoising_dataset".format( 143 | split, 144 | len(self.datasets[split]), 145 | ) 146 | ) 147 | 148 | def max_positions(self): 149 | """Return the max sentence length allowed by the task.""" 150 | #return (self.args.max_source_positions, self.args.max_target_positions) 151 | return (1024, 1024) 152 | 153 | @property 154 | def source_dictionary(self): 155 | """Return the source :class:`~fairseq.data.Dictionary`.""" 156 | return self.dictionary 157 | 158 | @property 159 | def target_dictionary(self): 160 | """Return the target :class:`~fairseq.data.Dictionary`.""" 161 | return self.dictionary 162 | -------------------------------------------------------------------------------- /src/bert_aug/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bert_aug/__init__.py -------------------------------------------------------------------------------- /src/bert_aug/__pycache__/bert_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bert_aug/__pycache__/bert_model.cpython-37.pyc -------------------------------------------------------------------------------- /src/bert_aug/__pycache__/bert_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bert_aug/__pycache__/bert_model.cpython-38.pyc -------------------------------------------------------------------------------- /src/bert_aug/__pycache__/data_processors.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bert_aug/__pycache__/data_processors.cpython-37.pyc -------------------------------------------------------------------------------- /src/bert_aug/__pycache__/data_processors.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/bert_aug/__pycache__/data_processors.cpython-38.pyc -------------------------------------------------------------------------------- /src/bert_aug/backtranslation.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import csv 5 | import logging 6 | import argparse 7 | import random 8 | 9 | import os 10 | import numpy as np 11 | import torch 12 | from fairseq.models.transformer import TransformerModel 13 | from data_processors import get_task_processor 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 18 | datefmt='%m/%d/%Y %H:%M:%S', 19 | level=logging.INFO) 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | 27 | ## Required parameters 28 | parser.add_argument("--data_dir", default="datasets", type=str, 29 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 30 | parser.add_argument("--output_dir", default="aug_data", type=str, 31 | help="The output dir for augmented dataset") 32 | parser.add_argument("--task_name",default="subj",type=str, 33 | help="The name of the task to train.") 34 | parser.add_argument("--train_batch_size", default=32, type=int, 35 | help="Total batch size for training.") 36 | parser.add_argument('--seed', type=int, default=42, 37 | help="random seed for initialization") 38 | parser.add_argument('--sample_num', type=int, default=1, 39 | help="sample number") 40 | parser.add_argument('--cache', default="fairseq_cache", type=str) 41 | parser.add_argument('--gpu', type=int, default=0, 42 | help="gpu id") 43 | args = parser.parse_args() 44 | 45 | print(args) 46 | backtranslation_using_en_de_model(args) 47 | 48 | 49 | def backtranslation_using_en_de_model(args): 50 | task_name = args.task_name 51 | os.makedirs(args.output_dir, exist_ok=True) 52 | 53 | random.seed(args.seed) 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | torch.cuda.manual_seed_all(args.seed) 57 | torch.backends.cudnn.deterministic = True 58 | 59 | os.makedirs(args.output_dir, exist_ok=True) 60 | processor = get_task_processor(task_name, args.data_dir) 61 | # load train and dev data 62 | train_examples = processor.get_train_examples() 63 | 64 | # load the best model 65 | en_de_model = TransformerModel.from_pretrained( 66 | os.path.join(args.cache, "wmt19.en-de.joined-dict.single_model"), 67 | checkpoint_file="model.pt", 68 | tokenizer='moses', 69 | bpe='fastbpe' 70 | ) 71 | 72 | de_en_model = TransformerModel.from_pretrained( 73 | os.path.join(args.cache, "wmt19.de-en.joined-dict.single_model"), 74 | checkpoint_file="model.pt", 75 | tokenizer='moses', 76 | bpe='fastbpe' 77 | ) 78 | 79 | # en_de_model.to(device) 80 | # de_en_model.to(device) 81 | 82 | save_train_path = os.path.join(args.output_dir, "bt_aug.tsv") 83 | save_train_file = open(save_train_path, 'w') 84 | tsv_writer = csv.writer(save_train_file, delimiter='\t') 85 | for example in train_examples: 86 | text = example.text_a 87 | de_example = en_de_model.translate(text, remove_bpe=True) 88 | back_translated_example = de_en_model.translate(de_example, remove_bpe=True) 89 | tsv_writer.writerow([example.label, back_translated_example]) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() -------------------------------------------------------------------------------- /src/bert_aug/bert_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | 5 | import torch 6 | import argparse 7 | 8 | from data_processors import get_data 9 | from bert_model import Classifier 10 | import random 11 | 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | print('device:', device) 15 | 16 | 17 | def main(args): 18 | random.seed(args.seed) 19 | torch.manual_seed(args.seed) 20 | torch.cuda.manual_seed_all(args.seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | examples, label_list = get_data( 24 | task=args.task, 25 | data_dir=args.data_dir, 26 | data_seed=args.seed) 27 | t_total = len(examples['train']) // args.epochs 28 | 29 | classifier = Classifier(label_list=label_list, device=device, cache_dir=args.cache,temp_rate=args.temp_rate,smooth_rate=args.smooth_rate) 30 | classifier.get_optimizer(learning_rate=args.learning_rate, 31 | warmup_steps=args.warmup_steps, 32 | t_total=t_total) 33 | 34 | classifier.load_data( 35 | 'train', examples['train'], args.batch_size, max_length=args.max_seq_length, shuffle=True) 36 | classifier.load_data( 37 | 'dev', examples['dev'], args.batch_size, max_length=args.max_seq_length, shuffle=False) 38 | classifier.load_data( 39 | 'test', examples['test'], args.batch_size, max_length=args.max_seq_length, shuffle=False) 40 | 41 | print('=' * 60, '\n', 'Training', '\n', '=' * 60, sep='') 42 | best_dev_acc, final_test_acc = -1., -1. 43 | for epoch in range(args.epochs): 44 | classifier.train_epoch() 45 | dev_acc = classifier.evaluate('dev') 46 | 47 | if epoch >= args.min_epochs: 48 | do_test = (dev_acc > best_dev_acc) 49 | best_dev_acc = max(best_dev_acc, dev_acc) 50 | else: 51 | do_test = False 52 | 53 | print('Epoch {}, Dev Acc: {:.4f}, Best Ever: {:.4f}'.format( 54 | epoch, 100. * dev_acc, 100. * best_dev_acc)) 55 | 56 | if do_test: 57 | final_test_acc = classifier.evaluate('test') 58 | print('Test Acc: {:.4f}'.format(100. * final_test_acc)) 59 | 60 | print('Final Dev Acc: {:.4f}, Final Test Acc: {:.4f}'.format( 61 | 100. * best_dev_acc, 100. * final_test_acc)) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser() 66 | 67 | parser.add_argument('--task', choices=['stsa', 'snips', 'trec']) 68 | parser.add_argument('--data_dir', type=str, help="Data dir path with {train, dev, test}.tsv") 69 | parser.add_argument('--seed', default=159, type=int) 70 | parser.add_argument('--hidden_dropout_prob', default=0.1, type=float) 71 | parser.add_argument("--warmup_steps", default=100, type=int, 72 | help="Linear warmup over warmup_steps.") 73 | 74 | parser.add_argument("--max_seq_length", default=64, type=int, 75 | help="The maximum total input sequence length after tokenization. " 76 | "Sequences longer than this will be truncated, sequences shorter will be padded.") 77 | 78 | parser.add_argument('--cache', default="transformers_cache", type=str) 79 | 80 | parser.add_argument('--epochs', default=8, type=int) 81 | parser.add_argument('--min_epochs', default=0, type=int) 82 | parser.add_argument("--learning_rate", default=4e-5, type=float) 83 | parser.add_argument('--batch_size', default=8, type=int) 84 | parser.add_argument('--temp_rate', default=1.0, type=float) 85 | parser.add_argument('--smooth_rate', default=0.5, type=float) 86 | 87 | args = parser.parse_args() 88 | print(args) 89 | main(args) 90 | 91 | -------------------------------------------------------------------------------- /src/bert_aug/bert_model.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | # Original Copyright huggingface and its affiliates. Licensed under the Apache-2.0 License as part 4 | # of huggingface's transformers package. 5 | # Credit https://github.com/huggingface/transformers/blob/master/examples/run_glue.py 6 | 7 | from transformers import BertTokenizer 8 | from transformers.modeling_bert import BertForSequenceClassification,BertForMaskedLM 9 | #from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertLMPredictionHead,BertForMaskedLM 10 | from transformers import AdamW, get_linear_schedule_with_warmup 11 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 12 | 13 | from tqdm import tqdm 14 | 15 | import torch 16 | from torch.utils.data import DataLoader, TensorDataset 17 | 18 | BERT_MODEL = '/share/wuxing/beifen_gaochaochen/gaochaochen/STS/model/bert-base-uncased' 19 | 20 | 21 | class Classifier: 22 | def __init__(self, label_list, device, cache_dir,temp_rate,smooth_rate): 23 | self._label_list = label_list 24 | self._device = device 25 | 26 | self._tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, 27 | do_lower_case=True, 28 | cache_dir=cache_dir) 29 | 30 | self._model = BertForSequenceClassification.from_pretrained(BERT_MODEL, 31 | num_labels=len(label_list), 32 | cache_dir=cache_dir) 33 | self._model.to(device) 34 | 35 | self._optimizer = None 36 | 37 | self.smooth_model = BertForMaskedLM.from_pretrained(BERT_MODEL).to(device) 38 | self.temp_rate=temp_rate 39 | self.smooth_rate=smooth_rate 40 | 41 | 42 | for params in self.smooth_model.parameters(): 43 | params.requires_grad = False 44 | 45 | self._dataset = {} 46 | self._data_loader = {} 47 | 48 | def load_data(self, set_type, examples, batch_size, max_length, shuffle): 49 | self._dataset[set_type] = examples 50 | self._data_loader[set_type] = _make_data_loader( 51 | examples=examples, 52 | label_list=self._label_list, 53 | tokenizer=self._tokenizer, 54 | batch_size=batch_size, 55 | max_length=max_length, 56 | shuffle=shuffle) 57 | 58 | def get_optimizer(self, learning_rate, warmup_steps, t_total): 59 | self._optimizer, self._scheduler = _get_optimizer( 60 | self._model, learning_rate=learning_rate, 61 | warmup_steps=warmup_steps, t_total=t_total) 62 | 63 | def train_epoch(self): 64 | self._model.train() 65 | 66 | for step, batch in enumerate(tqdm(self._data_loader['train'], 67 | desc='Training')): 68 | batch = tuple(t.to(self._device) for t in batch) 69 | inputs = {'input_ids': batch[0], 70 | 'attention_mask': batch[1], 71 | 'token_type_ids': batch[2], 72 | 'labels': batch[3]} 73 | 74 | self._optimizer.zero_grad() 75 | outputs = self._model(**inputs) 76 | loss = outputs[0] 77 | loss.backward() 78 | self._optimizer.step() 79 | self._scheduler.step() 80 | 81 | input_smooth = {'input_ids': batch[0], 82 | 'attention_mask': batch[1], 83 | 'token_type_ids': batch[2], 84 | } 85 | input_probs = self.smooth_model( 86 | **input_smooth 87 | ) 88 | 89 | word_embeddings = self._model.get_input_embeddings().to(self._device) 90 | one_hot = torch.zeros_like(input_probs[0]).scatter_(2,inputs['input_ids'].reshape(inputs['input_ids'].shape[0],inputs['input_ids'].shape[1],1).long(),1.0).to(self._device) 91 | 92 | 93 | now_probs = self.smooth_rate*(torch.nn.functional.softmax(input_probs[0]/self.temp_rate, dim=-1).to(self._device))+(1-self.smooth_rate)*one_hot # 4 2 0.5 0.25 94 | inputs_embeds_smooth = now_probs @ word_embeddings.weight 95 | input_new_smooth={ 96 | 'attention_mask': batch[1], 97 | 'token_type_ids': batch[2], 98 | 'inputs_embeds': inputs_embeds_smooth, 99 | 'labels': batch[3] 100 | } 101 | outputs_smooth = self._model(**input_new_smooth)[0] 102 | 103 | self._optimizer.zero_grad() 104 | loss = outputs_smooth 105 | loss.backward() 106 | self._optimizer.step() 107 | self._scheduler.step() 108 | 109 | def evaluate(self, set_type): 110 | self._model.eval() 111 | 112 | preds_all, labels_all = [], [] 113 | data_loader = self._data_loader[set_type] 114 | 115 | for batch in tqdm(data_loader, 116 | desc="Evaluating {} set".format(set_type)): 117 | batch = tuple(t.to(self._device) for t in batch) 118 | inputs = {'input_ids': batch[0], 119 | 'attention_mask': batch[1], 120 | 'token_type_ids': batch[2], 121 | 'labels': batch[3]} 122 | 123 | with torch.no_grad(): 124 | outputs = self._model(**inputs) 125 | tmp_eval_loss, logits = outputs[:2] 126 | preds = torch.argmax(logits, dim=1) 127 | 128 | preds_all.append(preds) 129 | labels_all.append(inputs["labels"]) 130 | 131 | preds_all = torch.cat(preds_all, dim=0) 132 | labels_all = torch.cat(labels_all, dim=0) 133 | 134 | return torch.sum(preds_all == labels_all).item() / labels_all.shape[0] 135 | 136 | 137 | def _get_optimizer(model, learning_rate, warmup_steps, t_total): 138 | # Prepare optimizer and schedule (linear warmup and decay) 139 | no_decay = ['bias', 'LayerNorm.weight'] 140 | optimizer_grouped_parameters = [ 141 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 142 | 'weight_decay': 0.01}, 143 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 144 | 'weight_decay': 0.0} 145 | ] 146 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8) 147 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 148 | num_training_steps=t_total) 149 | return optimizer, scheduler 150 | 151 | 152 | def _make_data_loader(examples, label_list, tokenizer, batch_size, max_length, shuffle): 153 | features = convert_examples_to_features(examples, 154 | tokenizer, 155 | label_list=label_list, 156 | max_length=max_length, 157 | output_mode="classification") 158 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 159 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 160 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 161 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 162 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 163 | 164 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) 165 | -------------------------------------------------------------------------------- /src/bert_aug/cbert.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | # Original Copyright: https://github.com/1024er/cbert_aug/blob/develop/aug_dataset.py 4 | 5 | import csv 6 | import os 7 | import logging 8 | import argparse 9 | import random 10 | from tqdm import tqdm, trange 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 16 | 17 | from transformers.tokenization_bert import BertTokenizer 18 | from transformers.modeling_bert import BertForMaskedLM 19 | 20 | from transformers import AdamW 21 | from data_processors import get_task_processor 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | BERT_MODEL = 'bert-base-uncased' 26 | 27 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 28 | datefmt='%m/%d/%Y %H:%M:%S', 29 | level=logging.INFO) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class InputFeatures(object): 35 | """A single set of features of data.""" 36 | 37 | def __init__(self, init_ids, input_ids, input_mask, segment_ids, masked_lm_labels): 38 | self.init_ids = init_ids 39 | self.input_ids = input_ids 40 | self.input_mask = input_mask 41 | self.segment_ids = segment_ids 42 | self.masked_lm_labels = masked_lm_labels 43 | 44 | 45 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, seed=12345): 46 | """Loads a data file into a list of `InputBatch`s.""" 47 | 48 | label_map = {} 49 | for (i, label) in enumerate(label_list): 50 | label_map[label] = i 51 | 52 | features = [] 53 | # ---- 54 | dupe_factor = 5 55 | masked_lm_prob = 0.15 56 | max_predictions_per_seq = 20 57 | rng = random.Random(seed) 58 | 59 | for (ex_index, example) in enumerate(examples): 60 | tokens_a = tokenizer.tokenize(example.text_a) 61 | segment_id = label_map[example.label] 62 | # Account for [CLS] and [SEP] with "- 2" 63 | if len(tokens_a) > max_seq_length - 2: 64 | tokens_a = tokens_a[0:(max_seq_length - 2)] 65 | 66 | # Due to we use conditional bert, we need to place label information in segment_ids 67 | tokens = [] 68 | segment_ids = [] 69 | tokens.append("[CLS]") 70 | segment_ids.append(segment_id) 71 | for token in tokens_a: 72 | tokens.append(token) 73 | segment_ids.append(segment_id) 74 | tokens.append("[SEP]") 75 | segment_ids.append(segment_id) 76 | masked_lm_labels = [-100] * max_seq_length 77 | 78 | cand_indexes = [] 79 | for (i, token) in enumerate(tokens): 80 | if token == "[CLS]" or token == "[SEP]": 81 | continue 82 | cand_indexes.append(i) 83 | 84 | rng.shuffle(cand_indexes) 85 | len_cand = len(cand_indexes) 86 | 87 | output_tokens = list(tokens) 88 | 89 | num_to_predict = min(max_predictions_per_seq, 90 | max(1, int(round(len(tokens) * masked_lm_prob)))) 91 | 92 | masked_lms_pos = [] 93 | covered_indexes = set() 94 | for index in cand_indexes: 95 | if len(masked_lms_pos) >= num_to_predict: 96 | break 97 | if index in covered_indexes: 98 | continue 99 | covered_indexes.add(index) 100 | 101 | masked_token = None 102 | # 80% of the time, replace with [MASK] 103 | if rng.random() < 0.8: 104 | masked_token = "[MASK]" 105 | else: 106 | # 10% of the time, keep original 107 | if rng.random() < 0.5: 108 | masked_token = tokens[index] 109 | # 10% of the time, replace with random word 110 | else: 111 | masked_token = tokens[cand_indexes[rng.randint(0, len_cand - 1)]] 112 | 113 | masked_lm_labels[index] = tokenizer.convert_tokens_to_ids([tokens[index]])[0] 114 | output_tokens[index] = masked_token 115 | masked_lms_pos.append(index) 116 | 117 | init_ids = tokenizer.convert_tokens_to_ids(tokens) 118 | input_ids = tokenizer.convert_tokens_to_ids(output_tokens) 119 | 120 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 121 | # tokens are attended to. 122 | input_mask = [1] * len(input_ids) 123 | 124 | # Zero-pad up to the sequence length. 125 | while len(input_ids) < max_seq_length: 126 | init_ids.append(0) 127 | input_ids.append(0) 128 | input_mask.append(0) 129 | segment_ids.append(0) # ?segment_id 130 | 131 | assert len(init_ids) == max_seq_length 132 | assert len(input_ids) == max_seq_length 133 | assert len(input_mask) == max_seq_length 134 | assert len(segment_ids) == max_seq_length 135 | 136 | if ex_index < 2: 137 | logger.info("*** Example ***") 138 | logger.info("guid: %s" % (example.guid)) 139 | logger.info("tokens: %s" % " ".join( 140 | [str(x) for x in tokens])) 141 | logger.info("init_ids: %s" % " ".join([str(x) for x in init_ids])) 142 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 143 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 144 | logger.info( 145 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 146 | logger.info("masked_lm_labels: %s" % " ".join([str(x) for x in masked_lm_labels])) 147 | 148 | features.append( 149 | InputFeatures(init_ids=init_ids, 150 | input_ids=input_ids, 151 | input_mask=input_mask, 152 | segment_ids=segment_ids, 153 | masked_lm_labels=masked_lm_labels)) 154 | return features 155 | 156 | 157 | def prepare_data(features): 158 | all_init_ids = torch.tensor([f.init_ids for f in features], dtype=torch.long) 159 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 160 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 161 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 162 | all_masked_lm_labels = torch.tensor([f.masked_lm_labels for f in features], 163 | dtype=torch.long) 164 | tensor_data = TensorDataset(all_init_ids, all_input_ids, all_input_mask, all_segment_ids, 165 | all_masked_lm_labels) 166 | return tensor_data 167 | 168 | 169 | def rev_wordpiece(str): 170 | #print(str) 171 | if len(str) > 1: 172 | for i in range(len(str)-1, 0, -1): 173 | if str[i] == '[PAD]': 174 | str.remove(str[i]) 175 | elif len(str[i]) > 1 and str[i][0]=='#' and str[i][1]=='#': 176 | str[i-1] += str[i][2:] 177 | str.remove(str[i]) 178 | return " ".join(str[1:-1]) 179 | 180 | 181 | def main(): 182 | parser = argparse.ArgumentParser() 183 | 184 | ## Required parameters 185 | parser.add_argument("--data_dir", default="datasets", type=str, 186 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 187 | parser.add_argument("--output_dir", default="aug_data", type=str, 188 | help="The output dir for augmented dataset") 189 | parser.add_argument("--task_name",default="subj",type=str, 190 | help="The name of the task to train.") 191 | parser.add_argument("--max_seq_length", default=64, type=int, 192 | help="The maximum total input sequence length after WordPiece tokenization. \n" 193 | "Sequences longer than this will be truncated, and sequences shorter \n" 194 | "than this will be padded.") 195 | # parser.add_argument("--do_lower_case", default=False, action='store_true', 196 | # help="Set this flag if you are using an uncased model.") 197 | parser.add_argument('--cache', default="transformers_cache", type=str) 198 | parser.add_argument("--train_batch_size", default=8, type=int, 199 | help="Total batch size for training.") 200 | parser.add_argument("--learning_rate", default=4e-5, type=float, 201 | help="The initial learning rate for Adam.") 202 | parser.add_argument("--num_train_epochs", default=10.0, type=float, 203 | help="Total number of training epochs to perform.") 204 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 205 | help="Proportion of training to perform linear learning rate warmup for. " 206 | "E.g., 0.1 = 10%% of training.") 207 | parser.add_argument('--seed', type=int, default=42, 208 | help="random seed for initialization") 209 | parser.add_argument('--sample_num', type=int, default=1, 210 | help="sample number") 211 | parser.add_argument('--sample_ratio', type=int, default=7, 212 | help="sample ratio") 213 | parser.add_argument('--gpu', type=int, default=0, 214 | help="gpu id") 215 | parser.add_argument('--temp', type=float, default=1.0, 216 | help="temperature") 217 | 218 | args = parser.parse_args() 219 | 220 | print(args) 221 | train_cbert_and_augment(args) 222 | 223 | 224 | def compute_dev_loss(model, dev_dataloader): 225 | model.eval() 226 | sum_loss = 0. 227 | for step, batch in enumerate(dev_dataloader): 228 | batch = tuple(t.to(device) for t in batch) 229 | _, input_ids, input_mask, segment_ids, masked_ids = batch 230 | inputs = {'input_ids': batch[1], 231 | 'attention_mask': batch[2], 232 | 'token_type_ids': batch[3], 233 | 'masked_lm_labels': batch[4]} 234 | 235 | outputs = model(**inputs) 236 | loss = outputs[0] 237 | sum_loss += loss.item() 238 | return sum_loss 239 | 240 | 241 | def augment_train_data(model, tokenizer, train_data, label_list, args): 242 | # load the best model 243 | train_sampler = SequentialSampler(train_data) 244 | train_dataloader = DataLoader(train_data, sampler=train_sampler, 245 | batch_size=args.train_batch_size) 246 | best_model_path = os.path.join(args.output_dir, "best_cbert.pt") 247 | if os.path.exists(best_model_path): 248 | model.load_state_dict(torch.load(best_model_path)) 249 | else: 250 | raise ValueError("Unable to find the saved model at {}".format(best_model_path)) 251 | 252 | save_train_path = os.path.join(args.output_dir, "cbert_aug.tsv") 253 | save_train_file = open(save_train_path, 'w') 254 | 255 | MASK_id = tokenizer.convert_tokens_to_ids(['[MASK]'])[0] 256 | tsv_writer = csv.writer(save_train_file, delimiter='\t') 257 | 258 | for step, batch in enumerate(train_dataloader): 259 | model.eval() 260 | batch = tuple(t.to(device) for t in batch) 261 | init_ids, _, input_mask, segment_ids, _ = batch 262 | input_lens = [sum(mask).item() for mask in input_mask] 263 | masked_idx = np.squeeze( 264 | [np.random.randint(0, l, max(l // args.sample_ratio, 1)) for l in input_lens]) 265 | for ids, idx in zip(init_ids, masked_idx): 266 | ids[idx] = MASK_id 267 | 268 | inputs = {'input_ids': init_ids, 269 | 'attention_mask': input_mask, 270 | 'token_type_ids': segment_ids} 271 | 272 | outputs = model(**inputs) 273 | predictions = outputs[0] # model(init_ids, segment_ids, input_mask) 274 | predictions = F.softmax(predictions / args.temp, dim=2) 275 | 276 | for ids, idx, preds, seg in zip(init_ids, masked_idx, predictions, segment_ids): 277 | preds = torch.multinomial(preds, args.sample_num, replacement=True)[idx] 278 | if len(preds.size()) == 2: 279 | preds = torch.transpose(preds, 0, 1) 280 | for pred in preds: 281 | ids[idx] = pred 282 | new_str = tokenizer.convert_ids_to_tokens(ids.cpu().numpy()) 283 | new_str = rev_wordpiece(new_str) 284 | tsv_writer.writerow([label_list[seg[0].item()], new_str]) 285 | 286 | 287 | def train_cbert_and_augment(args): 288 | task_name = args.task_name 289 | os.makedirs(args.output_dir, exist_ok=True) 290 | 291 | random.seed(args.seed) 292 | np.random.seed(args.seed) 293 | torch.manual_seed(args.seed) 294 | torch.cuda.manual_seed_all(args.seed) 295 | torch.backends.cudnn.deterministic = True 296 | 297 | os.makedirs(args.output_dir, exist_ok=True) 298 | processor = get_task_processor(task_name, args.data_dir) 299 | label_list = processor.get_labels(task_name) 300 | 301 | # load train and dev data 302 | train_examples = processor.get_train_examples() 303 | dev_examples = processor.get_dev_examples() 304 | 305 | tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, 306 | do_lower_case=True, 307 | cache_dir=args.cache) 308 | 309 | model = BertForMaskedLM.from_pretrained(BERT_MODEL, 310 | cache_dir=args.cache) 311 | 312 | if len(label_list) > 2: 313 | model.bert.embeddings.token_type_embeddings = torch.nn.Embedding(len(label_list), 768) 314 | model.bert.embeddings.token_type_embeddings.weight.data.normal_(mean=0.0, std=0.02) 315 | 316 | model.to(device) 317 | 318 | # train data 319 | train_features = convert_examples_to_features(train_examples, label_list, 320 | args.max_seq_length, 321 | tokenizer, args.seed) 322 | train_data = prepare_data(train_features) 323 | train_sampler = RandomSampler(train_data) 324 | train_dataloader = DataLoader(train_data, sampler=train_sampler, 325 | batch_size=args.train_batch_size) 326 | 327 | #dev data 328 | dev_features = convert_examples_to_features(dev_examples, label_list, 329 | args.max_seq_length, 330 | tokenizer, args.seed) 331 | dev_data = prepare_data(dev_features) 332 | dev_sampler = SequentialSampler(dev_data) 333 | dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, batch_size=args.train_batch_size) 334 | 335 | num_train_steps = int(len(train_features) / args.train_batch_size * args.num_train_epochs) 336 | logger.info("***** Running training *****") 337 | logger.info(" Num examples = %d", len(train_features)) 338 | logger.info(" Batch size = %d", args.train_batch_size) 339 | logger.info(" Num steps = %d", num_train_steps) 340 | 341 | # Prepare optimizer 342 | t_total = num_train_steps 343 | no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.weight'] 344 | optimizer_grouped_parameters = [ 345 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 346 | 'weight_decay': 0.01}, 347 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 348 | 'weight_decay': 0.0} 349 | ] 350 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) 351 | 352 | best_dev_loss = float('inf') 353 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 354 | avg_loss = 0. 355 | model.train() 356 | for step, batch in enumerate(train_dataloader): 357 | batch = tuple(t.to(device) for t in batch) 358 | inputs = {'input_ids': batch[1], 359 | 'attention_mask': batch[2], 360 | 'token_type_ids': batch[3], 361 | 'masked_lm_labels': batch[4]} 362 | 363 | outputs = model(**inputs) 364 | loss = outputs[0] 365 | optimizer.zero_grad() 366 | loss.backward() 367 | avg_loss += loss.item() 368 | optimizer.step() 369 | 370 | if (step + 1) % 50 == 0: 371 | print("avg_loss: {}".format(avg_loss / 50)) 372 | avg_loss = 0. 373 | 374 | # eval on dev after every epoch 375 | dev_loss = compute_dev_loss(model, dev_dataloader) 376 | print("Epoch {}, Dev loss {}".format(epoch, dev_loss)) 377 | if dev_loss < best_dev_loss: 378 | best_dev_loss = dev_loss 379 | print("Saving model. Best dev so far {}".format(best_dev_loss)) 380 | save_model_path = os.path.join(args.output_dir, 'best_cbert.pt') 381 | torch.save(model.state_dict(), save_model_path) 382 | 383 | # augment data using the best model 384 | augment_train_data(model, tokenizer, train_data, label_list, args) 385 | 386 | 387 | if __name__ == "__main__": 388 | main() -------------------------------------------------------------------------------- /src/bert_aug/cgpt2.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import csv 5 | import os 6 | import logging 7 | import argparse 8 | import random 9 | from tqdm import tqdm, trange 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 14 | 15 | from transformers.tokenization_gpt2 import GPT2Tokenizer 16 | from transformers.modeling_gpt2 import GPT2LMHeadModel 17 | 18 | from transformers import AdamW 19 | from data_processors import get_task_processor 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | GPT2_MODEL = 'gpt2' 24 | EOS_TOKEN = '<|endoftext|>' 25 | SEP_TOKEN = '' 26 | 27 | STOP_TOKENS = [EOS_TOKEN, '<'] 28 | 29 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 30 | datefmt='%m/%d/%Y %H:%M:%S', 31 | level=logging.INFO) 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class InputFeatures(object): 36 | """A single set of features of data.""" 37 | 38 | def __init__(self, examples): 39 | self.examples = examples 40 | 41 | 42 | def convert_examples_to_features(examples, block_size, tokenizer, seed=12345): 43 | """Loads a data file into a list of `InputBatch`s.""" 44 | 45 | features = [] 46 | 47 | text = "" 48 | for (ex_index, example) in enumerate(examples): 49 | if ex_index: 50 | text += " " + example.label + SEP_TOKEN + example.text_a + EOS_TOKEN 51 | else: 52 | text += example.label + SEP_TOKEN + example.text_a + EOS_TOKEN 53 | 54 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 55 | 56 | for i in range(0, len(tokenized_text) - block_size + 1, 57 | block_size): # Truncate in block of block_size 58 | features.append(InputFeatures( 59 | examples=tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i + block_size]))) 60 | 61 | return features 62 | 63 | 64 | def prepare_data(features): 65 | all_input_ids = torch.tensor([f.examples for f in features], dtype=torch.long) 66 | all_labels = torch.tensor([f.examples for f in features], dtype=torch.long) 67 | tensor_data = TensorDataset(all_input_ids, all_labels) 68 | return tensor_data 69 | 70 | 71 | def main(): 72 | parser = argparse.ArgumentParser() 73 | 74 | ## Required parameters 75 | parser.add_argument("--data_dir", default="datasets", type=str, 76 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 77 | parser.add_argument("--output_dir", default="aug_data", type=str, 78 | help="The output dir for augmented dataset") 79 | parser.add_argument("--max_seq_length", default=64, type=int, 80 | help="The maximum total input sequence length after WordPiece tokenization. \n" 81 | "Sequences longer than this will be truncated, and sequences shorter \n" 82 | "than this will be padded.") 83 | parser.add_argument("--block_size", default=64, type=int, 84 | help="The maximum total input sequence length after WordPiece tokenization. \n" 85 | "Sequences longer than this will be truncated, and sequences shorter \n" 86 | "than this will be padded.") 87 | parser.add_argument('--cache', default="transformers_cache", type=str) 88 | parser.add_argument("--task_name", default="trec", type=str, 89 | help="The name of the task to train.") 90 | parser.add_argument("--train_batch_size", default=32, type=int, 91 | help="Total batch size for training.") 92 | parser.add_argument("--learning_rate", default=4e-5, type=float, 93 | help="The initial learning rate for Adam.") 94 | parser.add_argument("--num_train_epochs", default=20.0, type=float, 95 | help="Total number of training epochs to perform.") 96 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 97 | help="Proportion of training to perform linear learning rate warmup for. " 98 | "E.g., 0.1 = 10%% of training.") 99 | parser.add_argument('--seed', type=int, default=42, 100 | help="random seed for initialization") 101 | parser.add_argument('--sample_num', type=int, default=1, 102 | help="sample number") 103 | parser.add_argument('--sample_ratio', type=int, default=7, 104 | help="sample ratio") 105 | parser.add_argument("--temperature", type=float, default=1.0, 106 | help="temperature of 0 implies greedy sampling") 107 | parser.add_argument("--repetition_penalty", type=float, default=1.0, 108 | help="primarily useful for CTRL model; in that case, use 1.2") 109 | parser.add_argument("--prefix", type=int, default=3) 110 | parser.add_argument("--top_k", type=int, default=0) 111 | parser.add_argument("--top_p", type=float, default=0.0) 112 | parser.add_argument('--gpu', type=int, default=0, 113 | help="gpu id") 114 | parser.add_argument('--temp', type=float, default=1.0, 115 | help="temperature") 116 | 117 | args = parser.parse_args() 118 | 119 | print(args) 120 | train_cmodgpt2_and_augment(args) 121 | 122 | 123 | def compute_dev_loss(model, dev_dataloader): 124 | model.eval() 125 | sum_loss = 0. 126 | for step, batch in enumerate(dev_dataloader): 127 | batch = tuple(t.to(device) for t in batch) 128 | inputs = {'input_ids': batch[0], 129 | 'labels': batch[1]} 130 | 131 | outputs = model(**inputs) 132 | loss = outputs[0] 133 | sum_loss += loss.item() 134 | return sum_loss 135 | 136 | 137 | def augment_train_data(model, tokenizer, train_examples, args): 138 | # load the best model 139 | best_model_path = os.path.join(args.output_dir, "best_cmodgpt2.pt") 140 | if os.path.exists(best_model_path): 141 | model.load_state_dict(torch.load(best_model_path)) 142 | model.to(device) 143 | else: 144 | raise ValueError("Unable to find the saved model at {}".format(best_model_path)) 145 | prefix_size = args.prefix 146 | save_train_path = os.path.join(args.output_dir, "cmodgpt2_aug_{}.tsv".format(prefix_size)) 147 | save_train_file = open(save_train_path, 'w') 148 | 149 | tsv_writer = csv.writer(save_train_file, delimiter='\t') 150 | 151 | prefix_text = None 152 | for ex_index, example in enumerate(train_examples): 153 | model.eval() 154 | if prefix_size > 0: 155 | prefix_text = " ".join(example.text_a.split(' ')[:prefix_size]) 156 | raw_text = example.label + SEP_TOKEN + prefix_text 157 | else: 158 | raw_text = example.label + SEP_TOKEN 159 | 160 | context_tokens = tokenizer.encode(raw_text, return_tensors='pt').to(device) 161 | out = model.generate( 162 | input_ids=context_tokens, 163 | max_length=args.max_seq_length, 164 | num_return_sequences=1, 165 | do_sample=True, 166 | temperature=args.temperature, 167 | top_k=args.top_k, 168 | top_p=args.top_p, 169 | repetition_penalty=args.repetition_penalty, 170 | pad_token_id=50256 171 | ) 172 | 173 | out = out[:, len(context_tokens):].tolist() 174 | for o in out: 175 | text = tokenizer.decode(o, clean_up_tokenization_spaces=True) 176 | eosn_index = 128 177 | for stop_token in STOP_TOKENS: 178 | idx = text.find(stop_token) 179 | if idx > 0: 180 | eosn_index = min(eosn_index, idx) 181 | text = text[: eosn_index] 182 | text = text.replace("\n", " ").replace(EOS_TOKEN, ' ').strip() 183 | if prefix_size > 0: 184 | text = prefix_text + " " + text 185 | tsv_writer.writerow([example.label, text]) 186 | 187 | 188 | def train_cmodgpt2_and_augment(args): 189 | task_name = args.task_name 190 | os.makedirs(args.output_dir, exist_ok=True) 191 | 192 | random.seed(args.seed) 193 | np.random.seed(args.seed) 194 | torch.manual_seed(args.seed) 195 | 196 | os.makedirs(args.output_dir, exist_ok=True) 197 | processor = get_task_processor(task_name, args.data_dir) 198 | #label_list = processor.get_labels(task_name) 199 | 200 | # load train and dev data 201 | train_examples = processor.get_train_examples() 202 | dev_examples = processor.get_dev_examples() 203 | 204 | tokenizer = GPT2Tokenizer.from_pretrained(GPT2_MODEL, 205 | do_lower_case=True, 206 | cache_dir=args.cache) 207 | 208 | args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) 209 | 210 | model = GPT2LMHeadModel.from_pretrained(GPT2_MODEL, 211 | cache_dir=args.cache) 212 | 213 | model.to(device) 214 | 215 | # train data 216 | train_features = convert_examples_to_features(train_examples, 217 | args.block_size, 218 | tokenizer, args.seed) 219 | train_data = prepare_data(train_features) 220 | train_sampler = RandomSampler(train_data) 221 | train_dataloader = DataLoader(train_data, sampler=train_sampler, 222 | batch_size=args.train_batch_size) 223 | 224 | # dev data 225 | dev_features = convert_examples_to_features(dev_examples, 226 | args.block_size, 227 | tokenizer, args.seed) 228 | dev_data = prepare_data(dev_features) 229 | dev_sampler = SequentialSampler(dev_data) 230 | dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, 231 | batch_size=args.train_batch_size) 232 | 233 | num_train_steps = int(len(train_features) / args.train_batch_size * args.num_train_epochs) 234 | logger.info("***** Running training *****") 235 | logger.info(" Num examples = %d", len(train_features)) 236 | logger.info(" Batch size = %d", args.train_batch_size) 237 | logger.info(" Num steps = %d", num_train_steps) 238 | 239 | # Prepare optimizer and schedule (linear warmup and decay) 240 | t_total = num_train_steps 241 | no_decay = ['bias', 'LayerNorm.weight'] 242 | optimizer_grouped_parameters = [ 243 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 244 | 'weight_decay': 0.01}, 245 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 246 | 'weight_decay': 0.0} 247 | ] 248 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) 249 | 250 | best_dev_loss = float('inf') 251 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 252 | avg_loss = 0. 253 | model.train() 254 | for step, batch in enumerate(train_dataloader): 255 | batch = tuple(t.to(device) for t in batch) 256 | 257 | inputs = {'input_ids': batch[0], 258 | 'labels': batch[1]} 259 | 260 | outputs = model(**inputs) 261 | loss = outputs[0] 262 | # loss = model(input_ids, segment_ids, input_mask, masked_ids) 263 | optimizer.zero_grad() 264 | loss.backward() 265 | avg_loss += loss.item() 266 | optimizer.step() 267 | model.zero_grad() 268 | if (step + 1) % 50 == 0: 269 | print("avg_loss: {}".format(avg_loss / 50)) 270 | # avg_loss = 0. 271 | 272 | # eval on dev after every epoch 273 | dev_loss = compute_dev_loss(model, dev_dataloader) 274 | print("Epoch {}, Dev loss {}".format(epoch, dev_loss)) 275 | if dev_loss < best_dev_loss: 276 | best_dev_loss = dev_loss 277 | print("Saving model. Best dev so far {}".format(best_dev_loss)) 278 | save_model_path = os.path.join(args.output_dir, 'best_cmodgpt2.pt') 279 | torch.save(model.state_dict(), save_model_path) 280 | 281 | # augment data using the best model 282 | augment_train_data(model, tokenizer, train_examples, args) 283 | 284 | 285 | if __name__ == "__main__": 286 | main() -------------------------------------------------------------------------------- /src/bert_aug/cmodbert.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import csv 5 | import os 6 | import logging 7 | import argparse 8 | import random 9 | from tqdm import tqdm, trange 10 | import json 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 16 | 17 | from transformers.tokenization_bert import BertTokenizer 18 | from transformers.modeling_bert import BertForMaskedLM, BertOnlyMLMHead 19 | 20 | from transformers import AdamW 21 | from data_processors import get_task_processor 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | BERT_MODEL = 'bert-base-uncased' 26 | 27 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 28 | datefmt='%m/%d/%Y %H:%M:%S', 29 | level=logging.INFO) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class InputFeatures(object): 35 | """A single set of features of data.""" 36 | 37 | def __init__(self, init_ids, input_ids, input_mask, masked_lm_labels): 38 | self.init_ids = init_ids 39 | self.input_ids = input_ids 40 | self.input_mask = input_mask 41 | self.masked_lm_labels = masked_lm_labels 42 | 43 | 44 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, seed=12345): 45 | """Loads a data file into a list of `InputBatch`s.""" 46 | 47 | features = [] 48 | # ---- 49 | # dupe_factor = 5 50 | masked_lm_prob = 0.15 51 | max_predictions_per_seq = 20 52 | rng = random.Random(seed) 53 | 54 | 55 | for (ex_index, example) in enumerate(examples): 56 | modified_example = example.label + " " + example.text_a 57 | tokens_a = tokenizer.tokenize(modified_example) 58 | # Account for [CLS] and [SEP] and label with "- 3" 59 | if len(tokens_a) > max_seq_length - 3: 60 | tokens_a = tokens_a[0:(max_seq_length - 3)] 61 | 62 | # take care of prepending the class label in this code 63 | tokens = [] 64 | tokens.append("[CLS]") 65 | for token in tokens_a: 66 | tokens.append(token) 67 | tokens.append("[SEP]") 68 | masked_lm_labels = [-100] * max_seq_length 69 | 70 | cand_indexes = [] 71 | for (i, token) in enumerate(tokens): 72 | # making sure that masking of # prepended label is avoided 73 | if token == "[CLS]" or token == "[SEP]" or (token in label_list and i == 1): 74 | continue 75 | cand_indexes.append(i) 76 | 77 | rng.shuffle(cand_indexes) 78 | len_cand = len(cand_indexes) 79 | 80 | output_tokens = list(tokens) 81 | 82 | num_to_predict = min(max_predictions_per_seq, 83 | max(1, int(round(len(tokens) * masked_lm_prob)))) 84 | 85 | masked_lms_pos = [] 86 | covered_indexes = set() 87 | for index in cand_indexes: 88 | if len(masked_lms_pos) >= num_to_predict: 89 | break 90 | if index in covered_indexes: 91 | continue 92 | covered_indexes.add(index) 93 | 94 | masked_token = None 95 | # 80% of the time, replace with [MASK] 96 | if rng.random() < 0.8: 97 | masked_token = "[MASK]" 98 | else: 99 | # 10% of the time, keep original 100 | if rng.random() < 0.5: 101 | masked_token = tokens[index] 102 | # 10% of the time, replace with random word 103 | else: 104 | masked_token = tokens[cand_indexes[rng.randint(0, len_cand - 1)]] 105 | 106 | masked_lm_labels[index] = tokenizer.convert_tokens_to_ids([tokens[index]])[0] 107 | output_tokens[index] = masked_token 108 | masked_lms_pos.append(index) 109 | 110 | init_ids = tokenizer.convert_tokens_to_ids(tokens) 111 | input_ids = tokenizer.convert_tokens_to_ids(output_tokens) 112 | 113 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 114 | # tokens are attended to. 115 | input_mask = [1] * len(input_ids) 116 | 117 | # Zero-pad up to the sequence length. 118 | while len(input_ids) < max_seq_length: 119 | init_ids.append(0) 120 | input_ids.append(0) 121 | input_mask.append(0) 122 | 123 | assert len(init_ids) == max_seq_length 124 | assert len(input_ids) == max_seq_length 125 | assert len(input_mask) == max_seq_length 126 | 127 | if ex_index < 2: 128 | logger.info("*** Example ***") 129 | logger.info("guid: %s" % (example.guid)) 130 | logger.info("tokens: %s" % " ".join( 131 | [str(x) for x in tokens])) 132 | logger.info("init_ids: %s" % " ".join([str(x) for x in init_ids])) 133 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 134 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 135 | logger.info("masked_lm_labels: %s" % " ".join([str(x) for x in masked_lm_labels])) 136 | 137 | features.append( 138 | InputFeatures(init_ids=init_ids, 139 | input_ids=input_ids, 140 | input_mask=input_mask, 141 | masked_lm_labels=masked_lm_labels)) 142 | return features 143 | 144 | 145 | def prepare_data(features): 146 | all_init_ids = torch.tensor([f.init_ids for f in features], dtype=torch.long) 147 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 148 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 149 | all_masked_lm_labels = torch.tensor([f.masked_lm_labels for f in features], 150 | dtype=torch.long) 151 | tensor_data = TensorDataset(all_init_ids, all_input_ids, all_input_mask, all_masked_lm_labels) 152 | return tensor_data 153 | 154 | 155 | def rev_wordpiece(str): 156 | #print(str) 157 | if len(str) > 1: 158 | for i in range(len(str)-1, 0, -1): 159 | if str[i] == '[PAD]': 160 | str.remove(str[i]) 161 | elif len(str[i]) > 1 and str[i][0]=='#' and str[i][1]=='#': 162 | str[i-1] += str[i][2:] 163 | str.remove(str[i]) 164 | return " ".join(str[2:-1]) 165 | 166 | 167 | def main(): 168 | parser = argparse.ArgumentParser() 169 | 170 | ## Required parameters 171 | parser.add_argument("--data_dir", default="datasets", type=str, 172 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 173 | parser.add_argument("--output_dir", default="aug_data", type=str, 174 | help="The output dir for augmented dataset") 175 | parser.add_argument("--task_name",default="subj",type=str, 176 | help="The name of the task to train.") 177 | parser.add_argument("--max_seq_length", default=64, type=int, 178 | help="The maximum total input sequence length after WordPiece tokenization. \n" 179 | "Sequences longer than this will be truncated, and sequences shorter \n" 180 | "than this will be padded.") 181 | # parser.add_argument("--do_lower_case", default=False, action='store_true', 182 | # help="Set this flag if you are using an uncased model.") 183 | parser.add_argument('--cache', default="transformers_cache", type=str) 184 | parser.add_argument("--train_batch_size", default=8, type=int, 185 | help="Total batch size for training.") 186 | parser.add_argument("--learning_rate", default=4e-5, type=float, 187 | help="The initial learning rate for Adam.") 188 | parser.add_argument("--num_train_epochs", default=10.0, type=float, 189 | help="Total number of training epochs to perform.") 190 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 191 | help="Proportion of training to perform linear learning rate warmup for. " 192 | "E.g., 0.1 = 10%% of training.") 193 | parser.add_argument('--seed', type=int, default=42, 194 | help="random seed for initialization") 195 | parser.add_argument('--sample_num', type=int, default=1, 196 | help="sample number") 197 | parser.add_argument('--sample_ratio', type=int, default=7, 198 | help="sample ratio") 199 | parser.add_argument('--gpu', type=int, default=0, 200 | help="gpu id") 201 | parser.add_argument('--temp', type=float, default=1.0, 202 | help="temperature") 203 | 204 | args = parser.parse_args() 205 | 206 | print(args) 207 | train_cmodbert_and_augment(args) 208 | 209 | 210 | def compute_dev_loss(model, dev_dataloader): 211 | model.eval() 212 | sum_loss = 0. 213 | for step, batch in enumerate(dev_dataloader): 214 | batch = tuple(t.to(device) for t in batch) 215 | _, input_ids, input_mask, masked_ids = batch 216 | inputs = {'input_ids': batch[1], 217 | 'attention_mask': batch[2], 218 | 'masked_lm_labels': batch[3]} 219 | 220 | outputs = model(**inputs) 221 | loss = outputs[0] 222 | sum_loss += loss.item() 223 | return sum_loss 224 | 225 | 226 | def augment_train_data(model, tokenizer, train_data, label_list, args): 227 | # load the best model 228 | 229 | train_sampler = SequentialSampler(train_data) 230 | train_dataloader = DataLoader(train_data, sampler=train_sampler, 231 | batch_size=args.train_batch_size) 232 | best_model_path = os.path.join(args.output_dir, "best_cmodbert.pt") 233 | if os.path.exists(best_model_path): 234 | model.load_state_dict(torch.load(best_model_path)) 235 | else: 236 | raise ValueError("Unable to find the saved model at {}".format(best_model_path)) 237 | 238 | save_train_path = os.path.join(args.output_dir, "cmodbert_aug.tsv") 239 | save_train_file = open(save_train_path, 'w') 240 | 241 | MASK_id = tokenizer.convert_tokens_to_ids(['[MASK]'])[0] 242 | tsv_writer = csv.writer(save_train_file, delimiter='\t') 243 | 244 | for step, batch in enumerate(train_dataloader): 245 | model.eval() 246 | batch = tuple(t.to(device) for t in batch) 247 | init_ids, _, input_mask, _ = batch 248 | input_lens = [sum(mask).item() for mask in input_mask] 249 | masked_idx = np.squeeze( 250 | [np.random.randint(2, l, max((l-2) // args.sample_ratio, 1)) for l in input_lens]) 251 | for ids, idx in zip(init_ids, masked_idx): 252 | ids[idx] = MASK_id 253 | 254 | inputs = {'input_ids': init_ids, 255 | 'attention_mask': input_mask} 256 | 257 | outputs = model(**inputs) 258 | predictions = outputs[0] # model(init_ids, segment_ids, input_mask) 259 | predictions = F.softmax(predictions / args.temp, dim=2) 260 | 261 | for ids, idx, preds in zip(init_ids, masked_idx, predictions): 262 | preds = torch.multinomial(preds, args.sample_num, replacement=True)[idx] 263 | if len(preds.size()) == 2: 264 | preds = torch.transpose(preds, 0, 1) 265 | for pred in preds: 266 | ids[idx] = pred 267 | new_str = tokenizer.convert_ids_to_tokens(ids.cpu().numpy()) 268 | label = new_str[1] 269 | new_str = rev_wordpiece(new_str) 270 | tsv_writer.writerow([label, new_str]) 271 | 272 | 273 | def train_cmodbert_and_augment(args): 274 | task_name = args.task_name 275 | os.makedirs(args.output_dir, exist_ok=True) 276 | 277 | random.seed(args.seed) 278 | np.random.seed(args.seed) 279 | torch.manual_seed(args.seed) 280 | 281 | os.makedirs(args.output_dir, exist_ok=True) 282 | processor = get_task_processor(task_name, args.data_dir) 283 | label_list = processor.get_labels(task_name) 284 | 285 | # load train and dev data 286 | train_examples = processor.get_train_examples() 287 | dev_examples = processor.get_dev_examples() 288 | 289 | tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, 290 | do_lower_case=True, 291 | cache_dir=args.cache) 292 | 293 | model = BertForMaskedLM.from_pretrained(BERT_MODEL, 294 | cache_dir=args.cache) 295 | 296 | tokenizer.add_tokens(label_list) 297 | # Adding embeddings such that they are randomly initialized, however, follow instructions about initializing them 298 | # intelligently 299 | model.resize_token_embeddings(len(tokenizer)) 300 | model.cls = BertOnlyMLMHead(model.config) 301 | 302 | model.to(device) 303 | 304 | # train data 305 | train_features = convert_examples_to_features(train_examples, label_list, 306 | args.max_seq_length, 307 | tokenizer, args.seed) 308 | train_data = prepare_data(train_features) 309 | train_sampler = RandomSampler(train_data) 310 | train_dataloader = DataLoader(train_data, sampler=train_sampler, 311 | batch_size=args.train_batch_size) 312 | 313 | 314 | # dev data 315 | dev_features = convert_examples_to_features(dev_examples, label_list, 316 | args.max_seq_length, 317 | tokenizer, args.seed) 318 | dev_data = prepare_data(dev_features) 319 | dev_sampler = SequentialSampler(dev_data) 320 | dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, 321 | batch_size=args.train_batch_size) 322 | 323 | num_train_steps = int(len(train_features) / args.train_batch_size * args.num_train_epochs) 324 | logger.info("***** Running training *****") 325 | logger.info(" Num examples = %d", len(train_features)) 326 | logger.info(" Batch size = %d", args.train_batch_size) 327 | logger.info(" Num steps = %d", num_train_steps) 328 | 329 | # Prepare optimizer 330 | t_total = num_train_steps 331 | no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.weight'] 332 | optimizer_grouped_parameters = [ 333 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 334 | 'weight_decay': 0.01}, 335 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 336 | 'weight_decay': 0.0} 337 | ] 338 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) 339 | 340 | best_dev_loss = float('inf') 341 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 342 | avg_loss = 0. 343 | model.train() 344 | for step, batch in enumerate(train_dataloader): 345 | batch = tuple(t.to(device) for t in batch) 346 | _, input_ids, input_mask, masked_ids = batch 347 | inputs = {'input_ids': batch[1], 348 | 'attention_mask': batch[2], 349 | 'masked_lm_labels': batch[3]} 350 | 351 | outputs = model(**inputs) 352 | loss = outputs[0] 353 | #loss = model(input_ids, segment_ids, input_mask, masked_ids) 354 | loss.backward() 355 | avg_loss += loss.item() 356 | optimizer.step() 357 | model.zero_grad() 358 | if (step + 1) % 50 == 0: 359 | print("avg_loss: {}".format(avg_loss / 50)) 360 | avg_loss = 0. 361 | 362 | # eval on dev after every epoch 363 | dev_loss = compute_dev_loss(model, dev_dataloader) 364 | print("Epoch {}, Dev loss {}".format(epoch, dev_loss)) 365 | if dev_loss < best_dev_loss: 366 | best_dev_loss = dev_loss 367 | print("Saving model. Best dev so far {}".format(best_dev_loss)) 368 | save_model_path = os.path.join(args.output_dir, 'best_cmodbert.pt') 369 | torch.save(model.state_dict(), save_model_path) 370 | 371 | # augment data using the best model 372 | augment_train_data(model, tokenizer, train_data, label_list, args) 373 | 374 | 375 | if __name__ == "__main__": 376 | main() -------------------------------------------------------------------------------- /src/bert_aug/cmodbertp.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import csv 5 | import os 6 | import logging 7 | import argparse 8 | import random 9 | from tqdm import tqdm, trange 10 | import json 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 16 | 17 | from transformers.tokenization_bert import BertTokenizer 18 | from transformers.modeling_bert import BertForMaskedLM, BertOnlyMLMHead 19 | 20 | from transformers import AdamW 21 | from data_processors import get_task_processor 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | BERT_MODEL = 'bert-base-uncased' 26 | 27 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 28 | datefmt='%m/%d/%Y %H:%M:%S', 29 | level=logging.INFO) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class InputFeatures(object): 35 | """A single set of features of data.""" 36 | 37 | def __init__(self, init_ids, input_ids, input_mask, masked_lm_labels, label_length): 38 | self.init_ids = init_ids 39 | self.input_ids = input_ids 40 | self.input_mask = input_mask 41 | self.masked_lm_labels = masked_lm_labels 42 | self.label_length = label_length 43 | 44 | 45 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, seed=12345): 46 | """Loads a data file into a list of `InputBatch`s.""" 47 | 48 | features = [] 49 | # ---- 50 | # dupe_factor = 5 51 | masked_lm_prob = 0.15 52 | max_predictions_per_seq = 20 53 | rng = random.Random(seed) 54 | 55 | for (ex_index, example) in enumerate(examples): 56 | modified_example = example.label + " " + example.text_a 57 | label_len = len(tokenizer.tokenize(example.label)) 58 | tokens_a = tokenizer.tokenize(modified_example) 59 | # Account for [CLS] and [SEP] and label with "(2+label_len)" 60 | if len(tokens_a) > max_seq_length - (2+label_len): 61 | tokens_a = tokens_a[0:(max_seq_length - (2+label_len))] 62 | 63 | # take care of prepending the class label in this code 64 | tokens = [] 65 | tokens.append("[CLS]") 66 | for token in tokens_a: 67 | tokens.append(token) 68 | tokens.append("[SEP]") 69 | masked_lm_labels = [-100] * max_seq_length 70 | 71 | cand_indexes = [] 72 | for (i, token) in enumerate(tokens): 73 | # making sure that masking of # prepended label is avoided 74 | if token == "[CLS]" or token == "[SEP]" or (i < label_len + 1): 75 | continue 76 | cand_indexes.append(i) 77 | 78 | rng.shuffle(cand_indexes) 79 | len_cand = len(cand_indexes) 80 | 81 | output_tokens = list(tokens) 82 | 83 | num_to_predict = min(max_predictions_per_seq, 84 | max(1, int(round(len(tokens) * masked_lm_prob)))) 85 | 86 | masked_lms_pos = [] 87 | covered_indexes = set() 88 | for index in cand_indexes: 89 | if len(masked_lms_pos) >= num_to_predict: 90 | break 91 | if index in covered_indexes: 92 | continue 93 | covered_indexes.add(index) 94 | 95 | masked_token = None 96 | # 80% of the time, replace with [MASK] 97 | if rng.random() < 0.8: 98 | masked_token = "[MASK]" 99 | else: 100 | # 10% of the time, keep original 101 | if rng.random() < 0.5: 102 | masked_token = tokens[index] 103 | # 10% of the time, replace with random word 104 | else: 105 | masked_token = tokens[cand_indexes[rng.randint(0, len_cand - 1)]] 106 | 107 | masked_lm_labels[index] = tokenizer.convert_tokens_to_ids([tokens[index]])[0] 108 | output_tokens[index] = masked_token 109 | masked_lms_pos.append(index) 110 | 111 | init_ids = tokenizer.convert_tokens_to_ids(tokens) 112 | input_ids = tokenizer.convert_tokens_to_ids(output_tokens) 113 | 114 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 115 | # tokens are attended to. 116 | input_mask = [1] * len(input_ids) 117 | 118 | # Zero-pad up to the sequence length. 119 | while len(input_ids) < max_seq_length: 120 | init_ids.append(0) 121 | input_ids.append(0) 122 | input_mask.append(0) 123 | 124 | assert len(init_ids) == max_seq_length 125 | assert len(input_ids) == max_seq_length 126 | assert len(input_mask) == max_seq_length 127 | 128 | if ex_index < 2: 129 | logger.info("*** Example ***") 130 | logger.info("guid: %s" % (example.guid)) 131 | logger.info("tokens: %s" % " ".join( 132 | [str(x) for x in tokens])) 133 | logger.info("init_ids: %s" % " ".join([str(x) for x in init_ids])) 134 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 135 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 136 | logger.info("masked_lm_labels: %s" % " ".join([str(x) for x in masked_lm_labels])) 137 | logger.info("label_len: %s" % str(label_len)) 138 | 139 | features.append( 140 | InputFeatures(init_ids=init_ids, 141 | input_ids=input_ids, 142 | input_mask=input_mask, 143 | masked_lm_labels=masked_lm_labels, 144 | label_length=label_len)) 145 | return features 146 | 147 | 148 | def prepare_data(features): 149 | all_init_ids = torch.tensor([f.init_ids for f in features], dtype=torch.long) 150 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 151 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 152 | all_masked_lm_labels = torch.tensor([f.masked_lm_labels for f in features], 153 | dtype=torch.long) 154 | all_label_lengths = torch.tensor([f.label_length for f in features], 155 | dtype=torch.long) 156 | tensor_data = TensorDataset(all_init_ids, all_input_ids, all_input_mask, all_masked_lm_labels, all_label_lengths) 157 | return tensor_data 158 | 159 | 160 | def rev_wordpiece(str): 161 | #print(str) 162 | if len(str) > 1: 163 | for i in range(len(str)-1, 0, -1): 164 | if str[i] == '[PAD]': 165 | str.remove(str[i]) 166 | elif len(str[i]) > 1 and str[i][0]=='#' and str[i][1]=='#': 167 | str[i-1] += str[i][2:] 168 | str.remove(str[i]) 169 | return " ".join(str[2:-1]) 170 | 171 | 172 | def main(): 173 | parser = argparse.ArgumentParser() 174 | 175 | ## Required parameters 176 | parser.add_argument("--data_dir", default="datasets", type=str, 177 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 178 | parser.add_argument("--output_dir", default="aug_data", type=str, 179 | help="The output dir for augmented dataset") 180 | parser.add_argument("--task_name",default="subj",type=str, 181 | help="The name of the task to train.") 182 | parser.add_argument("--max_seq_length", default=64, type=int, 183 | help="The maximum total input sequence length after WordPiece tokenization. \n" 184 | "Sequences longer than this will be truncated, and sequences shorter \n" 185 | "than this will be padded.") 186 | # parser.add_argument("--do_lower_case", default=False, action='store_true', 187 | # help="Set this flag if you are using an uncased model.") 188 | parser.add_argument('--cache', default="transformers_cache", type=str) 189 | parser.add_argument("--train_batch_size", default=8, type=int, 190 | help="Total batch size for training.") 191 | parser.add_argument("--learning_rate", default=4e-5, type=float, 192 | help="The initial learning rate for Adam.") 193 | parser.add_argument("--num_train_epochs", default=10.0, type=float, 194 | help="Total number of training epochs to perform.") 195 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 196 | help="Proportion of training to perform linear learning rate warmup for. " 197 | "E.g., 0.1 = 10%% of training.") 198 | parser.add_argument('--seed', type=int, default=42, 199 | help="random seed for initialization") 200 | parser.add_argument('--sample_num', type=int, default=1, 201 | help="sample number") 202 | parser.add_argument('--sample_ratio', type=int, default=7, 203 | help="sample ratio") 204 | parser.add_argument('--gpu', type=int, default=0, 205 | help="gpu id") 206 | parser.add_argument('--temp', type=float, default=1.0, 207 | help="temperature") 208 | 209 | args = parser.parse_args() 210 | 211 | print(args) 212 | train_cmodbertp_and_augment(args) 213 | 214 | 215 | def compute_dev_loss(model, dev_dataloader): 216 | model.eval() 217 | sum_loss = 0. 218 | for step, batch in enumerate(dev_dataloader): 219 | batch = tuple(t.to(device) for t in batch) 220 | _, input_ids, input_mask, masked_ids, label_lengths = batch 221 | inputs = {'input_ids': batch[1], 222 | 'attention_mask': batch[2], 223 | 'masked_lm_labels': batch[3]} 224 | 225 | outputs = model(**inputs) 226 | loss = outputs[0] 227 | sum_loss += loss.item() 228 | return sum_loss 229 | 230 | 231 | def augment_train_data(model, tokenizer, train_data, label_list, args): 232 | # load the best model 233 | 234 | train_sampler = SequentialSampler(train_data) 235 | train_dataloader = DataLoader(train_data, sampler=train_sampler, 236 | batch_size=args.train_batch_size) 237 | best_model_path = os.path.join(args.output_dir, "best_cmodbertp.pt") 238 | if os.path.exists(best_model_path): 239 | model.load_state_dict(torch.load(best_model_path)) 240 | else: 241 | raise ValueError("Unable to find the saved model at {}".format(best_model_path)) 242 | 243 | save_train_path = os.path.join(args.output_dir, "cmodbertp_aug.tsv") 244 | save_train_file = open(save_train_path, 'w') 245 | 246 | MASK_id = tokenizer.convert_tokens_to_ids(['[MASK]'])[0] 247 | tsv_writer = csv.writer(save_train_file, delimiter='\t') 248 | 249 | for step, batch in enumerate(train_dataloader): 250 | model.eval() 251 | batch = tuple(t.to(device) for t in batch) 252 | init_ids, _, input_mask, _, label_lengths = batch 253 | input_lens = [sum(mask).item() for mask in input_mask] 254 | masked_idx = np.squeeze( 255 | [np.random.randint(1 + label_lengths[i].cpu().numpy(), l, 256 | max((l - (1 + label_lengths[i].cpu().numpy())) // args.sample_ratio, 1)) 257 | for i, l in enumerate(input_lens)]) 258 | for ids, idx in zip(init_ids, masked_idx): 259 | ids[idx] = MASK_id 260 | 261 | inputs = {'input_ids': init_ids, 262 | 'attention_mask': input_mask} 263 | 264 | outputs = model(**inputs) 265 | predictions = outputs[0] # model(init_ids, segment_ids, input_mask) 266 | predictions = F.softmax(predictions / args.temp, dim=2) 267 | 268 | lower_label_list = [x.lower() for x in label_list] 269 | for ids, idx, preds, label_len in zip(init_ids, masked_idx, predictions, label_lengths): 270 | preds = torch.multinomial(preds, args.sample_num, replacement=True)[idx] 271 | if len(preds.size()) == 2: 272 | preds = torch.transpose(preds, 0, 1) 273 | for pred in preds: 274 | ids[idx] = pred 275 | new_str = tokenizer.convert_ids_to_tokens(ids.cpu().numpy()) 276 | label_cand = "" 277 | for x in range(label_len): 278 | label_cand += '{}'.format(new_str[1+x][2:] if new_str[1+x].startswith('##') else new_str[1+x]) 279 | label = label_list[lower_label_list.index(label_cand.lower())] 280 | new_str = rev_wordpiece(new_str) 281 | tsv_writer.writerow([label, new_str]) 282 | 283 | 284 | def train_cmodbertp_and_augment(args): 285 | task_name = args.task_name 286 | os.makedirs(args.output_dir, exist_ok=True) 287 | 288 | random.seed(args.seed) 289 | np.random.seed(args.seed) 290 | torch.manual_seed(args.seed) 291 | 292 | os.makedirs(args.output_dir, exist_ok=True) 293 | processor = get_task_processor(task_name, args.data_dir) 294 | label_list = processor.get_labels(task_name) 295 | 296 | # load train and dev data 297 | train_examples = processor.get_train_examples() 298 | dev_examples = processor.get_dev_examples() 299 | 300 | tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, 301 | do_lower_case=True, 302 | cache_dir=args.cache) 303 | 304 | model = BertForMaskedLM.from_pretrained(BERT_MODEL, 305 | cache_dir=args.cache) 306 | 307 | # tokenizer.add_tokens(label_list) 308 | # Adding embeddings such that they are randomly initialized, however, follow instructions about initializing them 309 | # intelligently 310 | # model.resize_token_embeddings(len(tokenizer)) 311 | # model.cls = BertOnlyMLMHead(model.config) 312 | 313 | model.to(device) 314 | 315 | # train data 316 | train_features = convert_examples_to_features(train_examples, label_list, 317 | args.max_seq_length, 318 | tokenizer, args.seed) 319 | train_data = prepare_data(train_features) 320 | train_sampler = RandomSampler(train_data) 321 | train_dataloader = DataLoader(train_data, sampler=train_sampler, 322 | batch_size=args.train_batch_size) 323 | 324 | 325 | # dev data 326 | dev_features = convert_examples_to_features(dev_examples, label_list, 327 | args.max_seq_length, 328 | tokenizer, args.seed) 329 | dev_data = prepare_data(dev_features) 330 | dev_sampler = SequentialSampler(dev_data) 331 | dev_dataloader = DataLoader(dev_data, sampler=dev_sampler, 332 | batch_size=args.train_batch_size) 333 | 334 | num_train_steps = int(len(train_features) / args.train_batch_size * args.num_train_epochs) 335 | logger.info("***** Running training *****") 336 | logger.info(" Num examples = %d", len(train_features)) 337 | logger.info(" Batch size = %d", args.train_batch_size) 338 | logger.info(" Num steps = %d", num_train_steps) 339 | 340 | # Prepare optimizer 341 | t_total = num_train_steps 342 | no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.weight'] 343 | optimizer_grouped_parameters = [ 344 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 345 | 'weight_decay': 0.01}, 346 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 347 | 'weight_decay': 0.0} 348 | ] 349 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) 350 | 351 | best_dev_loss = float('inf') 352 | for epoch in trange(int(args.num_train_epochs), desc="Epoch"): 353 | avg_loss = 0. 354 | model.train() 355 | for step, batch in enumerate(train_dataloader): 356 | batch = tuple(t.to(device) for t in batch) 357 | _, input_ids, input_mask, masked_ids, label_lengths = batch 358 | inputs = {'input_ids': batch[1], 359 | 'attention_mask': batch[2], 360 | 'masked_lm_labels': batch[3]} 361 | 362 | outputs = model(**inputs) 363 | loss = outputs[0] 364 | #loss = model(input_ids, segment_ids, input_mask, masked_ids) 365 | loss.backward() 366 | avg_loss += loss.item() 367 | optimizer.step() 368 | model.zero_grad() 369 | if (step + 1) % 50 == 0: 370 | print("avg_loss: {}".format(avg_loss / 50)) 371 | avg_loss = 0. 372 | 373 | # eval on dev after every epoch 374 | dev_loss = compute_dev_loss(model, dev_dataloader) 375 | print("Epoch {}, Dev loss {}".format(epoch, dev_loss)) 376 | if dev_loss < best_dev_loss: 377 | best_dev_loss = dev_loss 378 | print("Saving model. Best dev so far {}".format(best_dev_loss)) 379 | save_model_path = os.path.join(args.output_dir, 'best_cmodbertp.pt') 380 | torch.save(model.state_dict(), save_model_path) 381 | 382 | # augment data using the best model 383 | augment_train_data(model, tokenizer, train_data, label_list, args) 384 | 385 | 386 | if __name__ == "__main__": 387 | main() -------------------------------------------------------------------------------- /src/bert_aug/data_processors.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import random 5 | import os 6 | import csv 7 | 8 | 9 | def get_task_processor(task, data_dir): 10 | """ 11 | A TSV processor for stsa, trec and snips dataset. 12 | """ 13 | if task == 'stsa': 14 | return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1) 15 | elif task == 'trec': 16 | return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1) 17 | elif task == 'snips': 18 | return TSVDataProcessor(data_dir=data_dir, skip_header=False, label_col=0, text_col=1) 19 | else: 20 | raise ValueError('Unknown task') 21 | 22 | 23 | def get_data(task, data_dir, data_seed=159): 24 | random.seed(data_seed) 25 | processor = get_task_processor(task, data_dir) 26 | 27 | examples = dict() 28 | 29 | examples['train'] = processor.get_train_examples() 30 | examples['dev'] = processor.get_dev_examples() 31 | examples['test'] = processor.get_test_examples() 32 | 33 | for key, value in examples.items(): 34 | print('#{}: {}'.format(key, len(value))) 35 | return examples, processor.get_labels(task) 36 | 37 | 38 | class InputExample: 39 | """A single training/test example for simple sequence classification.""" 40 | 41 | def __init__(self, guid, text_a, text_b=None, label=None): 42 | self.guid = guid 43 | self.text_a = text_a 44 | self.text_b = text_b 45 | self.label = label 46 | 47 | 48 | class InputFeatures(object): 49 | """A single set of features of data.""" 50 | 51 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 52 | self.input_ids = input_ids 53 | self.input_mask = input_mask 54 | self.segment_ids = segment_ids 55 | self.label_id = label_id 56 | 57 | def __getitem__(self, item): 58 | return [self.input_ids, self.input_mask, 59 | self.segment_ids, self.label_id][item] 60 | 61 | 62 | class DatasetProcessor(object): 63 | """Base class for data converters for sequence classification data sets.""" 64 | 65 | def get_train_examples(self): 66 | """Gets a collection of `InputExample`s for the train set.""" 67 | raise NotImplementedError() 68 | 69 | def get_dev_examples(self): 70 | """Gets a collection of `InputExample`s for the dev set.""" 71 | raise NotImplementedError() 72 | 73 | def get_test_examples(self): 74 | """Gets a collection of `InputExample`s for the dev set.""" 75 | raise NotImplementedError() 76 | 77 | def get_labels(self, task_name): 78 | """Gets the list of labels for this data set.""" 79 | raise NotImplementedError() 80 | 81 | @classmethod 82 | def _read_tsv(cls, input_file, quotechar=None): 83 | """Reads a tab separated value file.""" 84 | with open(input_file, "r") as f: 85 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 86 | lines = [] 87 | for line in reader: 88 | lines.append(line) 89 | return lines 90 | 91 | 92 | class TSVDataProcessor(DatasetProcessor): 93 | """Processor for dataset to be augmented.""" 94 | 95 | def __init__(self, data_dir, skip_header, label_col, text_col): 96 | self.data_dir = data_dir 97 | self.skip_header = skip_header 98 | self.label_col = label_col 99 | self.text_col = text_col 100 | 101 | def get_train_examples(self): 102 | """See base class.""" 103 | return self._create_examples( 104 | self._read_tsv(os.path.join(self.data_dir, "train.tsv")), "train") 105 | 106 | def get_dev_examples(self): 107 | """See base class.""" 108 | return self._create_examples( 109 | self._read_tsv(os.path.join(self.data_dir, "dev.tsv")), "dev") 110 | 111 | def get_test_examples(self): 112 | """See base class.""" 113 | return self._create_examples( 114 | self._read_tsv(os.path.join(self.data_dir, "test.tsv")), "test") 115 | 116 | def get_labels(self, task_name): 117 | """add your dataset here""" 118 | labels = set() 119 | with open(os.path.join(self.data_dir, "train.tsv"), "r") as in_file: 120 | for line in in_file: 121 | labels.add(line.split("\t")[self.label_col]) 122 | return sorted(labels) 123 | 124 | def _create_examples(self, lines, set_type): 125 | """Creates examples for the training and dev sets.""" 126 | examples = [] 127 | for (i, line) in enumerate(lines): 128 | if self.skip_header and i == 0: 129 | continue 130 | guid = "%s-%s" % (set_type, i) 131 | text_a = line[self.text_col] 132 | label = line[self.label_col] 133 | examples.append( 134 | InputExample(guid=guid, text_a=text_a, label=label)) 135 | return examples 136 | 137 | 138 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 139 | """Truncates a sequence pair in place to the maximum length.""" 140 | 141 | # This is a simple heuristic which will always truncate the longer sequence 142 | # one token at a time. This makes more sense than truncating an equal percent 143 | # of tokens from each, since if one sequence is very short then each token 144 | # that's truncated likely contains more information than a longer sequence. 145 | while True: 146 | total_length = len(tokens_a) + len(tokens_b) 147 | if total_length <= max_length: 148 | break 149 | if len(tokens_a) > len(tokens_b): 150 | tokens_a.pop() 151 | else: 152 | tokens_b.pop() 153 | -------------------------------------------------------------------------------- /src/bert_aug/eda.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | # Original Copyright https://github.com/jasonwei20/eda_nlp 4 | 5 | import random 6 | from random import shuffle 7 | import argparse 8 | 9 | # random.seed(1) 10 | 11 | # stop words list 12 | stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 13 | 'ours', 'ourselves', 'you', 'your', 'yours', 14 | 'yourself', 'yourselves', 'he', 'him', 'his', 15 | 'himself', 'she', 'her', 'hers', 'herself', 16 | 'it', 'its', 'itself', 'they', 'them', 'their', 17 | 'theirs', 'themselves', 'what', 'which', 'who', 18 | 'whom', 'this', 'that', 'these', 'those', 'am', 19 | 'is', 'are', 'was', 'were', 'be', 'been', 'being', 20 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 21 | 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 22 | 'because', 'as', 'until', 'while', 'of', 'at', 23 | 'by', 'for', 'with', 'about', 'against', 'between', 24 | 'into', 'through', 'during', 'before', 'after', 25 | 'above', 'below', 'to', 'from', 'up', 'down', 'in', 26 | 'out', 'on', 'off', 'over', 'under', 'again', 27 | 'further', 'then', 'once', 'here', 'there', 'when', 28 | 'where', 'why', 'how', 'all', 'any', 'both', 'each', 29 | 'few', 'more', 'most', 'other', 'some', 'such', 'no', 30 | 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 31 | 'very', 's', 't', 'can', 'will', 'just', 'don', 32 | 'should', 'now', ''] 33 | 34 | # cleaning up text 35 | import re 36 | 37 | 38 | def get_only_chars(line): 39 | clean_line = "" 40 | 41 | line = line.replace("’", "") 42 | line = line.replace("'", "") 43 | line = line.replace("-", " ") # replace hyphens with spaces 44 | line = line.replace("\t", " ") 45 | line = line.replace("\n", " ") 46 | line = line.lower() 47 | 48 | for char in line: 49 | if char in 'qwertyuiopasdfghjklzxcvbnm ': 50 | clean_line += char 51 | else: 52 | clean_line += ' ' 53 | 54 | clean_line = re.sub(' +', ' ', clean_line) # delete extra spaces 55 | if clean_line[0] == ' ': 56 | clean_line = clean_line[1:] 57 | return clean_line 58 | 59 | 60 | ######################################################################## 61 | # Synonym replacement 62 | # Replace n words in the sentence with synonyms from wordnet 63 | ######################################################################## 64 | 65 | # for the first time you use wordnet 66 | # import nltk 67 | # nltk.download('wordnet') 68 | from nltk.corpus import wordnet 69 | 70 | 71 | def synonym_replacement(words, n): 72 | new_words = words.copy() 73 | random_word_list = list(set([word for word in words if word not in stop_words])) 74 | random.shuffle(random_word_list) 75 | num_replaced = 0 76 | for random_word in random_word_list: 77 | synonyms = get_synonyms(random_word) 78 | if len(synonyms) >= 1: 79 | synonym = random.choice(list(synonyms)) 80 | new_words = [synonym if word == random_word else word for word in new_words] 81 | # print("replaced", random_word, "with", synonym) 82 | num_replaced += 1 83 | if num_replaced >= n: # only replace up to n words 84 | break 85 | 86 | # this is stupid but we need it, trust me 87 | sentence = ' '.join(new_words) 88 | new_words = sentence.split(' ') 89 | 90 | return new_words 91 | 92 | 93 | def get_synonyms(word): 94 | synonyms = set() 95 | for syn in wordnet.synsets(word): 96 | for l in syn.lemmas(): 97 | synonym = l.name().replace("_", " ").replace("-", " ").lower() 98 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm']) 99 | synonyms.add(synonym) 100 | if word in synonyms: 101 | synonyms.remove(word) 102 | return list(synonyms) 103 | 104 | 105 | ######################################################################## 106 | # Random deletion 107 | # Randomly delete words from the sentence with probability p 108 | ######################################################################## 109 | 110 | def random_deletion(words, p): 111 | # obviously, if there's only one word, don't delete it 112 | if len(words) == 1: 113 | return words 114 | 115 | # randomly delete words with probability p 116 | new_words = [] 117 | for word in words: 118 | r = random.uniform(0, 1) 119 | if r > p: 120 | new_words.append(word) 121 | 122 | # if you end up deleting all words, just return a random word 123 | if len(new_words) == 0: 124 | rand_int = random.randint(0, len(words) - 1) 125 | return [words[rand_int]] 126 | 127 | return new_words 128 | 129 | 130 | ######################################################################## 131 | # Random swap 132 | # Randomly swap two words in the sentence n times 133 | ######################################################################## 134 | 135 | def random_swap(words, n): 136 | new_words = words.copy() 137 | for _ in range(n): 138 | new_words = swap_word(new_words) 139 | return new_words 140 | 141 | 142 | def swap_word(new_words): 143 | random_idx_1 = random.randint(0, len(new_words) - 1) 144 | random_idx_2 = random_idx_1 145 | counter = 0 146 | while random_idx_2 == random_idx_1: 147 | random_idx_2 = random.randint(0, len(new_words) - 1) 148 | counter += 1 149 | if counter > 3: 150 | return new_words 151 | new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] 152 | return new_words 153 | 154 | 155 | ######################################################################## 156 | # Random insertion 157 | # Randomly insert n words into the sentence 158 | ######################################################################## 159 | 160 | def random_insertion(words, n): 161 | new_words = words.copy() 162 | for _ in range(n): 163 | add_word(new_words) 164 | return new_words 165 | 166 | 167 | def add_word(new_words): 168 | synonyms = [] 169 | counter = 0 170 | while len(synonyms) < 1: 171 | random_word = new_words[random.randint(0, len(new_words) - 1)] 172 | synonyms = get_synonyms(random_word) 173 | counter += 1 174 | if counter >= 10: 175 | return 176 | random_synonym = synonyms[0] 177 | random_idx = random.randint(0, len(new_words) - 1) 178 | new_words.insert(random_idx, random_synonym) 179 | 180 | 181 | ######################################################################## 182 | # main data augmentation function 183 | ######################################################################## 184 | 185 | def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9): 186 | sentence = get_only_chars(sentence) 187 | words = sentence.split(' ') 188 | words = [word for word in words if word is not ''] 189 | num_words = len(words) 190 | 191 | augmented_sentences = [] 192 | num_new_per_technique = int(num_aug / 4) + 1 193 | n_sr = max(1, int(alpha_sr * num_words)) 194 | n_ri = max(1, int(alpha_ri * num_words)) 195 | n_rs = max(1, int(alpha_rs * num_words)) 196 | 197 | # sr 198 | for _ in range(num_new_per_technique): 199 | a_words = synonym_replacement(words, n_sr) 200 | augmented_sentences.append(' '.join(a_words)) 201 | 202 | # ri 203 | for _ in range(num_new_per_technique): 204 | a_words = random_insertion(words, n_ri) 205 | augmented_sentences.append(' '.join(a_words)) 206 | 207 | # rs 208 | for _ in range(num_new_per_technique): 209 | a_words = random_swap(words, n_rs) 210 | augmented_sentences.append(' '.join(a_words)) 211 | 212 | # rd 213 | for _ in range(num_new_per_technique): 214 | a_words = random_deletion(words, p_rd) 215 | augmented_sentences.append(' '.join(a_words)) 216 | 217 | augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences] 218 | shuffle(augmented_sentences) 219 | 220 | # trim so that we have the desired number of augmented sentences 221 | if num_aug >= 1: 222 | augmented_sentences = augmented_sentences[:num_aug] 223 | else: 224 | keep_prob = num_aug / len(augmented_sentences) 225 | augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob] 226 | 227 | # append the original sentence 228 | augmented_sentences.append(sentence) 229 | 230 | return augmented_sentences 231 | 232 | 233 | def gen_eda(train_orig, output_file, alpha, num_aug=1): 234 | 235 | writer = open(output_file, 'w') 236 | lines = open(train_orig, 'r').readlines() 237 | 238 | for i, line in enumerate(lines): 239 | parts = line[:-1].split('\t') 240 | label = parts[0] 241 | sentence = parts[1] 242 | aug_sentences = eda(sentence, alpha_sr=alpha, alpha_ri=alpha, alpha_rs=alpha, p_rd=alpha, num_aug=num_aug) 243 | for aug_sentence in aug_sentences: 244 | writer.write(label + "\t" + aug_sentence + '\n') 245 | 246 | writer.close() 247 | print("generated augmented sentences with eda for " + train_orig + " to " + output_file + " with num_aug=" + str(num_aug)) 248 | 249 | 250 | def main(): 251 | parser = argparse.ArgumentParser() 252 | parser.add_argument("--input", required=True, type=str, help="input file of unaugmented data") 253 | parser.add_argument("--output", required=False, type=str, help="output file of unaugmented data") 254 | parser.add_argument("--num_aug", required=False, type=int, help="number of augmented sentences per original sentence") 255 | parser.add_argument("--alpha", required=False, type=float, help="percent of words in each sentence to be changed") 256 | parser.add_argument('--seed', default=1, type=int) 257 | args = parser.parse_args() 258 | 259 | random.seed(args.seed) 260 | #the output file 261 | output = None 262 | if args.output: 263 | output = args.output 264 | else: 265 | from os.path import dirname, basename, join 266 | output = join(dirname(args.input), 'eda_' + basename(args.input)) 267 | 268 | #number of augmented sentences to generate per original sentence 269 | num_aug = 1 #default 270 | if args.num_aug: 271 | num_aug = args.num_aug 272 | 273 | #how much to change each sentence 274 | alpha = 0.1#default 275 | if args.alpha: 276 | alpha = args.alpha 277 | 278 | gen_eda(args.input, output, alpha=alpha, num_aug=num_aug) 279 | 280 | 281 | if __name__ == "__main__": 282 | main() -------------------------------------------------------------------------------- /src/scripts/bart_snips_lower.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | WARMUP_UPDATES=60 4 | LR=1e-05 # Peak LR for polynomial LR scheduler. 5 | SRC=~/PretrainedDataAugment/src 6 | BART_PATH=~/bart.large 7 | CACHE=~/CACHE 8 | PREFIXSIZE=3 9 | MAXEPOCH=30 10 | TASK=snips 11 | temp_rate=1.0 12 | smooth_rate=0.5 13 | for NUMEXAMPLES in 10; 14 | do 15 | for i in {0..14}; 16 | do 17 | RAWDATADIR=~/datasets/${TASK}/exp_${i}_${NUMEXAMPLES} 18 | DATABIN=$RAWDATADIR/jointdatabin 19 | 20 | splits=( train dev ) 21 | for split in "${splits[@]}"; 22 | do 23 | python $SRC/utils/bpe_encoder.py \ 24 | --encoder-json $SRC/utils/gpt2_bpe/encoder.json \ 25 | --vocab-bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 26 | --inputs $RAWDATADIR/${split}.tsv \ 27 | --outputs $RAWDATADIR/${split}_bpe.src \ 28 | --workers 1 --keep-empty --tsv --dataset $TASK 29 | done 30 | 31 | fairseq-preprocess --user-dir=$SRC/bart_aug --only-source \ 32 | --task mask_s2s \ 33 | --trainpref $RAWDATADIR/train_bpe.src \ 34 | --validpref $RAWDATADIR/dev_bpe.src \ 35 | --destdir $DATABIN \ 36 | --srcdict $BART_PATH/dict.txt 37 | 38 | # Run data generation with different noise setting 39 | for mr in 40; 40 | do 41 | MRATIO=0.${mr} 42 | for MASKLEN in word span; 43 | do 44 | MODELDIR=$RAWDATADIR/bart_${MASKLEN}_mask_${MRATIO}_checkpoints 45 | mkdir $MODELDIR 46 | 47 | CUDA_VISIBLE_DEVICES=0 fairseq-train $DATABIN/ \ 48 | --user-dir=$SRC/bart_aug \ 49 | --restore-file $BART_PATH/model.pt \ 50 | --arch bart_large \ 51 | --task mask_s2s \ 52 | --bpe gpt2 \ 53 | --gpt2_encoder_json $SRC/utils/gpt2_bpe/encoder.json \ 54 | --gpt2_vocab_bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 55 | --layernorm-embedding \ 56 | --share-all-embeddings \ 57 | --save-dir $MODELDIR\ 58 | --seed $i \ 59 | --share-decoder-input-output-embed \ 60 | --reset-optimizer --reset-dataloader --reset-meters \ 61 | --required-batch-size-multiple 1 \ 62 | --max-tokens 2000 \ 63 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 64 | --dropout 0.1 --attention-dropout 0.1 \ 65 | --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \ 66 | --clip-norm 0.0 \ 67 | --lr-scheduler polynomial_decay --lr $LR \ 68 | --warmup-updates $WARMUP_UPDATES \ 69 | --replace-length 1 --mask-length $MASKLEN --mask $MRATIO --fp16 --update-freq 1 \ 70 | --max-epoch $MAXEPOCH --no-epoch-checkpoints > $MODELDIR/bart.log 71 | 72 | CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATABIN \ 73 | --user-dir=$SRC/bart_aug \ 74 | --task mask_s2s --tokens-to-keep $PREFIXSIZE \ 75 | --seed ${i} \ 76 | --bpe gpt2 \ 77 | --gpt2_encoder_json $SRC/utils/gpt2_bpe/encoder.json \ 78 | --gpt2_vocab_bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 79 | --path $MODELDIR/checkpoint_best.pt \ 80 | --replace-length 1 --mask-length $MASKLEN --mask $MRATIO \ 81 | --batch-size 64 --beam 5 --lenpen 5 \ 82 | --no-repeat-ngram-size 2 \ 83 | --max-len-b 50 --prefix-size $PREFIXSIZE \ 84 | --gen-subset train > $MODELDIR/bart_l5_${PREFIXSIZE}.gen 85 | 86 | grep ^H $MODELDIR/bart_l5_${PREFIXSIZE}.gen | cut -f3 > $MODELDIR/bart_l5_gen_${PREFIXSIZE}.bpe 87 | rm $MODELDIR/checkpoint_last.pt 88 | python $SRC/utils/bpe_encoder.py \ 89 | --encoder-json $SRC/utils/gpt2_bpe/encoder.json \ 90 | --vocab-bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 91 | --inputs $MODELDIR/bart_l5_gen_${PREFIXSIZE}.bpe \ 92 | --outputs $MODELDIR/bart_l5_gen_${PREFIXSIZE}.tsv --dataset $TASK \ 93 | --workers 1 --keep-empty --decode --tsv 94 | done 95 | done 96 | 97 | ######################## 98 | ## BART Classifier 99 | ######################## 100 | 101 | for mr in 40; 102 | do 103 | MRATIO=0.${mr} 104 | for MASKLEN in span word; 105 | do 106 | MODELDIR=$RAWDATADIR/bart_${MASKLEN}_mask_${MRATIO}_checkpoints 107 | 108 | cat $RAWDATADIR/train.tsv $MODELDIR/bart_l5_gen_${PREFIXSIZE}.tsv > $MODELDIR/train.tsv 109 | cp $RAWDATADIR/test.tsv $MODELDIR/test.tsv 110 | cp $RAWDATADIR/dev.tsv $MODELDIR/dev.tsv 111 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $MODELDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_bart_l5_${MASKLEN}_mask_${MRATIO}_prefix_${PREFIXSIZE}.log 112 | done 113 | done 114 | done 115 | done 116 | -------------------------------------------------------------------------------- /src/scripts/bart_stsa_lower.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | WARMUP_UPDATES=60 4 | LR=1e-05 # Peak LR for polynomial LR scheduler. 5 | SRC=~/PretrainedDataAugment/src 6 | BART_PATH=~/bart.large 7 | CACHE=~/CACHE 8 | PREFIXSIZE=2 9 | MAXEPOCH=30 10 | TASK=stsa 11 | temp_rate=1.0 12 | smooth_rate=0.5 13 | 14 | for NUMEXAMPLES in 10; 15 | do 16 | for i in {0..14}; 17 | do 18 | RAWDATADIR=~/datasets/${TASK}/exp_${i}_${NUMEXAMPLES} 19 | DATABIN=$RAWDATADIR/jointdatabin 20 | 21 | splits=( train dev ) 22 | for split in "${splits[@]}"; 23 | do 24 | python $SRC/utils/bpe_encoder.py \ 25 | --encoder-json $SRC/utils/gpt2_bpe/encoder.json \ 26 | --vocab-bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 27 | --inputs $RAWDATADIR/${split}.tsv \ 28 | --outputs $RAWDATADIR/${split}_bpe.src \ 29 | --workers 1 --keep-empty --tsv --dataset $TASK 30 | done 31 | 32 | fairseq-preprocess --user-dir=$SRC/bart_aug --only-source \ 33 | --task mask_s2s \ 34 | --trainpref $RAWDATADIR/train_bpe.src \ 35 | --validpref $RAWDATADIR/dev_bpe.src \ 36 | --destdir $DATABIN \ 37 | --srcdict $BART_PATH/dict.txt 38 | 39 | # Run data generation with different noise setting 40 | for mr in 40; 41 | do 42 | MRATIO=0.${mr} 43 | for MASKLEN in word span; 44 | do 45 | MODELDIR=$RAWDATADIR/bart_${MASKLEN}_mask_${MRATIO}_checkpoints 46 | mkdir $MODELDIR 47 | 48 | CUDA_VISIBLE_DEVICES=0 fairseq-train $DATABIN/ \ 49 | --user-dir=$SRC/bart_aug \ 50 | --restore-file $BART_PATH/model.pt \ 51 | --arch bart_large \ 52 | --task mask_s2s \ 53 | --bpe gpt2 \ 54 | --gpt2_encoder_json $SRC/utils/gpt2_bpe/encoder.json \ 55 | --gpt2_vocab_bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 56 | --layernorm-embedding \ 57 | --share-all-embeddings \ 58 | --save-dir $MODELDIR\ 59 | --seed $i \ 60 | --share-decoder-input-output-embed \ 61 | --reset-optimizer --reset-dataloader --reset-meters \ 62 | --required-batch-size-multiple 1 \ 63 | --max-tokens 2000 \ 64 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 65 | --dropout 0.1 --attention-dropout 0.1 \ 66 | --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \ 67 | --clip-norm 0.0 \ 68 | --lr-scheduler polynomial_decay --lr $LR \ 69 | --warmup-updates $WARMUP_UPDATES \ 70 | --replace-length 1 --mask-length $MASKLEN --mask $MRATIO --fp16 --update-freq 1 \ 71 | --max-epoch $MAXEPOCH --no-epoch-checkpoints > $MODELDIR/bart.log 72 | 73 | CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATABIN \ 74 | --user-dir=$SRC/bart_aug \ 75 | --task mask_s2s --tokens-to-keep $PREFIXSIZE \ 76 | --seed ${i} \ 77 | --bpe gpt2 \ 78 | --gpt2_encoder_json $SRC/utils/gpt2_bpe/encoder.json \ 79 | --gpt2_vocab_bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 80 | --path $MODELDIR/checkpoint_best.pt \ 81 | --replace-length 1 --mask-length $MASKLEN --mask $MRATIO \ 82 | --batch-size 64 --beam 5 --lenpen 5 \ 83 | --no-repeat-ngram-size 2 \ 84 | --max-len-b 50 --prefix-size $PREFIXSIZE \ 85 | --gen-subset train > $MODELDIR/bart_l5_${PREFIXSIZE}.gen 86 | 87 | grep ^H $MODELDIR/bart_l5_${PREFIXSIZE}.gen | cut -f3 > $MODELDIR/bart_l5_gen_${PREFIXSIZE}.bpe 88 | rm $MODELDIR/checkpoint_last.pt 89 | python $SRC/utils/bpe_encoder.py \ 90 | --encoder-json $SRC/utils/gpt2_bpe/encoder.json \ 91 | --vocab-bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 92 | --inputs $MODELDIR/bart_l5_gen_${PREFIXSIZE}.bpe \ 93 | --outputs $MODELDIR/bart_l5_gen_${PREFIXSIZE}.tsv --dataset $TASK \ 94 | --workers 1 --keep-empty --decode --tsv 95 | done 96 | done 97 | 98 | ######################## 99 | ## BART Classifier 100 | ######################## 101 | 102 | for mr in 40; 103 | do 104 | MRATIO=0.${mr} 105 | for MASKLEN in span word; 106 | do 107 | MODELDIR=$RAWDATADIR/bart_${MASKLEN}_mask_${MRATIO}_checkpoints 108 | 109 | cat $RAWDATADIR/train.tsv $MODELDIR/bart_l5_gen_${PREFIXSIZE}.tsv > $MODELDIR/train.tsv 110 | cp $RAWDATADIR/test.tsv $MODELDIR/test.tsv 111 | cp $RAWDATADIR/dev.tsv $MODELDIR/dev.tsv 112 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $MODELDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_bart_l5_${MASKLEN}_mask_${MRATIO}_prefix_${PREFIXSIZE}.log 113 | done 114 | done 115 | done 116 | done 117 | -------------------------------------------------------------------------------- /src/scripts/bart_trec_lower.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | WARMUP_UPDATES=60 4 | LR=1e-05 # Peak LR for polynomial LR scheduler. 5 | SRC=~/PretrainedDataAugment/src 6 | BART_PATH=~/bart.large 7 | CACHE=~/CACHE 8 | MAXEPOCH=30 9 | PREFIXSIZE=3 10 | TASK=trec 11 | temp_rate=1.0 12 | smooth_rate=0.5 13 | 14 | for NUMEXAMPLES in 10; 15 | do 16 | for i in {0..14}; 17 | do 18 | RAWDATADIR=~/datasets/${TASK}/exp_${i}_${NUMEXAMPLES} 19 | DATABIN=$RAWDATADIR/jointdatabin 20 | 21 | splits=( train dev ) 22 | for split in "${splits[@]}"; 23 | do 24 | python $SRC/utils/bpe_encoder.py \ 25 | --encoder-json $SRC/utils/gpt2_bpe/encoder.json \ 26 | --vocab-bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 27 | --inputs $RAWDATADIR/${split}.tsv \ 28 | --outputs $RAWDATADIR/${split}_bpe.src \ 29 | --workers 1 --keep-empty --tsv --dataset $TASK 30 | done 31 | 32 | fairseq-preprocess --user-dir=$SRC/bart_aug --only-source \ 33 | --task mask_s2s \ 34 | --trainpref $RAWDATADIR/train_bpe.src \ 35 | --validpref $RAWDATADIR/dev_bpe.src \ 36 | --destdir $DATABIN \ 37 | --srcdict $BART_PATH/dict.txt 38 | 39 | # Run data generation with different noise setting 40 | for mr in 40; 41 | do 42 | MRATIO=0.${mr} 43 | for MASKLEN in word span; 44 | do 45 | MODELDIR=$RAWDATADIR/bart_${MASKLEN}_mask_${MRATIO}_checkpoints 46 | mkdir $MODELDIR 47 | 48 | CUDA_VISIBLE_DEVICES=0 fairseq-train $DATABIN/ \ 49 | --user-dir=$SRC/bart_aug \ 50 | --restore-file $BART_PATH/model.pt \ 51 | --arch bart_large \ 52 | --task mask_s2s \ 53 | --bpe gpt2 \ 54 | --gpt2_encoder_json $SRC/utils/gpt2_bpe/encoder.json \ 55 | --gpt2_vocab_bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 56 | --layernorm-embedding \ 57 | --share-all-embeddings \ 58 | --save-dir $MODELDIR\ 59 | --seed $i \ 60 | --share-decoder-input-output-embed \ 61 | --reset-optimizer --reset-dataloader --reset-meters \ 62 | --required-batch-size-multiple 1 \ 63 | --max-tokens 2000 \ 64 | --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ 65 | --dropout 0.1 --attention-dropout 0.1 \ 66 | --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \ 67 | --clip-norm 0.0 \ 68 | --lr-scheduler polynomial_decay --lr $LR \ 69 | --warmup-updates $WARMUP_UPDATES \ 70 | --replace-length 1 --mask-length $MASKLEN --mask $MRATIO --fp16 --update-freq 1 \ 71 | --max-epoch $MAXEPOCH --no-epoch-checkpoints > $MODELDIR/bart.log 72 | 73 | CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATABIN \ 74 | --user-dir=$SRC/bart_aug \ 75 | --task mask_s2s --tokens-to-keep $PREFIXSIZE \ 76 | --seed ${i} \ 77 | --bpe gpt2 \ 78 | --gpt2_encoder_json $SRC/utils/gpt2_bpe/encoder.json \ 79 | --gpt2_vocab_bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 80 | --path $MODELDIR/checkpoint_best.pt \ 81 | --replace-length 1 --mask-length $MASKLEN --mask $MRATIO \ 82 | --batch-size 64 --beam 5 --lenpen 5 \ 83 | --no-repeat-ngram-size 2 \ 84 | --max-len-b 50 --prefix-size $PREFIXSIZE \ 85 | --gen-subset train > $MODELDIR/bart_l5_${PREFIXSIZE}.gen 86 | 87 | grep ^H $MODELDIR/bart_l5_${PREFIXSIZE}.gen | cut -f3 > $MODELDIR/bart_l5_gen_${PREFIXSIZE}.bpe 88 | rm $MODELDIR/checkpoint_last.pt 89 | python $SRC/utils/bpe_encoder.py \ 90 | --encoder-json $SRC/utils/gpt2_bpe/encoder.json \ 91 | --vocab-bpe $SRC/utils/gpt2_bpe/vocab.bpe \ 92 | --inputs $MODELDIR/bart_l5_gen_${PREFIXSIZE}.bpe \ 93 | --outputs $MODELDIR/bart_l5_gen_${PREFIXSIZE}.tsv --dataset $TASK \ 94 | --workers 1 --keep-empty --decode --tsv 95 | done 96 | done 97 | 98 | ######################## 99 | ## BART Classifier 100 | ######################## 101 | 102 | for mr in 40; 103 | do 104 | MRATIO=0.${mr} 105 | for MASKLEN in span word; 106 | do 107 | MODELDIR=$RAWDATADIR/bart_${MASKLEN}_mask_${MRATIO}_checkpoints 108 | 109 | cat $RAWDATADIR/train.tsv $MODELDIR/bart_l5_gen_${PREFIXSIZE}.tsv > $MODELDIR/train.tsv 110 | cp $RAWDATADIR/test.tsv $MODELDIR/test.tsv 111 | cp $RAWDATADIR/dev.tsv $MODELDIR/dev.tsv 112 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $MODELDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_bart_l5_${MASKLEN}_mask_${MRATIO}_prefix_${PREFIXSIZE}.log 113 | done 114 | done 115 | done 116 | done 117 | -------------------------------------------------------------------------------- /src/scripts/bert_snips_lower.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SRC=~/PretrainedDataAugment/src 4 | CACHE=~/CACHE 5 | TASK=snips 6 | temp_rate=1.0 7 | smooth_rate=0.5 8 | for NUMEXAMPLES in 10; 9 | do 10 | for i in {0..14}; 11 | do 12 | RAWDATADIR=~/datasets/${TASK}/exp_${i}_${NUMEXAMPLES} 13 | 14 | # Baseline classifier 15 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $RAWDATADIR --seed ${i} --learning_rate $BERTLR --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_baseline.log 16 | 17 | ############## 18 | ## EDA 19 | ############## 20 | 21 | EDADIR=$RAWDATADIR/eda 22 | mkdir $EDADIR 23 | python $SRC/bert_aug/eda.py --input $RAWDATADIR/train.tsv --output $EDADIR/eda_aug.tsv --num_aug=1 --alpha=0.1 --seed ${i} 24 | cat $RAWDATADIR/train.tsv $EDADIR/eda_aug.tsv > $EDADIR/train.tsv 25 | cp $RAWDATADIR/test.tsv $EDADIR/test.tsv 26 | cp $RAWDATADIR/dev.tsv $EDADIR/dev.tsv 27 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $EDADIR --seed ${i} --learning_rate $BERTLR --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_eda.log 28 | 29 | 30 | ####################### 31 | # GPT2 Classifier 32 | ####################### 33 | 34 | GPT2DIR=$RAWDATADIR/gpt2 35 | mkdir $GPT2DIR 36 | python $SRC/bert_aug/cgpt2.py --data_dir $RAWDATADIR --output_dir $GPT2DIR --task_name $TASK --num_train_epochs 25 --seed ${i} --top_p 0.9 --temp 1.0 --cache $CACHE 37 | cat $RAWDATADIR/train.tsv $GPT2DIR/cmodgpt2_aug_3.tsv > $GPT2DIR/train.tsv 38 | cp $RAWDATADIR/test.tsv $GPT2DIR/test.tsv 39 | cp $RAWDATADIR/dev.tsv $GPT2DIR/dev.tsv 40 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $GPT2DIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_gpt2_3.log 41 | 42 | # ####################### 43 | # # Backtranslation DA Classifier 44 | # ####################### 45 | 46 | BTDIR=$RAWDATADIR/bt 47 | mkdir $BTDIR 48 | python $SRC/bert_aug/backtranslation.py --data_dir $RAWDATADIR --output_dir $BTDIR --task_name $TASK --seed ${i} --cache $CACHE 49 | cat $RAWDATADIR/train.tsv $BTDIR/bt_aug.tsv > $BTDIR/train.tsv 50 | cp $RAWDATADIR/test.tsv $BTDIR/test.tsv 51 | cp $RAWDATADIR/dev.tsv $BTDIR/dev.tsv 52 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $BTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_bt.log 53 | 54 | # ####################### 55 | # # CBERT Classifier 56 | # ####################### 57 | 58 | CBERTDIR=$RAWDATADIR/cbert 59 | mkdir $CBERTDIR 60 | python $SRC/bert_aug/cbert.py --data_dir $RAWDATADIR --output_dir $CBERTDIR --task_name $TASK --num_train_epochs 10 --seed ${i} --cache $CACHE > $RAWDATADIR/cbert.log 61 | cat $RAWDATADIR/train.tsv $CBERTDIR/cbert_aug.tsv > $CBERTDIR/train.tsv 62 | cp $RAWDATADIR/test.tsv $CBERTDIR/test.tsv 63 | cp $RAWDATADIR/dev.tsv $CBERTDIR/dev.tsv 64 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CBERTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cbert.log 65 | 66 | # ####################### 67 | # # CMODBERT Classifier 68 | # ###################### 69 | 70 | CMODBERTDIR=$RAWDATADIR/cmodbert 71 | mkdir $CMODBERTDIR 72 | python $SRC/bert_aug/cmodbert.py --data_dir $RAWDATADIR --output_dir $CMODBERTDIR --task_name $TASK --num_train_epochs 150 --learning_rate 0.00015 --seed ${i} --cache $CACHE > $RAWDATADIR/cmodbert.log 73 | cat $RAWDATADIR/train.tsv $CMODBERTDIR/cmodbert_aug.tsv > $CMODBERTDIR/train.tsv 74 | cp $RAWDATADIR/test.tsv $CMODBERTDIR/test.tsv 75 | cp $RAWDATADIR/dev.tsv $CMODBERTDIR/dev.tsv 76 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CMODBERTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cmodbert.log 77 | 78 | # ####################### 79 | # # CMODBERTP Classifier 80 | # ###################### 81 | 82 | CMODBERTPDIR=$RAWDATADIR/cmodbertp 83 | mkdir $CMODBERTPDIR 84 | python $SRC/bert_aug/cmodbertp.py --data_dir $RAWDATADIR --output_dir $CMODBERTPDIR --task_name $TASK --num_train_epochs 10 --seed ${i} --cache $CACHE > $RAWDATADIR/cmodbertp.log 85 | cat $RAWDATADIR/train.tsv $CMODBERTPDIR/cmodbertp_aug.tsv > $CMODBERTPDIR/train.tsv 86 | cp $RAWDATADIR/test.tsv $CMODBERTPDIR/test.tsv 87 | cp $RAWDATADIR/dev.tsv $CMODBERTPDIR/dev.tsv 88 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CMODBERTPDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cmodbertp.log 89 | 90 | done 91 | done 92 | 93 | 94 | -------------------------------------------------------------------------------- /src/scripts/bert_stsa_lower.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SRC=~/PretrainedDataAugment/src 4 | CACHE=~/CACHE 5 | TASK=stsa 6 | temp_rate=1.0 7 | smooth_rate=0.5 8 | for NUMEXAMPLES in 10; 9 | do 10 | for i in {0..14}; 11 | do 12 | RAWDATADIR=~/datasets/${TASK}/exp_${i}_${NUMEXAMPLES} 13 | 14 | # Baseline classifier 15 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $RAWDATADIR --seed ${i} --learning_rate $BERTLR --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_baseline.log 16 | 17 | ############## 18 | ## EDA 19 | ############## 20 | 21 | EDADIR=$RAWDATADIR/eda 22 | mkdir $EDADIR 23 | python $SRC/bert_aug/eda.py --input $RAWDATADIR/train.tsv --output $EDADIR/eda_aug.tsv --num_aug=1 --alpha=0.1 --seed ${i} 24 | cat $RAWDATADIR/train.tsv $EDADIR/eda_aug.tsv > $EDADIR/train.tsv 25 | cp $RAWDATADIR/test.tsv $EDADIR/test.tsv 26 | cp $RAWDATADIR/dev.tsv $EDADIR/dev.tsv 27 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $EDADIR --seed ${i} --learning_rate $BERTLR --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_eda.log 28 | 29 | 30 | ####################### 31 | # GPT2 Classifier 32 | ####################### 33 | 34 | GPT2DIR=$RAWDATADIR/gpt2 35 | mkdir $GPT2DIR 36 | python $SRC/bert_aug/cgpt2.py --data_dir $RAWDATADIR --output_dir $GPT2DIR --task_name $TASK --num_train_epochs 25 --seed ${i} --top_p 0.9 --temp 1.0 --cache $CACHE 37 | cat $RAWDATADIR/train.tsv $GPT2DIR/cmodgpt2_aug_3.tsv > $GPT2DIR/train.tsv 38 | cp $RAWDATADIR/test.tsv $GPT2DIR/test.tsv 39 | cp $RAWDATADIR/dev.tsv $GPT2DIR/dev.tsv 40 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $GPT2DIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_gpt2_3.log 41 | 42 | # ####################### 43 | # # Backtranslation DA Classifier 44 | # ####################### 45 | 46 | BTDIR=$RAWDATADIR/bt 47 | mkdir $BTDIR 48 | python $SRC/bert_aug/backtranslation.py --data_dir $RAWDATADIR --output_dir $BTDIR --task_name $TASK --seed ${i} --cache $CACHE 49 | cat $RAWDATADIR/train.tsv $BTDIR/bt_aug.tsv > $BTDIR/train.tsv 50 | cp $RAWDATADIR/test.tsv $BTDIR/test.tsv 51 | cp $RAWDATADIR/dev.tsv $BTDIR/dev.tsv 52 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $BTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_bt.log 53 | 54 | # ####################### 55 | # # CBERT Classifier 56 | # ####################### 57 | 58 | CBERTDIR=$RAWDATADIR/cbert 59 | mkdir $CBERTDIR 60 | python $SRC/bert_aug/cbert.py --data_dir $RAWDATADIR --output_dir $CBERTDIR --task_name $TASK --num_train_epochs 10 --seed ${i} --cache $CACHE > $RAWDATADIR/cbert.log 61 | cat $RAWDATADIR/train.tsv $CBERTDIR/cbert_aug.tsv > $CBERTDIR/train.tsv 62 | cp $RAWDATADIR/test.tsv $CBERTDIR/test.tsv 63 | cp $RAWDATADIR/dev.tsv $CBERTDIR/dev.tsv 64 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CBERTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cbert.log 65 | 66 | # ####################### 67 | # # CMODBERT Classifier 68 | # ###################### 69 | 70 | CMODBERTDIR=$RAWDATADIR/cmodbert 71 | mkdir $CMODBERTDIR 72 | python $SRC/bert_aug/cmodbert.py --data_dir $RAWDATADIR --output_dir $CMODBERTDIR --task_name $TASK --num_train_epochs 150 --learning_rate 0.00015 --seed ${i} --cache $CACHE > $RAWDATADIR/cmodbert.log 73 | cat $RAWDATADIR/train.tsv $CMODBERTDIR/cmodbert_aug.tsv > $CMODBERTDIR/train.tsv 74 | cp $RAWDATADIR/test.tsv $CMODBERTDIR/test.tsv 75 | cp $RAWDATADIR/dev.tsv $CMODBERTDIR/dev.tsv 76 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CMODBERTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cmodbert.log 77 | 78 | # ####################### 79 | # # CMODBERTP Classifier 80 | # ###################### 81 | 82 | CMODBERTPDIR=$RAWDATADIR/cmodbertp 83 | mkdir $CMODBERTPDIR 84 | python $SRC/bert_aug/cmodbertp.py --data_dir $RAWDATADIR --output_dir $CMODBERTPDIR --task_name $TASK --num_train_epochs 10 --seed ${i} --cache $CACHE > $RAWDATADIR/cmodbertp.log 85 | cat $RAWDATADIR/train.tsv $CMODBERTPDIR/cmodbertp_aug.tsv > $CMODBERTPDIR/train.tsv 86 | cp $RAWDATADIR/test.tsv $CMODBERTPDIR/test.tsv 87 | cp $RAWDATADIR/dev.tsv $CMODBERTPDIR/dev.tsv 88 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CMODBERTPDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cmodbertp.log 89 | 90 | done 91 | done 92 | 93 | 94 | -------------------------------------------------------------------------------- /src/scripts/bert_trec_lower.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SRC=~/PretrainedDataAugment/src 4 | CACHE=~/CACHE 5 | TASK=trec 6 | temp_rate=1.0 7 | smooth_rate=0.5 8 | for NUMEXAMPLES in 10; 9 | do 10 | for i in {0..14}; 11 | do 12 | RAWDATADIR=~/datasets/${TASK}/exp_${i}_${NUMEXAMPLES} 13 | 14 | # Baseline classifier 15 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $RAWDATADIR --seed ${i} --learning_rate $BERTLR --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_baseline.log 16 | 17 | ############## 18 | ## EDA 19 | ############## 20 | 21 | EDADIR=$RAWDATADIR/eda 22 | mkdir $EDADIR 23 | python $SRC/bert_aug/eda.py --input $RAWDATADIR/train.tsv --output $EDADIR/eda_aug.tsv --num_aug=1 --alpha=0.1 --seed ${i} 24 | cat $RAWDATADIR/train.tsv $EDADIR/eda_aug.tsv > $EDADIR/train.tsv 25 | cp $RAWDATADIR/test.tsv $EDADIR/test.tsv 26 | cp $RAWDATADIR/dev.tsv $EDADIR/dev.tsv 27 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $EDADIR --seed ${i} --learning_rate $BERTLR --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_eda.log 28 | 29 | 30 | ####################### 31 | # GPT2 Classifier 32 | ####################### 33 | 34 | GPT2DIR=$RAWDATADIR/gpt2 35 | mkdir $GPT2DIR 36 | python $SRC/bert_aug/cgpt2.py --data_dir $RAWDATADIR --output_dir $GPT2DIR --task_name $TASK --num_train_epochs 25 --seed ${i} --top_p 0.9 --temp 1.0 --cache $CACHE 37 | cat $RAWDATADIR/train.tsv $GPT2DIR/cmodgpt2_aug_3.tsv > $GPT2DIR/train.tsv 38 | cp $RAWDATADIR/test.tsv $GPT2DIR/test.tsv 39 | cp $RAWDATADIR/dev.tsv $GPT2DIR/dev.tsv 40 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $GPT2DIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_gpt2_3.log 41 | 42 | # ####################### 43 | # # Backtranslation DA Classifier 44 | # ####################### 45 | 46 | BTDIR=$RAWDATADIR/bt 47 | mkdir $BTDIR 48 | python $SRC/bert_aug/backtranslation.py --data_dir $RAWDATADIR --output_dir $BTDIR --task_name $TASK --seed ${i} --cache $CACHE 49 | cat $RAWDATADIR/train.tsv $BTDIR/bt_aug.tsv > $BTDIR/train.tsv 50 | cp $RAWDATADIR/test.tsv $BTDIR/test.tsv 51 | cp $RAWDATADIR/dev.tsv $BTDIR/dev.tsv 52 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $BTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_bt.log 53 | 54 | # ####################### 55 | # # CBERT Classifier 56 | # ####################### 57 | 58 | CBERTDIR=$RAWDATADIR/cbert 59 | mkdir $CBERTDIR 60 | python $SRC/bert_aug/cbert.py --data_dir $RAWDATADIR --output_dir $CBERTDIR --task_name $TASK --num_train_epochs 10 --seed ${i} --cache $CACHE > $RAWDATADIR/cbert.log 61 | cat $RAWDATADIR/train.tsv $CBERTDIR/cbert_aug.tsv > $CBERTDIR/train.tsv 62 | cp $RAWDATADIR/test.tsv $CBERTDIR/test.tsv 63 | cp $RAWDATADIR/dev.tsv $CBERTDIR/dev.tsv 64 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CBERTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cbert.log 65 | 66 | # ####################### 67 | # # CMODBERT Classifier 68 | # ###################### 69 | 70 | CMODBERTDIR=$RAWDATADIR/cmodbert 71 | mkdir $CMODBERTDIR 72 | python $SRC/bert_aug/cmodbert.py --data_dir $RAWDATADIR --output_dir $CMODBERTDIR --task_name $TASK --num_train_epochs 150 --learning_rate 0.00015 --seed ${i} --cache $CACHE > $RAWDATADIR/cmodbert.log 73 | cat $RAWDATADIR/train.tsv $CMODBERTDIR/cmodbert_aug.tsv > $CMODBERTDIR/train.tsv 74 | cp $RAWDATADIR/test.tsv $CMODBERTDIR/test.tsv 75 | cp $RAWDATADIR/dev.tsv $CMODBERTDIR/dev.tsv 76 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CMODBERTDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cmodbert.log 77 | 78 | # ####################### 79 | # # CMODBERTP Classifier 80 | # ###################### 81 | 82 | CMODBERTPDIR=$RAWDATADIR/cmodbertp 83 | mkdir $CMODBERTPDIR 84 | python $SRC/bert_aug/cmodbertp.py --data_dir $RAWDATADIR --output_dir $CMODBERTPDIR --task_name $TASK --num_train_epochs 10 --seed ${i} --cache $CACHE > $RAWDATADIR/cmodbertp.log 85 | cat $RAWDATADIR/train.tsv $CMODBERTPDIR/cmodbertp_aug.tsv > $CMODBERTPDIR/train.tsv 86 | cp $RAWDATADIR/test.tsv $CMODBERTPDIR/test.tsv 87 | cp $RAWDATADIR/dev.tsv $CMODBERTPDIR/dev.tsv 88 | python $SRC/bert_aug/bert_classifier.py --task $TASK --data_dir $CMODBERTPDIR --seed ${i} --cache $CACHE --temp_rate ${temp_rate} --smooth_rate ${smooth_rate} > $RAWDATADIR/bert_cmodbertp.log 89 | 90 | done 91 | done 92 | 93 | 94 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caskcsg/TextSmoothing/cb9a1fd01732e5abe157c27562043425efcaadd1/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/bpe_encoder.py: -------------------------------------------------------------------------------- 1 | # Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | # Original Copyright Facebook, Inc. and its affiliates. Licensed under the MIT License as part of 4 | # fairseq package. 5 | 6 | 7 | import argparse 8 | import contextlib 9 | import sys 10 | 11 | from collections import Counter 12 | from multiprocessing import Pool 13 | 14 | from fairseq.data.encoders.gpt2_bpe_utils import get_encoder 15 | 16 | 17 | def get_labels(dataset_name, to_lower=True): 18 | """add your dataset here""" 19 | task_name = dataset_name.lower() 20 | if task_name == 'stsa': 21 | labels = ["Positive", "Negative"] 22 | elif task_name == 'trec': 23 | labels = ['Description', 'Entity', 'Abbreviation', 'Human', 'Location', 'Numeric'] 24 | elif task_name == "snips": 25 | labels = ["PlayMusic", "GetWeather", "RateBook", "SearchScreeningEvent", 26 | "SearchCreativeWork", "AddToPlaylist", "BookRestaurant"] 27 | else: 28 | raise ValueError("unknown dataset {}".format(dataset_name)) 29 | if to_lower: 30 | return [l.lower() for l in labels] 31 | else: 32 | return labels 33 | 34 | 35 | def main(): 36 | """ 37 | Helper script to encode raw text with the GPT-2 BPE using multiple processes. 38 | 39 | The encoder.json and vocab.bpe files can be obtained here: 40 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json 41 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe 42 | """ 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument( 45 | "--encoder-json", 46 | help='path to encoder.json', 47 | ) 48 | parser.add_argument( 49 | "--vocab-bpe", 50 | type=str, 51 | help='path to vocab.bpe', 52 | ) 53 | parser.add_argument( 54 | "--inputs", 55 | nargs="+", 56 | default=['-'], 57 | help="input files to filter/encode", 58 | ) 59 | parser.add_argument( 60 | "--outputs", 61 | nargs="+", 62 | default=['-'], 63 | help="path to save encoded outputs", 64 | ) 65 | parser.add_argument( 66 | "--keep-empty", 67 | action="store_true", 68 | help="keep empty lines", 69 | ) 70 | parser.add_argument( 71 | "--decode", 72 | action="store_true", 73 | help="keep empty lines", 74 | ) 75 | parser.add_argument( 76 | "--tsv", 77 | action="store_true", 78 | help="Is a TSV file. If true, will merge the columns", 79 | ) 80 | parser.add_argument( 81 | "--label", 82 | action="store_true", 83 | help="Replace the labels with single BPE token", 84 | ) 85 | parser.add_argument( 86 | "--dataset", 87 | default="sst2", type=str, 88 | help="Dataset. Used for filtering invalid utterances", 89 | ) 90 | 91 | parser.add_argument("--workers", type=int, default=4) 92 | args = parser.parse_args() 93 | 94 | assert len(args.inputs) == len(args.outputs), \ 95 | "number of input and output paths should match" 96 | 97 | with contextlib.ExitStack() as stack: 98 | inputs = [ 99 | stack.enter_context(open(input, "r", encoding="utf-8")) 100 | if input != "-" else sys.stdin 101 | for input in args.inputs 102 | ] 103 | outputs = [ 104 | stack.enter_context(open(output, "w", encoding="utf-8")) 105 | if output != "-" else sys.stdout 106 | for output in args.outputs 107 | ] 108 | 109 | encoder = MultiprocessingEncoder(args) 110 | pool = Pool(args.workers, initializer=encoder.initializer) 111 | 112 | if args.decode: 113 | processed_lines = pool.imap(encoder.decode_lines, zip(*inputs), 100) 114 | else: 115 | processed_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) 116 | 117 | stats = Counter() 118 | for i, (filt, _lines) in enumerate(processed_lines, start=1): 119 | if filt == "PASS": 120 | for _line, output_h in zip(_lines, outputs): 121 | print(_line, file=output_h) 122 | else: 123 | stats["num_filtered_" + filt] += 1 124 | if i % 10000 == 0: 125 | print("processed {} lines".format(i), file=sys.stderr) 126 | 127 | for k, v in stats.most_common(): 128 | print("[{}] filtered {} lines".format(k, v), file=sys.stderr) 129 | 130 | 131 | class MultiprocessingEncoder(object): 132 | 133 | def __init__(self, args): 134 | self.args = args 135 | 136 | def initializer(self): 137 | global bpe 138 | bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe) 139 | 140 | def encode(self, line): 141 | global bpe 142 | ids = bpe.encode(line) 143 | return list(map(str, ids)) 144 | 145 | def decode(self, tokens): 146 | global bpe 147 | return bpe.decode(tokens) 148 | 149 | def encode_lines(self, lines): 150 | """ 151 | Encode a set of lines. All lines will be encoded together. 152 | """ 153 | labels = get_labels(self.args.dataset) 154 | label_to_bpe_codes = {labels[i]: str(i+50265) for i in range(len(labels))} 155 | 156 | enc_lines = [] 157 | for line in lines: 158 | line = line.strip() 159 | if len(line) == 0 and not self.args.keep_empty: 160 | return ["EMPTY", None] 161 | 162 | if self.args.tsv: # merge columns 163 | fields = line.split("\t") 164 | label = fields[0] 165 | text = fields[1] 166 | if self.args.label: 167 | tokens = [label_to_bpe_codes[label]] + self.encode(text) 168 | else: 169 | line = " ".join([label, text]) 170 | tokens = self.encode(line) 171 | else: 172 | tokens = self.encode(line) 173 | 174 | enc_lines.append(" ".join(tokens)) 175 | return ["PASS", enc_lines] 176 | 177 | def decode_lines(self, lines): 178 | labels = get_labels(self.args.dataset) 179 | bpe_to_label_dict = {i + 50265: labels[i] for i in range(len(labels))} 180 | 181 | dec_lines = [] 182 | for line in lines: 183 | if self.args.tsv: # write in tsv format 184 | if self.args.label: 185 | tokens = line.strip().split() 186 | utterance_text = self.decode(tokens[1:]) 187 | decoded_text = bpe_to_label_dict[tokens[0]] + "\t" + " ".join(utterance_text) 188 | else: 189 | try: 190 | tokens = map(int, line.strip().split()) 191 | decoded_text = self.decode(tokens) 192 | except: 193 | print(line) 194 | continue 195 | 196 | word_tokens = decoded_text.strip().split(" ") 197 | if word_tokens[0] in labels: 198 | decoded_text = word_tokens[0] + "\t" + " ".join(word_tokens[1:]) 199 | else: 200 | print("Invalid utterance {}".format(word_tokens)) 201 | continue 202 | else: 203 | tokens = map(int, line.strip().split()) 204 | decoded_text = self.decode(tokens) 205 | dec_lines.append(decoded_text) 206 | return ["PASS", dec_lines] 207 | 208 | 209 | if __name__ == "__main__": 210 | main() 211 | 212 | -------------------------------------------------------------------------------- /src/utils/convert_num_to_text_labels.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | import argparse 5 | 6 | def get_label_dict(dataset_name): 7 | """ 8 | Map numeric labels to the actual labels for STSA-Binary and TREC dataset 9 | 10 | Returns: Dict of {int: label} mapping 11 | 12 | """ 13 | if dataset_name == "stsa": 14 | return {"0": "Negative", "1": "Positive"} 15 | elif dataset_name == "trec": 16 | label_list = ['Description', 'Entity', 'Abbreviation', 'Human', 'Location', 'Numeric'] 17 | return {str(k): label_list[k] for k in range(len(label_list))} 18 | else: 19 | raise ValueError("Unknown dataset name") 20 | 21 | 22 | def prepare_data(input_file, output_file, dataset_name): 23 | """ 24 | Remove header line from dataset and change numeric label to text labels 25 | """ 26 | line_count = 0 27 | label_dict = get_label_dict(dataset_name) 28 | with open(output_file, "w") as out_fp: 29 | with open(input_file, "r") as in_fp: 30 | for line in in_fp: 31 | if line_count == 0: 32 | line_count += 1 33 | continue 34 | fields = line.strip().split("\t") 35 | sentence = fields[0] 36 | label = fields[1] 37 | out_fp.write("\t".join([label_dict[label], sentence])) 38 | out_fp.write("\n") 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser(description='Replace Numeric labels with Text labels') 43 | group = parser.add_argument_group(title="I/O params") 44 | group.add_argument('-i', type=str, help='Input file') 45 | group.add_argument('-o', type=str, help='Output file') 46 | group.add_argument('-d', type=str, help='DataSet name') 47 | 48 | args = parser.parse_args() 49 | prepare_data(args.i, args.o, args.d) -------------------------------------------------------------------------------- /src/utils/create_fsl_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: CC-BY-NC-4.0 3 | 4 | from collections import defaultdict 5 | import argparse 6 | import os 7 | import random 8 | import logging 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def process_data_file(exp_id, source_file, target_file, num_examples=None, to_lower=False): 15 | random.seed(exp_id) 16 | target_category_data = defaultdict(list) 17 | 18 | with open(target_file, "w") as out_file, open(source_file, "r") as in_file: 19 | for line in in_file: 20 | if to_lower: 21 | line = line.lower() 22 | fields = line.strip().split("\t") 23 | if len(fields) == 2: 24 | category = fields[0] 25 | example = fields[1] 26 | else: 27 | raise ValueError("Unknown format. Expecting a two col tsv file") 28 | 29 | two_col_line = "\t".join([category, example]) 30 | if num_examples is None: 31 | out_file.write(two_col_line) 32 | out_file.write("\n") 33 | else: 34 | target_category_data[category].append(two_col_line) 35 | 36 | if num_examples: 37 | # write num_seed utterances from target_category_data 38 | for cat, cat_data in target_category_data.items(): 39 | if num_examples < len(cat_data): 40 | seed_utterances = random.sample(cat_data, num_examples) 41 | else: 42 | seed_utterances = cat_data 43 | for two_col_line in seed_utterances: 44 | out_file.write(two_col_line) 45 | out_file.write("\n") 46 | 47 | 48 | def split_data(data_dir, num_train, num_dev, num_simulations, lower): 49 | all_training_data_file = os.path.join(data_dir, "train.tsv") 50 | dev_data_file = os.path.join(data_dir, "dev.tsv") 51 | test_data_file = os.path.join(data_dir, "test.tsv") 52 | 53 | for exp_id in range(num_simulations): 54 | exp_folder_path = os.path.join(data_dir, "exp_{}_{}".format(exp_id, num_train)) 55 | if not os.path.exists(exp_folder_path): 56 | os.mkdir(exp_folder_path) 57 | else: 58 | raise ValueError("Directory {} already exists".format(exp_folder_path)) 59 | 60 | # randomly select train data 61 | target_train_file = os.path.join(exp_folder_path, "train.tsv") 62 | process_data_file(exp_id, all_training_data_file, target_train_file, 63 | num_examples=num_train, to_lower=lower) 64 | 65 | # randomly select dev data 66 | target_dev_file = os.path.join(exp_folder_path, "dev.tsv") 67 | process_data_file(exp_id, dev_data_file, target_dev_file, 68 | num_examples=num_dev, to_lower=lower) 69 | 70 | # copy test file as it is 71 | target_test_file = os.path.join(exp_folder_path, "test.tsv") 72 | process_data_file(exp_id, test_data_file, target_test_file, 73 | num_examples=None, to_lower=lower) 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser(description='Select N utterances from target category') 78 | # data category group 79 | parser.add_argument('-datadir', help='Data Dir', type=str, required=True) 80 | 81 | # data split parameters 82 | parser.add_argument('-num_train', help='Number of training examples to select', type=int, 83 | required=True) 84 | parser.add_argument('-num_dev', help='Number of dev examples to select', type=int, 85 | required=True) 86 | parser.add_argument('-sim', help='Number of simulations', type=int, default=15) 87 | 88 | # data pre-processing steps 89 | parser.add_argument('-lower', action='store_true', default=False) 90 | args = parser.parse_args() 91 | split_data(data_dir=args.datadir, 92 | num_train=args.num_train, 93 | num_dev=args.num_dev, 94 | num_simulations=args.sim, 95 | lower=args.lower) 96 | -------------------------------------------------------------------------------- /src/utils/download_and_prepare_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir datasets 4 | NUMDEV=10 5 | NUMEXP=15 6 | 7 | # TREC dataset 8 | mkdir -p datasets/trec 9 | for split in train dev test; 10 | do 11 | wget -O datasets/trec/${split}.raw https://raw.githubusercontent.com/1024er/cbert_aug/crayon/datasets/TREC/${split}.tsv 12 | python convert_num_to_text_labels.py -i datasets/trec/${split}.raw -o datasets/trec/${split}.tsv -d trec 13 | rm datasets/trec/${split}.raw 14 | done 15 | python create_fsl_dataset.py -datadir datasets/trec -num_train 10 -num_dev $NUMDEV -sim $NUMEXP -lower 16 | 17 | 18 | # STSA dataset 19 | mkdir -p datasets/stsa 20 | for split in train dev test; 21 | do 22 | wget -O datasets/stsa/${split}.raw https://raw.githubusercontent.com/1024er/cbert_aug/crayon/datasets/stsa.binary/${split}.tsv 23 | python convert_num_to_text_labels.py -i datasets/stsa/${split}.raw -o datasets/stsa/${split}.tsv -d stsa 24 | rm datasets/stsa/${split}.raw 25 | done 26 | python create_fsl_dataset.py -datadir datasets/stsa -num_train 10 -num_dev $NUMDEV -sim $NUMEXP -lower 27 | 28 | 29 | # SNIPS dataset 30 | mkdir -p datasets/snips 31 | for split in train valid test; 32 | do 33 | wget -O datasets/snips/${split}.seq https://raw.githubusercontent.com/MiuLab/SlotGated-SLU/master/data/snips/${split}/seq.in 34 | wget -O datasets/snips/${split}.label https://raw.githubusercontent.com/MiuLab/SlotGated-SLU/master/data/snips/${split}/label 35 | paste -d'\t' datasets/snips/${split}.label datasets/snips/${split}.seq > datasets/snips/${split}.tsv 36 | rm datasets/snips/${split}.label 37 | rm datasets/snips/${split}.seq 38 | done 39 | 40 | mv datasets/snips/valid.tsv datasets/snips/dev.tsv 41 | python create_fsl_dataset.py -datadir datasets/snips -num_train 10 -num_dev $NUMDEV -sim $NUMEXP -lower 42 | --------------------------------------------------------------------------------