├── .gitignore ├── LICENSE ├── README.md ├── layers.py ├── metric.py ├── models.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chris (Tu) NGUYEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi Supervised NLP 2 | 3 | Implementation of semi-supervised learning techniques: UDA, MixMatch, Mean-teacher, focusing on NLP. 4 | 5 | Notes: 6 | - Instead of `mixup` in the original paper, I use [Manifold Mixup](https://arxiv.org/abs/1806.05236) , which is better suited for NLP application. 7 | 8 | ## Encoder 9 | 10 | - Any `encoder` can be used: transformer, LSTM, etc. The default is [LSTMWeightDrop](https://github.com/chris-tng/semi-supervised-nlp/blob/8148d07f79a24acdd23e4ae55b1c12b7cf2ae7b7/layers.py#L130), used in [AWD-LSTM](https://arxiv.org/pdf/1708.02182.pdf), inspired by `fast.ai`-v1. 11 | 12 | - Since this repo is mainly concerned with exploring SSL techniques, using Transformer can be overkill. It could dominate the progress made by SSL, not to mention long training time. 13 | 14 | ## Data Augmentation 15 | 16 | There're many data augmentation techniques in Computer Vision, not so much in NLP. It's an open research into strong data augmentation in NLP. So far, what I found effectively is `back-translation`, confirmed by UDA paper. There're many ways to perform back-translation, one simple way is to use [MarianMT](https://huggingface.co/transformers/model_doc/marian.html), shipped in the excellent `huggingface-transformers`. 17 | 18 | - Some data augmentation techniques I would like to explore 19 | 20 | - [ ] TF-IDF word replacement 21 | - [ ] Sentence permutation 22 | - [ ] Nearest neighbor sentence replacement 23 | 24 | ## Citations 25 | 26 | ``` 27 | @article{xie2019unsupervised, 28 | title={Unsupervised Data Augmentation for Consistency Training}, 29 | author={Xie, Qizhe and Dai, Zihang and Hovy, Eduard and Luong, Minh-Thang and Le, Quoc V}, 30 | journal={arXiv preprint arXiv:1904.12848}, 31 | year={2019} 32 | } 33 | 34 | @article{berthelot2019mixmatch, 35 | title={MixMatch: A Holistic Approach to Semi-Supervised Learning}, 36 | author={Berthelot, David and Carlini, Nicholas and Goodfellow, Ian and Papernot, Nicolas and Oliver, Avital and Raffel, Colin}, 37 | journal={arXiv preprint arXiv:1905.02249}, 38 | year={2019} 39 | } 40 | ``` 41 | 42 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # + 2 | __all__ = ["RNNDropout", "WeightDropout", "EmbeddingDropout", 3 | "LSTMWeightDrop", "HANAttention"] 4 | 5 | from typing import List 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | from typing import Tuple 12 | # - 13 | 14 | from utils import softmax_with_mask, sequence_mask 15 | 16 | 17 | # ### RNNDropout 18 | 19 | def dropout_mask(x: Tensor, size: Tuple, p: float): 20 | """Return a dropout mask of the same type as `x`, 21 | size `size`, with probability `p` to nullify an element.""" 22 | return x.new(*size).bernoulli_(1-p).div_(1-p) 23 | 24 | 25 | class RNNDropout(nn.Module): 26 | "Dropout with probability `p` that is consistent on the seq_len dimension." 27 | def __init__(self, p: float=0.5): 28 | super().__init__() 29 | self.p = p 30 | 31 | def forward(self, x: Tensor): 32 | """batch-major x of shape (batch_size, seq_len, feature_size)""" 33 | assert x.ndim == 3, f"Expect x of dimension 3, whereas dim x is {x.ndim}" 34 | if not self.training or self.p == 0.: return x 35 | return x * dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p) 36 | 37 | 38 | # ### WeightDropout 39 | 40 | # + 41 | import warnings 42 | 43 | class WeightDropout(nn.Module): 44 | """ 45 | Wrapper around another layer in which some weights 46 | will be replaced by 0 during training. 47 | Args: 48 | - module {nn.Module}: the module being wrapped 49 | - weight_p {float}: probability of dropout 50 | - layer_names {List[str]}: names of weights of `module` being dropped out. 51 | By default: it drops hidden to hidden connection of LSTM 52 | """ 53 | 54 | def __init__(self, module: nn.Module, weight_p: float, 55 | layer_names: List[str]=['weight_hh_l0']): 56 | super().__init__() 57 | self.module, self.weight_p = module, weight_p 58 | self.layer_names = layer_names 59 | for layer in self.layer_names: 60 | # Makes a copy of the weights of the selected layers. 61 | w = getattr(self.module, layer) 62 | self.register_parameter(f'{layer}_raw', nn.Parameter(w.data)) 63 | self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False) 64 | 65 | def _setweights(self): 66 | "Apply dropout to the raw weights." 67 | for layer in self.layer_names: 68 | raw_w = getattr(self, f'{layer}_raw') 69 | self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training) 70 | 71 | def forward(self, *args): 72 | self._setweights() 73 | with warnings.catch_warnings(): 74 | #To avoid the warning that comes because the weights aren't flattened. 75 | warnings.simplefilter("ignore") 76 | return self.module.forward(*args) 77 | 78 | def reset(self): 79 | for layer in self.layer_names: 80 | raw_w = getattr(self, f'{layer}_raw') 81 | self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=False) 82 | if hasattr(self.module, 'reset'): self.module.reset() 83 | 84 | 85 | # - 86 | 87 | # ### EmbeddingDropout 88 | 89 | class EmbeddingDropout(nn.Module): 90 | "Apply dropout with probabily `embed_p` to an embedding layer `emb`." 91 | 92 | def __init__(self, emb: nn.Module, embed_p: float): 93 | super().__init__() 94 | self.emb, self.embed_p = emb, embed_p 95 | 96 | def forward(self, x: Tensor, scale: float=None): 97 | if self.training and self.embed_p != 0: 98 | size = (self.emb.weight.size(0),1) 99 | mask = dropout_mask(self.emb.weight.data, size, self.embed_p) 100 | masked_embed = self.emb.weight * mask 101 | else: 102 | masked_embed = self.emb.weight 103 | 104 | if scale: 105 | masked_embed.mul_(scale) 106 | return F.embedding(x, masked_embed, 107 | self.emb.padding_idx or -1, self.emb.max_norm, 108 | self.emb.norm_type, self.emb.scale_grad_by_freq, 109 | self.emb.sparse) 110 | 111 | 112 | # ### LSTMWeightDrop 113 | 114 | # + 115 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 116 | 117 | def one_param(m): 118 | "First parameter in `m`" 119 | return first(m.parameters()) 120 | 121 | def first(x): return next(iter(x)) 122 | 123 | def _to_detach(b): 124 | return [(o[0].detach(), o[1].detach()) if isinstance(o, tuple) else o.detach() 125 | for o in b] 126 | 127 | 128 | # + 129 | from typing import Tuple 130 | 131 | class LSTMWeightDrop(nn.Module): 132 | """ 133 | LSTM with dropouts 134 | 135 | Args: 136 | - input_p : float - RNNDropout applied to input after embedding 137 | - weight_p : float - WeightDropout applied to hidden-hidden connection of LSTM 138 | - hidden_p : float - RNNDropout applied to two of the inner LSTMs 139 | - hidden_sz : int - total hidden size including bidir 140 | 141 | Outputs: 142 | - raw outputs : List[torch.Tensor] - activation for each layer without dropout in reverse order, last at index 0 143 | - outputs : List[torch.Tensor] - activation for each layer with dropout in reverse order, last at index 0 144 | """ 145 | 146 | def __init__(self, input_size, hidden_size, num_layers=1, 147 | bidirectional=False, hidden_p=0.2, input_p=0.6, 148 | weight_p=0.5, pack_pad_seq=False): 149 | super().__init__() 150 | self.input_size, self.hidden_size = input_size, hidden_size 151 | self.num_layers = num_layers 152 | self.pack_pad_seq = pack_pad_seq 153 | self.batch_sz = 1 154 | self.n_dir = 2 if bidirectional else 1 155 | self.rnns = nn.ModuleList([ 156 | self._one_rnn( 157 | n_in = input_size if layer_idx == 0 else hidden_size*self.n_dir, 158 | n_out = hidden_size, 159 | bidir = bidirectional, weight_p = weight_p) 160 | for layer_idx in range(num_layers)] 161 | ) 162 | self.input_dropout = RNNDropout(input_p) 163 | self.hidden_dropouts = nn.ModuleList( 164 | [RNNDropout(hidden_p) for l in range(num_layers)] 165 | ) 166 | 167 | def forward(self, x: Tensor, x_lens=None): 168 | """ 169 | Args: 170 | - x : Tensor - batch-major input of shape `(batch, seq_len, emb_sz)` 171 | - x_lens : Tensor - 1D tensor containing sequence length 172 | 173 | Outputs: 174 | - raw outputs : List[Tensor] - activation for each layer without dropout 175 | - outputs : List[Tensor] - activation for each layer with dropout 176 | """ 177 | batch_sz, seq_len = x.shape[:2] 178 | if batch_sz != self.batch_sz: 179 | self.batch_sz = batch_sz 180 | # self.reset() 181 | 182 | all_h = self.input_dropout(x) 183 | last_hiddens, raw_outputs, outputs = [], [], [] 184 | for layer_idx, (rnn, hid_dropout) in enumerate(zip(self.rnns, self.hidden_dropouts)): 185 | if self.pack_pad_seq: 186 | if x_lens is not None: 187 | all_h = pack_padded_sequence( 188 | all_h, x_lens, batch_first=True, enforce_sorted=False) 189 | else: 190 | raise ValueError("Please supply `x_lens` when pack_pad_seq=True") 191 | # all_h, last_hidden = rnn(all_h, self.last_hiddens[layer_idx]) 192 | all_h, last_hidden = rnn(all_h) 193 | if self.pack_pad_seq: 194 | all_h = pad_packed_sequence(all_h, batch_first=True)[0] 195 | 196 | last_hiddens.append(last_hidden) 197 | raw_outputs.append(all_h) 198 | # apply dropout to hidden states except last layer 199 | if layer_idx != self.num_layers - 1: 200 | all_h = hid_dropout(all_h) 201 | outputs.append(all_h) 202 | # self.last_hiddens = _to_detach(last_hiddens) 203 | self.raw_ouputs = raw_outputs 204 | self.outputs = outputs 205 | return all_h, last_hidden 206 | 207 | def _one_rnn(self, n_in, n_out, bidir, weight_p): 208 | "Return one of the inner rnn wrapped by WeightDropout" 209 | rnn = nn.LSTM(n_in, n_out, 1, batch_first=True, bidirectional=bidir) 210 | return WeightDropout(rnn, weight_p) 211 | 212 | def _init_h0(self, layer_idx: int) -> Tuple: 213 | "Init (h0, c0) as zero tensors for layer i" 214 | h0 = one_param(self).new_zeros(self.n_dir, 215 | self.batch_sz, self.hidden_size) 216 | c0 = one_param(self).new_zeros(self.n_dir, 217 | self.batch_sz, self.hidden_size) 218 | return (h0, c0) 219 | 220 | def reset(self): 221 | "Reset the hidden states - (for weightdrop)" 222 | [r.reset() for r in self.rnns if hasattr(r, 'reset')] 223 | self.last_hiddens = [self._init_h0(l) for l in range(self.num_layers)] 224 | 225 | 226 | # - 227 | def init_weight_bias(model, init_range=0.1): 228 | for name_w, w in model.named_parameters(): 229 | if "weight" in name_w: 230 | w.data.uniform_(-init_range, init_range) 231 | elif "bias" in name_w: 232 | w.bias.data.fill_(0.) 233 | 234 | 235 | # ### HANAttention 236 | 237 | class HANAttention(nn.Module): 238 | """ 239 | HAN Attention described in [Hierarchial Attention Networks - ACL16] 240 | with multi-head mechanism and diversity penalization specified in 241 | [A Structure Self-Attentive Sentence Embedding - ICLR17] 242 | sometimes referred as SelfAttention 243 | 244 | Attrs: 245 | - input_size: num of features / embedding sz 246 | - n_heads: num of subspaces to project input 247 | - pool_mode : flatten to return summary vectors, 248 | otherwise sum across features 249 | """ 250 | 251 | def __init__(self, input_size: int, attention_size: int, n_heads: int, 252 | pool_mode: str="flatten"): 253 | super().__init__() 254 | self.n_heads, self.pool_mode = n_heads, pool_mode 255 | self.proj = nn.Linear(input_size, attention_size, bias=False) 256 | self.queries = nn.Linear(attention_size, n_heads, False) 257 | self.reset_parameters() 258 | 259 | def reset_parameters(self): 260 | torch.nn.init.xavier_normal_(self.proj.weight) 261 | torch.nn.init.xavier_normal_(self.queries.weight) 262 | 263 | def forward(self, x: Tensor, x_lens: Tensor=None): 264 | """ 265 | Args: 266 | x : input of shape `(batch_sz, seq_len, n_features)` 267 | x_lens : lengths of x of shape `(batch_sz)` 268 | """ 269 | x_proj = torch.tanh(self.proj(x)) 270 | x_queries_sim = self.queries(x_proj) 271 | if x_lens is not None: 272 | masks = sequence_mask(x_lens).unsqueeze(-1) 273 | # attn_w: (batch_sz, seq_len, n_head) 274 | attn_w = softmax_with_mask(x_queries_sim, 275 | masks.expand_as(x_queries_sim), dim=1) 276 | else: 277 | attn_w = F.softmax(x_queries_sim, dim=1) 278 | # x_attended: (batch_sz, n_head, n_features) 279 | x_attended = attn_w.transpose(2, 1) @ x 280 | self.attn_w = attn_w 281 | return self.pool(x_attended), attn_w 282 | 283 | def pool(self, x): 284 | return x.flatten(1, 2) if self.pool_mode=="flatten" else x.mean(dim=1) 285 | 286 | def diversity(self): 287 | "Don't seem to work at all" 288 | # cov: (batch_sz, n_head, n_head) 289 | cov = self.attn_w.transpose(2, 1).bmm(self.attn_w) - torch.eye(self.n_head, device=self.attn_w.device).unsqueeze(0) 290 | return (cov**2).sum(dim=[1, 2]) 291 | 292 | 293 | # ### Mixup 294 | 295 | def pad_seq_len(x: Tensor, max_len: int): 296 | "Pad seq len dimension by 0 - appending zero word vectors" 297 | size = (x.size(0), max_len - x.size(1), x.size(2)) 298 | pad = x.new_zeros(*size) 299 | return torch.cat([x, pad], dim=1) 300 | 301 | 302 | class ManifoldMixup(nn.Module): 303 | """ 304 | Perform manifold mixup on `seq_len` dimension 305 | """ 306 | 307 | def forward(self, x1: Tensor, x2: Tensor, m: float = 1.): 308 | """ 309 | Args: 310 | - x1: shape (batch_size, seq_len, feature_size) 311 | - m: mixup factor 312 | """ 313 | assert x1.ndim == x2.ndim == 3 314 | # seq_lens might be different at this point 315 | if x1.size(1) != x2.size(1): 316 | max_seq_len = max(x1.size(1), x2.size(1)) 317 | if x1.size(1) < max_seq_len: 318 | x1 = pad_seq_len(x1, max_seq_len) 319 | else: 320 | x2 = pad_seq_len(x2, max_seq_len) 321 | 322 | x_mixup = m * x1 + (1 - m) * x2 323 | return x_mixup 324 | 325 | 326 | 327 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional, List, Any, Callable, Dict, Tuple 5 | import torch 6 | 7 | 8 | # inspired from https://github.com/google/CommonLoopUtils/tree/master/clu/metric_writers 9 | # 10 | # + declarative programming using `dataclass` 11 | # + nice trick to return inner Subclass for fluent interface 12 | # 13 | # ```python 14 | # @dataclass 15 | # class Metrics(Collection): 16 | # top5acc: Accuracy.from_output("top5acc") 17 | # ``` 18 | 19 | class Metric: 20 | "Interface for computing metrics" 21 | 22 | @classmethod 23 | def from_model_output(cls, *args, **kwargs) -> "Metric": 24 | raise NotImplementedError("Must override from_model_output()") 25 | 26 | def merge(self, other: "Metric") -> "Metric": 27 | """Returns `Metric` that is the accumulation of `self` and `other`. 28 | Args: 29 | other: A `Metric` whose inermediate values should be accumulated onto the 30 | values of `self`. 31 | Returns: 32 | A new `Metric` that accumulates the value from both `self` and `other`. 33 | """ 34 | raise NotImplementedError("Must override merge()") 35 | 36 | def compute(self): 37 | "Computes final metrics from intermediate values." 38 | raise NotImplementedError("Must override compute()") 39 | 40 | @classmethod 41 | def from_fun(cls, fun: Callable): # pylint: disable=g-bare-generic 42 | """Calls `cls.from_model_output` with the return value from `fun`.""" 43 | 44 | class Fun(cls): 45 | @classmethod 46 | def from_model_output(cls, **model_output) -> Metric: 47 | return super().from_model_output(fun(**model_output)) 48 | 49 | return Fun 50 | 51 | @classmethod 52 | def from_output(cls, name: str): # pylint: disable=g-bare-generic 53 | """Calls `cls.from_model_output` with model output named `name`.""" 54 | 55 | class FromOutput(cls): 56 | @classmethod 57 | def from_model_output(cls, **model_output) -> Metric: 58 | return super().from_model_output(model_output[name]) 59 | 60 | return FromOutput 61 | 62 | 63 | Metric.__doc__ += """ 64 | Refer to `Collection` for computing multipel metrics at the same time. 65 | 66 | Synopsis: 67 | 68 | @dataclass 69 | class Average(Metric): 70 | total: torch.Tensor 71 | count: torch.Tensor 72 | @classmethod 73 | def from_model_output(cls, value: jnp.array, **_) -> Metric: 74 | return cls(total=value.sum(), count=jnp.prod(value.shape)) 75 | def merge(self, other: Metric) -> Metric: 76 | return type(self)( 77 | total=self.total + other.total, 78 | count=self.count + other.count, 79 | ) 80 | def compute(self): 81 | return self.total / self.count 82 | 83 | average = None 84 | for value in range(data): 85 | update = Average.from_model_output(value) 86 | average = update if average is None else average.merge(update) 87 | print(average.compute()) 88 | """ 89 | 90 | 91 | # ### Average 92 | 93 | @dataclass 94 | class Average(Metric): 95 | """Compute the average of `values`. 96 | 97 | Optionally taking a mask to ignore values with mask = 0 98 | - values : ndim = 0 or ndim = 1 99 | - masks : shape same as values 100 | """ 101 | total: torch.Tensor # accumulation 102 | count: torch.Tensor # number of merges 103 | 104 | @classmethod 105 | def from_model_output(cls, values: torch.Tensor, 106 | mask: Optional[torch.Tensor]=None, **_) -> Metric: 107 | if values.ndim == 0: 108 | values = values[None] # prepend 1 109 | if mask is None: 110 | mask = torch.ones(values.shape).to(values.device) 111 | return cls( 112 | total=(mask* values).sum(), 113 | count=mask.sum() 114 | ) 115 | 116 | def merge(self, other: "Average") -> "Average": 117 | # assert total of the same shape 118 | return type(self)( 119 | total=self.total + other.total, 120 | count=self.count + other.count 121 | ) 122 | 123 | def compute(self) -> Any: 124 | return self.total / self.count 125 | 126 | 127 | # ### Accuracy 128 | 129 | @dataclass 130 | class Accuracy(Average): 131 | """Computes the average accuracy from model outputs `logits` and `labels`. 132 | 133 | - `labels` {int32} : shape (num_classes) 134 | - `logits` : shape (batch_size, num_classes) 135 | """ 136 | 137 | @classmethod 138 | def from_model_output(cls, *, 139 | logits: torch.Tensor, 140 | labels: torch.Tensor, **kwargs) -> Metric: 141 | return super().from_model_output( 142 | values=(logits.argmax(axis=-1) == labels).float(), **kwargs 143 | ) 144 | 145 | 146 | # ### Loss 147 | 148 | @dataclass 149 | class Loss(Average): 150 | "Computes the average `loss`" 151 | 152 | @classmethod 153 | def from_model_output(cls, loss: torch.Tensor, **kwargs) -> Metric: 154 | return super().from_model_output(values=loss, **kwargs) 155 | 156 | 157 | # ### Std 158 | 159 | @dataclass 160 | class Std(Metric): 161 | "Computes the standard deviation of a scalar or a batch of scalars." 162 | total: torch.Tensor 163 | sum_of_squares: torch.Tensor 164 | count: torch.Tensor 165 | 166 | @classmethod 167 | def from_model_output(cls, values: torch.Tensor, 168 | mask: Optional[torch.Tensor] = None, 169 | **_) -> Metric: 170 | if values.ndim == 0: 171 | values = values[None] 172 | # utils.check_param(values, ndim=1) 173 | if mask is None: 174 | mask = torch.ones(values.shape[0]) 175 | return cls( 176 | total=values.sum(), 177 | sum_of_squares=(mask * values**2).sum(), 178 | count=mask.sum(), 179 | ) 180 | 181 | def merge(self, other: "Std") -> "Std": 182 | # _assert_same_shape(self.total, other.total) 183 | return type(self)( 184 | total=self.total + other.total, 185 | sum_of_squares=self.sum_of_squares + other.sum_of_squares, 186 | count=self.count + other.count, 187 | ) 188 | 189 | def compute(self) -> Any: 190 | # var(X) = 1/N \sum_i (x_i - mean)^2 191 | # = 1/N \sum_i (x_i^2 - 2 x_i mean + mean^2) 192 | # = 1/N ( \sum_i x_i^2 - 2 mean \sum_i x_i + N * mean^2 ) 193 | # = 1/N ( \sum_i x_i^2 - 2 mean N mean + N * mean^2 ) 194 | # = 1/N ( \sum_i x_i^2 - N * mean^2 ) 195 | # = \sum_i x_i^2 / N - mean^2 196 | mean = self.total / self.count 197 | return (self.sum_of_squares / self.count - mean**2)**.5 198 | 199 | 200 | # ### Collection 201 | 202 | @dataclass 203 | class _ReductionCounter(Metric): 204 | """Pseudo metric that keeps track of the total number of `.merge()`.""" 205 | value: torch.Tensor 206 | 207 | def merge(self, other: "_ReductionCounter") -> "_ReductionCounter": 208 | return _ReductionCounter(self.value + other.value) 209 | 210 | 211 | @dataclass 212 | class Collection: 213 | "Updates a collection of `Metric` from model outputs." 214 | _reduction_counter: _ReductionCounter 215 | 216 | @classmethod 217 | def _from_model_output(cls, **kwargs) -> "Collection": 218 | return cls( 219 | _reduction_counter=_ReductionCounter(torch.tensor(1)), 220 | **{ 221 | metric_name: metric.from_model_output(**kwargs) 222 | for metric_name, metric in cls.__annotations__.items() 223 | } 224 | ) 225 | 226 | @classmethod 227 | def single_from_model_output(cls, **kwargs) -> "Collection": 228 | return cls._from_model_output(**kwargs) 229 | 230 | def merge(self, other: "Collection") -> "Collection": 231 | """Returns `Collection` that is the accumulation of `self` and `other`.""" 232 | return type(self)(**{ 233 | metric_name: metric.merge(getattr(other, metric_name)) 234 | for metric_name, metric in vars(self).items() 235 | }) 236 | 237 | def reduce(self) -> "Collection": 238 | """Reduces the collection by calling `Metric.reduce()` on each metric.""" 239 | return type(self)(**{ 240 | metric_name: metric.reduce() 241 | for metric_name, metric in vars(self).items() 242 | }) 243 | 244 | def compute(self) -> Dict[str, torch.Tensor]: 245 | """Computes metrics and returns them as Python numbers/lists.""" 246 | ndim = self._reduction_counter.value.ndim 247 | if ndim != 0: 248 | raise ValueError( 249 | f"Collection is still replicated (ndim={ndim}). Maybe you forgot to " 250 | f"call a flax.jax_utils.unreplicate() or a Collections.reduce()?") 251 | return { 252 | metric_name: metric.compute() 253 | for metric_name, metric in vars(self).items() 254 | if metric_name != "_reduction_counter" 255 | } 256 | 257 | 258 | Collection.__doc__ +=""" 259 | Synopsis: 260 | @dataclass 261 | class Metrics(Collection): 262 | accuracy: Accuracy 263 | 264 | metrics = None 265 | for inputs, labels in data: 266 | logits = model(inputs) 267 | update = Metrics.single_from_model_output(logits=logits, labels=labels) 268 | metrics = update if metrics is None else metrics.merge() 269 | print(metrics.compute()) 270 | """ 271 | 272 | 273 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | from layers import ManifoldMixup 7 | 8 | 9 | class GenericMixupModel(nn.Module): 10 | """ 11 | Generic model consists of an embedding layer, an encoder, a pooler and a classification head 12 | """ 13 | 14 | def __init__(self, embed: nn.Module, encoder: nn.Module, 15 | pooler: nn.Module, cls_head: nn.Module, use_mixup: bool = True): 16 | super().__init__() 17 | self.embed = embed 18 | self.encoder = encoder 19 | self.pooler = pooler 20 | self.cls_head = cls_head 21 | if use_mixup: 22 | self.mixup = nn.Sequential( 23 | ManifoldMixup(), nn.ReLU() 24 | ) 25 | 26 | def forward(self, x1, x1_lens=None, x2=None, x2_lens=None, mixup_factor: float=1.): 27 | """ 28 | - x2: example to mixup with x1 29 | - mixup_factor: 1 no mixup 30 | """ 31 | x1_embed = self.embed(x1) 32 | x1_encoded, _ = self.encoder(x1_embed, x1_lens) 33 | 34 | if x2 is not None: 35 | x2_embed = self.embed(x2) 36 | x2_encoded, _ = self.encoder(x2_embed, x2_lens) 37 | x_encoded = self.mixup(x1_encoded, x2_encoded, mixup_factor) 38 | else: 39 | x_encoded = x1_encoded 40 | 41 | x_pooled = self.pooler(x_encoded) 42 | logits = self.cls_head(x_pooled) 43 | return logits 44 | 45 | 46 | class EMAModel: 47 | 48 | def __init__(self, model): 49 | self.original = model 50 | self.model = copy.deepcopy(model) 51 | 52 | def __call__(self, **kwargs): 53 | return self.model(**kwargs) 54 | 55 | def update_parameters(self, alpha, global_step): 56 | alpha = min(1 - 1/(global_step+1), alpha) 57 | for ema_p, p in zip(self.model.parameters(), self.original.parameters()): 58 | # ema * alpha + (1 - alpha) * p 59 | ema_p.data.mul_(alpha).add_(1 - alpha, p.data) 60 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # + 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | # - 8 | 9 | class Trainer: 10 | """abstract class for trainer""" 11 | 12 | 13 | from models import EMAModel 14 | from metric import Collection, Loss 15 | from dataclasses import dataclass 16 | 17 | 18 | # ### Mean Teacher 19 | 20 | @dataclass 21 | class Metrics(Collection): 22 | loss: Loss.from_output("loss") 23 | sup_loss: Loss.from_output("sup_loss") 24 | sup_cst_loss: Loss.from_output("sup_cst_loss") 25 | 26 | 27 | # + 28 | 29 | # sup_loss_fn = nn.CrossEntropyLoss(reduction="none") # "none" 30 | # consistency_fn = softmax_mse_loss 31 | 32 | class MeanTeacherTrainer(Trainer): 33 | 34 | def train(self, n_epochs, model, train_dl, loss_fn, consistency_loss_fn, optimizer, lr, **kwargs): 35 | assert "cst_factor" in kwargs 36 | cst_factor = kwargs["cst_factor"] 37 | 38 | optimizer = self.optimizer(model.parameters(), lr=lr) 39 | 40 | metrics = None 41 | for epoch in range(n_epochs): 42 | model.train() ; ema_model.train() 43 | 44 | for x, x_lens, xa, xa_lens, y in train_dl: 45 | x, xa, y = x.cuda(), xa.cuda(), y.cuda() 46 | 47 | # SUPERVISED LOSS 48 | logits = model(x, x_lens) 49 | sup_loss = loss_fn(logits, y).mean() 50 | 51 | # CONSISTENCY LOSS 52 | logits_aug = ema_model(xa, xa_lens).detach() # no backprop for teacher 53 | 54 | sup_cst_loss = cst_factor * consistency_loss_fn(logits, logits_aug, True).mean() 55 | loss = sup_loss + sup_cst_loss 56 | loss.backward() 57 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_value) 58 | optimizer.step() ; optimizer.zero_grad() 59 | ema_model.update_parameters(ema_decay, global_step) 60 | 61 | update = Metrics.single_from_model_output( 62 | loss=loss, 63 | sup_loss=sup_loss, 64 | sup_cst_loss=sup_cst_loss 65 | ) 66 | 67 | with torch.no_grad(): 68 | metrics = update if metrics is None else metrics.merge(update) 69 | logger.info(f"\t[{epoch}][{i}] Loss: {metrics.compute()}") 70 | 71 | 72 | # - 73 | 74 | # ### UDA 75 | 76 | class UDATrainer(Trainer): 77 | 78 | def get_tsa_threshold(self, schedule, global_step, num_train_steps, start, end): 79 | training_progress = torch.tensor(float(global_step) / float(num_train_steps)) 80 | if schedule == "linear_schedule": 81 | threshold = training_progress 82 | elif schedule == "exp_schedule": 83 | scale = 5 84 | threshold = torch.exp( (training_progress - 1) * scale ) 85 | elif schedule == "log_schedule": 86 | scale = 5 87 | threshold = 1 - torch.exp(-training_progress * scale) 88 | output = threshold * (end - start) + start 89 | return output 90 | 91 | 92 | def train(self, n_epochs, model, train_dl, unlabeled_dl, loss_fn, consistency_loss_fn, unsup_loss_fn, optimizer, lr, **kwargs): 93 | optimizer = self.optimizer(model.parameters(), lr=lr) 94 | 95 | for epoch in range(n_epochs): 96 | _sup_loss = 0. ; _unsup_loss = 0. 97 | _sup_cst_loss = 0. ; _unsup_cst_loss = 0. 98 | _sup_acc = 0. ; _unsup_acc = 0. 99 | _sup_agreement = 0. ; _unsup_agreement = 0. ; 100 | _unsup_ce_loss = 0. ; _unsup_ce_loss_au = 0. 101 | _n_sup = 0 ; _n_unsup = 0 102 | model.train() 103 | 104 | for i, sample in enumerate(train_dl): 105 | x, x_lens, xa, xa_lens, y = sample 106 | x=x.cuda() ; y=y.cuda() ; xa=xa.cuda() 107 | 108 | try: 109 | x_un, x_un_lens, xa_un, xa_un_lens, y_un = next(unlabeled_it) 110 | except StopIteration: 111 | unlabeled_it = iter(unlabeled_dl) 112 | x_un, x_un_lens, xa_un, xa_un_lens, y_un = next(unlabeled_it) 113 | x_un=x_un.cuda(); xa_un=xa_un.cuda(); y_un=y_un.cuda() 114 | 115 | logits = model(x, x_lens) 116 | logits_au = model(xa, xa_lens) 117 | sup_cst_loss = sup_cst * consistency_loss_fn(logits, logits_au, True) 118 | 119 | sup_loss = loss_fn(logits, y) 120 | # === TSA === 121 | global_step += 1 122 | tsa_threshold = self.get_tsa_threshold(tsa_schedule, global_step, n_warmup, start=1/n_classes+0.02, end=.75).cuda() 123 | larger_than_threshold = torch.exp(-sup_loss) > tsa_threshold 124 | loss_mask = 1. - larger_than_threshold.float() 125 | loss = 0. 126 | # broadcasting with loss mask: should we do loss mask 127 | n_sup = loss_mask.sum() 128 | if n_sup > 0: 129 | sup_loss = (sup_loss * loss_mask).sum(-1) / n_sup 130 | sup_cst_loss = (sup_cst_loss * loss_mask).sum(-1) / n_sup 131 | loss += sup_loss + sup_cst_loss 132 | 133 | # === UNSUPERVISED === 134 | logits_un = model(x_un, x_un_lens) 135 | logits_un_au = model(xa_un, xa_un_lens) 136 | 137 | prob_un = logits_un.softmax(-1) 138 | unsup_cst_loss = unsup_cst*consistency_loss_fn(logits_un, logits_un_au, True) 139 | # loss += unsup_cst_loss 140 | 141 | # Confidence based masking for unlabeled 142 | # only release examples with conf > min threshold 143 | if uda_conf_threshold > 0.: 144 | unsup_loss_mask = (prob_un.max(-1)[0] > uda_conf_threshold).float() 145 | else: 146 | unsup_loss_mask = torch.ones(logits_un.size(0)).cuda() 147 | 148 | # Sharpening for unlabeled aug 149 | # prob_aug = (logits_aug_un.softmax(-1) / sharpen_T).log_softmax(-1) 150 | # prob_aug = sharpen(logits_aug_un.softmax(-1), sharpen_T).exp().log_softmax(-1) 151 | # prob_aug = logits_aug_un / uda_softmax_temp 152 | 153 | # Pseudo-labels 154 | # original data is relatively better predictor 155 | prob_un = sharpen(prob_un, sharpen_T) 156 | unsup_loss = unsup_loss_fn(logits_un_au.log_softmax(-1), prob_un).sum(-1) 157 | 158 | # avoid / by 0 159 | n_unsup = unsup_loss_mask.sum() 160 | if n_unsup > 0.: 161 | unsup_loss = (unsup_loss * unsup_loss_mask).sum(-1) / n_unsup 162 | unsup_cst_loss = (unsup_cst_loss * unsup_loss_mask).sum(-1)/n_unsup 163 | unsup_loss *= uda_coeff 164 | loss += unsup_loss + unsup_cst_loss 165 | 166 | # unsup_cst_loss 167 | if loss > 0.: 168 | loss.backward() 169 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_value) 170 | optimizer.step() ; optimizer.zero_grad() 171 | with torch.no_grad(): 172 | _sup_loss += sup_loss.item() if n_sup > 0 else 0. 173 | _unsup_loss += unsup_loss.item() if n_unsup > 0 else 0. # loss_fn(logits_un, y_un).mean().item() 174 | _sup_cst_loss += sup_cst_loss.item() if n_sup > 0 else 0. 175 | _unsup_cst_loss += unsup_cst_loss.item() if n_unsup > 0 else 0. 176 | _sup_acc += (logits.argmax(-1) == y).float().mean().item() 177 | _unsup_acc += (logits_un.argmax(-1) == y_un).float().mean().item() 178 | _sup_agreement += (logits.argmax(-1) == logits_au.argmax(-1)).float().mean().item() 179 | _unsup_agreement += (logits_un.argmax(-1) == logits_un_au.argmax(-1)).float().mean().item() 180 | _n_sup += n_sup.item() 181 | _n_unsup += n_unsup.item() 182 | _unsup_ce_loss += loss_fn(logits_un, y_un).mean().item() 183 | _unsup_ce_loss_au += loss_fn(logits_un_au, y_un).mean().item() 184 | _sup_loss /= (i+1) ; _unsup_loss /= (i+1) ; _sup_cst_loss /= (i+1) 185 | _unsup_cst_loss /= (i+1) ; _sup_acc /= (i+1) ; _unsup_acc /= (i+1) 186 | _sup_agreement /= (i+1) ; _unsup_agreement /= (i+1) 187 | _n_sup /= (i+1) ; _n_unsup /= (i+1) 188 | _unsup_ce_loss /= (i+1) ; _unsup_ce_loss_au /= (i+1) 189 | msg.info(f"\t[{i}] Loss Sup: {_sup_loss:.4f} - Unsup: {_unsup_loss:.4f}") 190 | msg.info(f"\tCST Sup: {_sup_cst_loss:.4f} - Unsup: {_unsup_cst_loss:.4f}") 191 | msg.info(f"\tAcc Sup: {_sup_acc:.4f} - Unsup: {_unsup_acc:.4f}") 192 | msg.info(f"\tAgreement Sup: {_sup_agreement:.4f} - Unsup: {_unsup_agreement:.4f}") 193 | msg.info(f"\tTSA Num Exs Sup: {_n_sup:.2f} - Unsup: {_n_unsup:.2f} - threshold: {tsa_threshold.item():.2f}") 194 | msg.info(f"\tUnsup CE: {_unsup_ce_loss:.4f} - CE au: {_unsup_ce_loss_au:.4f}") 195 | 196 | best_val, patience, _val_loss, _val_acc = val_step(model, model_name, val_dl, best_val, patience) 197 | msg.info(f"\tElapsed: {time.time() - start:.4f}") 198 | 199 | train_sup_loss += [_sup_loss] ; train_unsup_loss += [_unsup_loss] 200 | train_sup_cst_loss += [_sup_cst_loss] ; train_unsup_cst_loss += [_unsup_cst_loss] 201 | train_sup_acc += [_sup_acc] ; train_unsup_acc += [_unsup_acc] 202 | train_sup_agreement += [_sup_agreement] ; train_unsup_agreement += [_unsup_agreement] 203 | val_loss += [_val_loss] ; val_acc += [_val_acc] 204 | train_unsup_ce_loss += [_unsup_ce_loss] ; train_unsup_ce_loss_au += [_unsup_ce_loss_au] 205 | train_n_sup += [_n_sup] ; train_n_unsup += [_n_unsup] 206 | if patience > max_patience: raise StopIteration() 207 | 208 | 209 | class MixMatchTrainer(Trainer): 210 | 211 | def train(self, n_epochs, model, train_dl, unlabeled_dl, loss_fn, consistency_loss_fn, unsup_loss_fn, optimizer, lr, **kwargs): 212 | 213 | best_val = 1e+4 ; max_patience = 70 ; patience = 0 214 | # METRIC 215 | train_loss = [] ; train_unsup_acc = [] 216 | train_n_unsup = [] 217 | val_loss = [] ; val_acc = [] 218 | 219 | unlabeled_it = iter(unlabeled_dl) 220 | n_warmup = len(train_dl) * 100 221 | global_step = 0 222 | 223 | for epoch in range(n_epochs): 224 | start = time.time() 225 | msg.divider() 226 | msg.info(f"\t=== EPOCH {epoch} ===") 227 | _loss = 0. ; _unsup_acc = 0. ; _n_unsup=0. 228 | model.train() 229 | 230 | for i, sample in enumerate(train_dl): 231 | global_step += 1 232 | x, x_lens, xa, xa_lens, y_ = sample 233 | x=x.cuda() ; y_=y_.cuda() ; xa=xa.cuda() 234 | 235 | try: 236 | x_un, x_un_lens, xa_un, xa_un_lens, y_un_ = next(unlabeled_it) 237 | except StopIteration: 238 | unlabeled_it = iter(unlabeled_dl) 239 | x_un, x_un_lens, xa_un, xa_un_lens, y_un_ = next(unlabeled_it) 240 | x_un=x_un.cuda(); xa_un=xa_un.cuda(); y_un_=y_un_.cuda() 241 | y = F.one_hot(y_, n_classes).float() 242 | y_un = F.one_hot(y_un_, n_classes).float() 243 | 244 | # Pseudo-labels 245 | # original data is relatively better predictor 246 | model.eval() # turn off drop out 247 | prob_un = model(x_un, x_un_lens).softmax(-1).detach() 248 | 249 | # Confidence based masking for unlabeled 250 | # only release examples with conf > min threshold 251 | unsup_loss_mask = (prob_un.max(-1)[0] > unlabeled_conf) 252 | 253 | prob_un = sharpen(prob_un, sharpen_T).detach() 254 | model.train() 255 | 256 | n_unsup = unsup_loss_mask.sum() 257 | if n_unsup > 0.: 258 | # threshold the growth 259 | if epoch > 30: 260 | growth_factor = 1.1 261 | else: 262 | growth_factor = 0. 263 | max_unsup = int(min(max(train_n_unsup[-1], 2) * growth_factor, n_unsup)) + 1 264 | max_unsup = 1 if max_unsup < 2 else max_unsup 265 | n_unsup = max_unsup 266 | 267 | x_un = x_un[unsup_loss_mask][:n_unsup] 268 | prob_un = prob_un[unsup_loss_mask][:n_unsup] 269 | x_un_lens = [x_un_lens[j] 270 | for j, k in enumerate(unsup_loss_mask) 271 | if k == True][:n_unsup] 272 | # concat data 273 | y_all = torch.cat([y, prob_un], dim=0) 274 | 275 | # Padding 276 | max_seq_len = max(x.size(1), x_un.size(1)) 277 | if max_seq_len > x.size(1): 278 | pad = x.new_zeros(x.size(0), max_seq_len - x.size(1)) 279 | x = torch.cat([x, pad], dim=1) 280 | else: 281 | pad = x_un.new_zeros(x_un.size(0), max_seq_len - x_un.size(1)) 282 | x_un = torch.cat([x_un, pad], dim=1) 283 | assert x.size(1) == x_un.size(1) 284 | x_all = torch.cat([x, x_un], dim=0) 285 | x_lens = x_lens + x_un_lens 286 | else: 287 | x_all = x 288 | y_all = y 289 | 290 | l = np.random.beta(alpha, alpha) 291 | l = max(l, 1-l) 292 | 293 | idx = torch.randperm(x_all.size(0)) 294 | x1, x2 = x_all, x_all[idx] 295 | y1, y2 = y_all, y_all[idx] 296 | x1_lens, x2_lens = x_lens, [x_lens[j] for j in idx] 297 | # mix target 298 | y_mix = l * y1 + (1-l) * y2 299 | 300 | logits_mix = model(x1, x1_lens, x2, x2_lens, l) 301 | loss = loss_fn(logits_mix.log_softmax(-1), y_mix).sum(-1).mean() 302 | loss.backward() 303 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_value) 304 | optimizer.step() ; optimizer.zero_grad() 305 | with torch.no_grad(): 306 | _loss += loss.item() 307 | _unsup_acc = (prob_un.argmax(-1) == y_un_[unsup_loss_mask][:prob_un.size(0)]).float().mean().item() if n_unsup > 0 else 0. 308 | _n_unsup += n_unsup 309 | 310 | _loss /= (i+1) ; _n_unsup /= (i+1) ; _unsup_acc /= (i+1) 311 | msg.info(f"\t[{i}] Loss Sup: {_loss:.4f} - Acc unsup: {_unsup_acc:.4f}") 312 | msg.info(f"\tN unsup: {_n_unsup:.4f}") 313 | 314 | best_val, patience, _val_loss, _val_acc = val_step(model, model_name, val_dl, best_val, patience) 315 | msg.info(f"\tElapsed: {time.time() - start:.4f}") 316 | 317 | train_loss += [_loss] ; train_unsup_acc += [_unsup_acc] 318 | train_n_unsup += [_n_unsup] 319 | val_loss += [_val_loss] ; val_acc += [_val_acc] 320 | if patience > max_patience: raise StopIteration() 321 | 322 | 323 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def sharpen(p: Tensor, T: float): 6 | "Sharpen a probability distribution" 7 | p1 = p**(1/T) 8 | return p1 / p1.sum(dim=-1, keepdim=True) 9 | 10 | 11 | def softmax_with_mask(x: Tensor, mask: Tensor, dim: int=-1): 12 | """ 13 | Perform softmax over x's dim factoring boolean `mask` of the same shape 14 | """ 15 | 16 | assert x.shape == mask.shape, f"Input's shape {x.shape} and mask's shape {mask.shape} need to be equal" 17 | scores = F.softmax(x, dim) 18 | masked_scores = scores * mask.float() 19 | return masked_scores / (masked_scores.sum(dim, keepdim=True) + 1e-10) 20 | 21 | 22 | def sequence_mask(lengths: Tensor, max_len: int=None): 23 | """ 24 | Creates a boolean mask from sequence lengths 25 | - lengths: 1D tensor 26 | """ 27 | batch_size = lengths.numel() 28 | max_len = max_len or lengths.max() 29 | return (torch.arange(0, max_len, device=lengths.device).type_as(lengths) 30 | .unsqueeze(0).expand(batch_size, max_len).lt(lengths.unsqueeze(1))) 31 | --------------------------------------------------------------------------------