├── config └── base.yaml ├── README.md └── models └── dinosr.py /config/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | tensorboard_logdir: tb 8 | 9 | checkpoint: 10 | save_interval: 5 11 | save_interval_updates: 25000 12 | keep_interval_updates: 1 13 | no_epoch_checkpoints: true 14 | load_checkpoint_on_all_dp_ranks: true 15 | 16 | task: 17 | _name: audio_pretraining 18 | data: ??? 19 | max_sample_size: 320000 20 | min_sample_size: 32000 21 | normalize: true 22 | 23 | dataset: 24 | num_workers: 6 25 | max_tokens: 3800000 26 | skip_invalid_size_inputs_valid_test: true 27 | validate_interval: 5 28 | required_batch_size_multiple: 1 29 | disable_validation: true 30 | 31 | distributed_training: 32 | distributed_world_size: 16 33 | ddp_backend: legacy_ddp 34 | 35 | criterion: 36 | _name: model 37 | log_keys: 38 | - ema_decay 39 | - target_ppl 40 | - pred_ppl 41 | - codebook_decay 42 | 43 | optimization: 44 | max_update: 400000 45 | lr: [0.0005] 46 | 47 | optimizer: 48 | _name: adam 49 | adam_betas: (0.9,0.98) 50 | adam_eps: 1e-06 51 | weight_decay: 0.01 52 | 53 | lr_scheduler: 54 | _name: tri_stage 55 | phase_ratio: [0.03,0.47,0.50] 56 | 57 | model: 58 | _name: dinosr 59 | extractor_mode: layer_norm 60 | encoder_layerdrop: 0.05 61 | dropout_input: 0.0 62 | dropout_features: 0.0 63 | feature_grad_mult: 1.0 64 | encoder_embed_dim: 768 65 | 66 | discrete: true 67 | codebook_size: 256 68 | average_top_k_layers: 8 69 | normal_init_codebook: false 70 | codebook_init_decay: 0.9 71 | instance_norm_target_layer: true 72 | 73 | mask_prob: 0.8 74 | mask_length: 10 75 | 76 | pos_conv_depth: 5 77 | conv_pos: 95 78 | 79 | ema_decay: 0.999 80 | ema_end_decay: 0.9999 81 | ema_anneal_end_step: 30000 82 | ema_transformer_only: true 83 | ema_layers_only: true 84 | 85 | require_same_masks: true 86 | mask_dropout: 0 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # [DinoSR: Self-Distillation and Online Clustering for Self-supervised Speech Representation Learning](https://arxiv.org/pdf/2305.10005.pdf) 3 | 4 | 5 | ### Setup 6 | 7 | - Codebase preparation (based on [`fairseq`](https://github.com/facebookresearch/fairseq)) 8 | ``` 9 | # we use fairseq to build the model 10 | git clone https://github.com/facebookresearch/fairseq 11 | cd fairseq 12 | git checkout 47e279842ac8776e3964b0e45c320ad1d2ea6096 # we recommend using the commit DinoSR was developed on 13 | pip install --editable ./ 14 | 15 | # plug in DinoSR 16 | cd examples 17 | git clone https://github.com/Alexander-H-Liu/dinosr.git 18 | ``` 19 | 20 | - Data preparation: 21 | please follow [`instruction provided by wav2vec2`](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec) for pre-training/fine-tuning data preprocessing 22 | 23 | 24 | ### Usage 25 | 26 | - Training 27 | 28 | For the list of hyper-parameters, see [`config file`](config/base.yaml) and also [`model attributes`](models/dinosr.py) where default settings used in the paper are provided. 29 | 30 | ``` 31 | # minimal example to reproduce model 32 | python fairseq_cli/hydra_train.py -m \ 33 | --config-dir examples/dinosr/config/ \ 34 | --config-name base \ 35 | task.data=/path/to/prepared/librispeech/ \ 36 | common.user_dir=examples/dinosr & 37 | ``` 38 | 39 | - Loading pre-trained model as python object 40 | 41 | ``` 42 | import fairseq 43 | import argparse 44 | code_path = "examples/dinosr" 45 | fairseq.utils.import_user_module(argparse.Namespace(user_dir=code_path)) 46 | ckpt_path = "/path/to/the/checkpoint.pt" 47 | models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) 48 | model = models[0] 49 | ``` 50 | 51 | - Fine-tuning pre-trained checkpoint as ASR 52 | 53 | ``` 54 | # minimal example for fine-tuning with 100hr data 55 | python fairseq_cli/hydra_train.py -m \ 56 | --config-dir examples/wav2vec/config/finetuning \ 57 | --config-name base_100h \ 58 | common.user_dir=examples/dinosr \ 59 | task.data=/path/to/labeled/librispeech/ \ 60 | model.w2v_path=/path/to/dinosr.ckpt \ 61 | task.normalize=True 62 | ``` 63 | 64 | ### Pre-trained checkpoint 65 | 66 | Pre-trained checkpoint without fine-tuning can be downloaded [here](https://data.csail.mit.edu/placesaudio/dinosr/dinosr.ckpt). 67 | -------------------------------------------------------------------------------- /models/dinosr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | 6 | from omegaconf import II 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.distributed as dist 12 | 13 | from fairseq.modules import EMAModule, EMAModuleConfig 14 | from fairseq.data.data_utils import compute_mask_indices 15 | from fairseq.models import BaseFairseqModel, register_model 16 | from fairseq.models.wav2vec import ( 17 | ConvFeatureExtractionModel, 18 | Wav2Vec2Config, 19 | TransformerEncoder, 20 | ) 21 | from fairseq.modules import ( 22 | GradMultiply, 23 | LayerNorm, 24 | ) 25 | from fairseq.utils import index_put 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | @dataclass 32 | class DinosrAudioConfig(Wav2Vec2Config): 33 | 34 | discrete: bool = field(default=False) 35 | codebook_size: int = field(default=256) 36 | normal_init_codebook: bool = field(default=False) 37 | codebook_init_decay: float = field(default=0.9) 38 | codebook_end_decay: float = field(default=0.9) 39 | codebook_end_decay_step: int = field(default=0) 40 | freeze_teacher_step: int = field( 41 | default=200001, metadata={"help": "step to freeze teacher"} 42 | ) 43 | freeze_pre_enc_modules: bool = field( 44 | default=True, metadata={"help": "when freezing teacher, freeze the CNN extractor as well"} 45 | ) 46 | loss_beta: float = field( 47 | default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} 48 | ) 49 | loss_scale: Optional[float] = field( 50 | default=None, 51 | metadata={ 52 | "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" 53 | }, 54 | ) 55 | average_top_k_layers: int = field( 56 | default=8, metadata={"help": "how many layers to average"} 57 | ) 58 | 59 | layer_norm_target_layer: bool = False 60 | instance_norm_target_layer: bool = False 61 | instance_norm_targets: bool = False 62 | layer_norm_targets: bool = False 63 | batch_norm_target_layer: bool = False 64 | group_norm_target_layer: bool = False 65 | 66 | ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) 67 | ema_end_decay: float = field( 68 | default=0.9999, metadata={"help": "final ema decay rate"} 69 | ) 70 | 71 | # when to finish annealing ema decay rate 72 | ema_anneal_end_step: int = II("optimization.max_update") 73 | 74 | ema_transformer_only: bool = field( 75 | default=True, 76 | metadata={"help": "whether to momentum update only the transformer"}, 77 | ) 78 | ema_layers_only: bool = field( 79 | default=True, 80 | metadata={"help": "whether to momentum update only the transformer layers"}, 81 | ) 82 | 83 | max_update: int = II("optimization.max_update") 84 | 85 | min_target_var: float = field( 86 | default=0.1, metadata={"help": "stop training if target var falls below this"} 87 | ) 88 | min_pred_var: float = field( 89 | default=0.01, 90 | metadata={"help": "stop training if prediction var falls below this"}, 91 | ) 92 | 93 | 94 | def get_annealed_rate(start, end, curr_step, total_steps): 95 | r = end - start 96 | pct_remaining = 1 - curr_step / total_steps 97 | return end - r * pct_remaining 98 | 99 | 100 | @register_model("dinosr", dataclass=DinosrAudioConfig) 101 | class DinosrModel(BaseFairseqModel): 102 | def __init__(self, cfg: DinosrAudioConfig): 103 | super().__init__() 104 | self.cfg = cfg 105 | self.discrete = cfg.discrete 106 | 107 | feature_enc_layers = eval(cfg.conv_feature_layers) 108 | self.extractor_embed = feature_enc_layers[-1][0] 109 | 110 | self.ema = None 111 | self.embed = cfg.encoder_embed_dim 112 | 113 | self.average_top_k_layers = cfg.average_top_k_layers 114 | self.loss_beta = cfg.loss_beta 115 | self.loss_scale = cfg.loss_scale 116 | 117 | self.feature_extractor = ConvFeatureExtractionModel( 118 | conv_layers=feature_enc_layers, 119 | dropout=0.0, 120 | mode=cfg.extractor_mode, 121 | conv_bias=cfg.conv_bias, 122 | ) 123 | 124 | self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim) 125 | 126 | self.mask_prob = cfg.mask_prob 127 | self.mask_selection = cfg.mask_selection 128 | self.mask_other = cfg.mask_other 129 | self.mask_length = cfg.mask_length 130 | self.no_mask_overlap = cfg.no_mask_overlap 131 | self.mask_min_space = cfg.mask_min_space 132 | 133 | self.mask_channel_prob = cfg.mask_channel_prob 134 | self.mask_channel_before = cfg.mask_channel_before 135 | self.mask_channel_selection = cfg.mask_channel_selection 136 | self.mask_channel_other = cfg.mask_channel_other 137 | self.mask_channel_length = cfg.mask_channel_length 138 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap 139 | self.mask_channel_min_space = cfg.mask_channel_min_space 140 | 141 | self.dropout_input = nn.Dropout(cfg.dropout_input) 142 | self.dropout_features = nn.Dropout(cfg.dropout_features) 143 | 144 | self.feature_grad_mult = cfg.feature_grad_mult 145 | 146 | self.mask_emb = nn.Parameter( 147 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_() 148 | ) 149 | 150 | self.encoder = TransformerEncoder(cfg) 151 | self.layer_norm = LayerNorm(self.extractor_embed) 152 | 153 | self.pre_encoder_copied = False 154 | if self.discrete: 155 | assert cfg.instance_norm_target_layer 156 | assert not (cfg.layer_norm_targets or cfg.instance_norm_targets) 157 | self.codebook_size = cfg.codebook_size 158 | self.n_codebooks = cfg.average_top_k_layers 159 | self.codebook_decay = cfg.codebook_init_decay 160 | # Prediction heads 161 | self.heads = torch.nn.ModuleList([ 162 | nn.Linear( 163 | cfg.encoder_embed_dim, 164 | cfg.codebook_size, 165 | ) 166 | for i in range(self.n_codebooks) 167 | ] 168 | ) 169 | # Codebook: use dictionary to store so codebooks are always in fp32 170 | if cfg.normal_init_codebook: 171 | codebooks = torch.normal(0.0, (1 / self.codebook_size**0.5), 172 | size=(self.n_codebooks, self.codebook_size, cfg.encoder_embed_dim)) 173 | else: 174 | codebooks = torch.randn(self.n_codebooks, cfg.encoder_embed_dim, self.codebook_size) 175 | codebooks = F.instance_norm(codebooks).transpose(1,2) 176 | self.codebooks = { 177 | i:codebooks[i] for i in range(self.n_codebooks) 178 | } 179 | self.codebook_cnts = { 180 | i:torch.ones([self.codebook_size]) for i in range(self.n_codebooks) 181 | } 182 | self.shared_module_state_dict = None 183 | else: 184 | self.final_proj = nn.Linear(self.embed, self.embed) 185 | 186 | self.num_updates = 0 187 | 188 | def make_ema_teacher(self): 189 | ema_config = EMAModuleConfig( 190 | ema_decay=self.cfg.ema_decay, 191 | ema_fp32=True, 192 | ) 193 | skip_keys = set() 194 | if self.cfg.ema_layers_only: 195 | self.cfg.ema_transformer_only = True 196 | for k, _ in self.encoder.pos_conv.named_parameters(): 197 | skip_keys.add(f"pos_conv.{k}") 198 | 199 | self.ema = EMAModule( 200 | self.encoder if self.cfg.ema_transformer_only else self, 201 | ema_config, 202 | skip_keys=skip_keys, 203 | ) 204 | 205 | def move_codebook_to_gpu(self): 206 | # Move codebook to GPU 207 | device = next(self.encoder.parameters()).device 208 | self.codebooks = { 209 | i:self.codebooks[i].to(device) for i in range(self.n_codebooks) 210 | } 211 | self.codebook_cnts = { 212 | i:self.codebook_cnts[i].to(device) for i in range(self.n_codebooks) 213 | } 214 | 215 | def freeze_shared_modules(self): 216 | # Hack to avoid updating any of the shared modules (e.g., Weight Decay from optimizer) 217 | # using WD=0 + torch.no_grad() for following modules will still result in higher loss somehow 218 | if self.shared_module_state_dict is None: 219 | self.shared_module_state_dict = {} 220 | self.shared_module_state_dict['feature_extractor'] = self.feature_extractor.state_dict() 221 | self.shared_module_state_dict['layer_norm'] = self.layer_norm.state_dict() 222 | self.shared_module_state_dict['post_extract_proj'] = self.post_extract_proj.state_dict() 223 | else: 224 | self.feature_extractor.load_state_dict(self.shared_module_state_dict['feature_extractor']) 225 | self.layer_norm.load_state_dict(self.shared_module_state_dict['layer_norm']) 226 | self.post_extract_proj.load_state_dict(self.shared_module_state_dict['post_extract_proj']) 227 | 228 | def copy_shared_modules(self): 229 | if not self.pre_encoder_copied: 230 | ema_config = EMAModuleConfig( 231 | ema_decay=1, 232 | ema_fp32=True, 233 | ) 234 | self.cnn_copy = EMAModule( 235 | self.feature_extractor, 236 | ema_config, 237 | skip_keys=set(), 238 | ) 239 | self.ln_copy = EMAModule( 240 | self.layer_norm, 241 | ema_config, 242 | skip_keys=set(), 243 | ) 244 | self.proj_copy = EMAModule( 245 | self.post_extract_proj, 246 | ema_config, 247 | skip_keys=set(), 248 | ) 249 | self.pre_encoder_copied = True 250 | logger.debug(f"pre-encoder modules copied for teacher model") 251 | 252 | def set_num_updates(self, num_updates): 253 | super().set_num_updates(num_updates) 254 | 255 | if self.cfg.freeze_teacher_step!=-1 and num_updates>=self.cfg.freeze_teacher_step: 256 | if self.cfg.freeze_pre_enc_modules: 257 | self.freeze_shared_modules() 258 | else: 259 | self.copy_shared_modules() 260 | self.cfg.ema_end_decay = 1 261 | 262 | if self.ema is None and (self.discrete or self.final_proj is not None): 263 | logger.info(f"making ema teacher") 264 | self.make_ema_teacher() 265 | elif self.training and self.ema is not None: 266 | if self.cfg.ema_decay != self.cfg.ema_end_decay: 267 | if num_updates >= self.cfg.ema_anneal_end_step: 268 | decay = self.cfg.ema_end_decay 269 | else: 270 | decay = get_annealed_rate( 271 | self.cfg.ema_decay, 272 | self.cfg.ema_end_decay, 273 | num_updates, 274 | self.cfg.ema_anneal_end_step, 275 | ) 276 | self.ema.set_decay(decay) 277 | if self.ema.get_decay() < 1: 278 | self.ema.step(self.encoder if self.cfg.ema_transformer_only else self) 279 | 280 | if self.cfg.codebook_init_decay == self.cfg.codebook_end_decay: 281 | self.codebook_decay = self.cfg.codebook_init_decay 282 | else: 283 | if num_updates >= self.cfg.codebook_end_decay_step: 284 | self.codebook_decay = self.cfg.codebook_end_decay 285 | else: 286 | self.codebook_decay = get_annealed_rate( 287 | self.cfg.codebook_init_decay, 288 | self.cfg.codebook_end_decay, 289 | num_updates, 290 | self.cfg.codebook_end_decay_step, 291 | ) 292 | 293 | self.num_updates = num_updates 294 | 295 | def state_dict(self, destination=None, prefix="", keep_vars=False): 296 | if self.shared_module_state_dict is not None: 297 | self.freeze_shared_modules() 298 | 299 | state = super().state_dict(destination, prefix, keep_vars) 300 | 301 | if self.ema is not None: 302 | state[prefix + "_ema"] = self.ema.fp32_params 303 | 304 | if self.discrete: 305 | for i in range(self.n_codebooks): 306 | state[prefix+f'_codebook{i}'] = self.codebooks[i] 307 | state[prefix+f'_codebook_cnts{i}'] = self.codebook_cnts[i] 308 | 309 | if self.pre_encoder_copied: 310 | state[prefix+'_pre_encoder_cnn'] = self.cnn_copy.fp32_params 311 | state[prefix+'_pre_encoder_ln'] = self.ln_copy.fp32_params 312 | state[prefix+'_pre_encoder_proj'] = self.proj_copy.fp32_params 313 | 314 | return state 315 | 316 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): 317 | if self.ema is not None: 318 | k = prefix + "_ema" 319 | assert k in state_dict 320 | self.ema.restore(state_dict[k], True) 321 | del state_dict[k] 322 | 323 | if self.discrete: 324 | for i in range(self.n_codebooks): 325 | k = prefix+f'_codebook{i}' 326 | assert k in state_dict 327 | self.codebooks[i] = state_dict[k].contiguous() 328 | del state_dict[k] 329 | k = prefix+f'_codebook_cnts{i}' 330 | assert k in state_dict 331 | self.codebook_cnts[i] = state_dict[k].contiguous() 332 | del state_dict[k] 333 | 334 | k = prefix+'_pre_encoder_cnn' 335 | if self.pre_encoder_copied: 336 | assert k in state_dict 337 | self.cnn_copy.restore(state_dict[k],True) 338 | del state_dict[k] 339 | k = prefix+'_pre_encoder_ln' 340 | self.ln_copy.restore(state_dict[k],True) 341 | del state_dict[k] 342 | k = prefix+'_pre_encoder_proj' 343 | self.proj_copy.restore(state_dict[k],True) 344 | del state_dict[k] 345 | return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 346 | 347 | @classmethod 348 | def build_model(cls, cfg: DinosrAudioConfig, task=None): 349 | """Build a new model instance.""" 350 | 351 | return cls(cfg) 352 | 353 | def apply_mask( 354 | self, 355 | x, 356 | padding_mask, 357 | mask_indices=None, 358 | mask_channel_indices=None, 359 | ): 360 | B, T, C = x.shape 361 | 362 | if self.mask_channel_prob > 0 and self.mask_channel_before: 363 | mask_channel_indices = compute_mask_indices( 364 | (B, C), 365 | None, 366 | self.mask_channel_prob, 367 | self.mask_channel_length, 368 | self.mask_channel_selection, 369 | self.mask_channel_other, 370 | no_overlap=self.no_mask_channel_overlap, 371 | min_space=self.mask_channel_min_space, 372 | ) 373 | mask_channel_indices = ( 374 | torch.from_numpy(mask_channel_indices) 375 | .to(x.device) 376 | .unsqueeze(1) 377 | .expand(-1, T, -1) 378 | ) 379 | x[mask_channel_indices] = 0 380 | 381 | if self.mask_prob > 0: 382 | if mask_indices is None: 383 | mask_indices = compute_mask_indices( 384 | (B, T), 385 | padding_mask, 386 | self.mask_prob, 387 | self.mask_length, 388 | self.mask_selection, 389 | self.mask_other, 390 | min_masks=1, 391 | no_overlap=self.no_mask_overlap, 392 | min_space=self.mask_min_space, 393 | require_same_masks=self.cfg.require_same_masks, 394 | mask_dropout=self.cfg.mask_dropout, 395 | ) 396 | mask_indices = torch.from_numpy(mask_indices).to(x.device) 397 | x = index_put(x, mask_indices, self.mask_emb) 398 | else: 399 | mask_indices = None 400 | 401 | if self.mask_channel_prob > 0 and not self.mask_channel_before: 402 | if mask_channel_indices is None: 403 | mask_channel_indices = compute_mask_indices( 404 | (B, C), 405 | None, 406 | self.mask_channel_prob, 407 | self.mask_channel_length, 408 | self.mask_channel_selection, 409 | self.mask_channel_other, 410 | no_overlap=self.no_mask_channel_overlap, 411 | min_space=self.mask_channel_min_space, 412 | ) 413 | mask_channel_indices = ( 414 | torch.from_numpy(mask_channel_indices) 415 | .to(x.device) 416 | .unsqueeze(1) 417 | .expand(-1, T, -1) 418 | ) 419 | x = index_put(x, mask_channel_indices, 0) 420 | 421 | return x, mask_indices 422 | 423 | def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): 424 | """ 425 | Computes the output length of the convolutional layers 426 | """ 427 | 428 | def _conv_out_length(input_length, kernel_size, stride): 429 | return torch.floor((input_length - kernel_size) / stride + 1) 430 | 431 | conv_cfg_list = eval(self.cfg.conv_feature_layers) 432 | 433 | for i in range(len(conv_cfg_list)): 434 | input_lengths = _conv_out_length( 435 | input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] 436 | ) 437 | 438 | return input_lengths.to(torch.long) 439 | 440 | def forward( 441 | self, 442 | source, 443 | padding_mask=None, 444 | mask=True, 445 | features_only=False, 446 | layer=None, 447 | mask_indices=None, 448 | mask_channel_indices=None, 449 | padding_count=None, 450 | ): 451 | features = source 452 | 453 | if self.feature_grad_mult > 0: 454 | features = self.feature_extractor(features) 455 | if self.feature_grad_mult != 1.0: 456 | features = GradMultiply.apply(features, self.feature_grad_mult) 457 | else: 458 | with torch.no_grad(): 459 | features = self.feature_extractor(features) 460 | 461 | features = features.transpose(1, 2) 462 | 463 | features = self.layer_norm(features) 464 | 465 | orig_padding_mask = padding_mask 466 | 467 | if padding_mask is not None and padding_mask.any(): 468 | input_lengths = (1 - padding_mask.long()).sum(-1) 469 | # apply conv formula to get real output_lengths 470 | output_lengths = self._get_feat_extract_output_lengths(input_lengths) 471 | 472 | padding_mask = torch.zeros( 473 | features.shape[:2], dtype=features.dtype, device=features.device 474 | ) 475 | 476 | # these two operations makes sure that all values 477 | # before the output lengths indices are attended to 478 | padding_mask[ 479 | ( 480 | torch.arange(padding_mask.shape[0], device=padding_mask.device), 481 | output_lengths - 1, 482 | ) 483 | ] = 1 484 | padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() 485 | else: 486 | padding_mask = None 487 | 488 | if self.post_extract_proj is not None: 489 | features = self.post_extract_proj(features) 490 | 491 | pre_encoder_features = None 492 | if self.pre_encoder_copied: 493 | # Copied pre-encoder modules used for teacher model 494 | self.cnn_copy.model.eval() 495 | self.ln_copy.model.eval() 496 | self.proj_copy.model.eval() 497 | with torch.no_grad(): 498 | pre_encoder_features = self.cnn_copy.model(source) 499 | pre_encoder_features = pre_encoder_features.transpose(1, 2) 500 | pre_encoder_features = self.ln_copy.model(pre_encoder_features) 501 | pre_encoder_features = self.proj_copy.model(pre_encoder_features) 502 | elif self.cfg.ema_transformer_only: 503 | pre_encoder_features = features.clone() 504 | 505 | features = self.dropout_input(features) 506 | 507 | if mask: 508 | x, mask_indices = self.apply_mask( 509 | features, 510 | padding_mask, 511 | mask_indices=mask_indices, 512 | mask_channel_indices=mask_channel_indices, 513 | ) 514 | else: 515 | x = features 516 | mask_indices = None 517 | 518 | x, layer_results = self.encoder( 519 | x, 520 | padding_mask=padding_mask, 521 | layer=layer, 522 | ) 523 | 524 | if features_only: 525 | return { 526 | "x": x, 527 | "padding_mask": padding_mask, 528 | "layer_results": layer_results, 529 | } 530 | 531 | result = { 532 | "losses": {}, 533 | } 534 | 535 | with torch.no_grad(): 536 | self.ema.model.eval() 537 | 538 | if self.cfg.ema_transformer_only: 539 | y, layer_results = self.ema.model.extract_features( 540 | pre_encoder_features, 541 | padding_mask=padding_mask, 542 | min_layer=self.cfg.encoder_layers - self.average_top_k_layers, 543 | ) 544 | y = { 545 | "x": y, 546 | "padding_mask": padding_mask, 547 | "layer_results": layer_results, 548 | } 549 | else: 550 | y = self.ema.model.extract_features( 551 | source=source, 552 | padding_mask=orig_padding_mask, 553 | mask=False, 554 | ) 555 | 556 | target_layer_results = [l[2] for l in y["layer_results"]] 557 | 558 | permuted = False 559 | if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: 560 | target_layer_results = [ 561 | tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT 562 | ] 563 | permuted = True 564 | 565 | if self.cfg.batch_norm_target_layer: 566 | target_layer_results = [ 567 | F.batch_norm( 568 | tl.float(), running_mean=None, running_var=None, training=True 569 | ) 570 | for tl in target_layer_results 571 | ] 572 | 573 | if self.cfg.instance_norm_target_layer: 574 | target_layer_results = [ 575 | F.instance_norm(tl.float()) for tl in target_layer_results 576 | ] 577 | 578 | if permuted: 579 | target_layer_results = [ 580 | tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC 581 | ] 582 | 583 | if self.cfg.group_norm_target_layer: 584 | target_layer_results = [ 585 | F.layer_norm(tl.float(), tl.shape[-2:]) 586 | for tl in target_layer_results 587 | ] 588 | 589 | if self.cfg.layer_norm_target_layer: 590 | target_layer_results = [ 591 | F.layer_norm(tl.float(), tl.shape[-1:]) 592 | for tl in target_layer_results 593 | ] 594 | 595 | if self.discrete: 596 | target_layer_results = [ 597 | tl[mask_indices] for tl in target_layer_results 598 | ] 599 | else: 600 | y = sum(target_layer_results) / len(target_layer_results) 601 | 602 | if self.cfg.layer_norm_targets: 603 | y = F.layer_norm(y.float(), y.shape[-1:]) 604 | 605 | if self.cfg.instance_norm_targets: 606 | y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2) 607 | 608 | if not permuted: 609 | y = y.transpose(0, 1) 610 | 611 | y = y[mask_indices] 612 | 613 | x = x[mask_indices] 614 | 615 | if self.discrete: 616 | if self.codebooks[0].device != x.device: 617 | self.move_codebook_to_gpu() 618 | 619 | losses = 0 620 | target_ppl, pred_ppl = 0,0 621 | 622 | for i,target in enumerate(target_layer_results): 623 | # Quantize target 624 | with torch.no_grad(): 625 | codebook = self.codebooks[i].float() / self.codebook_cnts[i].unsqueeze(1) 626 | neg_l2_dist = - (torch.sum(target**2, dim=1, keepdim=True) 627 | + torch.sum(codebook**2, dim=1) 628 | - 2 * torch.matmul(target, codebook.t())) 629 | onehot_target = torch.zeros_like(neg_l2_dist) 630 | onehot_target[range(len(neg_l2_dist)),neg_l2_dist.argmax(-1)] = 1.0 631 | # Compute loss 632 | pred = self.heads[i](x).float() 633 | pred = F.log_softmax(pred,dim=-1) 634 | loss = torch.sum(-onehot_target*pred,dim=-1) 635 | losses = losses + loss 636 | 637 | # Compute stats & update codebook 638 | with torch.no_grad(): 639 | # Stats 640 | target_ppl += self.compute_ppl(onehot_target,input_onehot=True) 641 | pred_ppl += self.compute_ppl(pred.float(),input_onehot=False) 642 | if self.training and self.codebook_decay<1: 643 | # Update codebook 644 | # Note: this is done in a per-forward style, 645 | # might wanna consider doing this in set_num_updates 646 | count = onehot_target.sum(0) 647 | memory = torch.matmul(onehot_target.t(), target) 648 | if dist.is_initialized(): 649 | dist.all_reduce(memory) # Sum of embeddings 650 | dist.all_reduce(count) # Total counts 651 | alpha = torch.ones_like(count).unsqueeze(1) 652 | alpha[count!=0] = self.codebook_decay 653 | self.codebook_cnts[i] = alpha.squeeze(1) * self.codebook_cnts[i] + (1-alpha).squeeze(1) * count 654 | self.codebooks[i] = alpha * self.codebooks[i] + (1-alpha) * memory 655 | 656 | result["losses"]["cross_entropy"] = (losses/self.n_codebooks).sum() 657 | 658 | else: 659 | x = self.final_proj(x) 660 | 661 | sz = x.size(-1) 662 | 663 | if self.loss_beta == 0: 664 | loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1) 665 | else: 666 | loss = F.smooth_l1_loss( 667 | x.float(), y.float(), reduction="none", beta=self.loss_beta 668 | ).sum(dim=-1) 669 | 670 | if self.loss_scale is not None: 671 | scale = self.loss_scale 672 | else: 673 | scale = 1 / math.sqrt(sz) 674 | 675 | result["losses"]["regression"] = loss.sum() * scale 676 | 677 | if "sample_size" not in result: 678 | result["sample_size"] = loss.numel() 679 | 680 | with torch.no_grad(): 681 | if self.discrete: 682 | result["target_ppl"] = target_ppl/self.n_codebooks 683 | result["pred_ppl"] = pred_ppl/self.n_codebooks 684 | result["codebook_decay"] = self.codebook_decay 685 | else: 686 | result["target_var"] = self.compute_var(y) 687 | result["pred_var"] = self.compute_var(x.float()) 688 | 689 | if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var: 690 | logger.error( 691 | f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" 692 | ) 693 | raise Exception( 694 | f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting" 695 | ) 696 | if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var: 697 | logger.error( 698 | f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" 699 | ) 700 | raise Exception( 701 | f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting" 702 | ) 703 | 704 | if self.ema is not None: 705 | result["ema_decay"] = self.ema.get_decay() * 1000 706 | 707 | return result 708 | 709 | @staticmethod 710 | def compute_ppl(y, input_onehot=False, tokenwise=False): 711 | # We track the avg. of 1-hot (argmax) 712 | if not input_onehot: 713 | y = y.softmax(dim=-1) 714 | if tokenwise: 715 | y = 2**(- y * (y+1e-8).log2()).sum(-1) 716 | y = y.mean(0) 717 | if dist.is_initialized(): 718 | dist.all_reduce(y) 719 | y = y / dist.get_world_size() 720 | if not tokenwise: 721 | y = 2**(- y * (y+1e-8).log2()).sum() 722 | return y 723 | 724 | @staticmethod 725 | def compute_var(y): 726 | y = y.view(-1, y.size(-1)) 727 | if dist.is_initialized(): 728 | zc = torch.tensor(y.size(0)).cuda() 729 | zs = y.sum(dim=0) 730 | zss = (y ** 2).sum(dim=0) 731 | 732 | dist.all_reduce(zc) 733 | dist.all_reduce(zs) 734 | dist.all_reduce(zss) 735 | 736 | var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) 737 | return torch.sqrt(var + 1e-6).mean() 738 | else: 739 | return torch.sqrt(y.var(dim=0) + 1e-6).mean() 740 | 741 | def extract_features( 742 | self, source, padding_mask, mask=False, layer=None 743 | ): 744 | res = self.forward( 745 | source, 746 | padding_mask, 747 | mask=mask, 748 | features_only=True, 749 | layer=layer, 750 | ) 751 | return res 752 | 753 | def remove_pretraining_modules(self, last_layer=None): 754 | self.heads = None 755 | self.final_proj = None 756 | self.ema = None 757 | if last_layer is not None: 758 | self.encoder.layers = nn.ModuleList( 759 | l for i, l in enumerate(self.encoder.layers) if i <= last_layer 760 | ) 761 | --------------------------------------------------------------------------------