├── mmcr ├── __init__.py ├── cifar_stl │ ├── model_select.py │ ├── loss_mmcr.py │ ├── models.py │ ├── train.py │ ├── knn.py │ ├── train_linear_classifier.py │ └── data.py └── imagenet │ ├── loss_mmcr.py │ ├── loss_mmcr_momentum.py │ ├── misc.py │ ├── knn.py │ ├── train.py │ ├── train_linear_classifier.py │ ├── models.py │ ├── data.py │ └── distributed.py ├── setup.py ├── .gitignore ├── LICENSE ├── model_select_cifar_stl.py ├── linear_classifier_imagenet.py ├── pretrain_cifar_stl.py ├── pretrain_imagenet.py ├── README.md └── environment.yml /mmcr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='mmcr', 5 | packages=find_packages(), 6 | ) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | slurm 3 | training_checkpoints 4 | *__pycache__* 5 | *.egg-info 6 | *dev.py 7 | *.ipynb 8 | .vscode 9 | .empty 10 | .ipynb_checkpoints 11 | notebooks/* 12 | notebooks -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Thomas Yerxa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model_select_cifar_stl.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("..") 4 | 5 | import submitit 6 | from mmcr.cifar_stl.model_select import select_model 7 | from argparse import ArgumentParser 8 | 9 | parser = ArgumentParser() 10 | parser.add_argument("--batch_size", type=int, default=1024) 11 | parser.add_argument("--dataset", type=str, default="cifar10") 12 | parser.add_argument("--lr", type=float, default=0.1) 13 | parser.add_argument("--epochs", type=int, default=50) 14 | parser.add_argument("--checkpoint_dir", type=str) 15 | parser.add_argument( 16 | "--save_path", type=str, default="./training_checkpoints/cifar_stl/" 17 | ) 18 | 19 | args = parser.parse_args() 20 | 21 | # submitit stuff 22 | slurm_folder = "./slurm/classifier/%j" 23 | executor = submitit.AutoExecutor(folder=slurm_folder) 24 | executor.update_parameters(mem_gb=128, timeout_min=10000) 25 | executor.update_parameters(slurm_array_parallelism=1024) 26 | executor.update_parameters(gpus_per_node=1) 27 | executor.update_parameters(cpus_per_task=13) 28 | executor.update_parameters(slurm_partition="gpu") 29 | executor.update_parameters(constraint="a100-80gb") 30 | executor.update_parameters(name="model_select") 31 | 32 | job = executor.submit( 33 | select_model, 34 | checkpoint_directory=args.checkpoint_dir, 35 | dataset=args.dataset, 36 | batch_size=args.batch_size, 37 | lr=args.lr, 38 | epochs=args.epochs, 39 | save_dir=args.save_path, 40 | ) 41 | -------------------------------------------------------------------------------- /linear_classifier_imagenet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append("..") 4 | 5 | import submitit 6 | from mmcr.imagenet.train_linear_classifier import train_classifier 7 | from torch import nn 8 | from argparse import ArgumentParser 9 | 10 | parser = ArgumentParser() 11 | parser.add_argument("--batch_size", type=int, default=2048) 12 | parser.add_argument("--lr", type=float, default=0.3) 13 | parser.add_argument("--epochs", type=int, default=100) 14 | parser.add_argument("--model_path", type=str) 15 | parser.add_argument("--save_path", type=str, default="./training_checkpoints/imagenet/") 16 | parser.add_argument("--save_name", type=str, default="classifier") 17 | parser.add_argument('--use_zip', action="store_true") 18 | 19 | args = parser.parse_args() 20 | 21 | # submitit stuff 22 | slurm_folder = "./slurm/classifier/%j" 23 | executor = submitit.AutoExecutor(folder=slurm_folder) 24 | executor.update_parameters(mem_gb=128, timeout_min=10000) 25 | executor.update_parameters(slurm_array_parallelism=1024) 26 | executor.update_parameters(gpus_per_node=1) 27 | executor.update_parameters(cpus_per_task=13) 28 | executor.update_parameters(slurm_partition="gpu") 29 | executor.update_parameters(constraint="a100-80gb") 30 | executor.update_parameters(name="classifier_train") 31 | 32 | job = executor.submit( 33 | train_classifier, 34 | model_path=args.model_path, 35 | batch_size=args.batch_size, 36 | lr=args.lr, 37 | epochs=args.epochs, 38 | save_path=args.save_path, 39 | save_name=args.save_name, 40 | use_zip=args.use_zip 41 | ) -------------------------------------------------------------------------------- /pretrain_cifar_stl.py: -------------------------------------------------------------------------------- 1 | from mmcr.cifar_stl.train import train 2 | 3 | from argparse import ArgumentParser 4 | import submitit 5 | 6 | parser = ArgumentParser() 7 | parser.add_argument("--dataset", type=str, default="cifar10") 8 | parser.add_argument("--batch_size", type=int, default=32) 9 | parser.add_argument("--n_aug", type=int, default=40) 10 | parser.add_argument("--lr", type=float, default=1e-3) 11 | parser.add_argument("--lmbda", type=float, default=0.0) 12 | parser.add_argument("--epochs", type=int, default=500) 13 | parser.add_argument("--num_workers", type=int, default=16) 14 | parser.add_argument("--save_freq", type=int, default=5) 15 | parser.add_argument( 16 | "--save_folder", 17 | type=str, 18 | default="./training_checkpoints/cifar_stl", 19 | ) 20 | args = parser.parse_args() 21 | 22 | # submitit job management 23 | executor = submitit.AutoExecutor(folder="./slurm/pretrain/%j", slurm_max_num_timeout=30) 24 | 25 | executor.update_parameters( 26 | mem_gb=128, 27 | gpus_per_node=1, 28 | tasks_per_node=1, 29 | cpus_per_task=args.num_workers, 30 | nodes=1, 31 | name="MMCR", 32 | timeout_min=60 * 72, 33 | slurm_partition="gpu", 34 | constraint="a100-80gb", 35 | slurm_array_parallelism=512, 36 | ) 37 | 38 | job = executor.submit( 39 | train, 40 | dataset=args.dataset, 41 | n_aug=args.n_aug, 42 | batch_size=args.batch_size, 43 | lr=args.lr, 44 | epochs=args.epochs, 45 | lmbda=args.lmbda, 46 | save_folder=args.save_folder, 47 | save_freq=args.save_freq, 48 | ) 49 | -------------------------------------------------------------------------------- /mmcr/cifar_stl/model_select.py: -------------------------------------------------------------------------------- 1 | import os 2 | from mmcr.cifar_stl.train_linear_classifier import train_classifier 3 | import torch 4 | 5 | 6 | def select_model( 7 | checkpoint_directory: str, 8 | dataset: str, 9 | save_dir: str, 10 | batch_size: int = 1024, 11 | epochs: int = 50, 12 | lr: float = 0.1, 13 | ): 14 | checkpoints = os.listdir(checkpoint_directory) 15 | accs = [] 16 | max_acc = 0 17 | for checkpoint in checkpoints: 18 | acc_chkp = float(checkpoint.split("_")[-1][:-4]) 19 | if acc_chkp > max_acc: 20 | max_acc = acc_chkp 21 | 22 | # will train classifiers for models with monitor accuracy within 1% of max accuracy 23 | checkpoints_to_test = [] 24 | for checkpoint in checkpoints: 25 | acc_chkp = float(checkpoint.split("_")[-1][:-4]) 26 | if acc_chkp >= max_acc - 1.0: 27 | checkpoints_to_test.append(checkpoint) 28 | 29 | print("Number of checkpoints to test: ", len(checkpoints_to_test)) 30 | best_acc = 0.0 31 | for checkpoint in checkpoints_to_test: 32 | model, acc = train_classifier( 33 | checkpoint_directory + checkpoint, 34 | dataset=dataset, 35 | batch_size=batch_size, 36 | epochs=epochs, 37 | lr=lr, 38 | ) 39 | 40 | if acc > best_acc: 41 | best_acc = acc 42 | best_model = model 43 | 44 | print() 45 | print("New best accuracy: ", best_acc) 46 | print() 47 | 48 | if save_dir is not None: 49 | to_save = {"model": best_model.state_dict(), "acc": best_acc} 50 | torch.save(to_save, save_dir + "best_model.pth") 51 | 52 | return best_acc 53 | -------------------------------------------------------------------------------- /mmcr/imagenet/loss_mmcr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import torch.nn.functional as F 4 | import einops 5 | import random 6 | from typing import Tuple 7 | 8 | import sys 9 | 10 | 11 | class MMCR_Loss(nn.Module): 12 | def __init__(self, lmbda: float, n_aug: int, distributed: bool = True): 13 | super(MMCR_Loss, self).__init__() 14 | self.lmbda = lmbda 15 | self.n_aug = n_aug 16 | self.distribured = distributed 17 | 18 | def forward(self, z: Tensor) -> Tuple[Tensor, dict]: 19 | z = F.normalize(z, dim=-1) 20 | z_local_ = einops.rearrange(z, "(B N) C -> B C N", N=self.n_aug) 21 | 22 | # gather across devices into list 23 | if self.distribured: 24 | z_list = [ 25 | torch.zeros_like(z_local_) 26 | for i in range(torch.distributed.get_world_size()) 27 | ] 28 | torch.distributed.all_gather(z_list, z_local_, async_op=False) 29 | z_list[torch.distributed.get_rank()] = z_local_ 30 | 31 | # append all 32 | z_local = torch.cat(z_list) 33 | 34 | else: 35 | z_local = z_local_ 36 | 37 | centroids = torch.mean(z_local, dim=-1) 38 | if self.lmbda != 0.0: 39 | local_nuc = torch.linalg.svdvals(z_local).sum() 40 | else: 41 | local_nuc = torch.tensor(0.0) 42 | global_nuc = torch.linalg.svdvals(centroids).sum() 43 | 44 | batch_size = z_local.shape[0] 45 | loss = self.lmbda * local_nuc / batch_size - global_nuc 46 | 47 | loss_dict = { 48 | "loss": loss.item(), 49 | "local_nuc": local_nuc.item(), 50 | "global_nuc": global_nuc.item(), 51 | } 52 | 53 | return loss, loss_dict 54 | -------------------------------------------------------------------------------- /mmcr/cifar_stl/loss_mmcr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import torch.nn.functional as F 4 | import einops 5 | import random 6 | from typing import Tuple 7 | 8 | import sys 9 | 10 | 11 | class MMCR_Loss(nn.Module): 12 | def __init__(self, lmbda: float, n_aug: int, distributed: bool = False): 13 | super(MMCR_Loss, self).__init__() 14 | self.lmbda = lmbda 15 | self.n_aug = n_aug 16 | self.distributed = distributed 17 | self.first_time = True 18 | 19 | def forward(self, z: Tensor) -> Tuple[Tensor, dict]: 20 | z = F.normalize(z, dim=-1) 21 | z_local_ = einops.rearrange(z, "(B N) C -> B C N", N=self.n_aug) 22 | 23 | # gather across devices into list 24 | if self.distributed: 25 | z_list = [ 26 | torch.zeros_like(z_local_) 27 | for i in range(torch.distributed.get_world_size()) 28 | ] 29 | torch.distributed.all_gather(z_list, z_local_, async_op=False) 30 | z_list[torch.distributed.get_rank()] = z_local_ 31 | 32 | # append all 33 | z_local = torch.cat(z_list) 34 | 35 | else: 36 | z_local = z_local_ 37 | 38 | centroids = torch.mean(z_local, dim=-1) 39 | if self.lmbda != 0.0: 40 | local_nuc = torch.linalg.svdvals(z_local).sum() 41 | else: 42 | local_nuc = torch.tensor(0.0) 43 | global_nuc = torch.linalg.svdvals(centroids).sum() 44 | 45 | batch_size = z_local.shape[0] 46 | loss = self.lmbda * local_nuc / batch_size - global_nuc 47 | 48 | loss_dict = { 49 | "loss": loss.item(), 50 | "local_nuc": local_nuc.item(), 51 | "global_nuc": global_nuc.item(), 52 | } 53 | 54 | self.first_time = False 55 | 56 | return loss, loss_dict 57 | -------------------------------------------------------------------------------- /mmcr/cifar_stl/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.resnet import resnet50 5 | from torch import Tensor 6 | from typing import Tuple 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, projector_dims: list = [512, 128], dataset: str = "cifar10"): 11 | super(Model, self).__init__() 12 | 13 | self.f = [] 14 | for name, module in resnet50().named_children(): 15 | if name == "conv1": 16 | module = nn.Conv2d( 17 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 18 | ) 19 | if dataset == "cifar10" or "cifar100": 20 | if not isinstance(module, nn.Linear) and not isinstance( 21 | module, nn.MaxPool2d 22 | ): 23 | self.f.append(module) 24 | elif dataset == "stl10": 25 | if not isinstance(module, nn.Linear): 26 | self.f.append(module) 27 | # encoder 28 | self.f = nn.Sequential(*self.f) 29 | 30 | # projection head (Following exactly barlow twins offical repo) 31 | projector_dims = [2048] + projector_dims 32 | layers = [] 33 | for i in range(len(projector_dims) - 2): 34 | layers.append( 35 | nn.Linear(projector_dims[i], projector_dims[i + 1], bias=False) 36 | ) 37 | layers.append(nn.BatchNorm1d(projector_dims[i + 1])) 38 | layers.append(nn.ReLU()) 39 | layers.append(nn.Linear(projector_dims[-2], projector_dims[-1], bias=False)) 40 | self.g = nn.Sequential(*layers) 41 | 42 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: 43 | x = self.f(x) 44 | feature = torch.flatten(x, start_dim=1) 45 | out = self.g(feature) 46 | 47 | return feature, out 48 | -------------------------------------------------------------------------------- /pretrain_imagenet.py: -------------------------------------------------------------------------------- 1 | from mmcr.imagenet.train import train 2 | from mmcr.imagenet.distributed import init_dist_node 3 | 4 | from argparse import ArgumentParser 5 | import submitit 6 | 7 | parser = ArgumentParser() 8 | parser.add_argument("--batch_size", type=int, default=2048) 9 | parser.add_argument("--dataset", type=str, default="imagenet") 10 | parser.add_argument("--n_aug", type=int, default=2) 11 | parser.add_argument("--lr", type=float, default=0.6) 12 | parser.add_argument("--tau", type=float, default=0.99) 13 | parser.add_argument("--lmbda", type=float, default=0.0) 14 | parser.add_argument("--epochs", type=int, default=100) 15 | parser.add_argument("--imagenet_path", type=str, default="./datasets/ILSVRC_2012") 16 | parser.add_argument("--zip_path", type=str, default="./datasets/ILSVRC_2012.zip") 17 | parser.add_argument("--num_workers", type=int, default=16) 18 | parser.add_argument("--save_freq", type=int, default=20) 19 | parser.add_argument("--knn_monitor", action="store_true") 20 | parser.add_argument('--use_zip', action="store_true") 21 | parser.add_argument( 22 | "--save_folder", type=str, default="./training_checkpoints/imagenet/two_views" 23 | ) 24 | 25 | parser.add_argument("--objective", type=str, default="MMCR_Momentum") 26 | 27 | 28 | parser.add_argument("--n_nodes", type=int, default=4) 29 | parser.add_argument("--n_gpus", type=int, default=4) 30 | 31 | args = parser.parse_args() 32 | 33 | 34 | class SLURM_Trainer(object): 35 | def __init__(self, args): 36 | self.args = args 37 | 38 | def __call__(self): 39 | # set up distributed environment 40 | init_dist_node(self.args) 41 | train(None, self.args) 42 | 43 | 44 | # submitit job management 45 | executor = submitit.AutoExecutor(folder="./slurm/pretrain/%j", slurm_max_num_timeout=30) 46 | 47 | executor.update_parameters( 48 | mem_gb=128 * args.n_gpus, 49 | gpus_per_node=args.n_gpus, 50 | tasks_per_node=args.n_gpus, 51 | cpus_per_task=args.num_workers, 52 | nodes=args.n_nodes, 53 | name="MMCR", 54 | timeout_min=60 * 24 * 5, 55 | slurm_partition="gpu", 56 | constraint="a100-80gb", 57 | slurm_array_parallelism=512, 58 | ) 59 | 60 | trainer = SLURM_Trainer(args) 61 | job = executor.submit(trainer) 62 | -------------------------------------------------------------------------------- /mmcr/imagenet/loss_mmcr_momentum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import torch.nn.functional as F 4 | 5 | from typing import List, Tuple 6 | import einops 7 | 8 | 9 | class MMCR_Momentum_Loss(nn.Module): 10 | def __init__( 11 | self, 12 | lmbda: float, 13 | n_aug: int, 14 | distributed: bool = True, 15 | ): 16 | super(MMCR_Momentum_Loss, self).__init__() 17 | self.lmbda = lmbda 18 | self.n_aug = n_aug 19 | self.distributed = distributed 20 | 21 | def forward(self, z_list: List[Tensor]) -> Tuple[Tensor, dict]: 22 | z, z_m = z_list[0], z_list[1] 23 | z = F.normalize(z, dim=-1) 24 | z_m = F.normalize(z_m, dim=-1) 25 | 26 | z_local_ = einops.rearrange(z, "(B N) C -> B C N", N=self.n_aug) 27 | z_local_m = einops.rearrange(z_m, "(B N) C -> B C N", N=self.n_aug) 28 | 29 | # gather across devices into list 30 | if self.distributed: 31 | z_list = [ 32 | torch.zeros_like(z_local_) 33 | for i in range(torch.distributed.get_world_size()) 34 | ] 35 | torch.distributed.all_gather(z_list, z_local_, async_op=False) 36 | z_list[torch.distributed.get_rank()] = z_local_ 37 | 38 | # gather momentum outputs 39 | z_m_list = [ 40 | torch.zeros_like(z_local_m) 41 | for i in range(torch.distributed.get_world_size()) 42 | ] 43 | torch.distributed.all_gather(z_m_list, z_local_m, async_op=False) 44 | z_m_list[torch.distributed.get_rank()] = z_local_m 45 | 46 | # append all 47 | z_local = torch.cat(z_list) 48 | z_m_local = torch.cat(z_m_list) 49 | 50 | else: 51 | z_local = z_local_ 52 | z_m_local = z_local_m 53 | 54 | if self.lmbda == 0: 55 | local_nuc = 0 56 | else: 57 | local_nuc = torch.linalg.svdvals(z_local).sum() 58 | 59 | centroids = (torch.mean(z_local, dim=-1) + torch.mean(z_m_local, dim=-1)) * 0.5 60 | 61 | # filter infs and nans 62 | selected = centroids.isfinite().all(dim=1) 63 | centroids = centroids[selected] 64 | 65 | if selected.sum() != centroids.shape[0]: 66 | print("filtered nan") 67 | 68 | global_nuc = torch.linalg.svdvals(centroids).sum() 69 | 70 | batch_size = z_local.shape[0] 71 | loss = -1.0 * global_nuc + self.lmbda * local_nuc / batch_size 72 | 73 | loss_dict = { 74 | "loss": loss.item(), 75 | "local_nuc": local_nuc, 76 | "global_nuc": global_nuc.item(), 77 | } 78 | 79 | return loss, loss_dict 80 | -------------------------------------------------------------------------------- /mmcr/cifar_stl/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import einops 4 | 5 | from mmcr.cifar_stl.data import get_datasets 6 | from mmcr.cifar_stl.models import Model 7 | from mmcr.cifar_stl.knn import test_one_epoch 8 | from mmcr.cifar_stl.loss_mmcr import MMCR_Loss 9 | 10 | 11 | def train( 12 | dataset: str, 13 | n_aug: int, 14 | batch_size: int, 15 | lr: float, 16 | epochs: int, 17 | lmbda: float, 18 | save_folder: str, 19 | save_freq: int, 20 | ): 21 | train_dataset, memory_dataset, test_dataset = get_datasets( 22 | dataset=dataset, n_aug=n_aug 23 | ) 24 | model = Model(projector_dims=[512, 128], dataset=dataset) 25 | train_loader = torch.utils.data.DataLoader( 26 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=12 27 | ) 28 | memory_loader = torch.utils.data.DataLoader( 29 | memory_dataset, batch_size=128, shuffle=True, num_workers=12 30 | ) 31 | test_loader = torch.utils.data.DataLoader( 32 | test_dataset, batch_size=128, shuffle=False, num_workers=12 33 | ) 34 | 35 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6) 36 | loss_function = MMCR_Loss(lmbda=lmbda, n_aug=n_aug, distributed=False) 37 | 38 | model = model.cuda() 39 | top_acc = 0.0 40 | for epoch in range(epochs): 41 | model.train() 42 | total_loss, total_num, train_bar = 0.0, 0, tqdm(train_loader) 43 | for step, data_tuple in enumerate(train_bar): 44 | optimizer.zero_grad() 45 | 46 | # forward pass 47 | img_batch, labels = data_tuple 48 | img_batch = einops.rearrange(img_batch, "B N C H W -> (B N) C H W") 49 | features, out = model(img_batch.cuda(non_blocking=True)) 50 | loss, loss_dict = loss_function(out) 51 | 52 | # backward pass 53 | loss.backward() 54 | optimizer.step() 55 | 56 | # update the training bar 57 | total_num += data_tuple[0].size(0) 58 | total_loss += loss.item() * data_tuple[0].size(0) 59 | train_bar.set_description( 60 | "Train Epoch: [{}/{}] Loss: {:.4f}".format( 61 | epoch, epochs, total_loss / total_num 62 | ) 63 | ) 64 | 65 | if epoch % 1 == 0: 66 | acc_1, acc_5 = test_one_epoch( 67 | model, 68 | memory_loader, 69 | test_loader, 70 | ) 71 | if acc_1 > top_acc: 72 | top_acc = acc_1 73 | 74 | if epoch % save_freq == 0 or acc_1 == top_acc: 75 | torch.save( 76 | model.state_dict(), 77 | f"{save_folder}/{dataset}_{n_aug}_{epoch}_acc_{acc_1:0.2f}.pth", 78 | ) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Coding of Natural Images using Maximum Manifold Capacity Representations 2 | This is a pytorch implementation of the paper [Efficient Coding of Natural Images using Maximum Manifold Capacity Representations](https://openreview.net/pdf?id=og9V7NgOrQ). 3 | 4 | ## Environment 5 | To install dependencies create a conda environment from the provided `environment.yml` file, and install thei project package by running `pip install -e .` in the base directory. 6 | We utilized Pytorch 1.11 for all experiments and [Composer from MosaicML](https://docs.mosaicml.com/projects/composer/en/stable/index.html) for distributed pretraining on ImageNet datasets. 7 | 8 | ## Datasets 9 | We provide code for pretraining and linear evaluation on CIFAR-10/100, STL-10, and ImageNet-100/1k. 10 | The code expects all dataset files to be located in the `/datasets` directory. 11 | For ImageNet datasets we also provide an implementation for reading images from a ZIP archive rather than opening each image file individually. 12 | This reduces the I/O overhead of dataloading, but requires zipping the datasets before training which can take up to several hours for ImageNet-1k. 13 | The use of zipped dataloading can be toggled on/off via the parameter `use_zip` (see below). 14 | 15 | 16 | ## Pretraining 17 | The code is setup to run on a SLURM cluster and uses [submitit](https://github.com/facebookincubator/submitit) for job submission. 18 | 19 | ### ImageNet 20 | To pretrain on ImageNet with default settings run the command: 21 | ``` 22 | python3 pretrain_imagenet.py 23 | ``` 24 | By default training uses 4 nodes each with 4 A100 GPUs (though 8-view training requires 8 nodes). 25 | Hyperparameters can be adjusted in the command line, i.e. to run with 4 views rather than 2: 26 | ``` 27 | python3 pretrain_imagenet.py --n_aug 4 28 | ``` 29 | See `pretrain_imagenet.py` for details. 30 | 31 | 32 | ### CIFAR/STL 33 | To pretrain on either CIFAR or STL instead run 34 | ``` 35 | python3 pretrain_cifar_stl.py 36 | ``` 37 | Use command line arguments to specify the pretraining dataset and other hyperparameters (see `pretrain_cifar_stl.py` for details). 38 | Pretraining on these smaller datasets uses a single A100 GPU. 39 | 40 | ## Evaluation 41 | We run frozen linear evaluation for all datasets on a single A100 GPU. 42 | 43 | ### ImageNet 44 | To run frozen-linear evaluation on an ImageNet dataset run 45 | ``` 46 | python3 linear_classifier_imagenet.py --model_path /path/to/checkpoint_file 47 | ``` 48 | ```checkpoint_file``` should contain a checkpoint that is generated during an ImageNet pretraining run. 49 | Other hyperparameters can be adjusted via command line arguments similarly to above. 50 | 51 | 52 | ### CIFAR/STL 53 | For CIFAR/STL we run frozen linear evaluations on a large number of checkpoints saved during pretraining to perform model selection. 54 | To run model selection run the command: 55 | ``` 56 | python3 model_select_cifar_stl.py --checkpoint_dir /path/to/checkpoint_directory 57 | ``` 58 | where `checkpoint_directory` contains all checkpoints generated by running pretraining on CIFAR/STL as specified above. 59 | -------------------------------------------------------------------------------- /mmcr/cifar_stl/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | ### KNN based evaluation, for use during unsupervised pretraining to track progress ### 10 | # adapted from: https://github.com/yaohungt/Barlow-Twins-HSIC/blob/main/main.py 11 | def test_one_epoch( 12 | net: nn.Module, 13 | memory_data_loader: DataLoader, 14 | test_data_loader: DataLoader, 15 | temperature: float = 0.5, 16 | k: int = 200, 17 | ): 18 | net.eval() 19 | total_top1, total_top5, total_num, feature_bank, target_bank = 0.0, 0.0, 0, [], [] 20 | with torch.no_grad(): 21 | # generate feature bank and target bank 22 | for data_tuple in tqdm(memory_data_loader): 23 | data, target = data_tuple 24 | target_bank.append(target) 25 | features, out = net(data.cuda(non_blocking=True)) 26 | feature = F.normalize(features, dim=-1) 27 | feature_bank.append(feature) 28 | # [D, N] 29 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 30 | # [N] 31 | feature_labels = ( 32 | torch.cat(target_bank, dim=0).contiguous().to(feature_bank.device) 33 | ) 34 | # loop test data to predict the label by weighted knn search 35 | test_bar = tqdm(test_data_loader) 36 | for data_tuple in test_bar: 37 | data, target = data_tuple 38 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 39 | features, out = net(data) 40 | feature = F.normalize(features, dim=-1) 41 | 42 | total_num += data.size(0) 43 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 44 | sim_matrix = torch.mm(feature, feature_bank) 45 | # [B, K] 46 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) 47 | # [B, K] 48 | sim_labels = torch.gather( 49 | feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices 50 | ) 51 | sim_weight = (sim_weight / temperature).exp() 52 | 53 | # counts for each class 54 | one_hot_label = torch.zeros( 55 | data.size(0) * k, 1000, device=sim_labels.device 56 | ) 57 | # [B*K, C] 58 | one_hot_label = one_hot_label.scatter( 59 | dim=-1, index=sim_labels.view(-1, 1), value=1.0 60 | ) 61 | # weighted score ---> [B, C] 62 | pred_scores = torch.sum( 63 | one_hot_label.view(data.size(0), -1, 1000) 64 | * sim_weight.unsqueeze(dim=-1), 65 | dim=1, 66 | ) 67 | 68 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 69 | total_top1 += torch.sum( 70 | (pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float() 71 | ).item() 72 | total_top5 += torch.sum( 73 | (pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float() 74 | ).item() 75 | 76 | test_bar.set_description( 77 | "Test Epoch: Acc@1:{:.2f}% Acc@5:{:.2f}%".format( 78 | total_top1 / total_num * 100, total_top5 / total_num * 100 79 | ) 80 | ) 81 | 82 | if total_num == 0: 83 | total_num += 1 84 | net.train() 85 | 86 | if total_num == 0: 87 | total_num += 1 88 | return total_top1 / total_num * 100, total_top5 / total_num * 100 89 | -------------------------------------------------------------------------------- /mmcr/imagenet/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from composer import State, Callback, Logger 4 | 5 | 6 | ## momentum update via callback 7 | class MomentumUpdate(Callback): 8 | def __init__(self, tau=0.99) -> None: 9 | super(MomentumUpdate, self).__init__() 10 | self.tau = tau 11 | self.first_time = True 12 | 13 | def batch_end(self, state: State, logger: Logger) -> None: 14 | online_f, momentum_f = ( 15 | state.model.module.module.f, 16 | state.model.module.module.mom_f, 17 | ) 18 | online_g, momentum_g = ( 19 | state.model.module.module.g, 20 | state.model.module.module.mom_g, 21 | ) 22 | 23 | with torch.no_grad(): 24 | for op, mp in zip(online_f.parameters(), momentum_f.parameters()): 25 | mp.data = self.tau * mp.data + (1 - self.tau) * op.data 26 | 27 | for op, mp in zip(online_g.parameters(), momentum_g.parameters()): 28 | mp.data = self.tau * mp.data + (1 - self.tau) * op.data 29 | 30 | if self.first_time: 31 | print("Performing momentum update") 32 | self.first_time = False 33 | 34 | return None 35 | 36 | 37 | # modified from https://github.com/mosaicml/composer/blob/80d3293df833edfdb4249daee3f0ddcd25259fa2/composer/callbacks/lr_monitor.py#L11 to print to stdout 38 | class LogLR(Callback): 39 | def __init__(self): 40 | pass 41 | 42 | def epoch_start(self, state: State, logger: Logger): 43 | for optimizer in state.optimizers: 44 | lrs = [group["lr"] for group in optimizer.param_groups] 45 | name = optimizer.__class__.__name__ 46 | for lr in lrs: 47 | for idx, lr in enumerate(lrs): 48 | print({f"lr-{name}/group{idx}": lr}) 49 | 50 | 51 | # modified from: https://github.com/facebookresearch/barlowtwins/blob/main/main.py 52 | class LARS(optim.Optimizer): 53 | def __init__( 54 | self, 55 | params, 56 | lr, 57 | weight_decay=1e-6, 58 | momentum=0.9, 59 | eta=0.001, 60 | weight_decay_filter=True, 61 | lars_adaptation_filter=True, 62 | ): 63 | defaults = dict( 64 | lr=lr, 65 | weight_decay=weight_decay, 66 | momentum=momentum, 67 | eta=eta, 68 | weight_decay_filter=weight_decay_filter, 69 | lars_adaptation_filter=lars_adaptation_filter, 70 | ) 71 | super().__init__(params, defaults) 72 | 73 | def exclude_bias_and_norm(self, p): 74 | return p.ndim == 1 75 | 76 | @torch.no_grad() 77 | def step(self): 78 | for g in self.param_groups: 79 | for p in g["params"]: 80 | dp = p.grad 81 | 82 | if dp is None: 83 | continue 84 | 85 | if not g["weight_decay_filter"] or not self.exclude_bias_and_norm(p): 86 | dp = dp.add(p, alpha=g["weight_decay"]) 87 | 88 | if not g["lars_adaptation_filter"] or not self.exclude_bias_and_norm(p): 89 | param_norm = torch.norm(p) 90 | update_norm = torch.norm(dp) 91 | one = torch.ones_like(param_norm) 92 | q = torch.where( 93 | param_norm > 0.0, 94 | torch.where( 95 | update_norm > 0, (g["eta"] * param_norm / update_norm), one 96 | ), 97 | one, 98 | ) 99 | dp = dp.mul(q) 100 | 101 | param_state = self.state[p] 102 | if "mu" not in param_state: 103 | param_state["mu"] = torch.zeros_like(p) 104 | mu = param_state["mu"] 105 | mu.mul_(g["momentum"]).add_(dp) 106 | 107 | p.add_(mu, alpha=-g["lr"]) 108 | 109 | 110 | def collate_fn(batch): 111 | new_img = [torch.as_tensor(img) for img, _ in batch] 112 | new_target = [torch.as_tensor(lbl) for _, lbl in batch] 113 | new_img = torch.cat(new_img, dim=0) 114 | new_target = torch.stack(new_target, dim=0) 115 | return new_img, new_img 116 | 117 | 118 | def get_num_samples_in_batch(multicrop_batch): 119 | return multicrop_batch[0][0][0].shape[0] 120 | -------------------------------------------------------------------------------- /mmcr/imagenet/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | 6 | from tqdm import tqdm 7 | from composer import Callback, State, Logger 8 | 9 | 10 | ### KNN based evaluation, for use during unsupervised pretrining to track progress ### 11 | # adapted from: https://github.com/yaohungt/Barlow-Twins-HSIC/blob/main/main.py 12 | def test_one_epoch( 13 | net: nn.Module, 14 | memory_data_loader: DataLoader, 15 | test_data_loader: DataLoader, 16 | temperature: float = 0.5, 17 | k: int = 200, 18 | ): 19 | net.eval() 20 | total_top1, total_top5, total_num, feature_bank, target_bank = 0.0, 0.0, 0, [], [] 21 | with torch.no_grad(): 22 | # generate feature bank and target bank 23 | for data_tuple in tqdm(memory_data_loader): 24 | data, target = data_tuple 25 | target_bank.append(target) 26 | features, out = net(data.cuda(non_blocking=True)) 27 | feature = F.normalize(features, dim=-1) 28 | feature_bank.append(feature) 29 | # [D, N] 30 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 31 | # [N] 32 | feature_labels = ( 33 | torch.cat(target_bank, dim=0).contiguous().to(feature_bank.device) 34 | ) 35 | # loop test data to predict the label by weighted knn search 36 | test_bar = tqdm(test_data_loader) 37 | for data_tuple in test_bar: 38 | data, target = data_tuple 39 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 40 | features, out = net(data) 41 | feature = F.normalize(features, dim=-1) 42 | 43 | total_num += data.size(0) 44 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 45 | sim_matrix = torch.mm(feature, feature_bank) 46 | # [B, K] 47 | sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1) 48 | # [B, K] 49 | sim_labels = torch.gather( 50 | feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices 51 | ) 52 | sim_weight = (sim_weight / temperature).exp() 53 | 54 | # counts for each class 55 | one_hot_label = torch.zeros( 56 | data.size(0) * k, 1000, device=sim_labels.device 57 | ) 58 | # [B*K, C] 59 | one_hot_label = one_hot_label.scatter( 60 | dim=-1, index=sim_labels.view(-1, 1), value=1.0 61 | ) 62 | # weighted score ---> [B, C] 63 | pred_scores = torch.sum( 64 | one_hot_label.view(data.size(0), -1, 1000) 65 | * sim_weight.unsqueeze(dim=-1), 66 | dim=1, 67 | ) 68 | 69 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 70 | total_top1 += torch.sum( 71 | (pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float() 72 | ).item() 73 | total_top5 += torch.sum( 74 | (pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float() 75 | ).item() 76 | 77 | test_bar.set_description( 78 | "Test Epoch: Acc@1:{:.2f}% Acc@5:{:.2f}%".format( 79 | total_top1 / total_num * 100, total_top5 / total_num * 100 80 | ) 81 | ) 82 | 83 | if total_num == 0: 84 | total_num += 1 85 | net.train() 86 | 87 | if total_num == 0: 88 | total_num += 1 89 | return total_top1 / total_num * 100, total_top5 / total_num * 100 90 | 91 | 92 | ### COMPOSER EVALUATION VIA CALLBACK ### 93 | class KnnMonitor(Callback): 94 | def __init__(self, memory_loader: DataLoader, test_loader: DataLoader): 95 | super(KnnMonitor, self).__init__() 96 | self.memory_loader = memory_loader 97 | self.test_loader = test_loader 98 | self.count_knn_eval = 0 99 | self.distributed = False 100 | self.top_acc = 0.0 101 | self.epochs_to_classify = [] 102 | 103 | def epoch_end(self, state: State, logger: Logger): 104 | if self.count_knn_eval % 10 == 0: 105 | if self.distributed: 106 | net = nn.parallel.DistributedDataParallel(state.model.module.module) 107 | else: 108 | net = state.model.module 109 | 110 | top_1, top_5 = test_one_epoch( 111 | memory_data_loader=self.memory_loader, 112 | test_data_loader=self.test_loader, 113 | net=net, 114 | ) 115 | 116 | print(f"top_1={top_1}") 117 | print(f"top_5={top_5}") 118 | 119 | if top_1 > self.top_acc: 120 | self.top_acc = top_1 121 | 122 | self.count_knn_eval += 1 123 | -------------------------------------------------------------------------------- /mmcr/cifar_stl/train_linear_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from torchvision.models.resnet import resnet50 4 | from torchvision import transforms 5 | from torch.utils.data import DataLoader 6 | from PIL import Image 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | from typing import OrderedDict 11 | 12 | from mmcr.cifar_stl.data import get_datasets 13 | from mmcr.cifar_stl.models import Model 14 | 15 | 16 | # train or test linear classifier for one epoch 17 | def train_val(net, data_loader, train_optimizer, epoch): 18 | is_train = train_optimizer is not None 19 | net.train() if is_train else net.eval() 20 | 21 | total_loss, total_correct_1, total_correct_5, total_num, data_bar = ( 22 | 0.0, 23 | 0.0, 24 | 0.0, 25 | 0, 26 | tqdm(data_loader), 27 | ) 28 | with torch.enable_grad() if is_train else torch.no_grad(): 29 | loss_criterion = nn.CrossEntropyLoss() 30 | for data, target in data_bar: 31 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 32 | out = net(data) 33 | loss = loss_criterion(out, target) 34 | 35 | if is_train: 36 | train_optimizer.zero_grad() 37 | loss.backward() 38 | train_optimizer.step() 39 | 40 | total_num += data.size(0) 41 | total_loss += loss.item() * data.size(0) 42 | prediction = torch.argsort(out, dim=-1, descending=True) 43 | total_correct_1 += torch.sum( 44 | (prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float() 45 | ).item() 46 | total_correct_5 += torch.sum( 47 | (prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float() 48 | ).item() 49 | 50 | data_bar.set_description( 51 | "{} Epoch: [{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%".format( 52 | "Train" if is_train else "Test", 53 | epoch, 54 | total_loss / total_num, 55 | total_correct_1 / total_num * 100, 56 | total_correct_5 / total_num * 100, 57 | ) 58 | ) 59 | 60 | return ( 61 | total_loss / total_num, 62 | total_correct_1 / total_num * 100, 63 | total_correct_5 / total_num * 100, 64 | ) 65 | 66 | 67 | def train_classifier( 68 | model_path: str, 69 | dataset: str = "cifar10", 70 | batch_size: int = 512, 71 | epochs: int = 50, 72 | lr: float = 1e-2, 73 | save_path=None, 74 | save_name=None, 75 | ): 76 | top_acc = 0.0 77 | train_data, _, test_data = get_datasets( 78 | dataset, 1, "./datasets", batch_transform=False, supervised=True 79 | ) 80 | 81 | train_loader = DataLoader( 82 | train_data, batch_size=batch_size, shuffle=True, num_workers=13, pin_memory=True 83 | ) 84 | test_loader = DataLoader( 85 | test_data, batch_size=batch_size, shuffle=False, num_workers=13, pin_memory=True 86 | ) 87 | 88 | # load pretrained weights 89 | pretrained_model = Model(dataset=dataset) 90 | sd = torch.load(model_path, map_location="cpu") 91 | pretrained_model.load_state_dict(sd) 92 | dataset_num_classes = {"cifar10": 10, "stl10": 10, "cifar100": 100} 93 | model = Net(pretrained_model.f, dataset_num_classes[dataset]) 94 | 95 | # only fully connected requires grad 96 | model.requires_grad_(False) 97 | model.fc.requires_grad_(True) 98 | model = model.cuda() 99 | 100 | optimizer = optim.Adam(model.fc.parameters(), lr=lr, weight_decay=1e-6) 101 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) 102 | 103 | if save_path is not None and save_name is None: 104 | save_name = str(np.random.rand() * 1e5) 105 | print(save_name) 106 | for epoch in range(1, epochs + 1): 107 | # train one epoch 108 | train_loss, train_acc_1, train_acc_5 = train_val( 109 | model, train_loader, optimizer, epoch 110 | ) 111 | # test one epoch 112 | test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None, epoch) 113 | scheduler.step() 114 | 115 | if test_acc_1 > top_acc: 116 | top_acc = test_acc_1 117 | 118 | if test_acc_1 == top_acc and save_path is not None and save_name is not None: 119 | save_str = save_path + save_name + ".pt" 120 | torch.save(model.state_dict(), save_str) 121 | 122 | return model, top_acc 123 | 124 | 125 | # a wrapper class for the resnet50 model 126 | class Net(nn.Module): 127 | def __init__(self, f, num_classes): 128 | super().__init__() 129 | self.f = f 130 | self.fc = nn.Linear(2048, num_classes) 131 | 132 | def forward(self, x): 133 | f = self.f(x) 134 | f = f.view(f.size(0), -1) 135 | return self.fc(f) 136 | -------------------------------------------------------------------------------- /mmcr/imagenet/train.py: -------------------------------------------------------------------------------- 1 | from mmcr.imagenet.misc import ( 2 | LARS, 3 | MomentumUpdate, 4 | LogLR, 5 | collate_fn, 6 | get_num_samples_in_batch, 7 | ) 8 | from mmcr.imagenet.loss_mmcr_momentum import MMCR_Momentum_Loss 9 | from mmcr.imagenet.loss_mmcr import MMCR_Loss 10 | from mmcr.imagenet.data import get_datasets 11 | from mmcr.imagenet.knn import KnnMonitor 12 | from mmcr.imagenet.models import ( 13 | MomentumModel, 14 | MomentumComposerWrapper, 15 | Model, 16 | ComposerWrapper, 17 | ) 18 | 19 | import torch 20 | import composer 21 | from composer.optim.scheduler import CosineAnnealingWithWarmupScheduler 22 | import submitit 23 | 24 | import os 25 | 26 | 27 | def train(gpu, args, **kwargs): 28 | # composer doesn't require init_dist_gpu() function call 29 | job_env = submitit.JobEnvironment() 30 | args.gpu = job_env.local_rank 31 | args.rank = job_env.global_rank 32 | 33 | # better port 34 | tmp_port = os.environ["SLURM_JOB_ID"] 35 | tmp_port = int(tmp_port[-4:]) + 50000 36 | args.port = tmp_port 37 | 38 | os.environ["RANK"] = str(job_env.global_rank) 39 | os.environ["WORLD_SIZE"] = str(args.n_gpus * args.n_nodes) 40 | os.environ["LOCAL_RANK"] = str(job_env.local_rank) 41 | os.environ["LOCAL_WORLD_SIZE"] = str(args.n_gpus) 42 | os.environ["NODE_RANK"] = str(int(os.getenv("SLURM_NODEID"))) 43 | os.environ["MASTER_ADDR"] = args.host_name_ 44 | os.environ["MASTER_PORT"] = str(args.port) 45 | os.environ["PYTHONUNBUFFERED"] = "1" 46 | 47 | args.torch_cuda_device_count = torch.cuda.device_count() 48 | args.slurm_nodeid = int(os.getenv("SLURM_NODEID")) 49 | args.slurm_nnodes = int(os.getenv("SLURM_NNODES")) 50 | 51 | print(args) 52 | 53 | # datasets 54 | train_data, memory_data, test_data = get_datasets( 55 | n_aug=args.n_aug, dataset=args.dataset, use_zip=args.use_zip 56 | ) 57 | 58 | # samplers 59 | train_sampler = torch.utils.data.DistributedSampler( 60 | train_data, 61 | num_replicas=args.world_size, 62 | rank=args.rank, 63 | ) 64 | memory_sampler = torch.utils.data.DistributedSampler( 65 | memory_data, 66 | num_replicas=args.world_size, 67 | rank=args.rank, 68 | ) 69 | test_sampler = torch.utils.data.DistributedSampler( 70 | test_data, 71 | num_replicas=args.world_size, 72 | rank=args.rank, 73 | ) 74 | 75 | # dataloaders 76 | batch_size = int(args.batch_size / args.n_gpus / args.n_nodes) 77 | train_loader = torch.utils.data.DataLoader( 78 | dataset=train_data, 79 | batch_size=batch_size, 80 | num_workers=args.num_workers, 81 | pin_memory=True, 82 | drop_last=True, 83 | collate_fn=collate_fn, 84 | sampler=train_sampler, 85 | ) 86 | 87 | memory_loader = torch.utils.data.DataLoader( 88 | dataset=memory_data, 89 | batch_size=512, 90 | num_workers=args.num_workers, 91 | pin_memory=True, 92 | drop_last=True, 93 | sampler=memory_sampler, 94 | ) 95 | 96 | test_loader = torch.utils.data.DataLoader( 97 | dataset=test_data, 98 | batch_size=128, 99 | num_workers=args.num_workers, 100 | pin_memory=True, 101 | drop_last=True, 102 | sampler=test_sampler, 103 | ) 104 | 105 | # objective/model 106 | args.distributed = args.n_gpus * args.n_nodes > 1 107 | projector_dims = [8192, 8192, 512] 108 | if args.objective == "MMCR_Momentum": 109 | objective = MMCR_Momentum_Loss(args.lmbda, args.n_aug, args.distributed) 110 | objective = torch.nn.SyncBatchNorm.convert_sync_batchnorm(objective) 111 | model = MomentumModel(projector_dims=projector_dims) 112 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 113 | wrapped_model = MomentumComposerWrapper(module=model, objective=objective) 114 | elif args.objective == "MMCR": 115 | objective = MMCR_Loss(args.lmbda, args.n_aug, args.distributed) 116 | objective = torch.nn.SyncBatchNorm.convert_sync_batchnorm(objective) 117 | model = Model(projector_dims=projector_dims) 118 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 119 | wrapped_model = ComposerWrapper(module=model, objective=objective) 120 | 121 | # optimizer 122 | lr = args.lr * args.batch_size / 256 123 | optimizer = LARS( 124 | model.parameters(), 125 | lr=lr, 126 | weight_decay=1e-6, 127 | momentum=0.9, 128 | weight_decay_filter=True, 129 | lars_adaptation_filter=True, 130 | ) 131 | 132 | # scheduler 133 | scheduler = CosineAnnealingWithWarmupScheduler(t_warmup="10ep", alpha_f=0.001) 134 | 135 | # callbacks 136 | callback_list = [LogLR()] 137 | if args.objective == "MMCR_Momentum": 138 | callback_list.append(MomentumUpdate(tau=args.tau)) 139 | if args.knn_monitor: 140 | callback_list.append(KnnMonitor(memory_loader, test_loader)) 141 | 142 | # dspec 143 | train_dspec = composer.DataSpec( 144 | train_loader, get_num_samples_in_batch=get_num_samples_in_batch 145 | ) 146 | 147 | print(model) 148 | 149 | # trainer 150 | trainer = composer.Trainer( 151 | train_dataloader=train_dspec, 152 | optimizers=optimizer, 153 | model=wrapped_model, 154 | max_duration=args.epochs, 155 | precision="amp", 156 | algorithms=[ 157 | composer.algorithms.ChannelsLast(), 158 | ], 159 | device="gpu", 160 | callbacks=callback_list, 161 | schedulers=(scheduler), 162 | save_interval=args.save_freq, 163 | save_overwrite=True, 164 | save_folder=args.save_folder, 165 | ) 166 | 167 | trainer.fit() 168 | -------------------------------------------------------------------------------- /mmcr/imagenet/train_linear_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from torchvision.models.resnet import resnet50 4 | from torchvision import transforms 5 | from torch.utils.data import DataLoader 6 | from PIL import Image 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | from typing import OrderedDict 11 | 12 | from mmcr.imagenet.data import ZipImageNet, ImageNetValTransform 13 | 14 | 15 | # train or test linear classifier for one epoch 16 | def train_val(net, data_loader, train_optimizer, epoch): 17 | is_train = train_optimizer is not None 18 | net.train() if is_train else net.eval() 19 | 20 | total_loss, total_correct_1, total_correct_5, total_num, data_bar = ( 21 | 0.0, 22 | 0.0, 23 | 0.0, 24 | 0, 25 | tqdm(data_loader), 26 | ) 27 | with torch.enable_grad() if is_train else torch.no_grad(): 28 | loss_criterion = nn.CrossEntropyLoss() 29 | for data, target in data_bar: 30 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 31 | out = net(data) 32 | loss = loss_criterion(out, target) 33 | 34 | if is_train: 35 | train_optimizer.zero_grad() 36 | loss.backward() 37 | train_optimizer.step() 38 | 39 | total_num += data.size(0) 40 | total_loss += loss.item() * data.size(0) 41 | prediction = torch.argsort(out, dim=-1, descending=True) 42 | total_correct_1 += torch.sum( 43 | (prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float() 44 | ).item() 45 | total_correct_5 += torch.sum( 46 | (prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float() 47 | ).item() 48 | 49 | data_bar.set_description( 50 | "{} Epoch: [{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%".format( 51 | "Train" if is_train else "Test", 52 | epoch, 53 | total_loss / total_num, 54 | total_correct_1 / total_num * 100, 55 | total_correct_5 / total_num * 100, 56 | ) 57 | ) 58 | 59 | return ( 60 | total_loss / total_num, 61 | total_correct_1 / total_num * 100, 62 | total_correct_5 / total_num * 100, 63 | ) 64 | 65 | 66 | def train_classifier( 67 | model_path: str, 68 | batch_size: int = 512, 69 | epochs: int = 50, 70 | lr: float = 1e-2, 71 | use_zip: bool = False, 72 | save_path=None, 73 | save_name=None, 74 | ): 75 | top_acc = 0.0 76 | train_transform = transforms.Compose( 77 | [ 78 | transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), 79 | transforms.RandomHorizontalFlip(p=0.5), 80 | transforms.ToTensor(), 81 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 82 | ] 83 | ) 84 | test_transform = transforms.Compose( 85 | [ 86 | transforms.Resize(256), 87 | transforms.CenterCrop(224), 88 | transforms.ToTensor(), 89 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 90 | ] 91 | ) 92 | if use_zip: 93 | train_data = ZipImageNet( 94 | zip_path="./datasets/ILSVRC_2012.zip", 95 | root="./datasets/ILSVRC_2012", 96 | split="train", 97 | transform=train_transform, 98 | ) 99 | test_data = ZipImageNet( 100 | zip_path="./datasets/ILSVRC_2012.zip", 101 | root="./datasets/ILSVRC_2012", 102 | split="val", 103 | transform=test_transform, 104 | ) 105 | else: 106 | train_data = torchvision.datasets.ImageNet( 107 | root="./datasets/ILSVRC_2012", 108 | split="train", 109 | transform=train_transform, 110 | ) 111 | test_data = torchvision.datasets.ImageNet( 112 | root="./datasets/ILSVRC_2012", 113 | split="val", 114 | transform=test_transform, 115 | ) 116 | 117 | train_loader = DataLoader( 118 | train_data, batch_size=batch_size, shuffle=True, num_workers=13, pin_memory=True 119 | ) 120 | test_loader = DataLoader( 121 | test_data, batch_size=batch_size, shuffle=False, num_workers=13, pin_memory=True 122 | ) 123 | 124 | # load pretrained weights (fully connected layer excluded) 125 | model = resnet50() 126 | 127 | sd = torch.load(model_path, map_location="cpu")["state"]["model"] 128 | new_sd = OrderedDict() 129 | for k, v in sd.items(): 130 | # skip projector, momentum networks, and fully connected 131 | if "g." in k or "mom_" in k or "fc" in k: 132 | continue 133 | parts = k.split(".") 134 | idx = parts.index("f") 135 | new_k = ".".join(parts[idx + 1 :]) 136 | new_sd[new_k] = v 137 | model.load_state_dict(new_sd, strict=False) 138 | 139 | # only fully connected requires grad 140 | model.requires_grad_(False) 141 | model.fc.requires_grad_(True) 142 | model = model.cuda() 143 | 144 | optimizer = optim.SGD(model.fc.parameters(), lr=lr, momentum=0.9, weight_decay=1e-6) 145 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) 146 | 147 | if save_path is not None and save_name is None: 148 | save_name = str(np.random.rand() * 1e5) 149 | for epoch in range(1, epochs + 1): 150 | # train one epoch 151 | train_loss, train_acc_1, train_acc_5 = train_val( 152 | model, train_loader, optimizer, epoch 153 | ) 154 | # test one epoch 155 | test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None, epoch) 156 | scheduler.step() 157 | 158 | if test_acc_1 > top_acc and save_path is not None and save_name is not None: 159 | top_acc = test_acc_1 160 | save_str = save_path + save_name + ".pt" 161 | torch.save(model.state_dict(), save_str) 162 | 163 | return None 164 | -------------------------------------------------------------------------------- /mmcr/imagenet/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from torchvision.models.resnet import resnet50 6 | from torch import Tensor 7 | from typing import Tuple 8 | import einops 9 | 10 | from composer.models import ComposerModel 11 | from typing import Any, Tuple 12 | 13 | import sys 14 | 15 | sys.path.append("..") 16 | 17 | 18 | class MomentumModel(nn.Module): 19 | def __init__( 20 | self, 21 | projector_dims: list = [8192, 8192, 512], 22 | bias_last=False, 23 | bias_proj=False, 24 | ): 25 | super(MomentumModel, self).__init__() 26 | # insures output of encoder for all datasets is 2048-dimensional 27 | self.f = resnet50(zero_init_residual=True) 28 | self.f.fc = nn.Identity() 29 | 30 | # projection head (Following exactly barlow twins offical repo) 31 | projector_dims = [2048] + projector_dims 32 | layers = [] 33 | for i in range(len(projector_dims) - 2): 34 | layers.append( 35 | nn.Linear(projector_dims[i], projector_dims[i + 1], bias=bias_proj) 36 | ) 37 | layers.append(nn.BatchNorm1d(projector_dims[i + 1])) 38 | layers.append(nn.ReLU()) 39 | layers.append(nn.Linear(projector_dims[-2], projector_dims[-1], bias=bias_last)) 40 | self.g = nn.Sequential(*layers) 41 | 42 | # initialize momentum background and projector 43 | self.mom_f = resnet50(zero_init_residual=True) 44 | self.mom_f.fc = nn.Identity() 45 | 46 | # projection head (Following exactly barlow twins offical repo) 47 | layers = [] 48 | for i in range(len(projector_dims) - 2): 49 | layers.append( 50 | nn.Linear(projector_dims[i], projector_dims[i + 1], bias=bias_proj) 51 | ) 52 | layers.append(nn.BatchNorm1d(projector_dims[i + 1])) 53 | layers.append(nn.ReLU()) 54 | layers.append(nn.Linear(projector_dims[-2], projector_dims[-1], bias=bias_last)) 55 | self.mom_g = nn.Sequential(*layers) 56 | 57 | params_f_online, params_f_mom = self.f.parameters(), self.mom_f.parameters() 58 | params_g_online, params_g_mom = self.g.parameters(), self.mom_g.parameters() 59 | 60 | for po, pm in zip(params_f_online, params_f_mom): 61 | pm.data.copy_(po.data) 62 | pm.requires_grad = False 63 | 64 | for po, pm in zip(params_g_online, params_g_mom): 65 | pm.data.copy_(po.data) 66 | pm.requires_grad = False 67 | 68 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: 69 | x_ = self.f(x) 70 | feature = torch.flatten(x_, start_dim=1) 71 | out = self.g(feature) 72 | 73 | x_momentum = self.mom_f(x) 74 | feature_momentum = torch.flatten(x_momentum, start_dim=1) 75 | out_momentum = self.mom_g(feature_momentum) 76 | 77 | return feature, out, feature_momentum, out_momentum 78 | 79 | 80 | class MomentumComposerWrapper(ComposerModel): 81 | def __init__(self, module: torch.nn.Module, objective): 82 | super().__init__() 83 | 84 | self.module = module 85 | self.objective = objective 86 | self.c = 0 # counts the number of forward calls 87 | 88 | def loss(self, outputs: Any, batch: Any, *args, **kwargs) -> Tensor: 89 | loss, loss_dict = self.objective(outputs) 90 | self.loss_dict = loss_dict 91 | self.c += 1 92 | return loss 93 | 94 | def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor: 95 | if isinstance(batch, Tensor): 96 | inputs = batch 97 | else: 98 | inputs, _ = batch 99 | 100 | features, outputs, features_momentum, outputs_momentum = self.module(inputs) 101 | if isinstance(batch, Tensor): 102 | return features, outputs 103 | else: 104 | return [outputs, outputs_momentum] 105 | 106 | def get_backbone(self): 107 | return self.module 108 | 109 | 110 | class Model(nn.Module): 111 | def __init__( 112 | self, 113 | projector_dims: list = [8192, 8192, 512], 114 | bias_last=False, 115 | bias_proj=False, 116 | ): 117 | super(Model, self).__init__() 118 | 119 | # insures output of encoder for all datasets is 2048-dimensional 120 | self.f = resnet50(zero_init_residual=True) 121 | self.f.fc = nn.Identity() 122 | 123 | # projection head (Following exactly barlow twins offical repo) 124 | projector_dims = [2048] + projector_dims 125 | layers = [] 126 | for i in range(len(projector_dims) - 2): 127 | layers.append( 128 | nn.Linear(projector_dims[i], projector_dims[i + 1], bias=bias_proj) 129 | ) 130 | layers.append(nn.BatchNorm1d(projector_dims[i + 1])) 131 | layers.append(nn.ReLU()) 132 | layers.append(nn.Linear(projector_dims[-2], projector_dims[-1], bias=bias_last)) 133 | self.g = nn.Sequential(*layers) 134 | 135 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: 136 | x_ = self.f(x) 137 | feature = torch.flatten(x_, start_dim=1) 138 | out = self.g(feature) 139 | 140 | return feature, out 141 | 142 | 143 | class ComposerWrapper(ComposerModel): 144 | def __init__(self, module: torch.nn.Module, objective): 145 | super().__init__() 146 | 147 | self.module = module 148 | self.objective = objective 149 | self.c = 0 150 | 151 | def loss(self, outputs: Any, batch: Any, *args, **kwargs) -> Tensor: 152 | loss, loss_dict = self.objective(outputs) 153 | self.loss_dict = loss_dict 154 | self.c += 1 155 | return loss 156 | 157 | def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor: 158 | if isinstance(batch, Tensor): 159 | inputs = batch 160 | else: 161 | inputs, _ = batch 162 | 163 | features, outputs = self.module(inputs) 164 | if isinstance(batch, Tensor): 165 | return features, outputs 166 | else: 167 | return outputs 168 | 169 | def get_backbone(self): 170 | return self.module 171 | -------------------------------------------------------------------------------- /mmcr/cifar_stl/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | 5 | sys.path.append("..") 6 | from PIL import Image 7 | from torchvision import transforms 8 | import torchvision 9 | import torchvision.transforms.functional as TF 10 | from torchvision.datasets import CIFAR10 11 | import torch 12 | 13 | import random 14 | from PIL import Image, ImageOps, ImageFilter 15 | 16 | 17 | def get_datasets(dataset, n_aug, batch_transform=True, supervised=False, **kwargs): 18 | data_dir = "./datasets/" 19 | if dataset == "stl10": 20 | train_split = "train" if supervised else "train+unlabeled" 21 | train_data = torchvision.datasets.STL10( 22 | root=data_dir, 23 | split=train_split, 24 | transform=StlBatchTransform( 25 | train_transform=True, n_transform=n_aug, batch_transform=batch_transform 26 | ), 27 | download=False, 28 | ) 29 | memory_data = torchvision.datasets.STL10( 30 | root=data_dir, 31 | split="train", 32 | transform=StlBatchTransform( 33 | train_transform=False, batch_transform=False, n_transform=n_aug 34 | ), 35 | download=False, 36 | ) 37 | test_data = torchvision.datasets.STL10( 38 | root=data_dir, 39 | split="test", 40 | transform=StlBatchTransform( 41 | train_transform=False, batch_transform=False, n_transform=n_aug 42 | ), 43 | download=False, 44 | ) 45 | elif dataset == "cifar10": 46 | train_data = torchvision.datasets.CIFAR10( 47 | root=data_dir, 48 | train=True, 49 | transform=CifarBatchTransform( 50 | train_transform=True, 51 | batch_transform=batch_transform, 52 | n_transform=n_aug, 53 | **kwargs, 54 | ), 55 | download=False, 56 | ) 57 | memory_data = torchvision.datasets.CIFAR10( 58 | root=data_dir, 59 | train=True, 60 | transform=CifarBatchTransform( 61 | train_transform=False, 62 | batch_transform=False, 63 | n_transform=n_aug, 64 | **kwargs, 65 | ), 66 | download=False, 67 | ) 68 | test_data = torchvision.datasets.CIFAR10( 69 | root=data_dir, 70 | train=False, 71 | transform=CifarBatchTransform( 72 | train_transform=False, 73 | batch_transform=False, 74 | n_transform=n_aug, 75 | **kwargs, 76 | ), 77 | download=False, 78 | ) 79 | elif dataset == "cifar100": 80 | train_data = torchvision.datasets.CIFAR100( 81 | root=data_dir, 82 | train=True, 83 | transform=CifarBatchTransform( 84 | train_transform=True, 85 | batch_transform=batch_transform, 86 | n_transform=n_aug, 87 | **kwargs, 88 | ), 89 | download=False, 90 | ) 91 | memory_data = torchvision.datasets.CIFAR100( 92 | root=data_dir, 93 | train=True, 94 | transform=CifarBatchTransform( 95 | train_transform=False, 96 | batch_transform=False, 97 | n_transform=n_aug, 98 | **kwargs, 99 | ), 100 | download=False, 101 | ) 102 | test_data = torchvision.datasets.CIFAR100( 103 | root=data_dir, 104 | train=False, 105 | transform=CifarBatchTransform( 106 | train_transform=False, 107 | batch_transform=False, 108 | n_transform=n_aug, 109 | **kwargs, 110 | ), 111 | download=False, 112 | ) 113 | 114 | return train_data, memory_data, test_data 115 | 116 | 117 | class StlBatchTransform: 118 | def __init__(self, n_transform, train_transform=True, batch_transform=True): 119 | if train_transform is True: 120 | self.transform = transforms.Compose( 121 | [ 122 | transforms.RandomApply( 123 | [ 124 | transforms.ColorJitter( 125 | brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 126 | ) 127 | ], 128 | p=0.8, 129 | ), 130 | transforms.RandomGrayscale(p=0.1), 131 | transforms.RandomResizedCrop( 132 | 64, 133 | scale=(0.2, 1.0), 134 | ratio=(0.75, (4 / 3)), 135 | interpolation=Image.BICUBIC, 136 | ), 137 | transforms.RandomHorizontalFlip(p=0.5), 138 | transforms.ToTensor(), 139 | transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)), 140 | ] 141 | ) 142 | else: 143 | self.transform = transforms.Compose( 144 | [ 145 | transforms.Resize(70, interpolation=Image.BICUBIC), 146 | transforms.CenterCrop(64), 147 | transforms.ToTensor(), 148 | transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)), 149 | ] 150 | ) 151 | self.n_transform = n_transform 152 | self.batch_transform = batch_transform 153 | 154 | def __call__(self, x): 155 | if self.batch_transform: 156 | C, H, W = TF.to_tensor(x).shape 157 | C_aug, H_aug, W_aug = self.transform(x).shape 158 | 159 | y = torch.zeros(self.n_transform, C_aug, H_aug, W_aug) 160 | for i in range(self.n_transform): 161 | y[i, :, :, :] = self.transform(x) 162 | return y 163 | else: 164 | return self.transform(x) 165 | 166 | 167 | class CifarBatchTransform: 168 | def __init__( 169 | self, 170 | n_transform, 171 | train_transform=True, 172 | batch_transform=True, 173 | **kwargs, 174 | ): 175 | if train_transform is True: 176 | lst_of_transform = [ 177 | transforms.RandomResizedCrop(32), 178 | transforms.RandomHorizontalFlip(p=0.5), 179 | transforms.RandomApply( 180 | [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8 181 | ), 182 | transforms.RandomGrayscale(p=0.2), 183 | transforms.ToTensor(), 184 | transforms.Normalize( 185 | [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010] 186 | ), 187 | ] 188 | 189 | self.transform = transforms.Compose(lst_of_transform) 190 | else: 191 | self.transform = transforms.Compose( 192 | [ 193 | transforms.ToTensor(), 194 | transforms.Normalize( 195 | [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010] 196 | ), 197 | ] 198 | ) 199 | self.n_transform = n_transform 200 | self.batch_transform = batch_transform 201 | 202 | def __call__(self, x): 203 | if self.batch_transform: 204 | C, H, W = TF.to_tensor(x).shape 205 | C_aug, H_aug, W_aug = self.transform(x).shape 206 | 207 | y = torch.zeros(self.n_transform, C_aug, H_aug, W_aug) 208 | for i in range(self.n_transform): 209 | y[i, :, :, :] = self.transform(x) 210 | return y 211 | else: 212 | return self.transform(x) 213 | -------------------------------------------------------------------------------- /mmcr/imagenet/data.py: -------------------------------------------------------------------------------- 1 | from zipfile import ZipFile 2 | import random 3 | import torch 4 | 5 | import torchvision 6 | from torchvision import transforms 7 | from PIL import Image, ImageOps, ImageFilter 8 | 9 | 10 | class ZipImageNet(torchvision.datasets.ImageNet): 11 | """ 12 | Loads imagenet files from a zip archive. 13 | """ 14 | 15 | def __init__(self, zip_path: str, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.zip_path = zip_path 18 | self.zip_archive = None 19 | 20 | def __getitem__(self, index: int): 21 | """ 22 | Args: 23 | index (int): Index 24 | Returns: 25 | tuple: (sample, target) where target is class_index of the target class. 26 | """ 27 | path, target = self.samples[index] 28 | parts = path.split("/") 29 | idx = parts.index("ILSVRC_2012") 30 | _path = "/".join(parts[idx:]) 31 | if self.zip_archive is None: 32 | self.zip_archive = ZipFile(self.zip_path) 33 | fh = self.zip_archive.open(_path) 34 | image = Image.open(fh) 35 | sample = image.convert("RGB") 36 | if self.transform is not None: 37 | sample = self.transform(sample) 38 | if self.target_transform is not None: 39 | target = self.target_transform(target) 40 | 41 | return sample, target 42 | 43 | 44 | class Zip_ImageFolder(torchvision.datasets.ImageFolder): 45 | def __init__(self, zip_path, *args, **kwargs): 46 | super().__init__(*args, **kwargs) 47 | self.zip_path = zip_path 48 | self.zip_archvive = None 49 | 50 | def __getitem__(self, index: int): 51 | path, target = self.samples[index] 52 | if self.zip_archvive is None: 53 | self.zip_archvive = ZipFile(self.zip_path) 54 | 55 | path_split = path.split("/") 56 | fh = self.zip_archvive.open( 57 | path_split[-3] + "/" + path_split[-2] + "/" + path_split[-1] 58 | ) 59 | 60 | image = Image.open(fh) 61 | sample = image.convert("RGB") 62 | 63 | if self.transform is not None: 64 | sample = self.transform(sample) 65 | 66 | if self.target_transform is not None: 67 | sample = self.target_transform(sample) 68 | 69 | return sample, target 70 | 71 | 72 | # augmentation pipeline adapted from: https://github.com/facebookresearch/barlowtwins/blob/main/main.py 73 | class ImageNetValTransform: 74 | def __init__(self): 75 | self.transform = transforms.Compose( 76 | [ 77 | transforms.Resize(256), 78 | transforms.CenterCrop(224), 79 | transforms.ToTensor(), 80 | transforms.Normalize( 81 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 82 | ), 83 | ] 84 | ) 85 | 86 | def __call__(self, x): 87 | return self.transform(x) 88 | 89 | 90 | class Solarization(object): 91 | def __init__(self, p): 92 | self.p = p 93 | 94 | def __call__(self, img): 95 | if random.random() < self.p: 96 | return ImageOps.solarize(img) 97 | else: 98 | return img 99 | 100 | 101 | class GaussianBlur(object): 102 | def __init__(self, p): 103 | self.p = p 104 | 105 | def __call__(self, img): 106 | if random.random() < self.p: 107 | sigma = random.random() * 1.9 + 0.1 108 | return img.filter(ImageFilter.GaussianBlur(sigma)) 109 | else: 110 | return img 111 | 112 | 113 | class Barlow_Transform: 114 | def __init__(self, n_transform): 115 | self.n_aug = n_transform 116 | self.transform = transforms.Compose( 117 | [ 118 | transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), 119 | transforms.RandomHorizontalFlip(p=0.5), 120 | transforms.RandomApply( 121 | [ 122 | transforms.ColorJitter( 123 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 124 | ) 125 | ], 126 | p=0.8, 127 | ), 128 | transforms.RandomGrayscale(p=0.2), 129 | GaussianBlur(p=1.0), 130 | Solarization(p=0.0), 131 | transforms.ToTensor(), 132 | transforms.Normalize( 133 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 134 | ), 135 | ] 136 | ) 137 | self.transform_prime = transforms.Compose( 138 | [ 139 | transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC), 140 | transforms.RandomHorizontalFlip(p=0.5), 141 | transforms.RandomApply( 142 | [ 143 | transforms.ColorJitter( 144 | brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 145 | ) 146 | ], 147 | p=0.8, 148 | ), 149 | transforms.RandomGrayscale(p=0.2), 150 | GaussianBlur(p=0.1), 151 | Solarization(p=0.2), 152 | transforms.ToTensor(), 153 | transforms.Normalize( 154 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 155 | ), 156 | ] 157 | ) 158 | 159 | def __call__(self, x): 160 | y1 = self.transform(x) 161 | y = torch.zeros(self.n_aug, y1.shape[0], y1.shape[1], y1.shape[2]) 162 | 163 | for i in range(self.n_aug // 2): 164 | y1 = self.transform(x) 165 | y2 = self.transform_prime(x) 166 | y[2 * i, :, :, :] = y1 167 | y[2 * i + 1, :, :, :] = y2 168 | return y 169 | 170 | 171 | def get_datasets(n_aug, dataset="imagenet", use_zip=True, **kwargs): 172 | if dataset == "imagenet": 173 | imagenet_path = "./datasets/ILSVRC_2012" 174 | zip_path = "./datasets/ILSVRC_2012.zip" 175 | if use_zip: 176 | train_data = ZipImageNet( 177 | zip_path=zip_path, 178 | root=imagenet_path, 179 | split="train", 180 | transform=Barlow_Transform(n_transform=n_aug), 181 | ) 182 | memory_data = ZipImageNet( 183 | zip_path=zip_path, 184 | root=imagenet_path, 185 | split="train", 186 | transform=ImageNetValTransform(), 187 | ) 188 | test_data = ZipImageNet( 189 | zip_path=zip_path, 190 | root=imagenet_path, 191 | split="val", 192 | transform=ImageNetValTransform(), 193 | ) 194 | else: 195 | train_data = torchvision.datasets.ImageNet( 196 | root=imagenet_path, 197 | split="train", 198 | transform=Barlow_Transform(n_transform=n_aug), 199 | ) 200 | memory_data = torchvision.datasets.ImageNet( 201 | root=imagenet_path, 202 | split="train", 203 | transform=ImageNetValTransform(), 204 | ) 205 | test_data = torchvision.datasets.ImageNet( 206 | root=imagenet_path, 207 | split="val", 208 | transform=ImageNetValTransform(), 209 | ) 210 | if dataset == "imagenet_100": 211 | imagenet_100_path = "./datasets/imagenet_100/" 212 | if use_zip: 213 | train_data = Zip_ImageFolder( 214 | zip_path=imagenet_100_path + "train.zip", 215 | root=imagenet_100_path + "train/", 216 | transform=Barlow_Transform(n_transform=n_aug), 217 | ) 218 | memory_data = Zip_ImageFolder( 219 | zip_path=imagenet_100_path + "train.zip", 220 | root=imagenet_100_path + "train/", 221 | transform=ImageNetValTransform(), 222 | ) 223 | test_data = Zip_ImageFolder( 224 | zip_path=imagenet_100_path + "val.zip", 225 | root=imagenet_100_path + "val/", 226 | transform=ImageNetValTransform(), 227 | ) 228 | else: 229 | train_data = torchvision.datasets.ImageFolder( 230 | root=imagenet_100_path + "train/", 231 | transform=Barlow_Transform(n_transform=n_aug), 232 | ) 233 | memory_data = torchvision.datasets.ImageFolder( 234 | root=imagenet_100_path + "train/", 235 | transform=ImageNetValTransform(), 236 | ) 237 | test_data = torchvision.datasets.ImageFolder( 238 | root=imagenet_100_path + "val/", 239 | transform=ImageNetValTransform(), 240 | ) 241 | 242 | return train_data, memory_data, test_data 243 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mmcr 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_kmp_llvm 10 | - absl-py=0.15.0=pyhd3eb1b0_0 11 | - adversarial-robustness-toolbox=1.10.3=pyhd8ed1ab_0 12 | - aiohttp=3.8.1=py310h7f8727e_1 13 | - aiosignal=1.2.0=pyhd3eb1b0_0 14 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 15 | - argon2-cffi-bindings=21.2.0=py310h7f8727e_0 16 | - asttokens=2.0.5=pyhd3eb1b0_0 17 | - async-timeout=4.0.1=pyhd3eb1b0_0 18 | - attrs=21.4.0=pyhd3eb1b0_0 19 | - backcall=0.2.0=pyhd3eb1b0_0 20 | - beautifulsoup4=4.11.1=py310h06a4308_0 21 | - black=22.3.0=pyhd8ed1ab_0 22 | - blas=1.0=mkl 23 | - bleach=4.1.0=pyhd3eb1b0_0 24 | - blinker=1.4=py310h06a4308_0 25 | - bottleneck=1.3.5=py310ha9d4c09_0 26 | - brotli=1.0.9=he6710b0_2 27 | - brotlipy=0.7.0=py310h7f8727e_1002 28 | - bzip2=1.0.8=h7b6447c_0 29 | - c-ares=1.18.1=h7f8727e_0 30 | - ca-certificates=2022.12.7=ha878542_0 31 | - cachetools=4.2.2=pyhd3eb1b0_0 32 | - certifi=2022.12.7=pyhd8ed1ab_0 33 | - cffi=1.15.0=py310hd667e15_1 34 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 35 | - click=8.0.4=py310h06a4308_0 36 | - cloudpickle=2.0.0=pyhd3eb1b0_0 37 | - colorama=0.4.5=py310h06a4308_0 38 | - cryptography=37.0.1=py310h9ce1e76_0 39 | - cudatoolkit=11.3.1=h2bc3f7f_2 40 | - cycler=0.11.0=pyhd3eb1b0_0 41 | - cython=0.29.32=py310h6a678d5_0 42 | - dataclasses=0.8=pyh6d0b6a4_7 43 | - dbus=1.13.18=hb2f20db_0 44 | - debugpy=1.5.1=py310h295c915_0 45 | - decorator=5.1.1=pyhd3eb1b0_0 46 | - defusedxml=0.7.1=pyhd3eb1b0_0 47 | - docker-pycreds=0.4.0=pyhd3eb1b0_0 48 | - einops=0.4.1=pyhd8ed1ab_0 49 | - entrypoints=0.4=py310h06a4308_0 50 | - executing=0.8.3=pyhd3eb1b0_0 51 | - expat=2.4.4=h295c915_0 52 | - ffmpeg=4.2.2=h20bf706_0 53 | - filelock=3.9.0=pyhd8ed1ab_0 54 | - fontconfig=2.13.1=h6c09931_0 55 | - fonttools=4.25.0=pyhd3eb1b0_0 56 | - freetype=2.11.0=h70c0345_0 57 | - frozenlist=1.2.0=py310h7f8727e_1 58 | - giflib=5.2.1=h7b6447c_0 59 | - gitdb=4.0.7=pyhd3eb1b0_0 60 | - gitpython=3.1.18=pyhd3eb1b0_1 61 | - glib=2.69.1=h4ff587b_1 62 | - gmp=6.2.1=h295c915_3 63 | - gnutls=3.6.15=he1e5248_0 64 | - google-auth=2.6.0=pyhd3eb1b0_0 65 | - google-auth-oauthlib=0.4.4=pyhd3eb1b0_0 66 | - grpcio=1.42.0=py310hce63b2e_0 67 | - gst-plugins-base=1.14.0=h8213a91_2 68 | - gstreamer=1.14.0=h28cd5cc_2 69 | - huggingface_hub=0.12.0=pyhd8ed1ab_0 70 | - icu=58.2=he6710b0_3 71 | - idna=3.3=pyhd3eb1b0_0 72 | - importlib-metadata=4.11.3=py310h06a4308_0 73 | - intel-openmp=2021.4.0=h06a4308_3561 74 | - ipykernel=6.9.1=py310h06a4308_0 75 | - ipython=8.4.0=py310h06a4308_0 76 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 77 | - ipywidgets=7.6.5=pyhd3eb1b0_1 78 | - jedi=0.18.1=py310h06a4308_1 79 | - jinja2=3.0.3=pyhd3eb1b0_0 80 | - joblib=1.1.0=pyhd3eb1b0_0 81 | - jpeg=9e=h7f8727e_0 82 | - jsonschema=4.4.0=py310h06a4308_0 83 | - jupyter=1.0.0=py310h06a4308_7 84 | - jupyter_client=7.2.2=py310h06a4308_0 85 | - jupyter_console=6.4.3=pyhd3eb1b0_0 86 | - jupyter_core=4.10.0=py310h06a4308_0 87 | - jupyterlab_pygments=0.1.2=py_0 88 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 89 | - kiwisolver=1.4.2=py310h295c915_0 90 | - lame=3.100=h7b6447c_0 91 | - lcms2=2.12=h3be6417_0 92 | - ld_impl_linux-64=2.38=h1181459_1 93 | - libblas=3.9.0=12_linux64_mkl 94 | - libcblas=3.9.0=12_linux64_mkl 95 | - libffi=3.3=he6710b0_2 96 | - libgcc-ng=12.2.0=h65d4601_19 97 | - libgfortran-ng=7.5.0=ha8ba4b0_17 98 | - libgfortran4=7.5.0=ha8ba4b0_17 99 | - libidn2=2.3.2=h7f8727e_0 100 | - liblapack=3.9.0=12_linux64_mkl 101 | - libllvm11=11.1.0=h3826bc1_1 102 | - libopus=1.3.1=h7b6447c_0 103 | - libpng=1.6.37=hbc83047_0 104 | - libprotobuf=3.20.1=h4ff587b_0 105 | - libsodium=1.0.18=h7b6447c_0 106 | - libstdcxx-ng=11.2.0=h1234567_1 107 | - libtasn1=4.16.0=h27cfd23_0 108 | - libtiff=4.2.0=h2818925_1 109 | - libunistring=0.9.10=h27cfd23_0 110 | - libuuid=1.0.3=h7f8727e_2 111 | - libuv=1.40.0=h7b6447c_0 112 | - libvpx=1.7.0=h439df22_0 113 | - libwebp=1.2.2=h55f646e_0 114 | - libwebp-base=1.2.2=h7f8727e_0 115 | - libxcb=1.15=h7f8727e_0 116 | - libxml2=2.9.14=h74e7548_0 117 | - llvm-openmp=14.0.6=h9e868ea_0 118 | - llvmlite=0.38.0=py310h4ff587b_0 119 | - lz4-c=1.9.3=h295c915_1 120 | - markdown=3.3.4=py310h06a4308_0 121 | - markupsafe=2.1.1=py310h7f8727e_0 122 | - matplotlib=3.5.1=py310h06a4308_1 123 | - matplotlib-base=3.5.1=py310ha18d171_1 124 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 125 | - mistune=0.8.4=py310h7f8727e_1000 126 | - mkl=2021.4.0=h06a4308_640 127 | - mkl-service=2.4.0=py310h7f8727e_0 128 | - mkl_fft=1.3.1=py310hd6ae3a3_0 129 | - mkl_random=1.2.2=py310h00e6091_0 130 | - multidict=6.0.2=py310h5764c6d_1 131 | - munkres=1.1.4=py_0 132 | - mypy_extensions=0.4.3=py310h06a4308_0 133 | - nbclient=0.5.13=py310h06a4308_0 134 | - nbconvert=6.4.4=py310h06a4308_0 135 | - nbformat=5.3.0=py310h06a4308_0 136 | - ncurses=6.3=h5eee18b_3 137 | - nest-asyncio=1.5.5=py310h06a4308_0 138 | - nettle=3.7.3=hbbd107a_1 139 | - notebook=6.4.11=py310h06a4308_0 140 | - numba=0.55.1=py310h00e6091_0 141 | - numexpr=2.8.3=py310hcea2de6_0 142 | - numpy=1.21.6=py310h45f3432_0 143 | - oauthlib=3.2.0=pyhd8ed1ab_0 144 | - openh264=2.1.1=h4ff587b_0 145 | - openssl=1.1.1t=h0b41bf4_0 146 | - packaging=21.3=pyhd3eb1b0_0 147 | - pandas=1.4.2=py310h295c915_0 148 | - pandocfilters=1.5.0=pyhd3eb1b0_0 149 | - parso=0.8.3=pyhd3eb1b0_0 150 | - pathspec=0.9.0=pyhd8ed1ab_0 151 | - pathtools=0.1.2=pyhd3eb1b0_1 152 | - pcre=8.45=h295c915_0 153 | - pexpect=4.8.0=pyhd3eb1b0_3 154 | - pickleshare=0.7.5=pyhd3eb1b0_1003 155 | - pillow=9.2.0=py310hace64e9_1 156 | - pip=22.1.2=py310h06a4308_0 157 | - platformdirs=2.4.0=pyhd3eb1b0_0 158 | - prometheus_client=0.13.1=pyhd3eb1b0_0 159 | - promise=2.3=py310h06a4308_0 160 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 161 | - prompt_toolkit=3.0.20=hd3eb1b0_0 162 | - protobuf=3.20.1=py310h295c915_0 163 | - psutil=5.9.0=py310h5eee18b_0 164 | - ptyprocess=0.7.0=pyhd3eb1b0_2 165 | - pure_eval=0.2.2=pyhd3eb1b0_0 166 | - pyasn1=0.4.8=pyhd3eb1b0_0 167 | - pyasn1-modules=0.2.8=py_0 168 | - pycocotools=2.0.6=py310hde88566_0 169 | - pycparser=2.21=pyhd3eb1b0_0 170 | - pygments=2.11.2=pyhd3eb1b0_0 171 | - pyjwt=2.4.0=py310h06a4308_0 172 | - pyopenssl=22.0.0=pyhd3eb1b0_0 173 | - pyparsing=3.0.4=pyhd3eb1b0_0 174 | - pyqt=5.9.2=py310h295c915_6 175 | - pyrsistent=0.18.0=py310h7f8727e_0 176 | - pysocks=1.7.1=py310h06a4308_0 177 | - python=3.10.4=h12debd9_0 178 | - python-dateutil=2.8.2=pyhd3eb1b0_0 179 | - python-fastjsonschema=2.15.1=pyhd3eb1b0_0 180 | - python_abi=3.10=2_cp310 181 | - pytorch=1.11.0=py3.10_cuda11.3_cudnn8.2.0_0 182 | - pytorch-model-summary=0.1.1=py_0 183 | - pytorch-mutex=1.0=cuda 184 | - pytz=2022.1=py310h06a4308_0 185 | - pyzmq=23.2.0=py310h6a678d5_0 186 | - qt=5.9.7=h5867ecd_1 187 | - qtconsole=5.3.1=py310h06a4308_0 188 | - qtpy=2.0.1=pyhd3eb1b0_0 189 | - readline=8.1.2=h7f8727e_1 190 | - requests=2.28.1=py310h06a4308_0 191 | - requests-oauthlib=1.3.0=py_0 192 | - rsa=4.7.2=pyhd3eb1b0_1 193 | - scikit-learn=1.0.1=py310h00e6091_0 194 | - scipy=1.7.3=py310hfa59a62_0 195 | - seaborn=0.11.2=pyhd3eb1b0_0 196 | - send2trash=1.8.0=pyhd3eb1b0_1 197 | - sentry-sdk=1.7.1=pyhd8ed1ab_0 198 | - setproctitle=1.2.2=py310h7f8727e_0 199 | - setuptools=61.2.0=py310h06a4308_0 200 | - shortuuid=1.0.9=pyha770c72_1 201 | - sip=4.19.13=py310h295c915_0 202 | - six=1.16.0=pyhd3eb1b0_1 203 | - smmap=4.0.0=pyhd3eb1b0_0 204 | - soupsieve=2.3.1=pyhd3eb1b0_0 205 | - sqlite=3.38.5=hc218d9a_0 206 | - stack_data=0.2.0=pyhd3eb1b0_0 207 | - submitit=1.2.1=pyh44b312d_0 208 | - tbb=2021.5.0=hd09550d_0 209 | - tensorboard=2.9.0=pyhd8ed1ab_0 210 | - tensorboard-data-server=0.6.0=py310hca6d32c_0 211 | - tensorboard-plugin-wit=1.6.0=py_0 212 | - terminado=0.13.1=py310h06a4308_0 213 | - testpath=0.6.0=py310h06a4308_0 214 | - threadpoolctl=2.2.0=pyh0d69192_0 215 | - timm=0.6.12=pyhd8ed1ab_0 216 | - tk=8.6.12=h1ccaba5_0 217 | - tomli=2.0.1=py310h06a4308_0 218 | - torchvision=0.12.0=py310_cu113 219 | - tornado=6.1=py310h7f8727e_0 220 | - tqdm=4.64.0=pyhd8ed1ab_0 221 | - traitlets=5.1.1=pyhd3eb1b0_0 222 | - typed-ast=1.4.3=py310h7f8727e_1 223 | - typing-extensions=4.1.1=hd3eb1b0_0 224 | - typing_extensions=4.1.1=pyh06a4308_0 225 | - tzdata=2022a=hda174b7_0 226 | - urllib3=1.26.9=py310h06a4308_0 227 | - wandb=0.12.21=pyhd8ed1ab_0 228 | - wcwidth=0.2.5=pyhd3eb1b0_0 229 | - webencodings=0.5.1=py310h06a4308_1 230 | - werkzeug=2.0.3=pyhd3eb1b0_0 231 | - wheel=0.37.1=pyhd3eb1b0_0 232 | - widgetsnbextension=3.5.2=py310h06a4308_0 233 | - x264=1!157.20191217=h7b6447c_0 234 | - xz=5.2.5=h7f8727e_1 235 | - yacs=0.1.8=pyhd8ed1ab_0 236 | - yaml=0.2.5=h7b6447c_0 237 | - yarl=1.6.3=py310h7f8727e_1 238 | - zeromq=4.3.4=h2531618_0 239 | - zipp=3.8.0=py310h06a4308_0 240 | - zlib=1.2.12=h7f8727e_2 241 | - zstd=1.5.2=ha4553b6_0 242 | - pip: 243 | - addict==2.4.0 244 | - albumentations==1.0.0 245 | - coolname==1.1.0 246 | - data-science-types==0.2.23 247 | - docstring-parser==0.14.1 248 | - exceptiongroup==1.1.1 249 | - h5py==3.8.0 250 | - imageio==2.19.3 251 | - iniconfig==2.0.0 252 | - mat73==0.60 253 | - mmcv-full==1.7.0 254 | - mosaicml==0.8.0 255 | - networkx==2.8.4 256 | - opencv-python-headless==4.6.0.66 257 | - pluggy==1.0.0 258 | - py-cpuinfo==8.0.0 259 | - pydeprecate==0.3.2 260 | - pyrtools==1.0.1 261 | - pytest==7.3.1 262 | - pytorch-ranger==0.1.1 263 | - pywavelets==1.3.0 264 | - pyyaml==6.0 265 | - ruamel-yaml==0.17.21 266 | - ruamel-yaml-clib==0.2.6 267 | - scikit-image==0.19.3 268 | - surgeon-pytorch==0.0.4 269 | - tifffile==2022.5.4 270 | - torch-optimizer==0.1.0 271 | - torch-tb-profiler==0.4.0 272 | - torchmetrics==0.7.3 273 | - yahp==0.1.1 274 | - yapf==0.33.0 275 | -------------------------------------------------------------------------------- /mmcr/imagenet/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, random 3 | import numpy as np 4 | import torch.distributed as dist 5 | import torch.backends.cudnn as cudnn 6 | 7 | from collections import defaultdict, deque 8 | import time, datetime, signal 9 | import subprocess 10 | from pathlib import Path 11 | import submitit 12 | import random 13 | 14 | """ 15 | Misc functions. 16 | Mostly copy-paste from torchvision references or other public repos like DETR: 17 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 18 | """ 19 | 20 | 21 | def setup_for_distributed(is_master): 22 | """ 23 | This function disables printing when not in master process 24 | """ 25 | import builtins as __builtin__ 26 | 27 | builtin_print = __builtin__.print 28 | 29 | def print(*args, **kwargs): 30 | force = kwargs.pop("force", False) 31 | if is_master or force: 32 | builtin_print(*args, **kwargs) 33 | 34 | __builtin__.print = print 35 | 36 | 37 | def fix_random_seeds(seed=31): 38 | """ 39 | Fix random seeds. 40 | """ 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | 45 | 46 | def get_shared_folder() -> Path: 47 | user = os.getenv("USER") 48 | if Path("/data/sarkar-vision/slurm_jobs/").is_dir(): 49 | p = Path(f"/data/sarkar-vision/slurm_jobs/{user}") 50 | p.mkdir(exist_ok=True) 51 | return p 52 | raise RuntimeError("No shared folder available") 53 | 54 | 55 | def init_dist_node(args): 56 | if "SLURM_JOB_ID" in os.environ: 57 | args.ngpus_per_node = torch.cuda.device_count() 58 | 59 | # requeue job on SLURM preemption 60 | signal.signal(signal.SIGUSR1, handle_sigusr1) 61 | signal.signal(signal.SIGTERM, handle_sigterm) 62 | 63 | # find a common host name on all nodes 64 | cmd = "scontrol show hostnames " + os.getenv("SLURM_JOB_NODELIST") 65 | stdout = subprocess.check_output(cmd.split()) 66 | host_name = stdout.decode().splitlines()[0] 67 | args.dist_url = f"tcp://{host_name}:{random.randint(49152, 65535)}" 68 | args.host_name_ = host_name 69 | 70 | # distributed parameters 71 | args.rank = int(os.getenv("SLURM_NODEID")) * args.ngpus_per_node 72 | args.world_size = int(os.getenv("SLURM_NNODES")) * args.ngpus_per_node 73 | 74 | else: 75 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 76 | args.ngpus_per_node = torch.cuda.device_count() 77 | 78 | args.rank = 0 79 | args.dist_url = f"tcp://localhost:{args.port}" 80 | args.world_size = args.ngpus_per_node 81 | 82 | 83 | def init_dist_gpu(gpu, args, fix_seed=False): 84 | if args.slurm or True: 85 | job_env = submitit.JobEnvironment() 86 | args.gpu = job_env.local_rank 87 | args.rank = job_env.global_rank 88 | else: 89 | args.gpu = gpu 90 | args.rank += gpu 91 | 92 | dist.init_process_group( 93 | backend="gloo", 94 | init_method=args.dist_url, 95 | world_size=args.world_size, 96 | rank=args.rank, 97 | ) 98 | 99 | if fix_seed: 100 | fix_random_seeds() 101 | torch.cuda.set_device(args.gpu) 102 | cudnn.benchmark = True 103 | dist.barrier() 104 | 105 | args.main = args.rank == 0 106 | setup_for_distributed(args.main) 107 | 108 | 109 | def handle_sigusr1(signum, frame): 110 | os.system(f'scontrol requeue {os.getenv("SLURM_JOB_ID")}') 111 | exit() 112 | 113 | 114 | def handle_sigterm(signum, frame): 115 | pass 116 | 117 | 118 | def is_dist_avail_and_initialized(): 119 | if not dist.is_available(): 120 | return False 121 | if not dist.is_initialized(): 122 | return False 123 | return True 124 | 125 | 126 | class SmoothedValue(object): 127 | """Track a series of values and provide access to smoothed values over a 128 | window or the global series average. 129 | """ 130 | 131 | def __init__(self, window_size=20, fmt=None): 132 | if fmt is None: 133 | fmt = "{median:.6f} ({global_avg:.6f})" 134 | self.deque = deque(maxlen=window_size) 135 | self.total = 0.0 136 | self.count = 0 137 | self.fmt = fmt 138 | 139 | def update(self, value, n=1): 140 | self.deque.append(value) 141 | self.count += n 142 | self.total += value * n 143 | 144 | def synchronize_between_processes(self): 145 | """ 146 | Warning: does not synchronize the deque! 147 | """ 148 | if not is_dist_avail_and_initialized(): 149 | return 150 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 151 | dist.barrier() 152 | dist.all_reduce(t) 153 | t = t.tolist() 154 | self.count = int(t[0]) 155 | self.total = t[1] 156 | 157 | @property 158 | def median(self): 159 | d = torch.tensor(list(self.deque)) 160 | return d.median().item() 161 | 162 | @property 163 | def avg(self): 164 | d = torch.tensor(list(self.deque), dtype=torch.float32) 165 | return d.mean().item() 166 | 167 | @property 168 | def global_avg(self): 169 | return self.total / self.count 170 | 171 | @property 172 | def max(self): 173 | return max(self.deque) 174 | 175 | @property 176 | def value(self): 177 | return self.deque[-1] 178 | 179 | def __str__(self): 180 | return self.fmt.format( 181 | median=self.median, 182 | avg=self.avg, 183 | global_avg=self.global_avg, 184 | max=self.max, 185 | value=self.value, 186 | ) 187 | 188 | 189 | class MetricLogger(object): 190 | def __init__(self, delimiter="\t"): 191 | self.meters = defaultdict(SmoothedValue) 192 | self.delimiter = delimiter 193 | 194 | def update(self, **kwargs): 195 | for k, v in kwargs.items(): 196 | if isinstance(v, torch.Tensor): 197 | v = v.item() 198 | assert isinstance(v, (float, int)) 199 | self.meters[k].update(v) 200 | 201 | def __getattr__(self, attr): 202 | if attr in self.meters: 203 | return self.meters[attr] 204 | if attr in self.__dict__: 205 | return self.__dict__[attr] 206 | raise AttributeError( 207 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 208 | ) 209 | 210 | def __str__(self): 211 | loss_str = [] 212 | for name, meter in self.meters.items(): 213 | loss_str.append("{}: {}".format(name, str(meter))) 214 | return self.delimiter.join(loss_str) 215 | 216 | def synchronize_between_processes(self): 217 | for meter in self.meters.values(): 218 | meter.synchronize_between_processes() 219 | 220 | def add_meter(self, name, meter): 221 | self.meters[name] = meter 222 | 223 | def log_every(self, iterable, print_freq, header=None): 224 | i = 0 225 | if not header: 226 | header = "" 227 | start_time = time.time() 228 | end = time.time() 229 | iter_time = SmoothedValue(fmt="{avg:.6f}") 230 | data_time = SmoothedValue(fmt="{avg:.6f}") 231 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 232 | if torch.cuda.is_available(): 233 | log_msg = self.delimiter.join( 234 | [ 235 | header, 236 | "[{0" + space_fmt + "}/{1}]", 237 | "eta: {eta}", 238 | "{meters}", 239 | "time: {time}", 240 | "data: {data}", 241 | "max mem: {memory:.0f}", 242 | ] 243 | ) 244 | else: 245 | log_msg = self.delimiter.join( 246 | [ 247 | header, 248 | "[{0" + space_fmt + "}/{1}]", 249 | "eta: {eta}", 250 | "{meters}", 251 | "time: {time}", 252 | "data: {data}", 253 | ] 254 | ) 255 | MB = 1024.0 * 1024.0 256 | for obj in iterable: 257 | data_time.update(time.time() - end) 258 | yield obj 259 | iter_time.update(time.time() - end) 260 | if i % print_freq == 0 or i == len(iterable) - 1: 261 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 262 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 263 | if torch.cuda.is_available(): 264 | print( 265 | log_msg.format( 266 | i, 267 | len(iterable), 268 | eta=eta_string, 269 | meters=str(self), 270 | time=str(iter_time), 271 | data=str(data_time), 272 | memory=torch.cuda.max_memory_allocated() / MB, 273 | ) 274 | ) 275 | else: 276 | print( 277 | log_msg.format( 278 | i, 279 | len(iterable), 280 | eta=eta_string, 281 | meters=str(self), 282 | time=str(iter_time), 283 | data=str(data_time), 284 | ) 285 | ) 286 | i += 1 287 | end = time.time() 288 | total_time = time.time() - start_time 289 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 290 | print( 291 | "{} Total time: {} ({:.6f} s / it)".format( 292 | header, total_time_str, total_time / len(iterable) 293 | ) 294 | ) 295 | --------------------------------------------------------------------------------