├── .gitignore
├── LICENSE
├── README.md
├── audio
├── __init__.py
├── configs
│ └── wav2vec2-pretraining.yaml
├── dataset.py
├── encoder.py
└── trainer.py
├── data2vec.png
├── data2vec
├── __init__.py
├── data2vec.py
└── ema.py
├── requirements.txt
├── text
├── __init__.py
├── configs
│ └── roberta-pretraining.yaml
├── dataset.py
├── encoder.py
└── trainer.py
├── train.py
├── utils.py
└── vision
├── __init__.py
├── configs
└── beit-pretraining.yaml
├── dataset.py
├── encoder.py
├── trainer.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .idea/
131 | logs/
132 | weights/
133 | checkpoints/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Aryan Shekarlaban
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # data2vec-pytorch
2 | ##### PyTorch implementation of "[data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555)" from Meta AI (FAIR)
3 | #### Disclaimer: This repo's goal is to make data2vec easier to understand hence it's not recommended to use for actual model pretraining but instead you'd better use the official version in fairseq or the ones provided on HuggingFace.
4 | Data2Vec is the first high-performance self-supervised algorithm that learns the same way in multiple modalities, including speech, vision and text.
5 | Most machines learn exclusively from labeled data. However, through self-supervised learning, machines are able to learn about the world just by observing it
6 | and then figuring out the structure of images, speech or text. This is a more scalable and efficient approach for machines to tackle new complex tasks,
7 | such as understanding text for more spoken languages.
8 |
9 | 
10 |
11 | In summary, the method is as follows:
12 | 1. The encoder extracts features from the masked inputs. These features are outputs of every transformer/linear layer.
13 | 2. The teacher which is an EMA instance of the encoder (in eval model), extracts features from the unmasked inputs.
14 | 3. Optional normalizations are applied to the layers/outputs of the teacher.
15 | 4. Encoder outputs are regressed by a projection block/layer.
16 | 5. The loss is calculated from encoder outputs and teacher outputs.
17 |
18 | You can read the paper for more detail.
19 |
20 | ## Implementation
21 | Data2Vec is already implemented in [fairseq](https://github.com/pytorch/fairseq/tree/main/examples/data2vec) in which for all modalities there is a seperate implementation (text, vision, audio). According to the paper:
22 | > Our primary is to design a single learning mechanism for different modalities.
23 | Despite the unified learning regime, we still use modality-specific features extractors and masking strategies.
24 | This makes sense given the vastly different nature of the input data.
25 |
26 | This implementation differs in the fact that a single Data2Vec model is provided powered by a custom encoder (implemented using PyTorch + HuggingFace Transformers) and tries to unify the whole concept in a single module.
27 | The key concept is that there must be modality-specific feature extractions and masking strategies.
28 |
29 | - **Masking:** For each modality, the Dataset instance must return the masked source, the target and the mask tensor.
30 |
31 | - **Feature Extraction:** Features are the outputs from the transformer/attention layers. So the forward method must return outputs from all Encoder blocks of the transformer model. HuggingFace Transformers/Fairseq models return transformer layers outputs separately out of the box.
32 |
33 | This implementation uses HuggingFace Transformers models as encoders for Data2Vec which you can inspect in the `encoder.py` files for each modality. Although, you can provide your own encoder model. Just make sure that your encoder must be Transformer-based according to the paper and outputs from every encoder layer must be provided.
34 |
35 | **Note**: This implementation's goal is to provide the necessary building blocks of Data2Vec so anyone can adapt it to their own use case with ease, so in order to make it easy to get hands on, some functionalities like mixed precision, distributed training, etc are not included to keep it as clean & simple as possible. If you only need to train a standard large scale Data2Vec model use the [official repo](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
36 |
37 | ## Train
38 | First things first, install the requirements:
39 | ```bash
40 | pip install -r requirements.txt
41 | ```
42 |
43 | #### **NLP**
44 | Train a Language Model based on RoBERTa (HuggingFace) on WikiText103
45 |
46 | Configure the related properties in `text/configs/roberta-pretraining.yaml` and run:
47 | ```bash
48 | python train.py --config text/configs/roberta-pretraining.yaml
49 | ```
50 |
51 | #### **Vision**
52 | Run a Masked Image modeling training based on BEiT (HuggingFace)
53 |
54 | Pass the path to the image dataset in the config file at `vision/configs/beit-pretraining.yaml` under dataset > path > train/test and modify other properties as you desire and run the following:
55 | ```bash
56 | python train.py --config vision/configs/beit-pretraining.yaml
57 | ```
58 |
59 | #### **Speech**
60 | Audio pretraining based on Wav2Vec2 (HuggingFace) on `timit` dataset. If you want to use other datasets like `librispeech` provide it in `audio/dataset.py` (some minor changes to the timit class would do the job because both are loaded from HuggingFace datasets)
61 |
62 | Configure other properties as you desire and run the following:
63 | ```bash
64 | python train.py --config audio/configs/wav2vec2-pretraining.yaml
65 | ```
66 |
67 | ## Pre-trained Weights
68 | **Note:** The below models' weights were carefully ported from the original checkpoints in the `fairseq` version.
69 |
70 | #### **RoBERTa**
71 | Data2Vec model trained with RoBERTa as the encoder ([data2vec-roberta-base](https://huggingface.co/arxyzan/data2vec-roberta-base))
72 | ```python
73 | from transformers import AutoModel, AutoConfig
74 | from transformers import RobertaModel
75 |
76 | checkpoint = 'arxyzan/data2vec-roberta-base'
77 |
78 | # Option 1: load using AutoModel
79 | data2vec_roberta = AutoModel.from_pretrained(checkpoint)
80 |
81 | # Option 2: load directly by RobertaModel
82 | data2vec_roberta = RobertaModel.from_pretrained(checkpoint)
83 |
84 | ```
85 |
86 | #### **BEiT**
87 | Data2Vec model trained with BEiT as the encoder ([data2vec-beit-base](https://huggingface.co/arxyzan/data2vec-beit-base))
88 | ```python
89 | from transformers import AutoModel, AutoConfig
90 | from transformers import BeitModel
91 |
92 | checkpoint = 'arxyzan/data2vec-beit-base'
93 |
94 | # Option 1: load using AutoModel
95 | data2vec_beit = AutoModel.from_pretrained(checkpoint)
96 |
97 | # Option 2: load directly by BeitModel
98 | data2vec_beit = BeitModel.from_pretrained(checkpoint)
99 |
100 | ```
101 |
102 | #### **Wav2Vec2**
103 | Data2Vec model trained with Wav2Vec2 as the encoder ([data2vec-wav2vec2-base](https://huggingface.co/arxyzan/data2vec-wav2vec2-base))
104 | ```python
105 | from transformers import AutoModel, AutoConfig
106 | from transformers import Wav2Vec2Model
107 |
108 | checkpoint = 'arxyzan/data2vec-wav2vec2-base'
109 |
110 | # Option 1: load using AutoModel
111 | data2vec_wav2vec2 = AutoModel.from_pretrained(checkpoint)
112 |
113 | # Option 2: load directly by Wav2Vec2Model
114 | data2vec_wav2vec2 = Wav2Vec2Model.from_pretrained(checkpoint)
115 |
116 | ```
117 |
118 | ## Fine-tuning
119 |
120 | 1. Fine-tune using the checkpoints mentioned above:
121 | ```python
122 | # Text classification using Roberta model from HuggingFace
123 | from transformers import RobertaModel, RobertaForSequenceClassification
124 |
125 | checkpoint = 'arxyzan/data2vec-roberta-base'
126 | # this is exactly a roberta model but trained with data2vec
127 | data2vec_roberta = RobertaModel.from_pretrained(checkpoint)
128 | text_classifier = RobertaForSequenceClassification(data2vec_roberta.config)
129 | # assign `data2vec-roberta` weights to the roberta block of the classifier
130 | text_classifier.roberta = data2vec_roberta
131 | ...
132 | ```
133 | 2. In case you trained a model using this codebase, you can fine-tune it by taking out the encoder's state dict from the checkpoint which gives you a HuggingFace model and you can fine-tune it for any downstream task as you'd normally do for HuggingFace models.
134 | ```python
135 | # load a checkpoint for finetuning
136 | from transformers import RobertaModel, RobertaConfig
137 | roberta = RobertaModel(RobertaConfig())
138 | checkpoint = torch.load('path/to/data2vec.pt')
139 | roberta_state_dict = checkpoint['encoder']
140 | # load roberta weights from the encoder part of the data2vec model
141 | encoder = roberta.load_state_dict(roberta_state_dict)
142 |
143 | # Now fine-tune a regular HuggingFace RoBERTa model
144 | ...
145 | ```
146 |
147 | ## Contributions
148 | Any contribution regarding training, development (for Data2Vec2) and issues are welcome!
149 |
--------------------------------------------------------------------------------
/audio/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/audio/__init__.py
--------------------------------------------------------------------------------
/audio/configs/wav2vec2-pretraining.yaml:
--------------------------------------------------------------------------------
1 | modality: 'audio'
2 | device: 'cuda'
3 | model:
4 | encoder_checkpoint: 'facebook/wav2vec2-base-960h'
5 | embed_dim: 768
6 | average_top_k_layers: 8
7 | head_layers: 2
8 | num_classes: 1000
9 | normalize_targets: false
10 | ema_decay: 0.9998
11 | ema_end_decay: 0.9999
12 | ema_anneal_end_step: 300000
13 | dataset:
14 | path: 'timit_asr'
15 | optimizer:
16 | lr: 0.0005
17 | train:
18 | batch_size: 16
19 | eval_batch_size: 16
20 | num_epochs: 1000
21 | log_dir: 'audio/logs'
22 | save_ckpt_freq: 10
23 | checkpoints_dir: 'audio/checkpoints/wav2vec2-pretraining'
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/audio/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from datasets import load_dataset
4 | from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
5 | from transformers import Wav2Vec2FeatureExtractor
6 |
7 |
8 | class TIMIT(Dataset):
9 | def __init__(self, cfg, split, **kwargs):
10 | super(TIMIT, self).__init__()
11 | path = cfg.dataset.path
12 | self.data = load_dataset(path, 'clean')[split]
13 | self.feature_extractor = Wav2Vec2FeatureExtractor(cfg.model.encoder_checkpoint)
14 | self.__dict__.update(kwargs)
15 |
16 | def __len__(self):
17 | return len(self.data)
18 |
19 | def __getitem__(self, index):
20 | x = self.data[index]['audio']
21 | x = self.feature_extractor(x['array'], sampling_rate=x['sampling_rate'], padding=True, return_tensors='pt')['input_values']
22 | return {'input_values': x[0]}
23 |
24 |
25 | class DataCollatorForWav2Vec2Pretraining: # copied from transformers/examples/pytorch/speech-pretraining
26 | """
27 | Data collator that will dynamically pad the inputs received and prepare masked indices for self-supervised
28 | pretraining. Args: model (:class:`~transformers.Wav2Vec2ForPreTraining`): The Wav2Vec2 model used for
29 | pretraining. The data collator needs to have access to config and ``_get_feat_extract_output_lengths`` function
30 | for correct padding. feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`): The processor used for
31 | processing the data. padding (:obj:`bool`, :obj:`str` or
32 | :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): Select a
33 | strategy to pad the returned sequences (according to the model's padding side and padding index) among: *
34 | :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
35 | sequence if provided). * :obj:`'max_length'`: Pad to a maximum length specified with the argument
36 | :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not provided. *
37 | :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
38 | lengths). max_length (:obj:`int`, `optional`): Maximum length of the ``input_values`` of the returned list and
39 | optionally padding length (see above). pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence
40 | to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA
41 | hardware with compute capability >= 7.5 (Volta).
42 | """
43 |
44 | def __init__(self, model, feature_extractor, padding, max_length=None, pad_to_multiple_of=None):
45 | self.model = model
46 | self.feature_extractor = feature_extractor
47 | self.padding = padding
48 | self.max_length = max_length
49 | self.pad_to_multiple_of = pad_to_multiple_of
50 |
51 | def __call__(self, features):
52 | # reformat list to dict and set to pytorch format
53 | batch = self.feature_extractor.pad(
54 | features,
55 | padding=self.padding,
56 | pad_to_multiple_of=self.pad_to_multiple_of,
57 | return_tensors="pt",
58 | )
59 |
60 | device = batch["input_values"].device
61 | batch_size = batch["input_values"].shape[0]
62 |
63 | mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
64 | # make sure masked sequence length is a Python scalar
65 | mask_indices_seq_length = int(mask_indices_seq_length)
66 |
67 | # make sure that no loss is computed on padded inputs
68 | if batch.get("attention_mask") is not None:
69 | # compute real output lengths according to convolution formula
70 | batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask(
71 | mask_indices_seq_length, batch["attention_mask"]
72 | )
73 |
74 | features_shape = (batch_size, mask_indices_seq_length)
75 |
76 | # sample randomly masked indices
77 | mask_time_indices = _compute_mask_indices(
78 | features_shape,
79 | self.model.config.mask_time_prob,
80 | self.model.config.mask_time_length,
81 | attention_mask=batch.get("sub_attention_mask"),
82 | )
83 | mask_time_indices = torch.tensor(mask_time_indices, dtype=torch.long, device=device)
84 | src = batch['input_values']
85 |
86 | return src, mask_time_indices
87 |
88 |
89 | if __name__ == '__main__':
90 | from torch.utils.data import DataLoader
91 | from omegaconf import OmegaConf
92 | from transformers import Wav2Vec2Model, Wav2Vec2Config
93 |
94 | cfg = OmegaConf.load('configs/wav2vec2-pretraining.yaml')
95 | model = Wav2Vec2Model(Wav2Vec2Config())
96 | feature_extractor = Wav2Vec2FeatureExtractor()
97 | dataset = TIMIT(cfg, 'train')
98 | collate_fn = DataCollatorForWav2Vec2Pretraining(model, feature_extractor, padding='longest')
99 | loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
100 | itr = iter(loader)
101 | sample = next(itr)
102 | print(sample)
103 |
--------------------------------------------------------------------------------
/audio/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModel, AutoConfig
3 | import torch.nn as nn
4 |
5 |
6 | class Encoder(nn.Module):
7 | """
8 | Encoder model using HuggingFace for audio i.e, Wav2Vec2
9 |
10 | Args:
11 | cfg: An omegaconf.DictConf instance containing all the configurations.
12 | **kwargs: extra args which are set as model properties
13 | """
14 |
15 | def __init__(self, cfg, **kwargs):
16 | super(Encoder, self).__init__()
17 | self.cfg = cfg
18 | checkpoint = cfg.model.encoder_checkpoint
19 | model_config = AutoConfig.from_pretrained(checkpoint)
20 | self.encoder = AutoModel.from_config(model_config)
21 | self.__dict__.update(kwargs)
22 |
23 | def forward(self, inputs, mask=None, **kwargs):
24 | """
25 | Forward inputs through the encoder and extract transformer/attention layers outputs
26 |
27 | Args:
28 | inputs: raw audio array
29 | mask: bool masked indices
30 | **kwargs: keyword args specific to the encoder's forward method
31 |
32 | Returns:
33 | A dictionary of the encoder outputs including transformer layers outputs and attentions outputs
34 | """
35 | outputs = self.encoder(inputs, mask_time_indices=mask, output_hidden_states=True,
36 | output_attentions=True, **kwargs)
37 | encoder_states = outputs['hidden_states'][:-1] # encoder layers outputs separately
38 | encoder_out = outputs['hidden_states'][-1] # last encoder output (accumulated)
39 | attentions = outputs['attentions']
40 | return {
41 | 'encoder_states': encoder_states,
42 | 'encoder_out': encoder_out,
43 | 'attentions': attentions
44 | }
45 |
46 |
47 | if __name__ == '__main__':
48 | from dataset import TIMIT, DataCollatorForWav2Vec2Pretraining
49 | from omegaconf import OmegaConf
50 | from transformers import Wav2Vec2FeatureExtractor
51 | from torch.utils.data import DataLoader
52 |
53 | cfg = OmegaConf.load('configs/wav2vec2-pretraining.yaml')
54 | feature_extractor = Wav2Vec2FeatureExtractor()
55 | model = Encoder(cfg)
56 | dataset = TIMIT(cfg, 'train')
57 | collate_fn = DataCollatorForWav2Vec2Pretraining(model.encoder, feature_extractor, padding='longest')
58 | loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
59 | itr = iter(loader)
60 | inputs, mask = next(itr)
61 | features = model(inputs, mask)
62 | print(features)
63 |
--------------------------------------------------------------------------------
/audio/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.utils.data import DataLoader
5 | import torch.optim as optim
6 | from torch.utils.tensorboard import SummaryWriter
7 |
8 | from omegaconf import DictConfig
9 | from tqdm import tqdm
10 |
11 | from audio.encoder import Encoder
12 | from audio.dataset import TIMIT, DataCollatorForWav2Vec2Pretraining
13 | from data2vec import Data2Vec
14 | from utils import AverageMeter, maybe_save_checkpoint
15 |
16 |
17 | class AudioTrainer:
18 | def __init__(self, cfg: DictConfig):
19 | self.cfg = cfg
20 | self.num_epochs = self.cfg.train.num_epochs
21 | self.device = self.cfg.device
22 | self.ckpt_dir = cfg.train.checkpoints_dir
23 | self.save_ckpt_freq = cfg.train.save_ckpt_freq
24 | # Model, Optim, Criterion
25 | self.encoder = Encoder(cfg=cfg)
26 | self.model = Data2Vec(encoder=self.encoder, cfg=cfg)
27 | self.model.to(self.device)
28 | self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr)
29 | self.criterion = nn.MSELoss(reduction='none')
30 | self.criterion.to(self.device)
31 | # Datasets & Data Loaders
32 | self.train_dataset = TIMIT(cfg, 'train')
33 | self.test_dataset = TIMIT(cfg, 'test')
34 | self.feature_extractor = self.train_dataset.feature_extractor
35 | self.data_collator = DataCollatorForWav2Vec2Pretraining(self.encoder.encoder, self.feature_extractor,
36 | padding='longest')
37 | self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size,
38 | collate_fn=self.data_collator)
39 | self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size,
40 | collate_fn=self.data_collator)
41 | # Tensorboard
42 | self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir)
43 |
44 | # Trackers
45 | self.loss_tracker = AverageMeter('loss')
46 |
47 | def train_step(self, batch):
48 | """
49 | Train one batch of data and return loss.
50 |
51 | Args:
52 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len]
53 |
54 | Returns:
55 | Loss value
56 | """
57 | src, mask = batch
58 | src, mask = src.to(self.device), mask.to(self.device)
59 | # src is not masked so can be used as trg. (src will be masked in the encoder forward)
60 | x, y = self.model(src, src, mask)
61 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).div(x.size(0))
62 | loss.backward()
63 | self.optimizer.step()
64 | self.optimizer.zero_grad()
65 |
66 | return loss.item()
67 |
68 | def test_step(self, batch):
69 | """
70 | Test a model on one batch of data and return loss.
71 |
72 | Args:
73 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len]
74 |
75 | Returns:
76 | Loss value
77 | """
78 | src, mask = batch
79 | src, mask = src.to(self.device), mask.to(self.device)
80 | # src is not masked so can be used as trg. (src will be masked in the encoder forward)
81 | x, y = self.model(src, src, mask=mask)
82 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).div(x.size(0))
83 |
84 | return loss.item()
85 |
86 | def train_epoch(self, epoch_num):
87 | """
88 | Train the model for one epoch and verbose using the progress bar.
89 |
90 | Args:
91 | epoch_num: number of the current epoch
92 |
93 | Returns:
94 | The average loss through the whole epoch
95 | """
96 | self.model.train()
97 | self.loss_tracker.reset()
98 | with tqdm(self.train_loader, unit="batch", desc=f'Epoch: {epoch_num}/{self.num_epochs} ',
99 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
100 | for batch in iterator:
101 | loss = self.train_step(batch)
102 | self.model.ema_step()
103 | self.loss_tracker.update(loss)
104 | avg_loss = self.loss_tracker.avg
105 | iterator.set_postfix(loss=avg_loss)
106 |
107 | return avg_loss
108 |
109 | def evaluate(self):
110 | """
111 | Evaluate the model on the test set
112 |
113 | Returns:
114 | The average loss through the whole test dataset
115 | """
116 | self.model.eval()
117 | self.loss_tracker.reset()
118 | with tqdm(self.test_loader, unit="batch", desc=f'Evaluating... ',
119 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
120 | with torch.no_grad():
121 | for batch in iterator:
122 | loss = self.test_step(batch)
123 | self.loss_tracker.update(loss)
124 | avg_loss = self.loss_tracker.avg
125 | iterator.set_postfix(loss=avg_loss)
126 |
127 | return avg_loss
128 |
129 | def train(self):
130 | """
131 | Train and evaluate the model on the datasets and save checkpoints and write summaries to TensorBoard.
132 |
133 | """
134 | for epoch in range(1, self.num_epochs + 1):
135 | print()
136 | train_loss = self.train_epoch(epoch)
137 | val_loss = self.evaluate()
138 |
139 | # tensorboard
140 | self.tensorboard.add_scalar('train_loss', train_loss, epoch)
141 | self.tensorboard.add_scalar('val_loss', val_loss, epoch)
142 |
143 | maybe_save_checkpoint(self.model, self.optimizer, self.ckpt_dir, epoch, self.save_ckpt_freq)
144 |
--------------------------------------------------------------------------------
/data2vec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/data2vec.png
--------------------------------------------------------------------------------
/data2vec/__init__.py:
--------------------------------------------------------------------------------
1 | from .data2vec import Data2Vec
2 | from .ema import EMA
3 |
--------------------------------------------------------------------------------
/data2vec/data2vec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .ema import EMA
5 |
6 |
7 | class Data2Vec(nn.Module):
8 | """
9 | Data2Vec main module.
10 |
11 | Args:
12 | encoder (nn.Module): The encoder module like BEiT, ViT, etc.
13 | cfg (omegaconf.DictConfig): The config containing model properties
14 | """
15 | MODALITIES = ['vision', 'text', 'audio']
16 |
17 | def __init__(self, encoder, cfg, **kwargs):
18 | super(Data2Vec, self).__init__()
19 | self.modality = cfg.modality
20 | self.embed_dim = cfg.model.embed_dim
21 | self.encoder = encoder
22 | self.__dict__.update(kwargs)
23 |
24 | self.cfg = cfg
25 | self.ema = EMA(self.encoder, cfg) # EMA acts as the teacher
26 | self.regression_head = self._build_regression_head()
27 |
28 | self.cfg = cfg
29 | self.ema_decay = self.cfg.model.ema_decay
30 | self.ema_end_decay = self.cfg.model.ema_end_decay
31 | self.ema_anneal_end_step = self.cfg.model.ema_anneal_end_step
32 |
33 | def _build_regression_head(self):
34 | """
35 | Construct the regression head consisting of linear and activation layers.
36 |
37 | Each modality might have its own regression block.
38 |
39 | Returns:
40 | A nn.Module layer or block of layers
41 | """
42 | if self.modality == 'text':
43 | return nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim * 2),
44 | nn.GELU(),
45 | nn.Linear(self.embed_dim * 2, self.embed_dim))
46 |
47 | if self.modality in ['audio', 'vision']:
48 | return nn.Linear(self.embed_dim, self.embed_dim)
49 |
50 | def ema_step(self):
51 | """
52 | One EMA step for the offline model until the ending decay value is reached
53 | """
54 | if self.ema_decay != self.ema_end_decay:
55 | if self.ema.num_updates >= self.ema_anneal_end_step:
56 | decay = self.ema_end_decay
57 | else:
58 | decay = self.ema.get_annealed_rate(
59 | self.ema_decay,
60 | self.ema_end_decay,
61 | self.ema.num_updates,
62 | self.ema_anneal_end_step,
63 | )
64 | self.ema.decay = decay
65 | if self.ema.decay < 1:
66 | self.ema.step(self.encoder)
67 |
68 | def forward(self, src, trg=None, mask=None, **kwargs):
69 | """
70 | Data2Vec forward method.
71 |
72 | Args:
73 | src: src tokens (masked inputs for training)
74 | trg: trg tokens (unmasked inputs for training but left as `None` otherwise)
75 | mask: bool masked indices, Note: if a modality requires the inputs to be masked before forward this param
76 | has no effect. (see the Encoder for each modality to see if it uses mask or not)
77 |
78 | Returns:
79 | Either encoder outputs or a tuple of encoder + EMA outputs
80 |
81 | """
82 | # model forward in online mode (student)
83 | x = self.encoder(src, mask, **kwargs)['encoder_out'] # fetch the last layer outputs
84 | if trg is None:
85 | return x
86 |
87 | # model forward in offline mode (teacher)
88 | with torch.no_grad():
89 | self.ema.model.eval()
90 | y = self.ema.model(trg, ~mask, **kwargs)['encoder_states'] # fetch the last transformer layers outputs
91 | y = y[-self.cfg.model.average_top_k_layers:] # take the last k transformer layers
92 |
93 | # Follow the same layer normalization procedure for text and vision
94 | if self.modality in ['vision', 'text']:
95 | y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
96 | y = sum(y) / len(y)
97 | if self.cfg.model.normalize_targets:
98 | y = F.layer_norm(y.float(), y.shape[-1:])
99 |
100 | # Use instance normalization for audio
101 | elif self.modality == 'audio':
102 | y = [F.instance_norm(tl.float()) for tl in y]
103 | y = sum(y) / len(y)
104 | if self.cfg.model.normalize_targets:
105 | y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
106 |
107 | x = x[mask]
108 | y = y[mask]
109 |
110 | x = self.regression_head(x)
111 |
112 | return x, y
113 |
--------------------------------------------------------------------------------
/data2vec/ema.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class EMA:
9 | """
10 | Modified version of class fairseq.models.ema.EMAModule.
11 |
12 | Args:
13 | model (nn.Module):
14 | cfg (DictConfig):
15 | device (str):
16 | skip_keys (list): The keys to skip assigning averaged weights to.
17 | """
18 |
19 | def __init__(self, model: nn.Module, cfg, skip_keys=None):
20 | self.model = self.deepcopy_model(model)
21 | self.model.requires_grad_(False)
22 | self.cfg = cfg
23 | self.device = cfg.device
24 | self.model.to(self.device)
25 | self.skip_keys = skip_keys or set()
26 | self.decay = self.cfg.model.ema_decay
27 | self.num_updates = 0
28 |
29 | @staticmethod
30 | def deepcopy_model(model):
31 | try:
32 | model = copy.deepcopy(model)
33 | return model
34 | except RuntimeError:
35 | tmp_path = 'tmp_model_for_ema_deepcopy.pt'
36 | torch.save(model, tmp_path)
37 | model = torch.load(tmp_path)
38 | os.remove(tmp_path)
39 | return model
40 |
41 | def step(self, new_model: nn.Module):
42 | """
43 | One EMA step
44 |
45 | Args:
46 | new_model (nn.Module): Online model to fetch new weights from
47 |
48 | """
49 | ema_state_dict = {}
50 | ema_params = self.model.state_dict()
51 | for key, param in new_model.state_dict().items():
52 | ema_param = ema_params[key].float()
53 | if key in self.skip_keys:
54 | ema_param = param.to(dtype=ema_param.dtype).clone()
55 | else:
56 | ema_param.mul_(self.decay)
57 | ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - self.decay)
58 | ema_state_dict[key] = ema_param
59 | self.model.load_state_dict(ema_state_dict, strict=False)
60 | self.num_updates += 1
61 |
62 | def restore(self, model: nn.Module):
63 | """
64 | Reassign weights from another model
65 |
66 | Args:
67 | model (nn.Module): model to load weights from.
68 |
69 | Returns:
70 | model with new weights
71 | """
72 | d = self.model.state_dict()
73 | model.load_state_dict(d, strict=False)
74 | return model
75 |
76 | def state_dict(self):
77 | return self.model.state_dict()
78 |
79 | @staticmethod
80 | def get_annealed_rate(start, end, curr_step, total_steps):
81 | """
82 | Calculate EMA annealing rate
83 | """
84 | r = end - start
85 | pct_remaining = 1 - curr_step / total_steps
86 | return end - r * pct_remaining
87 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | transformers
4 | datasets
5 | tqdm
6 | omegaconf
7 | tensorboard
8 | timm
9 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/text/__init__.py
--------------------------------------------------------------------------------
/text/configs/roberta-pretraining.yaml:
--------------------------------------------------------------------------------
1 | modality: 'text'
2 | device: 'cuda'
3 | train:
4 | batch_size: 32
5 | eval_batch_size: 32
6 | num_epochs: 20
7 | checkpoints_dir: 'text/checkpoints/roberta-pretrain'
8 | log_dir: 'text/logs/roberta-pretrain'
9 | save_ckpt_freq: 20
10 | criterion:
11 | loss_beta: 4
12 | optimizer:
13 | lr: 0.0002
14 | weight_decay: 0.01
15 | dataset:
16 | name: 'wikitext-103-v1'
17 | mlm_probability: 0.15
18 | valid_seq_lenghts: [12, 512]
19 | clean_dataset: false
20 | model:
21 | average_top_k_layers: 10
22 | embed_dim: 768
23 | num_classes: null
24 | encoder_checkpoint: 'roberta-base'
25 | normalize_targets: false
26 | ema_decay: 0.999
27 | ema_end_decay: 0.9999
28 | ema_anneal_end_step: 300000
29 |
30 |
--------------------------------------------------------------------------------
/text/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from datasets import load_dataset
4 | from tqdm import tqdm
5 |
6 |
7 | class WikiText(Dataset):
8 | """
9 | A Dataset instance for WikiText dataset loaded from HuggingFace datasets.
10 |
11 | Args:
12 | cfg (DictConfig): config object
13 | split: Split to load ['train', 'test']
14 | tokenizer: A HuggingFace Tokenizer model like BPE
15 | **kwargs: extra args which are set as dataset properties
16 | """
17 |
18 | def __init__(self, cfg, split, tokenizer, **kwargs):
19 | super(WikiText, self).__init__()
20 | self.cfg = cfg
21 | self.path = cfg.dataset.name
22 | self.mlm_probability = cfg.dataset.mlm_probability
23 | raw_data = load_dataset('wikitext', self.path)[split]
24 | self.data = self.clean_dataset(raw_data) if self.cfg.dataset.clean_dataset else raw_data
25 | self.tokenizer = tokenizer
26 | self.__dict__.update(kwargs)
27 |
28 | def clean_dataset(self, data):
29 | """
30 | Cleanup dataset by removing invalid sized samples, etc.
31 | """
32 | print('Cleaning dataset ...')
33 | min_seq_len, max_seq_len = self.cfg.data.valid_seq_lenghts
34 | texts = []
35 | with tqdm(data, desc='Removing invalid sized inputs: ') as tbar:
36 | for i, x in enumerate(tbar):
37 | if len(x['text']) in range(min_seq_len, max_seq_len + 1):
38 | texts.append(x)
39 | return texts
40 |
41 | def __len__(self):
42 | return len(self.data)
43 |
44 | def __getitem__(self, index):
45 | """
46 | Only return tokens from raw text with no additions e.g, padding, bos/eos, etc.
47 | Args:
48 | index: sample index to pick from dataset
49 |
50 | Returns:
51 | tokenized outputs
52 | """
53 | raw_text = self.data[index]['text']
54 | tokens = self.tokenizer(raw_text, return_attention_mask=False)
55 | return tokens
56 |
57 | def _mask_tokens(self, inputs, special_tokens_mask=None):
58 | """
59 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Ported
60 | from `transformers.data.DataCollatorForLanguageModeling.torch_mask_tokens()`
61 | Args:
62 | inputs: batch of input tokens
63 | special_tokens_mask:
64 |
65 | Returns:
66 | a dict batch of masked and padded inputs/labels
67 |
68 | """
69 | labels = inputs.clone()
70 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
71 | probability_matrix = torch.full(labels.shape, self.mlm_probability)
72 | if special_tokens_mask is None:
73 | special_tokens_mask = [
74 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
75 | labels.tolist()
76 | ]
77 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
78 | else:
79 | special_tokens_mask = special_tokens_mask.bool()
80 |
81 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
82 | masked_indices = torch.bernoulli(probability_matrix).bool()
83 | labels[~masked_indices] = self.tokenizer.pad_token_id
84 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
85 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
86 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
87 |
88 | # 10% of the time, we replace masked input tokens with random word
89 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
90 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
91 | inputs[indices_random] = random_words[indices_random]
92 |
93 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
94 | return inputs, labels, masked_indices
95 |
96 | def collate_fn(self, batch):
97 | """
98 | Collate the batch of data using BERT masking strategy. carefully ported from
99 | transformers.data.DataCollatorForLanguageModeling
100 | Args:
101 | batch: batch of data
102 |
103 | Returns:
104 | same batch of data masked and padded
105 | """
106 |
107 | batch = self.tokenizer.pad(batch, return_tensors="pt")
108 | # If special token mask has been preprocessed, pop it from the dict.
109 | special_tokens_mask = batch.pop("special_tokens_mask", None)
110 | src, trg, masked_indices = self._mask_tokens(
111 | batch["input_ids"], special_tokens_mask=special_tokens_mask
112 | )
113 | return src, trg, masked_indices
114 |
115 |
116 | if __name__ == '__main__':
117 | from transformers.models.roberta import RobertaTokenizer
118 | from torch.utils.data import DataLoader
119 | from omegaconf import OmegaConf
120 |
121 | cfg = OmegaConf.load('configs/roberta-pretraining.yaml')
122 | dataset = WikiText(cfg, 'train', RobertaTokenizer.from_pretrained('roberta-base'))
123 | dataloader = DataLoader(dataset, batch_size=1, collate_fn=dataset.collate_fn)
124 | data_iter = iter(dataloader)
125 | batch = next(data_iter)
126 | print(batch)
127 |
--------------------------------------------------------------------------------
/text/encoder.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoConfig, AutoTokenizer
2 | import torch.nn as nn
3 |
4 |
5 | class Encoder(nn.Module):
6 | """
7 | Encoder model using HuggingFace for NLP
8 |
9 | To load your desired model specify model checkpoint under cfg.model.encoder_checkpoint
10 |
11 | Args:
12 | cfg: An omegaconf.DictConf instance containing all the configurations.
13 | **kwargs: extra args which are set as model properties
14 | """
15 |
16 | def __init__(self, cfg, **kwargs):
17 | super(Encoder, self).__init__()
18 | self.cfg = cfg
19 | checkpoint = cfg.model.encoder_checkpoint
20 | model_config = AutoConfig.from_pretrained(checkpoint)
21 | self.encoder = AutoModel.from_config(model_config)
22 | self.__dict__.update(kwargs)
23 |
24 | def forward(self, inputs, mask=None, **kwargs):
25 | """
26 | Forward inputs through the encoder and extract transformer/attention layers outputs
27 |
28 | Args:
29 | inputs: source tokens
30 | mask: bool masked indices
31 | kwargs: keyword args specific to the encoder's forward method
32 |
33 | Returns:
34 | A dictionary of the encoder outputs including transformer layers outputs and attentions outputs
35 |
36 | """
37 | # Note: inputs are already masked for MLM so mask is not used
38 | outputs = self.encoder(inputs, output_hidden_states=True, output_attentions=True, **kwargs)
39 | encoder_states = outputs['hidden_states'][:-1] # encoder layers outputs separately
40 | encoder_out = outputs['hidden_states'][-1] # last encoder output (accumulated)
41 | attentions = outputs['attentions']
42 | return {
43 | 'encoder_states': encoder_states,
44 | 'encoder_out': encoder_out,
45 | 'attentions': attentions
46 | }
47 |
48 |
49 | if __name__ == '__main__':
50 | from omegaconf import OmegaConf
51 | cfg = OmegaConf.load('configs/roberta-pretraining.yaml')
52 | tokenizer = AutoTokenizer.from_pretrained('roberta-base')
53 | model = Encoder(cfg)
54 | inputs = tokenizer("The capital of France is .", return_tensors="pt")
55 | outputs = model(inputs['input_ids'])
56 | print(outputs)
57 |
--------------------------------------------------------------------------------
/text/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Train Data2Vec for text. The encoder is loaded from huggingface specified in the config file.
3 | """
4 | import os
5 | import torch
6 | import torch.nn as nn
7 | from torch.utils.data import DataLoader
8 | import torch.optim as optim
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from omegaconf import DictConfig
12 | from tqdm import tqdm
13 |
14 | from text.encoder import Encoder, AutoTokenizer
15 | from text.dataset import WikiText
16 | from data2vec import Data2Vec
17 | from utils import AverageMeter, maybe_save_checkpoint
18 |
19 |
20 | class TextTrainer:
21 | """
22 | A Trainer class to train NLP model on Data2Vec.
23 |
24 | Args:
25 | cfg (DictConfig): the config object containing all properties
26 | """
27 |
28 | def __init__(self, cfg: DictConfig):
29 | self.cfg = cfg
30 | self.num_epochs = self.cfg.train.num_epochs
31 | self.device = self.cfg.device
32 | self.ckpt_dir = cfg.train.checkpoints_dir
33 | self.save_ckpt_freq = cfg.train.save_ckpt_freq
34 | # Model, Optim, Criterion
35 | self.tokenizer = AutoTokenizer.from_pretrained(cfg.model.encoder_checkpoint)
36 | self.encoder = Encoder(cfg=cfg)
37 | self.model = Data2Vec(encoder=self.encoder, cfg=cfg)
38 | self.model.to(self.device)
39 | self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr)
40 | self.criterion = nn.SmoothL1Loss(reduction='none', beta=cfg.criterion.loss_beta)
41 | self.criterion.to(self.device)
42 | # Datasets & Data Loaders
43 | self.train_dataset = WikiText(cfg, 'train', self.tokenizer)
44 | self.test_dataset = WikiText(cfg, 'test', self.tokenizer)
45 | self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size,
46 | collate_fn=self.train_dataset.collate_fn)
47 | self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size,
48 | collate_fn=self.test_dataset.collate_fn)
49 | # Tensorboard
50 | self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir)
51 |
52 | # Trackers
53 | self.loss_tracker = AverageMeter('loss')
54 |
55 | def train_step(self, batch):
56 | """
57 | Train one batch of data and return loss.
58 |
59 | Args:
60 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len]
61 |
62 | Returns:
63 | Loss value
64 | """
65 | src, trg, mask = batch
66 | src, trg, mask = src.to(self.device), trg.to(self.device), mask.to(self.device)
67 |
68 | x, y = self.model(src, trg, mask)
69 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0))
70 | loss.backward()
71 | self.optimizer.step()
72 | self.optimizer.zero_grad()
73 |
74 | return loss.item()
75 |
76 | def test_step(self, batch):
77 | """
78 | Test a model on one batch of data and return loss.
79 |
80 | Args:
81 | batch: A batch of data, inputs, labels and mask with shape [batch_size, seq_len]
82 |
83 | Returns:
84 | Loss value
85 | """
86 | src = batch['input_ids'].to(self.device)
87 | trg = batch['labels'].to(self.device)
88 | mask = batch['masked_indices'].to(self.device)
89 |
90 | x, y = self.model(src, trg, mask=mask)
91 | loss = self.criterion(x, y)
92 |
93 | return loss.item()
94 |
95 | def train_epoch(self, epoch_num):
96 | """
97 | Train the model for one epoch and verbose using the progress bar.
98 |
99 | Args:
100 | epoch_num: number of the current epoch
101 |
102 | Returns:
103 | The average loss through the whole epoch
104 | """
105 | self.model.train()
106 | self.loss_tracker.reset()
107 | with tqdm(self.train_loader, unit="batch", desc=f'Epoch: {epoch_num}/{self.num_epochs} ',
108 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
109 | for batch in iterator:
110 | loss = self.train_step(batch)
111 | self.model.ema_step()
112 | self.loss_tracker.update(loss)
113 | avg_loss = self.loss_tracker.avg
114 | iterator.set_postfix(loss=avg_loss)
115 |
116 | return avg_loss
117 |
118 | def evaluate(self):
119 | """
120 | Evaluate the model on the test set
121 |
122 | Returns:
123 | The average loss through the whole test dataset
124 | """
125 | self.model.eval()
126 | self.loss_tracker.reset()
127 | with tqdm(self.test_loader, unit="batch", desc=f'Evaluating... ',
128 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
129 | with torch.no_grad():
130 | for batch in iterator:
131 | loss = self.test_step(batch)
132 | self.loss_tracker.update(loss)
133 | avg_loss = self.loss_tracker.avg
134 | iterator.set_postfix(loss=avg_loss)
135 |
136 | return avg_loss
137 |
138 | def train(self):
139 | """
140 | Train and evaluate the model on the datasets and save checkpoints and write summaries to TensorBoard.
141 |
142 | """
143 | for epoch in range(1, self.num_epochs + 1):
144 | print()
145 | train_loss = self.train_epoch(epoch)
146 | val_loss = self.evaluate()
147 |
148 | # tensorboard
149 | self.tensorboard.add_scalar('train_loss', train_loss, epoch)
150 | self.tensorboard.add_scalar('val_loss', val_loss, epoch)
151 |
152 | # save checkpoint
153 | maybe_save_checkpoint(self.model, self.optimizer, self.ckpt_dir, epoch, self.save_ckpt_freq)
154 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import omegaconf
2 |
3 | from text.trainer import TextTrainer
4 | from vision.trainer import VisionTrainer
5 | from audio.trainer import AudioTrainer
6 |
7 | if __name__ == '__main__':
8 | import argparse
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--config', type=str, help='path to yaml config file')
12 | args = parser.parse_args()
13 |
14 | cfg_path = args.config
15 | cfg = omegaconf.OmegaConf.load(cfg_path)
16 | modality = cfg.modality
17 |
18 | trainers_dict = {
19 | 'text': TextTrainer,
20 | 'vision': VisionTrainer,
21 | 'audio': AudioTrainer
22 | }
23 | assert modality in trainers_dict.keys(), f'invalid modality `{cfg.modality}`, expected {list(trainers_dict.keys())}'
24 | trainer = trainers_dict[modality](cfg)
25 | trainer.train()
26 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 |
8 | def __init__(self, name, fmt=':f'):
9 | self.name = name
10 | self.fmt = fmt
11 | self.reset()
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
25 | def __str__(self):
26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
27 | return fmtstr.format(**self.__dict__)
28 |
29 |
30 | def maybe_save_checkpoint(model, optimizer, path, epoch_num, save_freq):
31 | """
32 | Save a checkpoint specific to Data2Vec
33 | Args:
34 | model: a nn.Module instance
35 | optimizer
36 | path: path to save checkpoint to
37 | epoch_num: current epoch number
38 | save_freq: save frequency based on epoch number
39 |
40 | """
41 | if not os.path.exists(path):
42 | os.makedirs(path)
43 | path = os.path.join(path, f'{epoch_num}.pt')
44 | if epoch_num % save_freq == 0:
45 | checkpoint = {'data2vec': model.state_dict(),
46 | 'encoder': model.encoder.encoder.state_dict(),
47 | 'optimizer': optimizer.state_dict()}
48 | torch.save(checkpoint, path)
49 | print(f'Saved checkpoint to `{path}`')
50 |
--------------------------------------------------------------------------------
/vision/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arxyzan/data2vec-pytorch/61c0aa8d500c4101f8626b080bfe2b6263f5695e/vision/__init__.py
--------------------------------------------------------------------------------
/vision/configs/beit-pretraining.yaml:
--------------------------------------------------------------------------------
1 | modality: 'vision'
2 | device: 'cuda'
3 | model:
4 | encoder_checkpoint: 'microsoft/beit-base-patch16-224-pt22k'
5 | embed_dim: 768
6 | average_top_k_layers: 6
7 | head_layers: 2
8 | num_classes: 1000
9 | normalize_targets: false
10 | ema_decay: 0.9998
11 | ema_end_decay: 0.9999
12 | ema_anneal_end_step: 300000
13 | dataset:
14 | path:
15 | train: 'vision/dummy_data'
16 | test: 'vision/dummy_data'
17 | input_size: 224
18 | interpolation: 'bicubic'
19 | patch_size: 16
20 | num_patches: 14
21 | num_mask_patches: 120
22 | max_mask_patches_per_block: 196
23 | min_mask_patches_per_block: 16
24 | imagenet_default_mean_and_std: false
25 | train:
26 | num_epochs: 800
27 | batch_size: 16
28 | eval_batch_size: 16
29 | shuffle: true
30 | save_ckpt_freq: 20
31 | checkpoints_dir: 'vision/checkpoints/beit-pretrain'
32 | log_dir: 'vision/logs/beit-pretrain'
33 | criterion:
34 | loss_beta: 2
35 | optimizer:
36 | lr: 2e-3
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/vision/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.datasets import ImageFolder
3 |
4 | from .transforms import MIMTransform
5 |
6 |
7 | class MIMPretrainingDataset(ImageFolder):
8 | """
9 | Dataset for Masked Image Modeling derived from BEiT.
10 |
11 | Given an image, the common transforms and augmentations are applied like random crop, color jitter, etc., then the
12 | image is split into 14x14 patches and some patches are masked randomly. The input image to the model is the masked
13 | image and the target image is the full image
14 |
15 | Args:
16 | cfg (DictConfig): config containing model, dataset, etc. properties
17 | split: either `train` or `test`
18 | **kwargs: extra args which are set as dataset properties
19 |
20 | """
21 |
22 | def __init__(self, cfg, split, **kwargs):
23 | super(MIMPretrainingDataset, self).__init__(root=cfg.dataset.path[split])
24 | self.transform = MIMTransform(cfg.dataset)
25 | self.input_size = cfg.dataset.input_size
26 | self.device = cfg.device
27 | self.__dict__.update(kwargs)
28 |
29 | def __getitem__(self, index):
30 | """
31 | Load image from disk, transform the image (augmentation and randomly mask some patches)
32 |
33 | Args:
34 | index: index to the image
35 |
36 | Returns:
37 | a tuple of masked image, unmasked target image and bool masked positions
38 | """
39 | path, target = self.samples[index]
40 | image = self.loader(path)
41 | image, mask = self.transform(image)
42 | mask = mask.reshape(1, 14, 14, 1, 1)
43 | image = image.reshape(-1, 14, 14, 16, 16)
44 | masked_image = (image * mask).reshape(-1, self.input_size, self.input_size)
45 | target_image = image.reshape(-1, self.input_size, self.input_size)
46 | return masked_image, target_image, mask.flatten().bool()
47 |
48 |
49 | if __name__ == '__main__':
50 | from omegaconf import OmegaConf
51 | from torch.utils.data import DataLoader
52 |
53 | cfg = OmegaConf.load('configs/beit-pretraining.yaml')
54 | cfg.dataset.path = 'dummy_data'
55 | dataset = MIMPretrainingDataset(cfg, split='train')
56 | loader = DataLoader(dataset, batch_size=4)
57 | src, trg = next(iter(loader))
58 | print(src)
59 |
--------------------------------------------------------------------------------
/vision/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModel, AutoConfig
3 | import torch.nn as nn
4 |
5 |
6 | class Encoder(nn.Module):
7 | """
8 | Encoder model using HuggingFace Transformers for vision e.g, BeiT
9 |
10 | Args:
11 | cfg: An omegaconf.DictConf instance containing all the configurations.
12 | **kwargs: extra args which are set as dataset properties
13 | """
14 |
15 | def __init__(self, cfg, **kwargs):
16 | super(Encoder, self).__init__()
17 | self.cfg = cfg
18 | checkpoint = cfg.model.encoder_checkpoint
19 | model_config = AutoConfig.from_pretrained(checkpoint)
20 | self.encoder = AutoModel.from_config(model_config)
21 | self.vocab_size = model_config.vocab_size
22 | self.mask_token = self.encoder.embeddings.mask_token
23 | self.__dict__.update(kwargs)
24 |
25 | def forward(self, inputs, mask=None, **kwargs):
26 | """
27 | Forward inputs through the encoder and extract transformer/attention layers outputs
28 |
29 | Args:
30 | inputs: input pixels with shape [batch_size, channels, height, width]
31 | mask: bool masked indices
32 | **kwargs: keyword args specific to the encoder's forward method
33 |
34 | Returns:
35 | A dictionary of the encoder outputs including transformer layers outputs and attentions outputs
36 | """
37 | # Note: inputs are already masked for MIM so mask is not used
38 | outputs = self.encoder(pixel_values=inputs, output_hidden_states=True, output_attentions=True, **kwargs)
39 | encoder_states = outputs['hidden_states'][:-1] # encoder layers outputs separately
40 | encoder_out = outputs['hidden_states'][-1] # last encoder output (accumulated)
41 | attentions = outputs['attentions']
42 |
43 | # remove cls token from outputs
44 | encoder_states = [output[:, 1:, :] for output in encoder_states]
45 | encoder_out = encoder_out[:, 1:, :]
46 | attentions = [output[:, 1:, 1:] for output in attentions]
47 |
48 | return {
49 | 'encoder_states': encoder_states,
50 | 'encoder_out': encoder_out,
51 | 'attentions': attentions
52 | }
53 |
54 |
55 | if __name__ == '__main__':
56 | from omegaconf import OmegaConf
57 | import numpy as np
58 | from PIL import Image
59 | import requests
60 | from torchvision import transforms as T
61 |
62 | cfg = OmegaConf.load('configs/beit-pretraining.yaml')
63 | model = Encoder(cfg)
64 | url = "http://images.cocodataset.org/val2017/000000039769.jpg"
65 | image = Image.open(requests.get(url, stream=True).raw)
66 | image = T.Compose([T.Resize((224, 224)),
67 | T.ToTensor(),
68 | T.Normalize(mean=.5, std=.5)])(image).unsqueeze(0)
69 | outputs = model(image)
70 | print(outputs)
71 |
--------------------------------------------------------------------------------
/vision/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 | from torch.utils.tensorboard import SummaryWriter
8 | from tqdm import tqdm
9 |
10 | from vision.encoder import Encoder
11 | from vision.dataset import MIMPretrainingDataset
12 | from data2vec import Data2Vec
13 | from utils import AverageMeter, maybe_save_checkpoint
14 |
15 |
16 | class VisionTrainer:
17 | def __init__(self, cfg):
18 | self.cfg = cfg
19 | self.device = cfg.device
20 | self.num_epochs = cfg.train.num_epochs
21 | self.ckpt_dir = cfg.train.checkpoints_dir
22 | self.save_ckpt_freq = cfg.train.save_ckpt_freq
23 | # Model, Criterion, Optimizer
24 | self.encoder = Encoder(cfg=cfg)
25 | self.model = Data2Vec(encoder=self.encoder, cfg=cfg)
26 | self.model.to(self.device)
27 | self.optimizer = optim.Adam(self.model.parameters(), cfg.optimizer.lr)
28 | self.criterion = nn.SmoothL1Loss(reduction='none', beta=cfg.criterion.loss_beta)
29 | self.criterion.to(self.device)
30 | # Datasets & Data Loaders
31 | self.train_dataset = MIMPretrainingDataset(cfg, split='train')
32 | self.test_dataset = MIMPretrainingDataset(cfg, split='test')
33 | self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.train.batch_size, shuffle=cfg.train.shuffle)
34 | self.test_loader = DataLoader(self.test_dataset, batch_size=cfg.train.eval_batch_size, shuffle=cfg.train.shuffle)
35 |
36 | # Tensorboard
37 | self.tensorboard = SummaryWriter(log_dir=self.cfg.train.log_dir)
38 |
39 | # Trackers
40 | self.loss_tracker = AverageMeter('loss')
41 |
42 | def train_step(self, batch):
43 | """
44 | Train one batch of data
45 | Args:
46 | batch: A batch of data, src, trg of shape [N, C, H, W] and mask of shape [N, num_total_patches]
47 |
48 | Returns:
49 | Loss value
50 | """
51 | src, trg, mask = batch
52 | src = src.to(self.device)
53 | trg = trg.to(self.device)
54 | mask = mask.to(self.device)
55 |
56 | x, y = self.model(src, trg, mask)
57 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0))
58 | loss.backward()
59 | self.optimizer.step()
60 | self.optimizer.zero_grad()
61 |
62 | return loss.item()
63 |
64 | def test_step(self, batch):
65 | """
66 | Evaluate one batch of data
67 | Args:
68 | batch: A batch of data, src, trg of shape [N, C, H, W] and mask of shape [N, num_total_patches]
69 |
70 | Returns:
71 | Loss value
72 | """
73 | src, trg, mask = batch
74 | src = src.to(self.device)
75 | trg = trg.to(self.device)
76 | mask = mask.to(self.device)
77 |
78 | x, y = self.model(src, trg, mask)
79 | loss = self.criterion(x.float(), y.float()).sum(dim=-1).sum().div(x.size(0))
80 |
81 | return loss.item()
82 |
83 | def train_epoch(self, epoch_num):
84 | """
85 | Train the model for one epoch
86 | Args:
87 | epoch_num: number of the current epoch
88 |
89 | Returns:
90 | Average loss through the whole epoch
91 | """
92 | self.model.train()
93 | self.loss_tracker.reset()
94 | with tqdm(self.train_loader, unit="batch", desc=f'Epoch: {epoch_num}/{self.num_epochs} ',
95 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
96 | for batch in iterator:
97 | loss = self.train_step(batch)
98 | self.model.ema_step()
99 | self.loss_tracker.update(loss)
100 | avg_loss = self.loss_tracker.avg
101 | iterator.set_postfix(loss=avg_loss)
102 |
103 | return avg_loss
104 |
105 | def evaluate(self):
106 | """
107 | Evaluate the model on the test data
108 | Returns:
109 | Average loss on the test set
110 | """
111 | self.model.eval()
112 | self.loss_tracker.reset()
113 | with tqdm(self.test_loader, unit="batch", desc=f'Evaluating... ',
114 | bar_format='{desc:<16}{percentage:3.0f}%|{bar:70}{r_bar}', ascii=" #") as iterator:
115 | with torch.no_grad():
116 | for batch in iterator:
117 | loss = self.test_step(batch)
118 | self.loss_tracker.update(loss)
119 | avg_loss = self.loss_tracker.avg
120 | iterator.set_postfix(loss=avg_loss)
121 |
122 | return avg_loss
123 |
124 | def train(self):
125 | """
126 | Train and evaluate the model on the datasets and save checkpoints and write summaries to TensorBoard.
127 |
128 | """
129 | for epoch in range(1, self.num_epochs + 1):
130 | print()
131 | train_loss = self.train_epoch(epoch)
132 | val_loss = self.evaluate()
133 |
134 | # tensorboard
135 | self.tensorboard.add_scalar('train_loss', train_loss, epoch)
136 | self.tensorboard.add_scalar('val_loss', val_loss, epoch)
137 |
138 | # save checkpoint
139 | maybe_save_checkpoint(self.model, self.optimizer, self.ckpt_dir, epoch, self.save_ckpt_freq)
140 |
--------------------------------------------------------------------------------
/vision/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Transforms and masking strategies for Self Supervised Pretraining of Vision models.
3 | All codes copied from below repos:
4 | https://github.com/microsoft/unilm/tree/master/beit
5 | https://github.com/rwightman/pytorch-image-models/tree/master/timm
6 | https://github.com/facebookresearch/deit
7 | """
8 | import warnings
9 | import random
10 | import math
11 | from PIL import Image
12 | import numpy as np
13 | import torch
14 | from torchvision import transforms
15 | import torchvision.transforms.functional as F
16 | from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD,
17 | IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD)
18 | from timm.data.transforms import RandomResizedCropAndInterpolation
19 |
20 | _pil_interpolation_to_str = {
21 | Image.NEAREST: 'PIL.Image.NEAREST',
22 | Image.BILINEAR: 'PIL.Image.BILINEAR',
23 | Image.BICUBIC: 'PIL.Image.BICUBIC',
24 | Image.LANCZOS: 'PIL.Image.LANCZOS',
25 | Image.HAMMING: 'PIL.Image.HAMMING',
26 | Image.BOX: 'PIL.Image.BOX',
27 | }
28 |
29 |
30 | def _pil_interp(method):
31 | if method == 'bicubic':
32 | return Image.BICUBIC
33 | elif method == 'lanczos':
34 | return Image.LANCZOS
35 | elif method == 'hamming':
36 | return Image.HAMMING
37 | else:
38 | # default bilinear, do we want to allow nearest?
39 | return Image.BILINEAR
40 |
41 |
42 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
43 |
44 |
45 | class MIMTransform(object):
46 | """
47 | Masked Image Modeling transforms based on BEiT. copied from https://github.com/microsoft/unilm/tree/master/beit
48 | """
49 |
50 | def __init__(self, cfg):
51 | imagenet_default_mean_and_std = cfg.imagenet_default_mean_and_std
52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
54 | patch_size = cfg.patch_size
55 | num_patches = cfg.num_patches
56 | assert patch_size * num_patches == cfg.input_size
57 |
58 | self.common_transform = transforms.Compose([
59 | transforms.ColorJitter(0.4, 0.4, 0.4),
60 | transforms.RandomHorizontalFlip(p=0.5),
61 | RandomResizedCropAndInterpolation(size=cfg.input_size, interpolation=cfg.interpolation),
62 | transforms.ToTensor(),
63 | transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))
64 | ])
65 |
66 | self.masked_position_generator = MaskingGenerator(
67 | cfg.num_patches, num_masking_patches=cfg.num_mask_patches,
68 | max_num_patches=cfg.max_mask_patches_per_block,
69 | min_num_patches=cfg.min_mask_patches_per_block,
70 | )
71 |
72 | def __call__(self, image):
73 | return self.common_transform(image), self.masked_position_generator()
74 |
75 |
76 | class MaskingGenerator:
77 | def __init__(
78 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None,
79 | min_aspect=0.3, max_aspect=None):
80 | if not isinstance(input_size, tuple):
81 | input_size = (input_size,) * 2
82 | self.height, self.width = input_size
83 |
84 | self.num_masking_patches = num_masking_patches
85 |
86 | self.min_num_patches = min_num_patches
87 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
88 |
89 | max_aspect = max_aspect or 1 / min_aspect
90 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
91 |
92 | def __repr__(self):
93 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
94 | self.height, self.width, self.min_num_patches, self.max_num_patches,
95 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
96 | return repr_str
97 |
98 | def get_shape(self):
99 | return self.height, self.width
100 |
101 | def _mask(self, mask, max_mask_patches):
102 | delta = 0
103 | for attempt in range(10):
104 | target_area = random.uniform(self.min_num_patches, max_mask_patches)
105 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
106 | h = int(round(math.sqrt(target_area * aspect_ratio)))
107 | w = int(round(math.sqrt(target_area / aspect_ratio)))
108 | if w < self.width and h < self.height:
109 | top = random.randint(0, self.height - h)
110 | left = random.randint(0, self.width - w)
111 |
112 | num_masked = mask[top: top + h, left: left + w].sum()
113 | # Overlap
114 | if 0 < h * w - num_masked <= max_mask_patches:
115 | for i in range(top, top + h):
116 | for j in range(left, left + w):
117 | if mask[i, j] == 0:
118 | mask[i, j] = 1
119 | delta += 1
120 |
121 | if delta > 0:
122 | break
123 | return delta
124 |
125 | def __call__(self):
126 | mask = torch.zeros(self.get_shape(), dtype=torch.int)
127 | mask_count = 0
128 | while mask_count < self.num_masking_patches:
129 | max_mask_patches = self.num_masking_patches - mask_count
130 | max_mask_patches = min(max_mask_patches, self.max_num_patches)
131 |
132 | delta = self._mask(mask, max_mask_patches)
133 | if delta == 0:
134 | break
135 | else:
136 | mask_count += delta
137 |
138 | return mask
139 |
--------------------------------------------------------------------------------