├── .gitignore ├── LICENSE ├── README.md ├── audio ├── __init__.py ├── configs │ └── wav2vec2-pretraining.yaml ├── dataset.py ├── encoder.py └── trainer.py ├── data2vec.png ├── data2vec ├── __init__.py ├── data2vec.py └── ema.py ├── requirements.txt ├── text ├── __init__.py ├── configs │ └── roberta-pretraining.yaml ├── dataset.py ├── encoder.py └── trainer.py ├── train.py ├── utils.py └── vision ├── __init__.py ├── configs └── beit-pretraining.yaml ├── dataset.py ├── encoder.py ├── trainer.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | logs/ 132 | weights/ 133 | checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Aryan Shekarlaban 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # data2vec-pytorch 2 | ##### PyTorch implementation of "[data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555)" from Meta AI (FAIR) 3 | #### Disclaimer: This repo's goal is to make data2vec easier to understand hence it's not recommended to use for actual model pretraining but instead you'd better use the official version in fairseq or the ones provided on HuggingFace. 4 | Data2Vec is the first high-performance self-supervised algorithm that learns the same way in multiple modalities, including speech, vision and text. 5 | Most machines learn exclusively from labeled data. However, through self-supervised learning, machines are able to learn about the world just by observing it 6 | and then figuring out the structure of images, speech or text. This is a more scalable and efficient approach for machines to tackle new complex tasks, 7 | such as understanding text for more spoken languages. 8 | 9 | ![](data2vec.png) 10 | 11 | In summary, the method is as follows:
12 | 1. The encoder extracts features from the masked inputs. These features are outputs of every transformer/linear layer. 13 | 2. The teacher which is an EMA instance of the encoder (in eval model), extracts features from the unmasked inputs. 14 | 3. Optional normalizations are applied to the layers/outputs of the teacher. 15 | 4. Encoder outputs are regressed by a projection block/layer. 16 | 5. The loss is calculated from encoder outputs and teacher outputs. 17 | 18 | You can read the paper for more detail. 19 | 20 | ## Implementation 21 | Data2Vec is already implemented in [fairseq](https://github.com/pytorch/fairseq/tree/main/examples/data2vec) in which for all modalities there is a seperate implementation (text, vision, audio). According to the paper: 22 | > Our primary is to design a single learning mechanism for different modalities. 23 | Despite the unified learning regime, we still use modality-specific features extractors and masking strategies. 24 | This makes sense given the vastly different nature of the input data. 25 | 26 | This implementation differs in the fact that a single Data2Vec model is provided powered by a custom encoder (implemented using PyTorch + HuggingFace Transformers) and tries to unify the whole concept in a single module. 27 | The key concept is that there must be modality-specific feature extractions and masking strategies. 28 | 29 | - **Masking:** For each modality, the Dataset instance must return the masked source, the target and the mask tensor. 30 | 31 | - **Feature Extraction:** Features are the outputs from the transformer/attention layers. So the forward method must return outputs from all Encoder blocks of the transformer model. HuggingFace Transformers/Fairseq models return transformer layers outputs separately out of the box. 32 | 33 | This implementation uses HuggingFace Transformers models as encoders for Data2Vec which you can inspect in the `encoder.py` files for each modality. Although, you can provide your own encoder model. Just make sure that your encoder must be Transformer-based according to the paper and outputs from every encoder layer must be provided. 34 | 35 | **Note**: This implementation's goal is to provide the necessary building blocks of Data2Vec so anyone can adapt it to their own use case with ease, so in order to make it easy to get hands on, some functionalities like mixed precision, distributed training, etc are not included to keep it as clean & simple as possible. If you only need to train a standard large scale Data2Vec model use the [official repo](https://github.com/pytorch/fairseq/tree/main/examples/data2vec). 36 | 37 | ## Train 38 | First things first, install the requirements: 39 | ```bash 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | #### **NLP** 44 | Train a Language Model based on RoBERTa (HuggingFace) on WikiText103 45 | 46 | Configure the related properties in `text/configs/roberta-pretraining.yaml` and run: 47 | ```bash 48 | python train.py --config text/configs/roberta-pretraining.yaml 49 | ``` 50 | 51 | #### **Vision** 52 | Run a Masked Image modeling training based on BEiT (HuggingFace) 53 | 54 | Pass the path to the image dataset in the config file at `vision/configs/beit-pretraining.yaml` under dataset > path > train/test and modify other properties as you desire and run the following: 55 | ```bash 56 | python train.py --config vision/configs/beit-pretraining.yaml 57 | ``` 58 | 59 | #### **Speech** 60 | Audio pretraining based on Wav2Vec2 (HuggingFace) on `timit` dataset. If you want to use other datasets like `librispeech` provide it in `audio/dataset.py` (some minor changes to the timit class would do the job because both are loaded from HuggingFace datasets) 61 | 62 | Configure other properties as you desire and run the following: 63 | ```bash 64 | python train.py --config audio/configs/wav2vec2-pretraining.yaml 65 | ``` 66 | 67 | ## Pre-trained Weights 68 | **Note:** The below models' weights were carefully ported from the original checkpoints in the `fairseq` version. 69 | 70 | #### **RoBERTa** 71 | Data2Vec model trained with RoBERTa as the encoder ([data2vec-roberta-base](https://huggingface.co/arxyzan/data2vec-roberta-base)) 72 | ```python 73 | from transformers import AutoModel, AutoConfig 74 | from transformers import RobertaModel 75 | 76 | checkpoint = 'arxyzan/data2vec-roberta-base' 77 | 78 | # Option 1: load using AutoModel 79 | data2vec_roberta = AutoModel.from_pretrained(checkpoint) 80 | 81 | # Option 2: load directly by RobertaModel 82 | data2vec_roberta = RobertaModel.from_pretrained(checkpoint) 83 | 84 | ``` 85 | 86 | #### **BEiT** 87 | Data2Vec model trained with BEiT as the encoder ([data2vec-beit-base](https://huggingface.co/arxyzan/data2vec-beit-base)) 88 | ```python 89 | from transformers import AutoModel, AutoConfig 90 | from transformers import BeitModel 91 | 92 | checkpoint = 'arxyzan/data2vec-beit-base' 93 | 94 | # Option 1: load using AutoModel 95 | data2vec_beit = AutoModel.from_pretrained(checkpoint) 96 | 97 | # Option 2: load directly by BeitModel 98 | data2vec_beit = BeitModel.from_pretrained(checkpoint) 99 | 100 | ``` 101 | 102 | #### **Wav2Vec2** 103 | Data2Vec model trained with Wav2Vec2 as the encoder ([data2vec-wav2vec2-base](https://huggingface.co/arxyzan/data2vec-wav2vec2-base)) 104 | ```python 105 | from transformers import AutoModel, AutoConfig 106 | from transformers import Wav2Vec2Model 107 | 108 | checkpoint = 'arxyzan/data2vec-wav2vec2-base' 109 | 110 | # Option 1: load using AutoModel 111 | data2vec_wav2vec2 = AutoModel.from_pretrained(checkpoint) 112 | 113 | # Option 2: load directly by Wav2Vec2Model 114 | data2vec_wav2vec2 = Wav2Vec2Model.from_pretrained(checkpoint) 115 | 116 | ``` 117 | 118 | ## Fine-tuning 119 | 120 | 1. Fine-tune using the checkpoints mentioned above: 121 | ```python 122 | # Text classification using Roberta model from HuggingFace 123 | from transformers import RobertaModel, RobertaForSequenceClassification 124 | 125 | checkpoint = 'arxyzan/data2vec-roberta-base' 126 | # this is exactly a roberta model but trained with data2vec 127 | data2vec_roberta = RobertaModel.from_pretrained(checkpoint) 128 | text_classifier = RobertaForSequenceClassification(data2vec_roberta.config) 129 | # assign `data2vec-roberta` weights to the roberta block of the classifier 130 | text_classifier.roberta = data2vec_roberta 131 | ... 132 | ``` 133 | 2. In case you trained a model using this codebase, you can fine-tune it by taking out the encoder's state dict from the checkpoint which gives you a HuggingFace model and you can fine-tune it for any downstream task as you'd normally do for HuggingFace models. 134 | ```python 135 | # load a checkpoint for finetuning 136 | from transformers import RobertaModel, RobertaConfig 137 | roberta = RobertaModel(RobertaConfig()) 138 | checkpoint = torch.load('path/to/data2vec.pt') 139 | roberta_state_dict = checkpoint['encoder'] 140 | # load roberta weights from the encoder part of the data2vec model 141 | encoder = roberta.load_state_dict(roberta_state_dict) 142 | 143 | # Now fine-tune a regular HuggingFace RoBERTa model 144 | ... 145 | ``` 146 | 147 | ## Contributions 148 | Any contribution regarding training, development (for Data2Vec2) and issues are welcome! 149 | -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/audio/__init__.py -------------------------------------------------------------------------------- /audio/configs/wav2vec2-pretraining.yaml: -------------------------------------------------------------------------------- 1 | modality: 'audio' 2 | device: 'cuda' 3 | model: 4 | encoder_checkpoint: 'facebook/wav2vec2-base-960h' 5 | embed_dim: 768 6 | average_top_k_layers: 8 7 | head_layers: 2 8 | num_classes: 1000 9 | normalize_targets: false 10 | ema_decay: 0.9998 11 | ema_end_decay: 0.9999 12 | ema_anneal_end_step: 300000 13 | dataset: 14 | path: 'timit_asr' 15 | optimizer: 16 | lr: 0.0005 17 | train: 18 | batch_size: 16 19 | eval_batch_size: 16 20 | num_epochs: 1000 21 | log_dir: 'audio/logs' 22 | save_ckpt_freq: 10 23 | checkpoints_dir: 'audio/checkpoints/wav2vec2-pretraining' 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /audio/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from datasets import load_dataset 4 | from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices 5 | from transformers import Wav2Vec2FeatureExtractor 6 | 7 | 8 | class TIMIT(Dataset): 9 | def __init__(self, cfg, split, **kwargs): 10 | super(TIMIT, self).__init__() 11 | path = cfg.dataset.path 12 | self.data = load_dataset(path, 'clean')[split] 13 | self.feature_extractor = Wav2Vec2FeatureExtractor(cfg.model.encoder_checkpoint) 14 | self.__dict__.update(kwargs) 15 | 16 | def __len__(self): 17 | return len(self.data) 18 | 19 | def __getitem__(self, index): 20 | x = self.data[index]['audio'] 21 | x = self.feature_extractor(x['array'], sampling_rate=x['sampling_rate'], padding=True, return_tensors='pt')['input_values'] 22 | return {'input_values': x[0]} 23 | 24 | 25 | class DataCollatorForWav2Vec2Pretraining: # copied from transformers/examples/pytorch/speech-pretraining 26 | """ 27 | Data collator that will dynamically pad the inputs received and prepare masked indices for self-supervised 28 | pretraining. Args: model (:class:`~transformers.Wav2Vec2ForPreTraining`): The Wav2Vec2 model used for 29 | pretraining. The data collator needs to have access to config and ``_get_feat_extract_output_lengths`` function 30 | for correct padding. feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): The processor used for 31 | processing the data. padding (:obj:`bool`, :obj:`str` or 32 | :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a 33 | strategy to pad the returned sequences (according to the model's padding side and padding index) among: * 34 | :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 35 | sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument 36 | :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. * 37 | :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 38 | lengths). max_length (:obj:`int`, `optional`): Maximum length of the ``input_values`` of the returned list and 39 | optionally padding length (see above). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence 40 | to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA 41 | hardware with compute capability >= 7.5 (Volta). 42 | """ 43 | 44 | def __init__(self, model, feature_extractor, padding, max_length=None, pad_to_multiple_of=None): 45 | self.model = model 46 | self.feature_extractor = feature_extractor 47 | self.padding = padding 48 | self.max_length = max_length 49 | self.pad_to_multiple_of = pad_to_multiple_of 50 | 51 | def __call__(self, features): 52 | # reformat list to dict and set to pytorch format 53 | batch = self.feature_extractor.pad( 54 | features, 55 | padding=self.padding, 56 | pad_to_multiple_of=self.pad_to_multiple_of, 57 | return_tensors="pt", 58 | ) 59 | 60 | device = batch["input_values"].device 61 | batch_size = batch["input_values"].shape[0] 62 | 63 | mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1]) 64 | # make sure masked sequence length is a Python scalar 65 | mask_indices_seq_length = int(mask_indices_seq_length) 66 | 67 | # make sure that no loss is computed on padded inputs 68 | if batch.get("attention_mask") is not None: 69 | # compute real output lengths according to convolution formula 70 | batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask( 71 | mask_indices_seq_length, batch["attention_mask"] 72 | ) 73 | 74 | features_shape = (batch_size, mask_indices_seq_length) 75 | 76 | # sample randomly masked indices 77 | mask_time_indices = _compute_mask_indices( 78 | features_shape, 79 | self.model.config.mask_time_prob, 80 | self.model.config.mask_time_length, 81 | attention_mask=batch.get("sub_attention_mask"), 82 | ) 83 | mask_time_indices = torch.tensor(mask_time_indices, dtype=torch.long, device=device) 84 | src = batch['input_values'] 85 | 86 | return src, mask_time_indices 87 | 88 | 89 | if __name__ == '__main__': 90 | from torch.utils.data import DataLoader 91 | from omegaconf import OmegaConf 92 | from transformers import Wav2Vec2Model, Wav2Vec2Config 93 | 94 | cfg = OmegaConf.load('configs/wav2vec2-pretraining.yaml') 95 | model = Wav2Vec2Model(Wav2Vec2Config()) 96 | feature_extractor = Wav2Vec2FeatureExtractor() 97 | dataset = TIMIT(cfg, 'train') 98 | collate_fn = DataCollatorForWav2Vec2Pretraining(model, feature_extractor, padding='longest') 99 | loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn) 100 | itr = iter(loader) 101 | sample = next(itr) 102 | print(sample) 103 | -------------------------------------------------------------------------------- /audio/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModel, AutoConfig 3 | import torch.nn as nn 4 | 5 | 6 | class Encoder(nn.Module): 7 | """ 8 | Encoder model using HuggingFace for audio i.e, Wav2Vec2 9 | 10 | Args: 11 | cfg: An omegaconf.DictConf instance containing all the configurations. 12 | **kwargs: extra args which are set as model properties 13 | """ 14 | 15 | def __init__(self, cfg, **kwargs): 16 | super(Encoder, self).__init__() 17 | self.cfg = cfg 18 | checkpoint = cfg.model.encoder_checkpoint 19 | model_config = AutoConfig.from_pretrained(checkpoint) 20 | self.encoder = AutoModel.from_config(model_config) 21 | self.__dict__.update(kwargs) 22 | 23 | def forward(self, inputs, mask=None, **kwargs): 24 | """ 25 | Forward inputs through the encoder and extract transformer/attention layers outputs 26 | 27 | Args: 28 | inputs: raw audio array 29 | mask: bool masked indices 30 | **kwargs: keyword args specific to the encoder's forward method 31 | 32 | Returns: 33 | A dictionary of the encoder outputs including transformer layers outputs and attentions outputs 34 | """ 35 | outputs = self.encoder(inputs, mask_time_indices=mask, output_hidden_states=True, 36 | output_attentions=True, **kwargs) 37 | encoder_states = outputs['hidden_states'][:-1] # encoder layers outputs separately 38 | encoder_out = outputs['hidden_states'][-1] # last encoder output (accumulated) 39 | attentions = outputs['attentions'] 40 | return { 41 | 'encoder_states': encoder_states, 42 | 'encoder_out': encoder_out, 43 | 'attentions': attentions 44 | } 45 | 46 | 47 | if __name__ == '__main__': 48 | from dataset import TIMIT, DataCollatorForWav2Vec2Pretraining 49 | from omegaconf import OmegaConf 50 | from transformers import Wav2Vec2FeatureExtractor 51 | from torch.utils.data import DataLoader 52 | 53 | cfg = OmegaConf.load('configs/wav2vec2-pretraining.yaml') 54 | feature_extractor = Wav2Vec2FeatureExtractor() 55 | model = Encoder(cfg) 56 | dataset = TIMIT(cfg, 'train') 57 | collate_fn = DataCollatorForWav2Vec2Pretraining(model.encoder, feature_extractor, padding='longest') 58 | loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn) 59 | itr = iter(loader) 60 | inputs, mask = next(itr) 61 | features = model(inputs, mask) 62 | print(features) 63 | -------------------------------------------------------------------------------- /audio/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | import torch.optim as optim 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | from omegaconf import DictConfig 9 | from tqdm import tqdm 10 | 11 | from audio.encoder import Encoder 12 | from audio.dataset import TIMIT, DataCollatorForWav2Vec2Pretraining 13 | from data2vec import Data2Vec 14 | from utils import AverageMeter, maybe_save_checkpoint 15 | 16 | 17 | class AudioTrainer: 18 | def __init__(self, cfg: DictConfig): 19 | self.cfg = cfg 20 | self.num_epochs = self.cfg.train.num_epochs 21 | self.device = self.cfg.device 22 | self.ckpt_dir = cfg.train.checkpoints_dir 23 | self.save_ckpt_freq = cfg.train.save_ckpt_freq 24 | # Model, Optim, Criterion 25 | self.encoder = Encoder(cfg=cfg) 26 | self.model = Data2Vec(encoder=self.encoder, cfg=cfg) 27 | self.model.to(self.device) 28 | self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr) 29 | self.criterion = nn.MSELoss(reduction='none') 30 | self.criterion.to(self.device) 31 | # Datasets & Data Loaders 32 | self.train_dataset = TIMIT(cfg, 'train') 33 | self.test_dataset = TIMIT(cfg, 'test') 34 | self.feature_extractor = self.train_dataset.feature_extractor 35 | self.data_collator = DataCollatorForWav2Vec2Pretraining(self.encoder.encoder, self.feature_extractor, 36 | padding='longest') 37 | self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size, 38 | collate_fn=self.data_collator) 39 | self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size, 40 | collate_fn=self.data_collator) 41 | # Tensorboard 42 | self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir) 43 | 44 | # Trackers 45 | self.loss_tracker = AverageMeter('loss') 46 | 47 | def train_step(self, batch): 48 | """ 49 | Train one batch of data and return loss. 50 | 51 | Args: 52 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len] 53 | 54 | Returns: 55 | Loss value 56 | """ 57 | src, mask = batch 58 | src, mask = src.to(self.device), mask.to(self.device) 59 | # src is not masked so can be used as trg. (src will be masked in the encoder forward) 60 | x, y = self.model(src, src, mask) 61 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).div(x.size(0)) 62 | loss.backward() 63 | self.optimizer.step() 64 | self.optimizer.zero_grad() 65 | 66 | return loss.item() 67 | 68 | def test_step(self, batch): 69 | """ 70 | Test a model on one batch of data and return loss. 71 | 72 | Args: 73 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len] 74 | 75 | Returns: 76 | Loss value 77 | """ 78 | src, mask = batch 79 | src, mask = src.to(self.device), mask.to(self.device) 80 | # src is not masked so can be used as trg. (src will be masked in the encoder forward) 81 | x, y = self.model(src, src, mask=mask) 82 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).div(x.size(0)) 83 | 84 | return loss.item() 85 | 86 | def train_epoch(self, epoch_num): 87 | """ 88 | Train the model for one epoch and verbose using the progress bar. 89 | 90 | Args: 91 | epoch_num: number of the current epoch 92 | 93 | Returns: 94 | The average loss through the whole epoch 95 | """ 96 | self.model.train() 97 | self.loss_tracker.reset() 98 | with tqdm(self.train_loader, unit="batch", desc=f'Epoch: {epoch_num}/{self.num_epochs} ', 99 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator: 100 | for batch in iterator: 101 | loss = self.train_step(batch) 102 | self.model.ema_step() 103 | self.loss_tracker.update(loss) 104 | avg_loss = self.loss_tracker.avg 105 | iterator.set_postfix(loss=avg_loss) 106 | 107 | return avg_loss 108 | 109 | def evaluate(self): 110 | """ 111 | Evaluate the model on the test set 112 | 113 | Returns: 114 | The average loss through the whole test dataset 115 | """ 116 | self.model.eval() 117 | self.loss_tracker.reset() 118 | with tqdm(self.test_loader, unit="batch", desc=f'Evaluating... ', 119 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator: 120 | with torch.no_grad(): 121 | for batch in iterator: 122 | loss = self.test_step(batch) 123 | self.loss_tracker.update(loss) 124 | avg_loss = self.loss_tracker.avg 125 | iterator.set_postfix(loss=avg_loss) 126 | 127 | return avg_loss 128 | 129 | def train(self): 130 | """ 131 | Train and evaluate the model on the datasets and save checkpoints and write summaries to TensorBoard. 132 | 133 | """ 134 | for epoch in range(1, self.num_epochs + 1): 135 | print() 136 | train_loss = self.train_epoch(epoch) 137 | val_loss = self.evaluate() 138 | 139 | # tensorboard 140 | self.tensorboard.add_scalar('train_loss', train_loss, epoch) 141 | self.tensorboard.add_scalar('val_loss', val_loss, epoch) 142 | 143 | maybe_save_checkpoint(self.model, self.optimizer, self.ckpt_dir, epoch, self.save_ckpt_freq) 144 | -------------------------------------------------------------------------------- /data2vec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/data2vec.png -------------------------------------------------------------------------------- /data2vec/__init__.py: -------------------------------------------------------------------------------- 1 | from .data2vec import Data2Vec 2 | from .ema import EMA 3 | -------------------------------------------------------------------------------- /data2vec/data2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .ema import EMA 5 | 6 | 7 | class Data2Vec(nn.Module): 8 | """ 9 | Data2Vec main module. 10 | 11 | Args: 12 | encoder (nn.Module): The encoder module like BEiT, ViT, etc. 13 | cfg (omegaconf.DictConfig): The config containing model properties 14 | """ 15 | MODALITIES = ['vision', 'text', 'audio'] 16 | 17 | def __init__(self, encoder, cfg, **kwargs): 18 | super(Data2Vec, self).__init__() 19 | self.modality = cfg.modality 20 | self.embed_dim = cfg.model.embed_dim 21 | self.encoder = encoder 22 | self.__dict__.update(kwargs) 23 | 24 | self.cfg = cfg 25 | self.ema = EMA(self.encoder, cfg) # EMA acts as the teacher 26 | self.regression_head = self._build_regression_head() 27 | 28 | self.cfg = cfg 29 | self.ema_decay = self.cfg.model.ema_decay 30 | self.ema_end_decay = self.cfg.model.ema_end_decay 31 | self.ema_anneal_end_step = self.cfg.model.ema_anneal_end_step 32 | 33 | def _build_regression_head(self): 34 | """ 35 | Construct the regression head consisting of linear and activation layers. 36 | 37 | Each modality might have its own regression block. 38 | 39 | Returns: 40 | A nn.Module layer or block of layers 41 | """ 42 | if self.modality == 'text': 43 | return nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim * 2), 44 | nn.GELU(), 45 | nn.Linear(self.embed_dim * 2, self.embed_dim)) 46 | 47 | if self.modality in ['audio', 'vision']: 48 | return nn.Linear(self.embed_dim, self.embed_dim) 49 | 50 | def ema_step(self): 51 | """ 52 | One EMA step for the offline model until the ending decay value is reached 53 | """ 54 | if self.ema_decay != self.ema_end_decay: 55 | if self.ema.num_updates >= self.ema_anneal_end_step: 56 | decay = self.ema_end_decay 57 | else: 58 | decay = self.ema.get_annealed_rate( 59 | self.ema_decay, 60 | self.ema_end_decay, 61 | self.ema.num_updates, 62 | self.ema_anneal_end_step, 63 | ) 64 | self.ema.decay = decay 65 | if self.ema.decay < 1: 66 | self.ema.step(self.encoder) 67 | 68 | def forward(self, src, trg=None, mask=None, **kwargs): 69 | """ 70 | Data2Vec forward method. 71 | 72 | Args: 73 | src: src tokens (masked inputs for training) 74 | trg: trg tokens (unmasked inputs for training but left as `None` otherwise) 75 | mask: bool masked indices, Note: if a modality requires the inputs to be masked before forward this param 76 | has no effect. (see the Encoder for each modality to see if it uses mask or not) 77 | 78 | Returns: 79 | Either encoder outputs or a tuple of encoder + EMA outputs 80 | 81 | """ 82 | # model forward in online mode (student) 83 | x = self.encoder(src, mask, **kwargs)['encoder_out'] # fetch the last layer outputs 84 | if trg is None: 85 | return x 86 | 87 | # model forward in offline mode (teacher) 88 | with torch.no_grad(): 89 | self.ema.model.eval() 90 | y = self.ema.model(trg, ~mask, **kwargs)['encoder_states'] # fetch the last transformer layers outputs 91 | y = y[-self.cfg.model.average_top_k_layers:] # take the last k transformer layers 92 | 93 | # Follow the same layer normalization procedure for text and vision 94 | if self.modality in ['vision', 'text']: 95 | y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y] 96 | y = sum(y) / len(y) 97 | if self.cfg.model.normalize_targets: 98 | y = F.layer_norm(y.float(), y.shape[-1:]) 99 | 100 | # Use instance normalization for audio 101 | elif self.modality == 'audio': 102 | y = [F.instance_norm(tl.float()) for tl in y] 103 | y = sum(y) / len(y) 104 | if self.cfg.model.normalize_targets: 105 | y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2) 106 | 107 | x = x[mask] 108 | y = y[mask] 109 | 110 | x = self.regression_head(x) 111 | 112 | return x, y 113 | -------------------------------------------------------------------------------- /data2vec/ema.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class EMA: 9 | """ 10 | Modified version of class fairseq.models.ema.EMAModule. 11 | 12 | Args: 13 | model (nn.Module): 14 | cfg (DictConfig): 15 | device (str): 16 | skip_keys (list): The keys to skip assigning averaged weights to. 17 | """ 18 | 19 | def __init__(self, model: nn.Module, cfg, skip_keys=None): 20 | self.model = self.deepcopy_model(model) 21 | self.model.requires_grad_(False) 22 | self.cfg = cfg 23 | self.device = cfg.device 24 | self.model.to(self.device) 25 | self.skip_keys = skip_keys or set() 26 | self.decay = self.cfg.model.ema_decay 27 | self.num_updates = 0 28 | 29 | @staticmethod 30 | def deepcopy_model(model): 31 | try: 32 | model = copy.deepcopy(model) 33 | return model 34 | except RuntimeError: 35 | tmp_path = 'tmp_model_for_ema_deepcopy.pt' 36 | torch.save(model, tmp_path) 37 | model = torch.load(tmp_path) 38 | os.remove(tmp_path) 39 | return model 40 | 41 | def step(self, new_model: nn.Module): 42 | """ 43 | One EMA step 44 | 45 | Args: 46 | new_model (nn.Module): Online model to fetch new weights from 47 | 48 | """ 49 | ema_state_dict = {} 50 | ema_params = self.model.state_dict() 51 | for key, param in new_model.state_dict().items(): 52 | ema_param = ema_params[key].float() 53 | if key in self.skip_keys: 54 | ema_param = param.to(dtype=ema_param.dtype).clone() 55 | else: 56 | ema_param.mul_(self.decay) 57 | ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - self.decay) 58 | ema_state_dict[key] = ema_param 59 | self.model.load_state_dict(ema_state_dict, strict=False) 60 | self.num_updates += 1 61 | 62 | def restore(self, model: nn.Module): 63 | """ 64 | Reassign weights from another model 65 | 66 | Args: 67 | model (nn.Module): model to load weights from. 68 | 69 | Returns: 70 | model with new weights 71 | """ 72 | d = self.model.state_dict() 73 | model.load_state_dict(d, strict=False) 74 | return model 75 | 76 | def state_dict(self): 77 | return self.model.state_dict() 78 | 79 | @staticmethod 80 | def get_annealed_rate(start, end, curr_step, total_steps): 81 | """ 82 | Calculate EMA annealing rate 83 | """ 84 | r = end - start 85 | pct_remaining = 1 - curr_step / total_steps 86 | return end - r * pct_remaining 87 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | transformers 4 | datasets 5 | tqdm 6 | omegaconf 7 | tensorboard 8 | timm 9 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/text/__init__.py -------------------------------------------------------------------------------- /text/configs/roberta-pretraining.yaml: -------------------------------------------------------------------------------- 1 | modality: 'text' 2 | device: 'cuda' 3 | train: 4 | batch_size: 32 5 | eval_batch_size: 32 6 | num_epochs: 20 7 | checkpoints_dir: 'text/checkpoints/roberta-pretrain' 8 | log_dir: 'text/logs/roberta-pretrain' 9 | save_ckpt_freq: 20 10 | criterion: 11 | loss_beta: 4 12 | optimizer: 13 | lr: 0.0002 14 | weight_decay: 0.01 15 | dataset: 16 | name: 'wikitext-103-v1' 17 | mlm_probability: 0.15 18 | valid_seq_lenghts: [12, 512] 19 | clean_dataset: false 20 | model: 21 | average_top_k_layers: 10 22 | embed_dim: 768 23 | num_classes: null 24 | encoder_checkpoint: 'roberta-base' 25 | normalize_targets: false 26 | ema_decay: 0.999 27 | ema_end_decay: 0.9999 28 | ema_anneal_end_step: 300000 29 | 30 | -------------------------------------------------------------------------------- /text/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from datasets import load_dataset 4 | from tqdm import tqdm 5 | 6 | 7 | class WikiText(Dataset): 8 | """ 9 | A Dataset instance for WikiText dataset loaded from HuggingFace datasets. 10 | 11 | Args: 12 | cfg (DictConfig): config object 13 | split: Split to load ['train', 'test'] 14 | tokenizer: A HuggingFace Tokenizer model like BPE 15 | **kwargs: extra args which are set as dataset properties 16 | """ 17 | 18 | def __init__(self, cfg, split, tokenizer, **kwargs): 19 | super(WikiText, self).__init__() 20 | self.cfg = cfg 21 | self.path = cfg.dataset.name 22 | self.mlm_probability = cfg.dataset.mlm_probability 23 | raw_data = load_dataset('wikitext', self.path)[split] 24 | self.data = self.clean_dataset(raw_data) if self.cfg.dataset.clean_dataset else raw_data 25 | self.tokenizer = tokenizer 26 | self.__dict__.update(kwargs) 27 | 28 | def clean_dataset(self, data): 29 | """ 30 | Cleanup dataset by removing invalid sized samples, etc. 31 | """ 32 | print('Cleaning dataset ...') 33 | min_seq_len, max_seq_len = self.cfg.data.valid_seq_lenghts 34 | texts = [] 35 | with tqdm(data, desc='Removing invalid sized inputs: ') as tbar: 36 | for i, x in enumerate(tbar): 37 | if len(x['text']) in range(min_seq_len, max_seq_len + 1): 38 | texts.append(x) 39 | return texts 40 | 41 | def __len__(self): 42 | return len(self.data) 43 | 44 | def __getitem__(self, index): 45 | """ 46 | Only return tokens from raw text with no additions e.g, padding, bos/eos, etc. 47 | Args: 48 | index: sample index to pick from dataset 49 | 50 | Returns: 51 | tokenized outputs 52 | """ 53 | raw_text = self.data[index]['text'] 54 | tokens = self.tokenizer(raw_text, return_attention_mask=False) 55 | return tokens 56 | 57 | def _mask_tokens(self, inputs, special_tokens_mask=None): 58 | """ 59 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Ported 60 | from `transformers.data.DataCollatorForLanguageModeling.torch_mask_tokens()` 61 | Args: 62 | inputs: batch of input tokens 63 | special_tokens_mask: 64 | 65 | Returns: 66 | a dict batch of masked and padded inputs/labels 67 | 68 | """ 69 | labels = inputs.clone() 70 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 71 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 72 | if special_tokens_mask is None: 73 | special_tokens_mask = [ 74 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in 75 | labels.tolist() 76 | ] 77 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 78 | else: 79 | special_tokens_mask = special_tokens_mask.bool() 80 | 81 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 82 | masked_indices = torch.bernoulli(probability_matrix).bool() 83 | labels[~masked_indices] = self.tokenizer.pad_token_id 84 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 85 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 86 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 87 | 88 | # 10% of the time, we replace masked input tokens with random word 89 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 90 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) 91 | inputs[indices_random] = random_words[indices_random] 92 | 93 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 94 | return inputs, labels, masked_indices 95 | 96 | def collate_fn(self, batch): 97 | """ 98 | Collate the batch of data using BERT masking strategy. carefully ported from 99 | transformers.data.DataCollatorForLanguageModeling 100 | Args: 101 | batch: batch of data 102 | 103 | Returns: 104 | same batch of data masked and padded 105 | """ 106 | 107 | batch = self.tokenizer.pad(batch, return_tensors="pt") 108 | # If special token mask has been preprocessed, pop it from the dict. 109 | special_tokens_mask = batch.pop("special_tokens_mask", None) 110 | src, trg, masked_indices = self._mask_tokens( 111 | batch["input_ids"], special_tokens_mask=special_tokens_mask 112 | ) 113 | return src, trg, masked_indices 114 | 115 | 116 | if __name__ == '__main__': 117 | from transformers.models.roberta import RobertaTokenizer 118 | from torch.utils.data import DataLoader 119 | from omegaconf import OmegaConf 120 | 121 | cfg = OmegaConf.load('configs/roberta-pretraining.yaml') 122 | dataset = WikiText(cfg, 'train', RobertaTokenizer.from_pretrained('roberta-base')) 123 | dataloader = DataLoader(dataset, batch_size=1, collate_fn=dataset.collate_fn) 124 | data_iter = iter(dataloader) 125 | batch = next(data_iter) 126 | print(batch) 127 | -------------------------------------------------------------------------------- /text/encoder.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoConfig, AutoTokenizer 2 | import torch.nn as nn 3 | 4 | 5 | class Encoder(nn.Module): 6 | """ 7 | Encoder model using HuggingFace for NLP 8 | 9 | To load your desired model specify model checkpoint under cfg.model.encoder_checkpoint 10 | 11 | Args: 12 | cfg: An omegaconf.DictConf instance containing all the configurations. 13 | **kwargs: extra args which are set as model properties 14 | """ 15 | 16 | def __init__(self, cfg, **kwargs): 17 | super(Encoder, self).__init__() 18 | self.cfg = cfg 19 | checkpoint = cfg.model.encoder_checkpoint 20 | model_config = AutoConfig.from_pretrained(checkpoint) 21 | self.encoder = AutoModel.from_config(model_config) 22 | self.__dict__.update(kwargs) 23 | 24 | def forward(self, inputs, mask=None, **kwargs): 25 | """ 26 | Forward inputs through the encoder and extract transformer/attention layers outputs 27 | 28 | Args: 29 | inputs: source tokens 30 | mask: bool masked indices 31 | kwargs: keyword args specific to the encoder's forward method 32 | 33 | Returns: 34 | A dictionary of the encoder outputs including transformer layers outputs and attentions outputs 35 | 36 | """ 37 | # Note: inputs are already masked for MLM so mask is not used 38 | outputs = self.encoder(inputs, output_hidden_states=True, output_attentions=True, **kwargs) 39 | encoder_states = outputs['hidden_states'][:-1] # encoder layers outputs separately 40 | encoder_out = outputs['hidden_states'][-1] # last encoder output (accumulated) 41 | attentions = outputs['attentions'] 42 | return { 43 | 'encoder_states': encoder_states, 44 | 'encoder_out': encoder_out, 45 | 'attentions': attentions 46 | } 47 | 48 | 49 | if __name__ == '__main__': 50 | from omegaconf import OmegaConf 51 | cfg = OmegaConf.load('configs/roberta-pretraining.yaml') 52 | tokenizer = AutoTokenizer.from_pretrained('roberta-base') 53 | model = Encoder(cfg) 54 | inputs = tokenizer("The capital of France is .", return_tensors="pt") 55 | outputs = model(inputs['input_ids']) 56 | print(outputs) 57 | -------------------------------------------------------------------------------- /text/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train Data2Vec for text. The encoder is loaded from huggingface specified in the config file. 3 | """ 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | import torch.optim as optim 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from omegaconf import DictConfig 12 | from tqdm import tqdm 13 | 14 | from text.encoder import Encoder, AutoTokenizer 15 | from text.dataset import WikiText 16 | from data2vec import Data2Vec 17 | from utils import AverageMeter, maybe_save_checkpoint 18 | 19 | 20 | class TextTrainer: 21 | """ 22 | A Trainer class to train NLP model on Data2Vec. 23 | 24 | Args: 25 | cfg (DictConfig): the config object containing all properties 26 | """ 27 | 28 | def __init__(self, cfg: DictConfig): 29 | self.cfg = cfg 30 | self.num_epochs = self.cfg.train.num_epochs 31 | self.device = self.cfg.device 32 | self.ckpt_dir = cfg.train.checkpoints_dir 33 | self.save_ckpt_freq = cfg.train.save_ckpt_freq 34 | # Model, Optim, Criterion 35 | self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint) 36 | self.encoder = Encoder(cfg=cfg) 37 | self.model = Data2Vec(encoder=self.encoder, cfg=cfg) 38 | self.model.to(self.device) 39 | self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr) 40 | self.criterion = nn.SmoothL1Loss(reduction='none', beta=cfg.criterion.loss_beta) 41 | self.criterion.to(self.device) 42 | # Datasets & Data Loaders 43 | self.train_dataset = WikiText(cfg, 'train', self.tokenizer) 44 | self.test_dataset = WikiText(cfg, 'test', self.tokenizer) 45 | self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size, 46 | collate_fn=self.train_dataset.collate_fn) 47 | self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size, 48 | collate_fn=self.test_dataset.collate_fn) 49 | # Tensorboard 50 | self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir) 51 | 52 | # Trackers 53 | self.loss_tracker = AverageMeter('loss') 54 | 55 | def train_step(self, batch): 56 | """ 57 | Train one batch of data and return loss. 58 | 59 | Args: 60 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len] 61 | 62 | Returns: 63 | Loss value 64 | """ 65 | src, trg, mask = batch 66 | src, trg, mask = src.to(self.device), trg.to(self.device), mask.to(self.device) 67 | 68 | x, y = self.model(src, trg, mask) 69 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0)) 70 | loss.backward() 71 | self.optimizer.step() 72 | self.optimizer.zero_grad() 73 | 74 | return loss.item() 75 | 76 | def test_step(self, batch): 77 | """ 78 | Test a model on one batch of data and return loss. 79 | 80 | Args: 81 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len] 82 | 83 | Returns: 84 | Loss value 85 | """ 86 | src = batch['input_ids'].to(self.device) 87 | trg = batch['labels'].to(self.device) 88 | mask = batch['masked_indices'].to(self.device) 89 | 90 | x, y = self.model(src, trg, mask=mask) 91 | loss = self.criterion(x, y) 92 | 93 | return loss.item() 94 | 95 | def train_epoch(self, epoch_num): 96 | """ 97 | Train the model for one epoch and verbose using the progress bar. 98 | 99 | Args: 100 | epoch_num: number of the current epoch 101 | 102 | Returns: 103 | The average loss through the whole epoch 104 | """ 105 | self.model.train() 106 | self.loss_tracker.reset() 107 | with tqdm(self.train_loader, unit="batch", desc=f'Epoch: {epoch_num}/{self.num_epochs} ', 108 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator: 109 | for batch in iterator: 110 | loss = self.train_step(batch) 111 | self.model.ema_step() 112 | self.loss_tracker.update(loss) 113 | avg_loss = self.loss_tracker.avg 114 | iterator.set_postfix(loss=avg_loss) 115 | 116 | return avg_loss 117 | 118 | def evaluate(self): 119 | """ 120 | Evaluate the model on the test set 121 | 122 | Returns: 123 | The average loss through the whole test dataset 124 | """ 125 | self.model.eval() 126 | self.loss_tracker.reset() 127 | with tqdm(self.test_loader, unit="batch", desc=f'Evaluating... ', 128 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator: 129 | with torch.no_grad(): 130 | for batch in iterator: 131 | loss = self.test_step(batch) 132 | self.loss_tracker.update(loss) 133 | avg_loss = self.loss_tracker.avg 134 | iterator.set_postfix(loss=avg_loss) 135 | 136 | return avg_loss 137 | 138 | def train(self): 139 | """ 140 | Train and evaluate the model on the datasets and save checkpoints and write summaries to TensorBoard. 141 | 142 | """ 143 | for epoch in range(1, self.num_epochs + 1): 144 | print() 145 | train_loss = self.train_epoch(epoch) 146 | val_loss = self.evaluate() 147 | 148 | # tensorboard 149 | self.tensorboard.add_scalar('train_loss', train_loss, epoch) 150 | self.tensorboard.add_scalar('val_loss', val_loss, epoch) 151 | 152 | # save checkpoint 153 | maybe_save_checkpoint(self.model, self.optimizer, self.ckpt_dir, epoch, self.save_ckpt_freq) 154 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import omegaconf 2 | 3 | from text.trainer import TextTrainer 4 | from vision.trainer import VisionTrainer 5 | from audio.trainer import AudioTrainer 6 | 7 | if __name__ == '__main__': 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--config', type=str, help='path to yaml config file') 12 | args = parser.parse_args() 13 | 14 | cfg_path = args.config 15 | cfg = omegaconf.OmegaConf.load(cfg_path) 16 | modality = cfg.modality 17 | 18 | trainers_dict = { 19 | 'text': TextTrainer, 20 | 'vision': VisionTrainer, 21 | 'audio': AudioTrainer 22 | } 23 | assert modality in trainers_dict.keys(), f'invalid modality `{cfg.modality}`, expected {list(trainers_dict.keys())}' 24 | trainer = trainers_dict[modality](cfg) 25 | trainer.train() 26 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self, name, fmt=':f'): 9 | self.name = name 10 | self.fmt = fmt 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def __str__(self): 26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 27 | return fmtstr.format(**self.__dict__) 28 | 29 | 30 | def maybe_save_checkpoint(model, optimizer, path, epoch_num, save_freq): 31 | """ 32 | Save a checkpoint specific to Data2Vec 33 | Args: 34 | model: a nn.Module instance 35 | optimizer 36 | path: path to save checkpoint to 37 | epoch_num: current epoch number 38 | save_freq: save frequency based on epoch number 39 | 40 | """ 41 | if not os.path.exists(path): 42 | os.makedirs(path) 43 | path = os.path.join(path, f'{epoch_num}.pt') 44 | if epoch_num % save_freq == 0: 45 | checkpoint = {'data2vec': model.state_dict(), 46 | 'encoder': model.encoder.encoder.state_dict(), 47 | 'optimizer': optimizer.state_dict()} 48 | torch.save(checkpoint, path) 49 | print(f'Saved checkpoint to `{path}`') 50 | -------------------------------------------------------------------------------- /vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/vision/__init__.py -------------------------------------------------------------------------------- /vision/configs/beit-pretraining.yaml: -------------------------------------------------------------------------------- 1 | modality: 'vision' 2 | device: 'cuda' 3 | model: 4 | encoder_checkpoint: 'microsoft/beit-base-patch16-224-pt22k' 5 | embed_dim: 768 6 | average_top_k_layers: 6 7 | head_layers: 2 8 | num_classes: 1000 9 | normalize_targets: false 10 | ema_decay: 0.9998 11 | ema_end_decay: 0.9999 12 | ema_anneal_end_step: 300000 13 | dataset: 14 | path: 15 | train: 'vision/dummy_data' 16 | test: 'vision/dummy_data' 17 | input_size: 224 18 | interpolation: 'bicubic' 19 | patch_size: 16 20 | num_patches: 14 21 | num_mask_patches: 120 22 | max_mask_patches_per_block: 196 23 | min_mask_patches_per_block: 16 24 | imagenet_default_mean_and_std: false 25 | train: 26 | num_epochs: 800 27 | batch_size: 16 28 | eval_batch_size: 16 29 | shuffle: true 30 | save_ckpt_freq: 20 31 | checkpoints_dir: 'vision/checkpoints/beit-pretrain' 32 | log_dir: 'vision/logs/beit-pretrain' 33 | criterion: 34 | loss_beta: 2 35 | optimizer: 36 | lr: 2e-3 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /vision/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import ImageFolder 3 | 4 | from .transforms import MIMTransform 5 | 6 | 7 | class MIMPretrainingDataset(ImageFolder): 8 | """ 9 | Dataset for Masked Image Modeling derived from BEiT. 10 | 11 | Given an image, the common transforms and augmentations are applied like random crop, color jitter, etc., then the 12 | image is split into 14x14 patches and some patches are masked randomly. The input image to the model is the masked 13 | image and the target image is the full image 14 | 15 | Args: 16 | cfg (DictConfig): config containing model, dataset, etc. properties 17 | split: either `train` or `test` 18 | **kwargs: extra args which are set as dataset properties 19 | 20 | """ 21 | 22 | def __init__(self, cfg, split, **kwargs): 23 | super(MIMPretrainingDataset, self).__init__(root=cfg.dataset.path[split]) 24 | self.transform = MIMTransform(cfg.dataset) 25 | self.input_size = cfg.dataset.input_size 26 | self.device = cfg.device 27 | self.__dict__.update(kwargs) 28 | 29 | def __getitem__(self, index): 30 | """ 31 | Load image from disk, transform the image (augmentation and randomly mask some patches) 32 | 33 | Args: 34 | index: index to the image 35 | 36 | Returns: 37 | a tuple of masked image, unmasked target image and bool masked positions 38 | """ 39 | path, target = self.samples[index] 40 | image = self.loader(path) 41 | image, mask = self.transform(image) 42 | mask = mask.reshape(1, 14, 14, 1, 1) 43 | image = image.reshape(-1, 14, 14, 16, 16) 44 | masked_image = (image * mask).reshape(-1, self.input_size, self.input_size) 45 | target_image = image.reshape(-1, self.input_size, self.input_size) 46 | return masked_image, target_image, mask.flatten().bool() 47 | 48 | 49 | if __name__ == '__main__': 50 | from omegaconf import OmegaConf 51 | from torch.utils.data import DataLoader 52 | 53 | cfg = OmegaConf.load('configs/beit-pretraining.yaml') 54 | cfg.dataset.path = 'dummy_data' 55 | dataset = MIMPretrainingDataset(cfg, split='train') 56 | loader = DataLoader(dataset, batch_size=4) 57 | src, trg = next(iter(loader)) 58 | print(src) 59 | -------------------------------------------------------------------------------- /vision/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModel, AutoConfig 3 | import torch.nn as nn 4 | 5 | 6 | class Encoder(nn.Module): 7 | """ 8 | Encoder model using HuggingFace Transformers for vision e.g, BeiT 9 | 10 | Args: 11 | cfg: An omegaconf.DictConf instance containing all the configurations. 12 | **kwargs: extra args which are set as dataset properties 13 | """ 14 | 15 | def __init__(self, cfg, **kwargs): 16 | super(Encoder, self).__init__() 17 | self.cfg = cfg 18 | checkpoint = cfg.model.encoder_checkpoint 19 | model_config = AutoConfig.from_pretrained(checkpoint) 20 | self.encoder = AutoModel.from_config(model_config) 21 | self.vocab_size = model_config.vocab_size 22 | self.mask_token = self.encoder.embeddings.mask_token 23 | self.__dict__.update(kwargs) 24 | 25 | def forward(self, inputs, mask=None, **kwargs): 26 | """ 27 | Forward inputs through the encoder and extract transformer/attention layers outputs 28 | 29 | Args: 30 | inputs: input pixels with shape [batch_size, channels, height, width] 31 | mask: bool masked indices 32 | **kwargs: keyword args specific to the encoder's forward method 33 | 34 | Returns: 35 | A dictionary of the encoder outputs including transformer layers outputs and attentions outputs 36 | """ 37 | # Note: inputs are already masked for MIM so mask is not used 38 | outputs = self.encoder(pixel_values=inputs, output_hidden_states=True, output_attentions=True, **kwargs) 39 | encoder_states = outputs['hidden_states'][:-1] # encoder layers outputs separately 40 | encoder_out = outputs['hidden_states'][-1] # last encoder output (accumulated) 41 | attentions = outputs['attentions'] 42 | 43 | # remove cls token from outputs 44 | encoder_states = [output[:, 1:, :] for output in encoder_states] 45 | encoder_out = encoder_out[:, 1:, :] 46 | attentions = [output[:, 1:, 1:] for output in attentions] 47 | 48 | return { 49 | 'encoder_states': encoder_states, 50 | 'encoder_out': encoder_out, 51 | 'attentions': attentions 52 | } 53 | 54 | 55 | if __name__ == '__main__': 56 | from omegaconf import OmegaConf 57 | import numpy as np 58 | from PIL import Image 59 | import requests 60 | from torchvision import transforms as T 61 | 62 | cfg = OmegaConf.load('configs/beit-pretraining.yaml') 63 | model = Encoder(cfg) 64 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 65 | image = Image.open(requests.get(url, stream=True).raw) 66 | image = T.Compose([T.Resize((224, 224)), 67 | T.ToTensor(), 68 | T.Normalize(mean=.5, std=.5)])(image).unsqueeze(0) 69 | outputs = model(image) 70 | print(outputs) 71 | -------------------------------------------------------------------------------- /vision/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | from torch.utils.tensorboard import SummaryWriter 8 | from tqdm import tqdm 9 | 10 | from vision.encoder import Encoder 11 | from vision.dataset import MIMPretrainingDataset 12 | from data2vec import Data2Vec 13 | from utils import AverageMeter, maybe_save_checkpoint 14 | 15 | 16 | class VisionTrainer: 17 | def __init__(self, cfg): 18 | self.cfg = cfg 19 | self.device = cfg.device 20 | self.num_epochs = cfg.train.num_epochs 21 | self.ckpt_dir = cfg.train.checkpoints_dir 22 | self.save_ckpt_freq = cfg.train.save_ckpt_freq 23 | # Model, Criterion, Optimizer 24 | self.encoder = Encoder(cfg=cfg) 25 | self.model = Data2Vec(encoder=self.encoder, cfg=cfg) 26 | self.model.to(self.device) 27 | self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr) 28 | self.criterion = nn.SmoothL1Loss(reduction='none', beta=cfg.criterion.loss_beta) 29 | self.criterion.to(self.device) 30 | # Datasets & Data Loaders 31 | self.train_dataset = MIMPretrainingDataset(cfg, split='train') 32 | self.test_dataset = MIMPretrainingDataset(cfg, split='test') 33 | self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size, shuffle=cfg.train.shuffle) 34 | self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size, shuffle=cfg.train.shuffle) 35 | 36 | # Tensorboard 37 | self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir) 38 | 39 | # Trackers 40 | self.loss_tracker = AverageMeter('loss') 41 | 42 | def train_step(self, batch): 43 | """ 44 | Train one batch of data 45 | Args: 46 | batch: A batch of data, src, trg of shape [N, C, H, W] and mask of shape [N, num_total_patches] 47 | 48 | Returns: 49 | Loss value 50 | """ 51 | src, trg, mask = batch 52 | src = src.to(self.device) 53 | trg = trg.to(self.device) 54 | mask = mask.to(self.device) 55 | 56 | x, y = self.model(src, trg, mask) 57 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0)) 58 | loss.backward() 59 | self.optimizer.step() 60 | self.optimizer.zero_grad() 61 | 62 | return loss.item() 63 | 64 | def test_step(self, batch): 65 | """ 66 | Evaluate one batch of data 67 | Args: 68 | batch: A batch of data, src, trg of shape [N, C, H, W] and mask of shape [N, num_total_patches] 69 | 70 | Returns: 71 | Loss value 72 | """ 73 | src, trg, mask = batch 74 | src = src.to(self.device) 75 | trg = trg.to(self.device) 76 | mask = mask.to(self.device) 77 | 78 | x, y = self.model(src, trg, mask) 79 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0)) 80 | 81 | return loss.item() 82 | 83 | def train_epoch(self, epoch_num): 84 | """ 85 | Train the model for one epoch 86 | Args: 87 | epoch_num: number of the current epoch 88 | 89 | Returns: 90 | Average loss through the whole epoch 91 | """ 92 | self.model.train() 93 | self.loss_tracker.reset() 94 | with tqdm(self.train_loader, unit="batch", desc=f'Epoch: {epoch_num}/{self.num_epochs} ', 95 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator: 96 | for batch in iterator: 97 | loss = self.train_step(batch) 98 | self.model.ema_step() 99 | self.loss_tracker.update(loss) 100 | avg_loss = self.loss_tracker.avg 101 | iterator.set_postfix(loss=avg_loss) 102 | 103 | return avg_loss 104 | 105 | def evaluate(self): 106 | """ 107 | Evaluate the model on the test data 108 | Returns: 109 | Average loss on the test set 110 | """ 111 | self.model.eval() 112 | self.loss_tracker.reset() 113 | with tqdm(self.test_loader, unit="batch", desc=f'Evaluating... ', 114 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator: 115 | with torch.no_grad(): 116 | for batch in iterator: 117 | loss = self.test_step(batch) 118 | self.loss_tracker.update(loss) 119 | avg_loss = self.loss_tracker.avg 120 | iterator.set_postfix(loss=avg_loss) 121 | 122 | return avg_loss 123 | 124 | def train(self): 125 | """ 126 | Train and evaluate the model on the datasets and save checkpoints and write summaries to TensorBoard. 127 | 128 | """ 129 | for epoch in range(1, self.num_epochs + 1): 130 | print() 131 | train_loss = self.train_epoch(epoch) 132 | val_loss = self.evaluate() 133 | 134 | # tensorboard 135 | self.tensorboard.add_scalar('train_loss', train_loss, epoch) 136 | self.tensorboard.add_scalar('val_loss', val_loss, epoch) 137 | 138 | # save checkpoint 139 | maybe_save_checkpoint(self.model, self.optimizer, self.ckpt_dir, epoch, self.save_ckpt_freq) 140 | -------------------------------------------------------------------------------- /vision/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms and masking strategies for Self Supervised Pretraining of Vision models. 3 | All codes copied from below repos: 4 | https://github.com/microsoft/unilm/tree/master/beit 5 | https://github.com/rwightman/pytorch-image-models/tree/master/timm 6 | https://github.com/facebookresearch/deit 7 | """ 8 | import warnings 9 | import random 10 | import math 11 | from PIL import Image 12 | import numpy as np 13 | import torch 14 | from torchvision import transforms 15 | import torchvision.transforms.functional as F 16 | from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, 17 | IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD) 18 | from timm.data.transforms import RandomResizedCropAndInterpolation 19 | 20 | _pil_interpolation_to_str = { 21 | Image.NEAREST: 'PIL.Image.NEAREST', 22 | Image.BILINEAR: 'PIL.Image.BILINEAR', 23 | Image.BICUBIC: 'PIL.Image.BICUBIC', 24 | Image.LANCZOS: 'PIL.Image.LANCZOS', 25 | Image.HAMMING: 'PIL.Image.HAMMING', 26 | Image.BOX: 'PIL.Image.BOX', 27 | } 28 | 29 | 30 | def _pil_interp(method): 31 | if method == 'bicubic': 32 | return Image.BICUBIC 33 | elif method == 'lanczos': 34 | return Image.LANCZOS 35 | elif method == 'hamming': 36 | return Image.HAMMING 37 | else: 38 | # default bilinear, do we want to allow nearest? 39 | return Image.BILINEAR 40 | 41 | 42 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 43 | 44 | 45 | class MIMTransform(object): 46 | """ 47 | Masked Image Modeling transforms based on BEiT. copied from https://github.com/microsoft/unilm/tree/master/beit 48 | """ 49 | 50 | def __init__(self, cfg): 51 | imagenet_default_mean_and_std = cfg.imagenet_default_mean_and_std 52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 54 | patch_size = cfg.patch_size 55 | num_patches = cfg.num_patches 56 | assert patch_size * num_patches == cfg.input_size 57 | 58 | self.common_transform = transforms.Compose([ 59 | transforms.ColorJitter(0.4, 0.4, 0.4), 60 | transforms.RandomHorizontalFlip(p=0.5), 61 | RandomResizedCropAndInterpolation(size=cfg.input_size, interpolation=cfg.interpolation), 62 | transforms.ToTensor(), 63 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)) 64 | ]) 65 | 66 | self.masked_position_generator = MaskingGenerator( 67 | cfg.num_patches, num_masking_patches=cfg.num_mask_patches, 68 | max_num_patches=cfg.max_mask_patches_per_block, 69 | min_num_patches=cfg.min_mask_patches_per_block, 70 | ) 71 | 72 | def __call__(self, image): 73 | return self.common_transform(image), self.masked_position_generator() 74 | 75 | 76 | class MaskingGenerator: 77 | def __init__( 78 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, 79 | min_aspect=0.3, max_aspect=None): 80 | if not isinstance(input_size, tuple): 81 | input_size = (input_size,) * 2 82 | self.height, self.width = input_size 83 | 84 | self.num_masking_patches = num_masking_patches 85 | 86 | self.min_num_patches = min_num_patches 87 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 88 | 89 | max_aspect = max_aspect or 1 / min_aspect 90 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 91 | 92 | def __repr__(self): 93 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 94 | self.height, self.width, self.min_num_patches, self.max_num_patches, 95 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 96 | return repr_str 97 | 98 | def get_shape(self): 99 | return self.height, self.width 100 | 101 | def _mask(self, mask, max_mask_patches): 102 | delta = 0 103 | for attempt in range(10): 104 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 105 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 106 | h = int(round(math.sqrt(target_area * aspect_ratio))) 107 | w = int(round(math.sqrt(target_area / aspect_ratio))) 108 | if w < self.width and h < self.height: 109 | top = random.randint(0, self.height - h) 110 | left = random.randint(0, self.width - w) 111 | 112 | num_masked = mask[top: top + h, left: left + w].sum() 113 | # Overlap 114 | if 0 < h * w - num_masked <= max_mask_patches: 115 | for i in range(top, top + h): 116 | for j in range(left, left + w): 117 | if mask[i, j] == 0: 118 | mask[i, j] = 1 119 | delta += 1 120 | 121 | if delta > 0: 122 | break 123 | return delta 124 | 125 | def __call__(self): 126 | mask = torch.zeros(self.get_shape(), dtype=torch.int) 127 | mask_count = 0 128 | while mask_count < self.num_masking_patches: 129 | max_mask_patches = self.num_masking_patches - mask_count 130 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 131 | 132 | delta = self._mask(mask, max_mask_patches) 133 | if delta == 0: 134 | break 135 | else: 136 | mask_count += delta 137 | 138 | return mask 139 | --------------------------------------------------------------------------------