├── IJEPA.png ├── LICENSE ├── README.md ├── finetune_IJEPA.py ├── model.py └── pretrain_IJEPA.py /IJEPA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaasher/I-JEPA/98b4ed2c0232a210ed149821e1d8897678d61eb6/IJEPA.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gabriel Asher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # I-JEPA 2 | Implementation of I-JEPA from ([Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture](https://arxiv.org/abs/2301.08243)) 3 | 4 | Uses @lucidrains x-transfromers (https://github.com/lucidrains/x-transformers) 5 | 6 | Basic Schematic of Architecture: 7 | 8 | ![screenshot](IJEPA.png) 9 | 10 | In order to run, just run: `python pretrain_IJEPA.py` or `python finetune_IJEPA.py' on your command line 11 | 12 | TODO: 13 | - linear probing setup 14 | 15 | 16 | 17 | Citation: 18 | 19 | ``` 20 | @article{assran2023self, 21 | title={Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture}, 22 | author={Assran, Mahmoud and Duval, Quentin and Misra, Ishan and Bojanowski, Piotr and Vincent, Pascal and Rabbat, Michael and LeCun, Yann and Ballas, Nicolas}, 23 | journal={arXiv preprint arXiv:2301.08243}, 24 | year={2023} 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /finetune_IJEPA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytorch_lightning as pl 3 | import torch.nn as nn 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.utils.data import DataLoader 7 | from pytorch_lightning.callbacks import ( 8 | ModelCheckpoint, 9 | LearningRateMonitor, 10 | ModelSummary, 11 | ) 12 | from pytorch_lightning.loggers import WandbLogger 13 | from model import IJEPA_base 14 | from pretrain_IJEPA import IJEPA 15 | 16 | 17 | '''Dummy Dataset''' 18 | class IJEPADataset(Dataset): 19 | def __init__(self, 20 | dataset_path, 21 | stage='train', 22 | ): 23 | super().__init__() 24 | img1 =torch.randn(3, 224, 224) 25 | self.data = img1.repeat(100, 1, 1, 1) 26 | label = torch.tensor([0., 0., 0., 1., 0.]) 27 | self.label = label.repeat(100, 1) 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | def __getitem__(self, index): 33 | return self.data[index], self.label[index] 34 | 35 | ''' 36 | Placeholder for datamodule in pytorch lightning 37 | ''' 38 | class D2VDataModule(pl.LightningDataModule): 39 | def __init__(self, 40 | dataset_path, 41 | batch_size=16, 42 | num_workers=4, 43 | pin_memory=True, 44 | shuffle=True 45 | ): 46 | super().__init__() 47 | 48 | self.dataset_path = dataset_path 49 | self.batch_size = batch_size 50 | self.num_workers = num_workers 51 | self.pin_memory = pin_memory 52 | self.shuffle = shuffle 53 | 54 | def setup(self, stage=None): 55 | self.train_dataset = IJEPADataset(dataset_path=self.dataset_path, stage='train') 56 | self.val_dataset = IJEPADataset(dataset_path=self.dataset_path, stage='val') 57 | 58 | def train_dataloader(self): 59 | return DataLoader( 60 | self.train_dataset, 61 | batch_size=self.batch_size, 62 | num_workers=self.num_workers, 63 | pin_memory=self.pin_memory, 64 | shuffle=self.shuffle, 65 | ) 66 | 67 | def val_dataloader(self): 68 | return DataLoader( 69 | self.val_dataset, 70 | batch_size=self.batch_size, 71 | num_workers=self.num_workers, 72 | pin_memory=self.pin_memory, 73 | shuffle=False, 74 | ) 75 | 76 | ''' 77 | Finetune IJEPA 78 | ''' 79 | class IJEPA_FT(pl.LightningModule): 80 | #take pretrained model path, number of classes, learning rate, weight decay, and drop path as input 81 | def __init__(self, pretrained_model_path, num_classes, lr=1e-3, weight_decay=0, drop_path=0.1): 82 | 83 | super().__init__() 84 | self.save_hyperparameters() 85 | 86 | #set parameters 87 | self.lr = lr 88 | self.weight_decay = weight_decay 89 | self.drop_path = drop_path 90 | 91 | #define model layers 92 | self.pretrained_model = IJEPA.load_from_checkpoint(pretrained_model_path) 93 | self.pretrained_model.model.mode = "test" 94 | self.pretrained_model.model.layer_dropout = self.drop_path 95 | self.average_pool = nn.AvgPool1d(kernel_size=self.pretrained_model.num_tokens) 96 | 97 | #mlp head 98 | self.mlp_head = nn.Sequential( 99 | nn.LayerNorm(self.pretrained_model.embed_dim), 100 | nn.Linear(self.pretrained_model.embed_dim, num_classes), 101 | ) 102 | 103 | #define loss 104 | self.criterion = nn.CrossEntropyLoss() 105 | 106 | def forward(self, x): 107 | x = self.pretrained_model.model(x) 108 | 109 | x = x.permute(0, 2, 1) 110 | x = self.average_pool(x) #conduct average pool like in paper 111 | x = x.squeeze(-1) 112 | x = self.mlp_head(x) #pass through mlp head 113 | return x 114 | 115 | def training_step(self, batch, batch_idx): 116 | x, y = batch 117 | y_hat = self(x) 118 | loss = self.criterion(y_hat, y) #calculate loss 119 | accuracy = (y_hat.argmax(dim=1) == y.argmax(dim=1)).float().mean() #calculate accuracy 120 | self.log('train_accuracy', accuracy) 121 | self.log('train_loss', loss) 122 | return loss 123 | 124 | def validation_step(self, batch, batch_idx): 125 | x, y = batch 126 | y_hat = self(x) 127 | loss = self.criterion(y_hat, y) 128 | accuracy = (y_hat.argmax(dim=1) == y.argmax(dim=1)).float().mean() 129 | self.log('val_loss', loss) 130 | self.log('val_accuracy', accuracy) 131 | return loss 132 | 133 | def predict_step(self, batch, batch_idx, dataloader_idx): 134 | return self(batch[1]) 135 | 136 | def configure_optimizers(self): 137 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 138 | return optimizer 139 | 140 | if __name__ == '__main__': 141 | dataset = D2VDataModule(dataset_path='data') 142 | 143 | model = IJEPA_FT(pretrained_model_path='.ckpt', num_classes=5) 144 | 145 | lr_monitor = LearningRateMonitor(logging_interval="step") 146 | model_summary = ModelSummary(max_depth=2) 147 | 148 | trainer = pl.Trainer( 149 | accelerator='cpu', 150 | precision=16, 151 | max_epochs=10, 152 | callbacks=[lr_monitor, model_summary], 153 | gradient_clip_val=.1, 154 | ) 155 | 156 | trainer.fit(model, dataset) 157 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | from x_transformers import Encoder, Decoder 7 | import copy 8 | 9 | ''' 10 | PatchEmbed class, adapted from https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632 I think, but I dont have medium premium so idk 11 | - This class is used to convert the image into patches using a convolutional layer 12 | ''' 13 | class PatchEmbed(nn.Module): 14 | """Image to Patch Embedding""" 15 | 16 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=64): 17 | super().__init__() 18 | if isinstance(img_size, int): 19 | img_size = img_size, img_size 20 | if isinstance(patch_size, int): 21 | patch_size = patch_size, patch_size 22 | #calculate the number of patches 23 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 24 | 25 | #convolutional layer to convert the image into patches 26 | self.conv = nn.Conv2d( 27 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 28 | ) 29 | 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | #flatten the patches 34 | x = rearrange(x, 'b e h w -> b (h w) e') 35 | return x 36 | 37 | '''Lightweight Predictor Module using VIT to predict target patches from context patches''' 38 | class Predictor(nn.Module): 39 | def __init__(self, embed_dim, num_heads, depth): 40 | super().__init__() 41 | 42 | self.predictor = Decoder(dim = embed_dim, depth = depth, heads = num_heads) 43 | def forward(self, context_encoding, target_masks): 44 | x = torch.cat((context_encoding, target_masks), dim = 1) 45 | x = self.predictor(x) 46 | #return last len(target_masks) tokens 47 | l = x.shape[1] 48 | return x[:, l - target_masks.shape[1]:, :] 49 | 50 | '''Main Model Class''' 51 | class IJEPA_base(nn.Module): 52 | def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, pred_depth, num_heads, post_emb_norm=False, M = 4, mode="train", layer_dropout=0.): 53 | super().__init__() 54 | self.M = M 55 | self.mode = mode 56 | self.layer_dropout = layer_dropout 57 | 58 | #define the patch embedding and positional embedding 59 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 60 | self.patch_dim = (self.patch_embed.patch_shape[0], self.patch_embed.patch_shape[1]) 61 | self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1] 62 | self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim)) 63 | 64 | #define the cls and mask tokens 65 | self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim)) 66 | nn.init.trunc_normal_(self.mask_token, 0.02) 67 | 68 | #define the encoder and decoder, as well as the layer normalization and dropout 69 | self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity() 70 | self.norm = nn.LayerNorm(embed_dim) 71 | self.teacher_encoder = Encoder( 72 | dim=embed_dim, 73 | heads=num_heads, 74 | depth=enc_depth, 75 | layer_dropout=self.layer_dropout, 76 | ) 77 | self.student_encoder = copy.deepcopy(self.teacher_encoder).cuda() 78 | self.predictor = Predictor(embed_dim, num_heads, pred_depth) 79 | 80 | @torch.no_grad() 81 | def get_target_block(self, target_encoder, x, patch_dim, aspect_ratio, scale, M): 82 | #get the target block 83 | target_encoder = target_encoder.eval() 84 | x = target_encoder(x) 85 | x = self.norm(x) 86 | #get the patch dimensions 87 | patch_h, patch_w = patch_dim 88 | #get the number of patches 89 | num_patches = patch_h * patch_w 90 | #get the number of patches in the target block 91 | num_patches_block = int(patch_h * patch_w * scale) 92 | #get the height and width of the target block with aspect ratio 93 | block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) 94 | block_w = int(aspect_ratio * block_h) 95 | #get the patches in the target block 96 | target_block = torch.zeros((M, x.shape[0], block_h*block_w, x.shape[2])) 97 | target_patches = [] 98 | all_patches = [] 99 | for z in range(M): 100 | #get the starting patch 101 | start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item() 102 | start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item() 103 | start_patch = start_patch_h * patch_w + start_patch_w 104 | 105 | patches = [] 106 | #get the patches in the target block 107 | for i in range(block_h): 108 | for j in range(block_w): 109 | patches.append(start_patch + i * patch_w + j) 110 | if start_patch + i * patch_w + j not in all_patches: 111 | all_patches.append(start_patch + i * patch_w + j) 112 | 113 | #get the target block 114 | target_patches.append(patches) 115 | target_block[z] = x[:, patches, :] 116 | return target_block.cuda(), target_patches, all_patches 117 | 118 | def get_context_block(self, x, patch_dim, aspect_ratio, scale, target_patches): 119 | patch_h, patch_w = patch_dim 120 | #get the number of patches in the target block 121 | num_patches_block = int(patch_h * patch_w * scale) 122 | #get the height and width of the target block with aspect ratio 123 | block_h = int(torch.sqrt(torch.tensor(num_patches_block / aspect_ratio))) 124 | block_w = int(aspect_ratio * block_h) 125 | #get the starting patch 126 | start_patch_h = torch.randint(0, patch_h - block_h+1, (1,)).item() 127 | start_patch_w = torch.randint(0, patch_w - block_w+1, (1,)).item() 128 | start_patch = start_patch_h * patch_w + start_patch_w 129 | #get the patches in the context_block 130 | patches = [] 131 | for i in range(block_h): 132 | for j in range(block_w): 133 | if start_patch + i * patch_w + j not in target_patches: #remove the target patches 134 | patches.append(start_patch + i * patch_w + j) 135 | return x[:, patches, :] 136 | 137 | 138 | def forward(self, x, target_aspect_ratio=1, target_scale=1, context_aspect_ratio=1, context_scale=1): 139 | #get the patch embeddings 140 | x = self.patch_embed(x) 141 | b, n, e = x.shape 142 | #add the positional embeddings 143 | x = x + self.pos_embedding 144 | #normalize the embeddings 145 | x = self.post_emb_norm(x) 146 | #if mode is test, we get return full embedding: 147 | if self.mode == 'test': 148 | return self.student_encoder(x) 149 | # #get target embeddings 150 | target_blocks, target_patches, all_patches = self.get_target_block(self.teacher_encoder, x, self.patch_dim, target_aspect_ratio, target_scale, self.M) 151 | m, b, n, e = target_blocks.shape 152 | #get context embedding 153 | 154 | context_block = self.get_context_block(x, self.patch_dim, context_aspect_ratio, context_scale, all_patches) 155 | context_encoding = self.student_encoder(context_block) 156 | context_encoding = self.norm(context_encoding) 157 | 158 | 159 | prediction_blocks = torch.zeros((m, b, n, e)).cuda() 160 | #get the prediction blocks, predict each target block separately 161 | for i in range(m): 162 | target_masks = self.mask_token.repeat(b, n, 1) 163 | target_pos_embedding = self.pos_embedding[:, target_patches[i], :] 164 | target_masks = target_masks + target_pos_embedding 165 | prediction_blocks[i] = self.predictor(context_encoding, target_masks) 166 | 167 | return prediction_blocks, target_blocks -------------------------------------------------------------------------------- /pretrain_IJEPA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytorch_lightning as pl 3 | import torch.nn as nn 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.utils.data import DataLoader 7 | from pytorch_lightning.callbacks import ( 8 | ModelCheckpoint, 9 | LearningRateMonitor, 10 | ModelSummary, 11 | ) 12 | from pytorch_lightning.loggers import WandbLogger 13 | from model import IJEPA_base 14 | 15 | 16 | '''Dummy Dataset''' 17 | class IJEPADataset(Dataset): 18 | def __init__(self, 19 | dataset_path, 20 | stage='train', 21 | ): 22 | super().__init__() 23 | img1 =torch.randn(3, 224, 224) 24 | self.data = img1.repeat(100, 1, 1, 1) 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | def __getitem__(self, index): 30 | return self.data[index] 31 | 32 | 33 | '''Placeholder for datamodule in pytorch lightning''' 34 | ''' 35 | Placeholder for datamodule in pytorch lightning 36 | ''' 37 | class D2VDataModule(pl.LightningDataModule): 38 | def __init__(self, 39 | dataset_path, 40 | batch_size=16, 41 | num_workers=4, 42 | pin_memory=True, 43 | shuffle=True 44 | ): 45 | super().__init__() 46 | 47 | self.dataset_path = dataset_path 48 | self.batch_size = batch_size 49 | self.num_workers = num_workers 50 | self.shuffle = shuffle 51 | 52 | def setup(self, stage=None): 53 | self.train_dataset = IJEPADataset(dataset_path=self.dataset_path, stage='train') 54 | self.val_dataset = IJEPADataset(dataset_path=self.dataset_path, stage='val') 55 | 56 | def train_dataloader(self): 57 | return DataLoader( 58 | self.train_dataset, 59 | batch_size=self.batch_size, 60 | num_workers=self.num_workers, 61 | shuffle=self.shuffle, 62 | ) 63 | 64 | def val_dataloader(self): 65 | return DataLoader( 66 | self.val_dataset, 67 | batch_size=self.batch_size, 68 | num_workers=self.num_workers, 69 | shuffle=False, 70 | ) 71 | 72 | ''' 73 | pytorch lightning model 74 | ''' 75 | class IJEPA(pl.LightningModule): 76 | def __init__( 77 | self, 78 | img_size=224, 79 | patch_size=16, 80 | in_chans=3, 81 | embed_dim=64, 82 | enc_heads=8, 83 | enc_depth=8, 84 | decoder_depth=6, 85 | lr=1e-6, 86 | weight_decay=0.05, 87 | target_aspect_ratio = (0.75,1.5), 88 | target_scale = (0.15, .2), 89 | context_aspect_ratio = 1, 90 | context_scale = (0.85,1.0), 91 | M = 4, #number of different target blocks 92 | m=0.996, #momentum 93 | m_start_end = (.996, 1.) 94 | 95 | ): 96 | super().__init__() 97 | self.save_hyperparameters() 98 | 99 | #define models 100 | self.model = IJEPA_base(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 101 | enc_depth = enc_depth, num_heads=enc_heads, pred_depth=decoder_depth, M=M) 102 | 103 | #define hyperparameters 104 | self.M = M 105 | self.lr = lr 106 | self.weight_decay = weight_decay 107 | self.m = m 108 | self.target_aspect_ratio = target_aspect_ratio 109 | self.target_scale = target_scale 110 | self.context_aspect_ratio = context_aspect_ratio 111 | self.context_scale = context_scale 112 | self.embed_dim = embed_dim 113 | self.patch_size = patch_size 114 | self.num_tokens = (img_size // patch_size) ** 2 115 | self.m_start_end = m_start_end 116 | 117 | #define loss 118 | self.criterion = nn.MSELoss() 119 | 120 | def forward(self, x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale): 121 | return self.model(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale) 122 | 123 | '''Update momentum for teacher encoder''' 124 | def update_momentum(self, m): 125 | student_model = self.model.student_encoder.eval() 126 | teacher_model = self.model.teacher_encoder.eval() 127 | with torch.no_grad(): 128 | for student_param, teacher_param in zip(student_model.parameters(), teacher_model.parameters()): 129 | teacher_param.data.mul_(other=m).add_(other=student_param.data, alpha=1 - m) 130 | 131 | 132 | def training_step(self, batch, batch_idx): 133 | x = batch 134 | #generate random target and context aspect ratio and scale 135 | target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1]) 136 | target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1]) 137 | context_aspect_ratio = self.context_aspect_ratio 138 | context_scale = np.random.uniform(self.context_scale[0], self.context_scale[1]) 139 | 140 | y_student, y_teacher = self(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale) 141 | loss = self.criterion(y_student, y_teacher) 142 | self.log('train_loss', loss) 143 | 144 | return loss 145 | 146 | def validation_step(self, batch, batch_idx): 147 | x = batch 148 | target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1]) 149 | target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1]) 150 | context_aspect_ratio = self.context_aspect_ratio 151 | context_scale = np.random.uniform(self.context_scale[0], self.context_scale[1]) 152 | 153 | y_student, y_teacher = self(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale) 154 | loss = self.criterion(y_student, y_teacher) 155 | self.log('val_loss', loss) 156 | 157 | return loss 158 | 159 | def predict_step(self, batch, batch_idx, dataloader_idx): 160 | target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1]) 161 | target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1]) 162 | context_aspect_ratio = self.context_aspect_ratio 163 | context_scale = 1 164 | self.model.mode = "test" 165 | 166 | return self(batch, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale) #just get teacher embedding 167 | 168 | def on_after_backward(self): 169 | self.update_momentum(self.m) 170 | self.m += (self.m_start_end[1] - self.m_start_end[0]) / self.trainer.estimated_stepping_batches 171 | 172 | 173 | def configure_optimizers(self): 174 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 175 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 176 | optimizer, 177 | max_lr=self.lr, 178 | total_steps=self.trainer.estimated_stepping_batches, 179 | ) 180 | return { 181 | "optimizer": optimizer, 182 | "lr_scheduler": { 183 | "scheduler": scheduler, 184 | "interval": "step", 185 | }, 186 | } 187 | 188 | 189 | if __name__ == '__main__': 190 | dataset = D2VDataModule(dataset_path='data') 191 | 192 | model = IJEPA(img_size=224, patch_size=16, in_chans=3, embed_dim=64, enc_heads=8, enc_depth=8, decoder_depth=6, lr=1e-3) 193 | 194 | lr_monitor = LearningRateMonitor(logging_interval="step") 195 | model_summary = ModelSummary(max_depth=2) 196 | 197 | trainer = pl.Trainer( 198 | accelerator='gpu', 199 | devices=1, 200 | precision=16, 201 | max_epochs=10, 202 | callbacks=[lr_monitor, model_summary], 203 | gradient_clip_val=.1, 204 | ) 205 | 206 | trainer.fit(model, dataset) 207 | --------------------------------------------------------------------------------