├── .gitignore ├── .vscode └── launch.json ├── README.md ├── bin ├── slurm_run.sh ├── train.sh └── write_dataset.py ├── configs ├── cifar.gin ├── dres.gin ├── mae_ffcv.gin ├── mae_if.gin ├── resnet18.gin ├── simclr_ffcv.gin ├── simclr_if.gin └── vitt.gin ├── dataset ├── build_dataset.py ├── ffcv_transform.py ├── multiloader.py └── transform.py ├── docs ├── config.md └── smalldata.md ├── environment.txt ├── layers ├── aim_vit.py ├── backbone.py ├── build_model.py ├── operation.py ├── target.py └── vit_rope.py ├── main_pretrain.py ├── main_pretrain_ema.py ├── model ├── __init__.py ├── aim.py ├── base.py ├── dino.py ├── mae.py ├── mcl.py ├── moco.py ├── msmae.py ├── simclr.py ├── simsiam.py ├── sit.py └── vcl.py ├── pics └── n01518878_10165.JPEG ├── profiler.py ├── requirements.txt ├── submitit_pretrain.py └── util ├── __init__.py ├── clustering.py ├── crop.py ├── datasets.py ├── dres.py ├── helper.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py ├── pos_embed.py └── prob.py /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | outputs 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyderworkspace 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .dmypy.json 118 | dmypy.json 119 | 120 | # Pyre type checker 121 | .pyre/ 122 | 123 | # pytype static type analyzer 124 | .pytype/ 125 | 126 | # Cython debug symbols 127 | cython_debug/ -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | }, 14 | { 15 | "name": "SimCLR", 16 | "type": "debugpy", 17 | "request": "launch", 18 | "module": "torch.distributed.launch", 19 | "console": "integratedTerminal", 20 | "args": [ 21 | "--nproc_per_node=1", 22 | "${workspaceFolder}/main_pretrain.py", 23 | "--batch_size=20", "--opt=adamw", "--blr=5e-4", "--epochs=100", 24 | "--data_path=../data/", "--data_set=cifar10", 25 | "--cfgs", "configs/cifar.gin", "configs/vitt.gin", 26 | "--gin", "build_dataset.transform_fn=@DataAugmentationDINO", "build_model.model_fn=@SimCLR", "build_model.embed_dim=192" 27 | ], 28 | }, 29 | { 30 | "name": "VCL", 31 | "type": "debugpy", 32 | "request": "launch", 33 | "module": "torch.distributed.launch", 34 | "console": "integratedTerminal", 35 | "args": [ 36 | "--nproc_per_node=1", 37 | "${workspaceFolder}/main_pretrain.py", 38 | "--batch_size=500", "--opt=adamw", "--blr=5e-4", "--epochs=100", 39 | "--data_path=../data/", "--data_set=cifar10", 40 | "--cfgs", "configs/cifar.gin", "configs/vitt.gin", 41 | "--gin", "build_dataset.transform_fn=@DataAugmentationDINO", "build_model.model_fn=@VCL", "build_model.embed_dim=192" 42 | ], 43 | }, 44 | { 45 | "name": "amae", 46 | "type": "debugpy", 47 | "request": "launch", 48 | "module": "torch.distributed.launch", 49 | "console": "integratedTerminal", 50 | "args": [ 51 | "--nproc_per_node=1", 52 | "${workspaceFolder}/main_pretrain.py", 53 | "--batch_size=20", "--opt=adamw", "--blr=5e-4", "--epochs=100", 54 | "--data_path=../data/", "--data_set=cifar10", 55 | "--cfgs", "${workspaceFolder}/configs/cifar.gin", 56 | "--gin", "build_dataset.transform_fn=@SimpleAugmentation", "SimpleAugmentation.img_size=32", "build_model.model_fn=@amae_tiny", "build_model.patch_size=4", "build_model.img_size=32", "build_model.decoder_patch_size=2" 57 | ], 58 | }, 59 | { 60 | "name": "aim", 61 | "type": "debugpy", 62 | "request": "launch", 63 | "module": "torch.distributed.launch", 64 | "console": "integratedTerminal", 65 | "args": [ 66 | "--nproc_per_node=1", 67 | "${workspaceFolder}/main_pretrain.py", 68 | "--batch_size=20", "--opt=adamw", "--blr=5e-4", "--epochs=100", 69 | "--data_path=../data/", "--data_set=cifar10", 70 | "--cfgs", "${workspaceFolder}/configs/cifar.gin", 71 | "--gin", "build_dataset.transform_fn=@SimpleAugmentation", "SimpleAugmentation.img_size=32", "build_model.model_fn=@aim_tiny", 72 | ], 73 | } 74 | ] 75 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Toward Training Self-supervised Models with Limited Budget 2 | 3 | This repository focuses on enabling efficient training for self-supervised learning (SSL). Often referred to as the "dark matter" of intelligence, SSL empowers AI systems to learn without supervision, drawing insights from their environments in ways reminiscent of human learning. While numerous advanced SSL algorithms have been proposed, many achieving state-of-the-art (SOTA) results, their adoption is often hindered by prohibitively high training costs. This limitation stifles innovation from academia and individual researchers. Designed to be beginner-friendly, this repository allows users to reproduce SSL algorithms and perform fast validation for new ideas. Here are key features: 4 | - Efficient data loading with [ffcv](https://github.com/erow/ffcv). 5 | - Flexible configuration with [gin-config](docs/config.md). 6 | - A collection of SSL algorithms. 7 | - Evaluation with [vitookit](https://github.com/erow/vitookit). 8 | - All models are available at [WANDB](https://wandb.ai/erow/FastSSL). 9 | - A [guideline](docs/smalldata.md) of training SSL models on CIFAR10 in a few minutes!. 10 | 11 | # Environment Setup 12 | 13 | Create a new environment with conda or micromamba: 14 | ```bash 15 | conda create -y -n FastSSL python=3.10 cupy pkg-config 'libjpeg-turbo=3.0.0' opencv numba pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia -c conda-forge 16 | conda activate FastSSL 17 | pip install -r requirements.txt 18 | ``` 19 | Or, you can use a docker image to ensure everything is the same with mine from [Github Package](https://github.com/erow/aisurrey-docker) 20 | ``` 21 | docker pull ghcr.io/erow/aisurrey-docker:sha256-d835a01e444257345d78c95cec157eb604a73935f70f9e7928cdd08d97411fa7.sig 22 | ``` 23 | 24 | # Usage 25 | 26 | ## torchrun 27 | 28 | To train a MAE, you can run the following command 29 | ```bash 30 | torchrun --nproc_per_node 8 main_pretrain.py --data_path=${train_path} --data_set=ffcv \ 31 | --epochs 800 --warmup_epochs 40 --blr 1.5e-4 --weight_decay 0.05 --batch_size 512\ 32 | --cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline --ckpt_freq=100 --output_dir outputs/IN1K_base 33 | ``` 34 | Optional arguments: `--compile` to compile the model, `--ckpt_freq` to save checkpoints every `ckpt_freq` epochs, `--online_prob` to evaluate the linear classifier during training. 35 | 36 | 37 | ## HPC 38 | 39 | The original settings for [ViT-Large](https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md) are bs=4096, epochs=800 ~42h in 64 V100 GPUs. 40 | 41 | ```bash 42 | WANDB_NAME=mae_1k python submitit_pretrain.py \ 43 | --job_dir ${JOB_DIR} \ 44 | -p gpu --ngpus 8 --nodes 8 \ 45 | --batch_size 64 \ 46 | --epochs 800 \ 47 | --warmup_epochs 40 \ 48 | --blr 1.5e-4 --weight_decay 0.05 \ 49 | --cfgs configs/mae_ffcv.gin --gin build_model.model_fn=@base/MaskedAutoencoderViT build_dataset.transform_fn=@SimplePipeline \ 50 | --data_path=${train_path} --data_set=ffcv 51 | ``` 52 | 53 | # Cite Me! 54 | 55 | ```bib 56 | @misc{wu2024dailymaepretrainingmaskedautoencoders, 57 | title={DailyMAE: Towards Pretraining Masked Autoencoders in One Day}, 58 | author={Jiantao Wu and Shentong Mo and Sara Atito and Zhenhua Feng and Josef Kittler and Muhammad Awais}, 59 | year={2024}, 60 | eprint={2404.00509}, 61 | archivePrefix={arXiv}, 62 | primaryClass={cs.LG}, 63 | url={https://arxiv.org/abs/2404.00509}, 64 | } 65 | ``` -------------------------------------------------------------------------------- /bin/slurm_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Parameters 3 | #SBATCH --cpus-per-task=10 4 | #SBATCH --error=outputs/slurm/%j_%t_log.err 5 | #SBATCH --gpus-per-node=8 6 | #SBATCH --job-name=run 7 | #SBATCH --mem=400GB 8 | #SBATCH --nodes=1 9 | #SBATCH --ntasks-per-node=8 10 | #SBATCH --open-mode=append 11 | #SBATCH --output=outputs/slurm/%j_%t_log.out 12 | 13 | ARGS=${@} 14 | torchrun --nproc_per_node=8 $ARGS 15 | -------------------------------------------------------------------------------- /bin/train.sh: -------------------------------------------------------------------------------- 1 | MODEL=$1 2 | PARAMS=${@:2} 3 | 4 | 5 | # name=${MODEL}_rand 6 | # WANDB_NAME=${name} python submitit_pretrain.py --comment ${name} -p long -t 6000 --job_dir outputs/pretrain/${name} --data_path ~/data/edata/IN100K_rand.ffcv $PARAMS --batch_size 128 --ckpt_freq 50 --online_prob --epochs 300 7 | 8 | name=${MODEL}_c100 9 | WANDB_NAME=${name} python submitit_pretrain.py --comment ${name} -p long -t 6000 --job_dir outputs/pretrain/${name} --data_path ~/data/ffcv/IN100_train_500.ffcv $PARAMS --batch_size 128 --ckpt_freq 50 --online_prob --epochs 300 10 | 11 | # name=${MODEL}_sas 12 | # WANDB_NAME=${name} python submitit_pretrain.py --comment ${name} -p long -t 6000 --job_dir outputs/pretrain/${name} --data_path ~/data/edata/SASIN100_resnet18.ffcv $PARAMS --batch_size 128 --ckpt_freq 50 --online_prob --epochs 300 13 | 14 | -------------------------------------------------------------------------------- /bin/write_dataset.py: -------------------------------------------------------------------------------- 1 | """example usage: 2 | export IMAGENET_DIR=/path/to/pytorch/format/imagenet/directory/ 3 | export WRITE_DIR=/your/path/here/ 4 | write_dataset train 500 0.50 90 5 | write_path=$WRITE_DIR/train500_0.5_90.ffcv 6 | echo "Writing ImageNet train dataset to ${write_path}" 7 | python examples/write_dataset.py \ 8 | --cfg.data_dir=$IMAGENET_DIR \ 9 | --cfg.write_path=$write_path \ 10 | --cfg.max_resolution=500 \ 11 | --cfg.write_mode=smart \ 12 | --cfg.compress_probability=0.50 \ 13 | --cfg.jpeg_quality=90 14 | """ 15 | import os 16 | from PIL import Image 17 | from torch.utils.data import Subset 18 | from ffcv.writer import DatasetWriter 19 | from ffcv.fields import IntField, RGBImageField 20 | import torchvision 21 | from torchvision.datasets import ImageFolder 22 | import torchvision.datasets as torch_datasets 23 | 24 | from argparse import ArgumentParser 25 | from fastargs import Section, Param 26 | from fastargs.validation import And, OneOf 27 | from fastargs.decorators import param, section 28 | from fastargs import get_current_config 29 | import cv2 30 | import numpy as np 31 | 32 | import torch 33 | from torchvision import transforms 34 | import timm 35 | import random 36 | 37 | 38 | Section('cfg', 'arguments to give the writer').params( 39 | dataset=Param(And(str, OneOf(['cifar', 'imagenet'])), 'Which dataset to write', default='imagenet'), 40 | data_dir=Param(str, 'Where to find the PyTorch dataset', required=True), 41 | write_path=Param(str, 'Where to write the new dataset', required=True), 42 | write_mode=Param(str, 'Mode: raw, smart or jpg', required=False, default='smart'), 43 | max_resolution=Param(int, 'Max image side length. 0 any size.', required=False,default=0), 44 | num_workers=Param(int, 'Number of workers to use', default=16), 45 | chunk_size=Param(int, 'Chunk size for writing', default=100), 46 | jpeg_quality=Param(float, 'Quality of jpeg images', default=90), 47 | subset=Param(float, 'How many images to use (the fraction of the dataset, 1 for all)', default=0.1), 48 | compress_probability=Param(float, 'compress probability', default=0.5), 49 | threshold=Param(int, 'threshold for smart mode to compress by jpeg', default=286432), 50 | proxy=Param(str, 'proxy model to use', default='resnet18'), 51 | sub_mode=Param(And(str, OneOf(['sas', 'random'])), 'Subset mode', default='sas'), 52 | ) 53 | 54 | @section('cfg') 55 | @param('dataset') 56 | @param('data_dir') 57 | @param('write_path') 58 | @param('max_resolution') 59 | @param('num_workers') 60 | @param('chunk_size') 61 | @param('subset') 62 | @param('jpeg_quality') 63 | @param('write_mode') 64 | @param('compress_probability') 65 | @param('threshold') 66 | @param('proxy') 67 | @param('sub_mode') 68 | def main(dataset, data_dir, write_path, max_resolution, num_workers, 69 | chunk_size, subset, jpeg_quality, write_mode, 70 | compress_probability, threshold, proxy,sub_mode): 71 | 72 | if dataset == 'imagenet': 73 | my_dataset = ImageFolder(root=data_dir) 74 | elif dataset == 'cifar': 75 | my_dataset = torch_datasets.CIFAR10(root=data_dir, train=True, download=True) 76 | else: 77 | raise ValueError('Unknown dataset') 78 | 79 | if sub_mode=="sas": 80 | device = "cuda:0" 81 | 82 | # Approximate Latent Classes 83 | from sas.approx_latent_classes import clip_approx 84 | from sas.subset_dataset import SASSubsetDataset 85 | ds_size = len(my_dataset) 86 | rand_labeled_examples_indices = random.sample(range(ds_size), 10000) 87 | rand_labeled_examples_labels = [my_dataset.samples[i][1] for i in rand_labeled_examples_indices] 88 | 89 | if os.path.exists('/tmp/clip_partition.npy'): 90 | partition = np.load('/tmp/clip_partition.npy', allow_pickle=True).item() 91 | else: 92 | partition = clip_approx( 93 | img_trainset=my_dataset, 94 | labeled_example_indices=rand_labeled_examples_indices, 95 | labeled_examples_labels=rand_labeled_examples_labels, 96 | num_classes=1000, 97 | device=device, verbose=True 98 | ) 99 | np.save('/tmp/clip_partition.npy', partition) 100 | 101 | # Get Subset 102 | proxy_model = timm.create_model(proxy, pretrained=True).to(device) 103 | 104 | augmentation_distance = None 105 | if os.path.exists('/tmp/augmentation_distance.npy'): 106 | augmentation_distance = np.load('/tmp/augmentation_distance.npy',allow_pickle=True).item() 107 | 108 | default_tfms = my_dataset.transform 109 | my_dataset.transform = torchvision.transforms.Compose([ 110 | transforms.Resize((224, 224)), 111 | torchvision.transforms.ToTensor(), 112 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 113 | ]) 114 | subset_dataset = SASSubsetDataset( 115 | dataset=my_dataset, 116 | subset_fraction=subset, 117 | num_downstream_classes=1000, 118 | device=device, 119 | proxy_model=proxy_model, 120 | approx_latent_class_partition=partition, 121 | augmentation_distance = augmentation_distance, 122 | verbose=False 123 | ) 124 | if augmentation_distance is None: 125 | np.save('/tmp/augmentation_distance.npy', subset_dataset.augmentation_distance) 126 | my_dataset.transform = default_tfms 127 | subset_indecies = subset_dataset.subset_indices 128 | elif sub_mode=="random": 129 | subset_indecies = random.sample(range(len(my_dataset)), int(subset*len(my_dataset))) 130 | else: 131 | raise ValueError('Unknown sub_mode') 132 | my_dataset = Subset(my_dataset, subset_indecies) 133 | 134 | writer = DatasetWriter(write_path, { 135 | 'image': RGBImageField(write_mode=write_mode, 136 | max_resolution=None if max_resolution==0 else max_resolution, 137 | compress_probability=compress_probability, 138 | jpeg_quality=jpeg_quality, 139 | smart_threshold=threshold), 140 | 'label': IntField(), 141 | }, num_workers=num_workers) 142 | 143 | writer.from_indexed_dataset(my_dataset, chunksize=chunk_size,shuffle_indices=False) 144 | 145 | if __name__ == '__main__': 146 | config = get_current_config() 147 | parser = ArgumentParser() 148 | config.augment_argparse(parser) 149 | config.collect_argparse_args(parser) 150 | config.validate(mode='stderr') 151 | config.summary() 152 | 153 | args=config.get().cfg 154 | assert args.write_path.endswith('.ffcv'), 'write_path must end with .ffcv' 155 | file=open(args.write_path.replace(".ffcv",".meta"), 'w') 156 | file.write(str(args.__dict__)) 157 | main() 158 | -------------------------------------------------------------------------------- /configs/cifar.gin: -------------------------------------------------------------------------------- 1 | build_dataset.transform_fn=@DataAugmentationDINO 2 | build_dataset.mean=(0.4914, 0.4822, 0.4465) 3 | build_dataset.std=(0.2470, 0.2435, 0.2616) 4 | DataAugmentationDINO.img_size=32 5 | DataAugmentationDINO.local_crops_number=0 -------------------------------------------------------------------------------- /configs/dres.gin: -------------------------------------------------------------------------------- 1 | 2 | DynamicMasking.start_ramp=0 3 | DynamicMasking.end_ramp=100 4 | DynamicMasking.scheme=1 -------------------------------------------------------------------------------- /configs/mae_ffcv.gin: -------------------------------------------------------------------------------- 1 | # Parameters for build_dataset: 2 | # ============================================================================== 3 | # build_dataset.transform_fn = @SimplePipeline 4 | 5 | # Parameters for build_model: 6 | # ============================================================================== 7 | build_model.model_fn = @base/MaskedAutoencoderViT 8 | 9 | # Parameters for MaskedAutoencoderViT: 10 | # ============================================================================== 11 | 12 | tiny/MaskedAutoencoderViT.img_size=224 13 | tiny/MaskedAutoencoderViT.patch_size=16 14 | tiny/MaskedAutoencoderViT.embed_dim=192 15 | tiny/MaskedAutoencoderViT.depth=12 16 | tiny/MaskedAutoencoderViT.num_heads=12 17 | tiny/MaskedAutoencoderViT.decoder_embed_dim=96 18 | tiny/MaskedAutoencoderViT.decoder_depth=1 19 | tiny/MaskedAutoencoderViT.decoder_num_heads=3 20 | 21 | small/MaskedAutoencoderViT.img_size=224 22 | small/MaskedAutoencoderViT.patch_size=16 23 | small/MaskedAutoencoderViT.embed_dim=384 24 | small/MaskedAutoencoderViT.depth=12 25 | small/MaskedAutoencoderViT.num_heads=6 26 | small/MaskedAutoencoderViT.decoder_embed_dim=512 27 | small/MaskedAutoencoderViT.decoder_depth=4 28 | small/MaskedAutoencoderViT.decoder_num_heads=16 29 | 30 | base/MaskedAutoencoderViT.img_size=224 31 | base/MaskedAutoencoderViT.patch_size=16 32 | base/MaskedAutoencoderViT.embed_dim=768 33 | base/MaskedAutoencoderViT.depth=12 34 | base/MaskedAutoencoderViT.num_heads=6 35 | base/MaskedAutoencoderViT.decoder_embed_dim=512 36 | base/MaskedAutoencoderViT.decoder_depth=8 37 | base/MaskedAutoencoderViT.decoder_num_heads=16 38 | 39 | large/MaskedAutoencoderViT.img_size=224 40 | large/MaskedAutoencoderViT.patch_size=16 41 | large/MaskedAutoencoderViT.embed_dim=1024 42 | large/MaskedAutoencoderViT.depth=24 43 | large/MaskedAutoencoderViT.num_heads=8 44 | large/MaskedAutoencoderViT.decoder_embed_dim=512 45 | large/MaskedAutoencoderViT.decoder_depth=8 46 | large/MaskedAutoencoderViT.decoder_num_heads=16 47 | 48 | huge/MaskedAutoencoderViT.img_size=224 49 | huge/MaskedAutoencoderViT.patch_size=16 50 | huge/MaskedAutoencoderViT.embed_dim=1280 51 | huge/MaskedAutoencoderViT.depth=32 52 | huge/MaskedAutoencoderViT.num_heads=16 53 | huge/MaskedAutoencoderViT.decoder_embed_dim=512 54 | huge/MaskedAutoencoderViT.decoder_depth=8 55 | huge/MaskedAutoencoderViT.decoder_num_heads=16 56 | 57 | # Parameters for SimplePipeline: 58 | # ============================================================================== 59 | SimplePipeline.img_size = 224 60 | SimplePipeline.scale = (0.2, 1.0) 61 | -------------------------------------------------------------------------------- /configs/mae_if.gin: -------------------------------------------------------------------------------- 1 | # Parameters for build_dataset: 2 | # ============================================================================== 3 | build_dataset.transform_fn = @SimpleAugmentation 4 | 5 | # Parameters for build_model: 6 | # ============================================================================== 7 | build_model.model_fn = @base/MaskedAutoencoderViT 8 | 9 | # Parameters for MaskedAutoencoderViT: 10 | # ============================================================================== 11 | 12 | tiny/MaskedAutoencoderViT.img_size=224 13 | tiny/MaskedAutoencoderViT.patch_size=16 14 | tiny/MaskedAutoencoderViT.embed_dim=192 15 | tiny/MaskedAutoencoderViT.depth=12 16 | tiny/MaskedAutoencoderViT.num_heads=12 17 | tiny/MaskedAutoencoderViT.decoder_embed_dim=96 18 | tiny/MaskedAutoencoderViT.decoder_depth=1 19 | tiny/MaskedAutoencoderViT.decoder_num_heads=3 20 | 21 | small/MaskedAutoencoderViT.img_size=224 22 | small/MaskedAutoencoderViT.patch_size=16 23 | small/MaskedAutoencoderViT.embed_dim=384 24 | small/MaskedAutoencoderViT.depth=12 25 | small/MaskedAutoencoderViT.num_heads=6 26 | small/MaskedAutoencoderViT.decoder_embed_dim=512 27 | small/MaskedAutoencoderViT.decoder_depth=4 28 | small/MaskedAutoencoderViT.decoder_num_heads=16 29 | 30 | base/MaskedAutoencoderViT.img_size=224 31 | base/MaskedAutoencoderViT.patch_size=16 32 | base/MaskedAutoencoderViT.embed_dim=768 33 | base/MaskedAutoencoderViT.depth=12 34 | base/MaskedAutoencoderViT.num_heads=6 35 | base/MaskedAutoencoderViT.decoder_embed_dim=512 36 | base/MaskedAutoencoderViT.decoder_depth=8 37 | base/MaskedAutoencoderViT.decoder_num_heads=16 38 | 39 | large/MaskedAutoencoderViT.img_size=224 40 | large/MaskedAutoencoderViT.patch_size=16 41 | large/MaskedAutoencoderViT.embed_dim=1024 42 | large/MaskedAutoencoderViT.depth=24 43 | large/MaskedAutoencoderViT.num_heads=8 44 | large/MaskedAutoencoderViT.decoder_embed_dim=512 45 | large/MaskedAutoencoderViT.decoder_depth=8 46 | large/MaskedAutoencoderViT.decoder_num_heads=16 47 | 48 | huge/MaskedAutoencoderViT.img_size=224 49 | huge/MaskedAutoencoderViT.patch_size=16 50 | huge/MaskedAutoencoderViT.embed_dim=1280 51 | huge/MaskedAutoencoderViT.depth=32 52 | huge/MaskedAutoencoderViT.num_heads=16 53 | huge/MaskedAutoencoderViT.decoder_embed_dim=512 54 | huge/MaskedAutoencoderViT.decoder_depth=8 55 | huge/MaskedAutoencoderViT.decoder_num_heads=16 56 | 57 | # Parameters for SimpleAugmentation: 58 | # ============================================================================== 59 | SimpleAugmentation.img_size = 224 60 | SimpleAugmentation.scale = (0.2, 1.0) 61 | -------------------------------------------------------------------------------- /configs/resnet18.gin: -------------------------------------------------------------------------------- 1 | create_backbone.name='resnet18' 2 | create_backbone.output_stride=8 3 | create_backbone.stem_type='deep' -------------------------------------------------------------------------------- /configs/simclr_ffcv.gin: -------------------------------------------------------------------------------- 1 | # Parameters for build_dataset: 2 | # ============================================================================== 3 | build_dataset.transform_fn = @MultiviewPipeline 4 | 5 | # Parameters for build_model: 6 | # ============================================================================== 7 | build_model.model_fn = @SimCLR 8 | 9 | # Parameters for SimCLR: 10 | # ============================================================================== 11 | 12 | 13 | # Parameters for MultiviewPipeline: 14 | # ============================================================================== 15 | MultiviewPipeline.img_size = 224 16 | MultiviewPipeline.local_crops_number = 0 -------------------------------------------------------------------------------- /configs/simclr_if.gin: -------------------------------------------------------------------------------- 1 | # Parameters for build_dataset: 2 | # ============================================================================== 3 | build_dataset.transform_fn = @DataAugmentationDINO 4 | 5 | # Parameters for build_model: 6 | # ============================================================================== 7 | build_model.model_fn = @SimCLR 8 | 9 | # Parameters for SimCLR: 10 | # ============================================================================== 11 | 12 | 13 | # Parameters for DataAugmentationDINO: 14 | # ============================================================================== 15 | DataAugmentationDINO.img_size = 224 16 | DataAugmentationDINO.local_crops_number = 0 -------------------------------------------------------------------------------- /configs/vitt.gin: -------------------------------------------------------------------------------- 1 | create_backbone.name='vit_tiny_patch16_224' 2 | create_backbone.patch_size=4 3 | create_backbone.depth=12 4 | create_backbone.img_size=32 5 | create_backbone.num_heads=12 -------------------------------------------------------------------------------- /dataset/build_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets 2 | 3 | import numpy as np 4 | import torch 5 | import gin 6 | from PIL import Image 7 | 8 | from torchvision.datasets import VisionDataset 9 | from torchvision.datasets.folder import default_loader 10 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 11 | import os 12 | 13 | # from dataset.cache_dataset import CacheDataset 14 | 15 | from dataset.transform import SimpleAugmentation 16 | def find_classes(directory: str): 17 | """Finds the class folders in a dataset. 18 | 19 | See :class:`DatasetFolder` for details. 20 | """ 21 | 22 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir() and entry.name[0]=='n') 23 | if not classes: 24 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 25 | 26 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 27 | return classes, class_to_idx 28 | 29 | def find_samples(path,split=None): 30 | classes,class_to_idx = find_classes(path) 31 | if split: 32 | split_list = open(split,'r').readlines() 33 | split_list = [i.strip('\n')for i in split_list] 34 | for c in classes: 35 | if not c in split_list: 36 | del class_to_idx[c] 37 | samples = [] 38 | for c,idx in class_to_idx.items(): 39 | for file in os.listdir(os.path.join(path,c)): 40 | samples.append((os.path.join(path,c,file),idx)) 41 | 42 | return samples,class_to_idx 43 | 44 | 45 | class CacheFolders(VisionDataset): 46 | def __init__( 47 | self, 48 | root: str, 49 | samples, 50 | class_to_idx, 51 | loader: Callable[[str], Any] = default_loader, 52 | transform: Optional[Callable] = None, 53 | target_transform: Optional[Callable] = None, 54 | num_cache: int = 10_000, 55 | ): 56 | super().__init__(root, transform=transform, target_transform=target_transform) 57 | classes = list(class_to_idx.keys()) 58 | print(f"find classes: {len(classes)}") 59 | # self.root = root 60 | self.samples = samples 61 | 62 | self.loader = loader 63 | self.classes = classes 64 | self.class_to_idx = class_to_idx 65 | 66 | self.cache = dict() 67 | self.num_cache=num_cache 68 | 69 | def __getitem__(self, index: int): 70 | """ 71 | Args: 72 | index (int): Index 73 | 74 | Returns: 75 | tuple: (sample, target) where target is class_index of the target class. 76 | """ 77 | 78 | path, target = self.samples[index] 79 | # path = os.path.join(self.root, path) 80 | 81 | if index in self.cache: 82 | sample = self.cache[index] 83 | else: 84 | sample = self.loader(path) 85 | if len(self.cache) < self.num_cache: 86 | self.cache[index] = sample 87 | if self.transform is not None: 88 | sample = self.transform(sample) 89 | if self.target_transform is not None: 90 | target = self.target_transform(target) 91 | 92 | return sample, target 93 | 94 | def __len__(self): 95 | return len(self.samples) 96 | 97 | class H5File(): 98 | def __init__(self,root, transform=None, target_transform=None): 99 | import h5py 100 | self.transform = transform 101 | self.target_transform = target_transform 102 | 103 | hdf = h5py.File(root, 'r', ) 104 | self.hdf = hdf 105 | self.images = hdf['images'] 106 | self.targets = hdf['labels'] 107 | 108 | def __len__(self): 109 | return len(self.images) 110 | 111 | def __getitem__(self, index): 112 | img, label = self.images[index],self.targets[index] 113 | img = Image.fromarray(img) 114 | if self.transform is not None: 115 | img = self.transform(img) 116 | if self.target_transform is not None: 117 | label = self.target_transform(label) 118 | return img, label 119 | 120 | 121 | @gin.configurable(denylist=["args"]) 122 | def build_dataset(args,transform_fn=SimpleAugmentation, 123 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], 124 | cached=False): 125 | transform_train = transform_fn(mean=mean,std=std) 126 | args.data_set = args.data_set.lower() 127 | if args.data_set == 'imnet': 128 | # simple augmentation 129 | if cached: 130 | samples,class_to_idx = find_samples(os.path.join(args.data_path,'train')) 131 | dataset_train = CacheFolders('./',samples,class_to_idx,transform=transform_train) 132 | else: 133 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 134 | elif args.data_set == 'ffcv': 135 | from dataset.multiloader import MultiLoader, OrderOption 136 | order = OrderOption.RANDOM if args.distributed else OrderOption.QUASI_RANDOM 137 | dataset_train = MultiLoader(args.data_path, pipelines=transform_train, 138 | batch_size=args.batch_size, num_workers=args.num_workers, 139 | batches_ahead=4, 140 | order=order, distributed=args.distributed,seed=args.seed) 141 | elif args.data_set == 'cifar10': 142 | dataset_train = datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform_train) 143 | elif args.data_set == 'imnet64': 144 | dataset_train = H5File(args.data_path,transform=transform_train) 145 | return dataset_train 146 | 147 | -------------------------------------------------------------------------------- /dataset/ffcv_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gin 3 | 4 | from ffcv.transforms import ToTensor, ToDevice, ToTorchImage, NormalizeImage,RandomHorizontalFlip, View, Convert 5 | from ffcv.transforms.color_jitter import RandomColorJitter 6 | from ffcv.transforms.solarization import RandomSolarization 7 | from ffcv.fields.decoders import IntDecoder, RandomResizedCropRGBImageDecoder, SimpleRGBImageDecoder, CenterCropRGBImageDecoder 8 | 9 | import torch 10 | import torchvision.transforms.v2 as tfms 11 | from torch import nn 12 | 13 | IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) * 255 14 | IMAGENET_STD = np.array([0.229, 0.224, 0.225]) * 255 15 | 16 | @gin.configurable 17 | def SimplePipeline(img_size=224,scale=(0.2,1), ratio=(3.0/4.0, 4.0/3.0), 18 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 19 | image_pipeline = [ 20 | RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,ratio=ratio,), 21 | RandomHorizontalFlip(), 22 | ToTensor(), 23 | ToDevice(torch.device('cuda')), 24 | ToTorchImage(), 25 | Convert(torch.float16), 26 | Normalize(mean=[255*i for i in mean], std=[255*i for i in std], inplace=True), 27 | ] 28 | label_pipeline = [IntDecoder(), ToTensor(),ToDevice(device), View(-1)] 29 | # Pipeline for each data field 30 | pipelines = { 31 | 'image': image_pipeline, 32 | 'label': label_pipeline 33 | } 34 | return pipelines 35 | 36 | from torchvision.transforms import InterpolationMode 37 | 38 | class ThreeAugmentation(nn.Module): 39 | """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" 40 | 41 | def __init__(self, ): 42 | super().__init__() 43 | self.guassian_blur = tfms.GaussianBlur(3,sigma=(0.1,2)) 44 | self.solarize = tfms.RandomSolarize(0,1) 45 | self.grayscale = tfms.RandomGrayscale(p=1) 46 | 47 | def __call__(self, x): 48 | op_index = torch.randint(0,3,(len(x),)) 49 | for i,op in enumerate([self.guassian_blur, 50 | self.solarize, 51 | self.grayscale]): 52 | tf_mask = op_index == i 53 | x[tf_mask] = op(x[tf_mask]) 54 | return x 55 | 56 | def extra_repr(self) -> str: 57 | return "GaussianBlur, Solarize, Grayscale" 58 | 59 | @gin.configurable 60 | def ThreeAugmentPipeline(img_size=224,scale=(0.08,1), color_jitter=None,device='cuda'): 61 | """ 62 | ThreeAugmentPipeline: https://github.com/facebookresearch/deit/blob/main/augment.py 63 | """ 64 | if not color_jitter is None: assert color_jitter >= 0 and color_jitter <= 1 65 | device = torch.device(device) 66 | image_pipeline = ( 67 | # first_tfl 68 | [ RandomResizedCropRGBImageDecoder((img_size, img_size), scale=scale,), 69 | RandomHorizontalFlip(),]+ 70 | # second_tfl 71 | ( [RandomColorJitter(brightness=color_jitter, contrast=color_jitter, saturation=color_jitter,hue=0, p=0.5)] if color_jitter else []) + 72 | # final_tfl 73 | [ 74 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 75 | ToTensor(), ToTorchImage(), 76 | ToDevice(device), 77 | ThreeAugmentation(), 78 | ]) 79 | 80 | label_pipeline = [IntDecoder(), ToTensor(),ToDevice(device),View(-1)] 81 | # Pipeline for each data field 82 | pipelines = { 83 | 'image': image_pipeline, 84 | 'label': label_pipeline 85 | } 86 | return pipelines 87 | 88 | @gin.configurable 89 | def ColorJitterPipeline(img_size=224,scale=(0.08, 1.0),device='cuda'): 90 | device = torch.device(device) 91 | image_pipeline = [ 92 | RandomHorizontalFlip(), 93 | RandomColorJitter(0.8, 0.4, 0.4, 0.2, p=0.1), 94 | RandomSolarization(128,p=0.2), 95 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 96 | ToTensor(), ToTorchImage(), 97 | ToDevice(device,non_blocking=True), 98 | tfms.RandomGrayscale(p=0.1), 99 | tfms.GaussianBlur(3, sigma=(0.1, 2)), 100 | ] 101 | label_pipeline = [IntDecoder(), ToTensor(),ToDevice(device),View(-1)] 102 | # Pipeline for each data field 103 | from ffcv.pipeline import PipelineSpec 104 | pipelines = { 105 | 'image': PipelineSpec("image",RandomResizedCropRGBImageDecoder((img_size, img_size),scale=scale),transforms=image_pipeline), 106 | } 107 | pipelines['label'] = label_pipeline 108 | return pipelines 109 | 110 | @gin.configurable 111 | def MultiviewPipeline(img_size=224,scale=(0.4, 1.0),local_crops_number=0, 112 | local_img_size=96,device='cuda'): 113 | k = local_img_size/img_size 114 | local_scale=(scale[0]*k, scale[1]*k) 115 | 116 | image_pipeline = [ 117 | RandomHorizontalFlip(), 118 | RandomColorJitter(0.8, 0.4, 0.4, 0.2, p=0.1), 119 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 120 | ToTensor(), ToTorchImage(), 121 | ToDevice(torch.device(device),non_blocking=True), 122 | tfms.RandomGrayscale(p=0.1), 123 | tfms.GaussianBlur(3, sigma=(0.1, 2)), 124 | ] 125 | image_pipeline2 = [ 126 | RandomHorizontalFlip(), 127 | RandomColorJitter(0.8, 0.4, 0.4, 0.2, p=0.1), 128 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 129 | ToTensor(), ToTorchImage(), 130 | Convert(torch.float16), 131 | ToDevice(torch.device(device),non_blocking=True), 132 | tfms.RandomGrayscale(p=0.1), 133 | tfms.GaussianBlur(3, sigma=(0.1, 2)), 134 | tfms.RandomSolarize(0,0.2), # asymmetric augmentation 135 | ] 136 | def _local_pipeline(): 137 | return [ 138 | RandomHorizontalFlip(), 139 | RandomColorJitter(0.8, 0.4, 0.4, 0.2, 0.1), 140 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 141 | ToTensor(), ToTorchImage(), 142 | Convert(torch.float16), 143 | ToDevice(torch.device(device),non_blocking=True), 144 | ] 145 | label_pipeline = [IntDecoder(), ToTensor(),View(-1)] 146 | # Pipeline for each data field 147 | from ffcv.pipeline import PipelineSpec 148 | pipelines = { 149 | 'image': PipelineSpec("image",RandomResizedCropRGBImageDecoder((img_size, img_size),scale=scale),transforms=image_pipeline), 150 | 'image2': PipelineSpec("image",RandomResizedCropRGBImageDecoder((img_size, img_size),scale=scale),transforms=image_pipeline2), 151 | } 152 | for i in range(local_crops_number): 153 | pipelines[f"local_{i}"] = PipelineSpec("image",RandomResizedCropRGBImageDecoder((local_img_size, local_img_size),scale=local_scale),transforms=_local_pipeline()) 154 | pipelines['label'] = label_pipeline 155 | return pipelines 156 | 157 | @gin.configurable 158 | def MultiviewPipeline(img_size=224,scale=(0.4, 1.0),local_crops_number=8, 159 | local_img_size=96, 160 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 161 | mean=np.array(mean)*255 162 | std = np.array(std)*255 163 | k = local_img_size/img_size 164 | local_scale=(scale[0]*k, scale[1]*k) 165 | 166 | image_pipeline = [ 167 | RandomHorizontalFlip(), 168 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 169 | ToTensor(), ToTorchImage(), 170 | ToDevice(torch.device(device),non_blocking=True), 171 | ] 172 | image_pipeline2 = [ 173 | RandomHorizontalFlip(), 174 | RandomColorJitter(0.8, 0.4, 0.4, 0.2, p=0.1), 175 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 176 | ToTensor(), ToTorchImage(), 177 | Convert(torch.float16), 178 | ToDevice(torch.device(device),non_blocking=True), 179 | tfms.RandomGrayscale(p=0.1), 180 | tfms.GaussianBlur(3, sigma=(0.1, 2)), 181 | tfms.RandomSolarize(0,0.2), # asymmetric augmentation 182 | ] 183 | def _local_pipeline(): 184 | return [ 185 | RandomHorizontalFlip(), 186 | RandomColorJitter(0.8, 0.4, 0.4, 0.2, 0.1), 187 | NormalizeImage(IMAGENET_MEAN, IMAGENET_STD, np.float32), 188 | ToTensor(), ToTorchImage(), 189 | Convert(torch.float16), 190 | ToDevice(torch.device(device),non_blocking=True), 191 | ] 192 | label_pipeline = [IntDecoder(), ToTensor(),View(-1)] 193 | # Pipeline for each data field 194 | from ffcv.pipeline import PipelineSpec 195 | pipelines = { 196 | 'image': PipelineSpec("image",RandomResizedCropRGBImageDecoder((img_size, img_size),scale=scale),transforms=image_pipeline), 197 | 'image2': PipelineSpec("image",RandomResizedCropRGBImageDecoder((img_size, img_size),scale=scale),transforms=image_pipeline2), 198 | } 199 | for i in range(local_crops_number): 200 | pipelines[f"local_{i}"] = PipelineSpec("image",RandomResizedCropRGBImageDecoder((local_img_size, local_img_size),scale=local_scale),transforms=_local_pipeline()) 201 | pipelines['label'] = label_pipeline 202 | return pipelines 203 | -------------------------------------------------------------------------------- /dataset/multiloader.py: -------------------------------------------------------------------------------- 1 | 2 | from multiprocessing import cpu_count 3 | from typing import Any, Callable, Literal, Mapping, Sequence, Type, Union 4 | import numpy as np 5 | import torch 6 | from ffcv.fields.base import Field 7 | from ffcv.loader import Loader 8 | from ffcv.loader.loader import (DEFAULT_PROCESS_CACHE, ORDER_MAP, ORDER_TYPE, 9 | OrderOption) 10 | from ffcv.memory_managers import (MemoryManager, OSCacheManager, 11 | ProcessCacheManager) 12 | from ffcv.pipeline import Compiler, Pipeline, PipelineSpec 13 | from ffcv.pipeline.graph import Graph 14 | from ffcv.pipeline.operation import Operation 15 | from ffcv.reader import Reader 16 | from ffcv.traversal_order.base import TraversalOrder 17 | 18 | 19 | class MultiLoader(Loader): 20 | def __init__(self, 21 | fname: str, 22 | batch_size: int, 23 | num_workers: int = -1, 24 | cache_type: int = DEFAULT_PROCESS_CACHE, 25 | order: Union[ORDER_TYPE, TraversalOrder] = OrderOption.SEQUENTIAL, 26 | distributed: bool = False, 27 | seed: int = None, # For ordering of samples 28 | indices: Sequence[int] = None, # For subset selection 29 | pipelines: Mapping[str, 30 | Sequence[Union[Operation, torch.nn.Module]]] = {}, 31 | custom_fields: Mapping[str, Type[Field]] = {}, 32 | drop_last: bool = True, 33 | batches_ahead: int = 3, 34 | recompile: bool = False, # Recompile at every epoch 35 | ): 36 | 37 | if distributed and order == OrderOption.RANDOM and (seed is None): 38 | print('Warning: no ordering seed was specified with distributed=True. ' 39 | 'Setting seed to 0 to match PyTorch distributed sampler.') 40 | seed = 0 41 | elif seed is None: 42 | tinfo = np.iinfo('int32') 43 | seed = np.random.randint(0, tinfo.max) 44 | 45 | # We store the original user arguments to be able to pass it to the 46 | # filtered version of the datasets 47 | self._args = { 48 | 'fname': fname, 49 | 'batch_size': batch_size, 50 | 'num_workers': num_workers, 51 | 'os_cache': cache_type, 52 | 'order': order, 53 | 'distributed': distributed, 54 | 'seed': seed, 55 | 'indices': indices, 56 | 'pipelines': pipelines, 57 | 'drop_last': drop_last, 58 | 'batches_ahead': batches_ahead, 59 | 'recompile': recompile 60 | } 61 | self.fname: str = fname 62 | self.batch_size: int = batch_size 63 | self.batches_ahead = batches_ahead 64 | self.seed: int = seed 65 | self.reader: Reader = Reader(self.fname, custom_fields) 66 | self.num_workers: int = num_workers 67 | self.drop_last: bool = drop_last 68 | self.distributed: bool = distributed 69 | self.code = None 70 | self.recompile = recompile 71 | 72 | if self.num_workers < 1: 73 | self.num_workers = cpu_count() 74 | 75 | Compiler.set_num_threads(self.num_workers) 76 | 77 | if indices is None: 78 | self.indices = np.arange(self.reader.num_samples, dtype='uint64') 79 | else: 80 | self.indices = np.array(indices) 81 | 82 | 83 | if cache_type == 0: 84 | self.memory_manager: MemoryManager = OSCacheManager(self.reader) 85 | elif cache_type == 1: 86 | self.memory_manager: MemoryManager = ProcessCacheManager( 87 | self.reader) 88 | elif cache_type == 2: 89 | from ffcv.memory_managers.shared_cache import SharedMemoryManager 90 | self.memory_manager: MemoryManager = SharedMemoryManager(self.reader) 91 | else: 92 | raise ValueError("Unknown cache type. Use 0 for process cache, 1 for os cache, or 2 for no cache.") 93 | 94 | if order in ORDER_MAP: 95 | self.traversal_order: TraversalOrder = ORDER_MAP[order](self) 96 | elif isinstance(order, TraversalOrder): 97 | self.traversal_order: TraversalOrder = order(self) 98 | elif issubclass(order, TraversalOrder): 99 | self.traversal_order: TraversalOrder = order(self) 100 | else: 101 | raise ValueError(f"Order {order} is not a supported order type or a subclass of TraversalOrder") 102 | 103 | memory_read = self.memory_manager.compile_reader() 104 | self.next_epoch: int = 0 105 | 106 | self.pipelines = {} 107 | self.pipeline_specs = {} 108 | self.field_name_to_f_ix = {} 109 | custom_pipeline_specs = {} 110 | 111 | # Creating PipelineSpec objects from the pipeline dict passed 112 | # by the user 113 | for output_name, spec in pipelines.items(): 114 | if isinstance(spec, PipelineSpec): 115 | pass 116 | elif isinstance(spec, Sequence): 117 | spec = PipelineSpec(output_name, decoder=None, transforms=spec) 118 | elif spec is None: 119 | continue # This is a disabled field 120 | else: 121 | msg = f"The pipeline for {output_name} has to be " 122 | msg += f"either a PipelineSpec or a sequence of operations" 123 | raise ValueError(msg) 124 | custom_pipeline_specs[output_name] = spec 125 | 126 | # Adding the default pipelines 127 | default_name_to_f_ix={} 128 | for f_ix, (field_name, field) in enumerate(self.reader.handlers.items()): 129 | default_name_to_f_ix[field_name] = f_ix 130 | 131 | # We add the custom fields after the default ones 132 | # This is to preserve backwards compatibility and make sure the order 133 | # is intuitive 134 | for field_name, spec in custom_pipeline_specs.items(): 135 | # redirect 136 | self.field_name_to_f_ix[field_name] = default_name_to_f_ix[spec.source] 137 | 138 | if field_name not in self.pipeline_specs: 139 | self.pipeline_specs[field_name] = spec 140 | 141 | self.graph = Graph(self.pipeline_specs, self.reader.handlers, 142 | self.field_name_to_f_ix, self.reader.metadata, 143 | memory_read) 144 | 145 | self.generate_code() 146 | self.first_traversal_order = self.next_traversal_order() -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions. 3 | 4 | Mostly copy-paste from torchvision references or other public repos like DETR: 5 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 6 | """ 7 | 8 | import os 9 | import sys 10 | import time 11 | import math 12 | import random 13 | import datetime 14 | import subprocess 15 | from collections import defaultdict, deque 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | import torch.distributed as dist 21 | from PIL import ImageFilter, ImageOps 22 | from torchvision import transforms 23 | from torchvision.transforms.v2 import Normalize, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, RandomApply, RandomGrayscale, Compose, ToDtype 24 | import gin 25 | from PIL import Image 26 | 27 | from typing import List 28 | 29 | IMAGENET_MEAN = np.array([0.485, 0.456, 0.406]) 30 | IMAGENET_STD = np.array([0.229, 0.224, 0.225]) 31 | DEFAULT_CROP_RATIO = 224/256 32 | 33 | class ToDevice(nn.Module): 34 | def __init__(self, device): 35 | super().__init__() 36 | self.device = device 37 | 38 | def forward(self, x:torch.Tensor): 39 | return x.to(self.device,non_blocking=True) 40 | 41 | class GaussianBlur(nn.Module): 42 | """ 43 | Apply Gaussian Blur to the PIL image. 44 | """ 45 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 46 | self.prob = p 47 | self.radius_min = radius_min 48 | self.radius_max = radius_max 49 | 50 | def __call__(self, img): 51 | do_it = random.random() <= self.prob 52 | if not do_it: 53 | return img 54 | 55 | return img.filter( 56 | ImageFilter.GaussianBlur( 57 | radius=random.uniform(self.radius_min, self.radius_max) 58 | ) 59 | ) 60 | 61 | class Solarization(nn.Module): 62 | """ 63 | Apply Solarization to the PIL image. 64 | """ 65 | def __init__(self, p): 66 | self.p = p 67 | 68 | def __call__(self, img): 69 | if random.random() < self.p: 70 | return ImageOps.solarize(img) 71 | else: 72 | return img 73 | 74 | @gin.configurable() 75 | class SimpleAugmentation(nn.Module): 76 | def __init__(self,img_size=224,scale=(0.2, 1.0), 77 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 78 | super().__init__() 79 | # simple augmentation 80 | self.transforms = Compose([ 81 | RandomResizedCrop(img_size, scale=scale, interpolation=Image.BICUBIC), # 3 is bicubic 82 | RandomHorizontalFlip(), 83 | ToTensor(), 84 | # ToDevice('cuda'), 85 | Normalize(mean=mean,std=std)]) 86 | def forward(self,x): 87 | return self.transforms(x) 88 | 89 | def change_resolution(self,img_size): 90 | decoder = self.transforms[0] 91 | decoder.size=(img_size,img_size) 92 | 93 | 94 | @gin.configurable() 95 | class DataAugmentationDINO(nn.Module): 96 | def __init__(self,img_size=224, global_crops_scale=(0.4, 1.), local_crops_scale=(0.05, 0.4), local_crops_number=8, color_jitter=True, 97 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 98 | """Multi-view data augmentation. 99 | 100 | Args: 101 | global_crops_scale (tuple, optional): _description_. Defaults to (0.4, 1.). 102 | local_crops_scale (tuple, optional): _description_. Defaults to (0.05, 0.4). 103 | local_crops_number (int, optional): _description_. Defaults to 8. 104 | 105 | Return: 106 | [2 x global views, local_crops_number x local views] 107 | """ 108 | super().__init__() 109 | flip_and_color_jitter = Compose([ 110 | RandomHorizontalFlip(p=0.5), 111 | RandomApply( 112 | [ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 113 | p=0.8 114 | ), 115 | RandomGrayscale(p=0.2), 116 | ]) 117 | 118 | normalize = Compose([ 119 | ToTensor(), 120 | # ToDevice('cuda'), 121 | Normalize(mean, std), 122 | ]) 123 | 124 | # first global crop 125 | self.global_transfo1 = Compose([ 126 | RandomResizedCrop(img_size, scale=global_crops_scale, interpolation=Image.BICUBIC), 127 | flip_and_color_jitter, 128 | GaussianBlur(5), 129 | normalize, 130 | ]) 131 | # second global crop 132 | self.global_transfo2 = Compose([ 133 | RandomResizedCrop(img_size, scale=global_crops_scale, interpolation=Image.BICUBIC), 134 | flip_and_color_jitter, 135 | GaussianBlur(5), 136 | Solarization(0.2), 137 | normalize, 138 | ]) 139 | # transformation for the local small crops 140 | self.local_crops_number = local_crops_number 141 | self.local_transfo = Compose([ 142 | RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC), 143 | flip_and_color_jitter, 144 | RandomApply([GaussianBlur(5)],p=0.5), 145 | normalize, 146 | ]) 147 | 148 | def __call__(self, image): 149 | crops = [] 150 | crops.append(self.global_transfo1(image)) 151 | crops.append(self.global_transfo2(image)) 152 | for _ in range(self.local_crops_number): 153 | crops.append(self.local_transfo(image)) 154 | return crops -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | # Configure 2 | 3 | ```text 4 | usage: FastSSL pre-training script [--batch_size BATCH_SIZE] [--epochs EPOCHS] [--accum_iter ACCUM_ITER] [--ckpt_freq CKPT_FREQ] 5 | [--no_wandb] [--dynamic_resolution] [--online_prob] [--compile] [-w PRETRAINED_WEIGHTS] 6 | [--opt OPTIMIZER] [--opt_eps EPSILON] [--opt_betas BETA [BETA ...]] [--clip_grad NORM] [--momentum M] 7 | [--weight_decay WEIGHT_DECAY] [--lr LR] [--blr LR] [--min_lr LR] [--warmup_epochs N] 8 | [--num_classes NUM_CLASSES] [--data_set DATA_SET] [--data_path DATA_PATH] [--output_dir OUTPUT_DIR] 9 | [--device DEVICE] [--seed SEED] [--resume RESUME] [--start_epoch N] [--num_workers NUM_WORKERS] 10 | [--pin_mem] [--no_pin_mem] [--world_size WORLD_SIZE] [--local-rank LOCAL_RANK] [--dist_on_itp] 11 | [--dist_url DIST_URL] [--no_resume] [--cfgs CFGS [CFGS ...]] [--gin GIN [GIN ...]] 12 | ``` 13 | There are two types of arguments: program arguments and gin arguments. The program arguments basically controls the training process, such as epochs, optimizer, and output path. The program arguments are the essential parameters to launch the job and must be used. In contrast, the gin arguments are managed by [gin-config](https://github.com/google/gin-config) to configure the behaviour of models in a flexible way. The gin arguments are passed by `--gin k1=v1 [k2=v2 ...]`, or you can read the configure from files `--cfgs *.gin`. 14 | 15 | 16 | The library contains many models requiring different hyperparameters. To resolve the conflict of models (some of them may have the same hyperparameter name) and reduce cohesion between them, gin is the best tool as far as I know to change the parameters without modifying the main. 17 | 18 | ## build_model 19 | `build_model.model_fn=@` defines the entrance of which model would you like to build. You can further pass arguments to the model by adding `build_model.=`. 20 | 21 | ### create_backbone 22 | `create_backbone` will cal `timm.create_model` to create the backbone network utilized in SimCLR and so on. 23 | 24 | ## build_transform 25 | `build_transform.transform_fn=@` defines the entrance of data augmentation, which should be one of 26 | 27 | - `@SimpleAugmentation` or `@SimplePipeline`: simple transforms used in MAE. 28 | - `@DataAugmentationDINO` or `@MultiviewPipeline`: Multiview data augmentation -------------------------------------------------------------------------------- /docs/smalldata.md: -------------------------------------------------------------------------------- 1 | # A Guideline for Small Datasets 2 | 3 | ## Introduction 4 | In this guideline, we will discuss the challenges of working with small datasets and provide some strategies to overcome these challenges. Small datasets provide affordable and quick access to validate your idea, but the results do bot necessarily generalize to larger datasets. Nevertheless, small datasets are useful for understanding the learning algorithms and debugging the code. Especially, when you are a beginner in machine learning without a lot GPU resources, small datasets are a good starting point. Our goal is to provide efficient training recipes for small datasets with competitive performance. Our discussion focuses on CIFAR10 and mini-ImageNet datasets, but the strategies can be applied to other small datasets as well. 5 | 6 | ## Datasets 7 | 8 | We provide an easy script to test the data loading and preprocessing. You can run the following command to download the datasets and test the data loading and preprocessing. 9 | 10 | ```bash 11 | torchrun bin/data_profile.py --data_set cifar10 --data_path ../data/ --gin SimpleAugmentation.img_size=32 --export outputs/cifar10.txt 12 | 13 | torchrun bin/data_profile.py --data_set imnet --data_path ../data/miniImagenet/ --gin SimpleAugmentation.img_size=64 --export outputs/imnet.txt 14 | ``` 15 | 16 | **CIFAR10** is a dataset of 60,000 32x32 color images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 test images. 17 | **IMNET64** is a variant of ImageNet with downsampled images. Here, we use the mini-ImageNet dataset with 64x64 images, which can be found at [https://www.image-net.org/download-images.php]. The dataset contains 1000 classes and 1,281,167 images. 18 | 19 | ## Results 20 | 21 | All models are available in https://wandb.ai/erow/FastSSL. 22 | 23 | ## Pretraining Models 24 | 25 | ### SimCLR 26 | SimCLR is a simple contrastive learning framework that learns representations by maximizing agreement between differently augmented views of the same data sample. We provide a simple implementation of SimCLR in the `models/simclr.py` file. You can run the following command to train the SimCLR model on CIFAR10. 27 | 28 | ```bash 29 | WANDB_NAME=simclr-cifar10 torchrun main_pretrain.py --data_set cifar10 --data_path ../data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --opt lion --blr=1e-4 --cfgs configs/cifar.gin configs/vitt.gin --gin build_model.embed_dim=192 build_model.model_fn=@SimCLR 30 | ``` 31 | 32 | ### MoCo 33 | Momentum contrastive learning is a simple contrastive learning framework that utilizes a momentum encoder to stabilize the training. We provide a simple implementation of MoCo in the `models/moco.py` file. You can run the following command to train the MoCo model on CIFAR10. 34 | 35 | ```bash 36 | WANDB_NAME=moco-cifar10 torchrun --master_port=12387 main_pretrain_ema.py --data_set cifar10 --data_path ../data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --opt lion --blr=1e-4 --cfgs configs/cifar.gin configs/vitt.gin --gin build_model.model_fn=@MoCo MoCo.embed_dim=192 MoCo.mlp_dim=512 MoCo.out_dim=128 37 | ``` 38 | 39 | ### DINO 40 | DINO is a self-distillation framework with momentum encoder that learns representations by maximizing agreement between differently augmented views of the same data sample. We provide a simple implementation of DINO in the `models/dino.py` file. You can run the following command to train the DINO model on CIFAR10. 41 | 42 | ```bash 43 | WANDB_NAME=dino-cifar10 torchrun main_pretrain_ema.py --data_set cifar10 --data_path ../data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --opt lion --blr=1e-4 --cfgs configs/cifar.gin configs/vitt.gin --gin build_model.model_fn=@DINO DINO.embed_dim=192 DINO.out_dim=1024 -m 0.996 44 | ``` 45 | 46 | ### SimSiam 47 | SimSiam is a negative free self-supervised learning framework that learns representations by maximizing agreement between differently augmented views of the same data sample. We provide a simple implementation of SimSiam in the `models/simsiam.py` file. You can run the following command to train the SimSiam model on CIFAR10. 48 | 49 | ```bash 50 | WANDB_NAME=simsiam-cifar10 torchrun main_pretrain.py --data_set cifar10 --data_path ../data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --opt lion --blr=1e-4 --cfgs configs/cifar.gin configs/vitt.gin --gin build_model.model_fn=@SimSiam SimSiam.embed_dim=192 SimSiam.proj_dim=192 SimSiam.mlp_dim=96 51 | ``` 52 | 53 | 54 | ### MAE 55 | Masked autoencoder is a simple pixel reconstruction model that learns to reconstruct the input image from the masked image. We provide a simple implementation of MAE in the `models/mae.py` file. You can run the following command to train the MAE model on CIFAR10. 56 | 57 | ```bash 58 | WANDB_NAME=mae-cifar10 torchrun main_pretrain.py --data_set cifar10 --data_path ../data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --opt lion --blr=1e-4 --cfgs configs/cifar.gin --gin build_dataset.transform_fn=@SimpleAugmentation SimpleAugmentation.img_size=32 build_model.model_fn=@mae_tiny build_model.patch_size=4 build_model.img_size=32 59 | 60 | 61 | WANDB_NAME=amae-cifar10 torchrun main_pretrain.py --data_set cifar10 --data_path ../data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --opt lion --blr=1e-4 --cfgs configs/cifar.gin --gin build_dataset.transform_fn=@SimpleAugmentation SimpleAugmentation.img_size=32 build_model.model_fn=@amae_tiny build_model.patch_size=4 build_model.img_size=32 build_model.decoder_patch_size=2 build_model.sigma=20 62 | ``` 63 | 64 | ### AIM 65 | 66 | ```bash 67 | WANDB_NAME=aim-cifar10 torchrun main_pretrain.py --data_set cifar10 --data_path ../data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --opt lion --blr=1e-4 --clip_grad 1 --cfgs configs/cifar.gin --gin build_dataset.transform_fn=@SimpleAugmentation SimpleAugmentation.img_size=32 build_model.model_fn=@aim_tiny 68 | ``` 69 | 70 | ## Evaluation 71 | 72 | finetune 73 | ```bash 74 | CUDA_VISIBLE_DEVICES=6 WANDB_NAME=simclr-vitt python finetune.py --vit_mlp_ratio=4 --opt lion --lr=5e-5 -w ../FastSSL/outputs/simclr-cifar10-s0/weights.pth --prefix='backbone.(.*)' 75 | 76 | 77 | CUDA_VISIBLE_DEVICES=0 WANDB_NAME=mae-vitt torchrun finetune.py --vit_mlp_ratio=4 --opt lion --lr 1e-4 -w ../FastSSL/outputs/mae-cifar10/weights.pth 78 | 79 | vitrun eval_cls.py --data_set CIFAR10 --data_location ../data/ --gin build_model.model_name=\'vit_tiny_patch16_224\' build_model.patch_size=4 build_model.img_size=32 --input_size=32 --prefix='backbone.(.*)' 80 | 81 | 82 | ``` 83 | 84 | KNN for fast evaluation 85 | ```bash 86 | vitrun eval_knn.py --data_set CIFAR10 --data_location ../data/ --gin build_model.model_name=\'vit_tiny_patch16_224\' build_model.num_heads=12 build_model.patch_size=4 build_model.global_pool=\'avg\' build_model.img_size=32 --input_size=32 --prefix='' -w '' 87 | ``` 88 | 89 | 90 | # VISION TRANSFORMERS IN 2022 AN UPDATE ON TINY IMAGENET : https://arxiv.org/pdf/2205.10660 -------------------------------------------------------------------------------- /environment.txt: -------------------------------------------------------------------------------- 1 | name: FastSSL 2 | channels: 3 | - conda-forge 4 | - nvidia 5 | - pytorch 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - alsa-lib=1.2.11=hd590300_1 10 | - aom=3.8.1=h59595ed_0 11 | - attr=2.5.1=h166bdaf_1 12 | - blas=2.116=mkl 13 | - blas-devel=3.9.0=16_linux64_mkl 14 | - brotli-python=1.1.0=py310hc6cd4ac_1 15 | - bzip2=1.0.8=hd590300_5 16 | - c-ares=1.27.0=hd590300_0 17 | - ca-certificates=2024.2.2=hbcca054_0 18 | - cairo=1.18.0=h3faef2a_0 19 | - certifi=2024.2.2=pyhd8ed1ab_0 20 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 21 | - cuda-cudart=12.1.105=0 22 | - cuda-cupti=12.1.105=0 23 | - cuda-libraries=12.1.0=0 24 | - cuda-nvrtc=12.1.105=0 25 | - cuda-nvtx=12.1.105=0 26 | - cuda-opencl=12.4.99=0 27 | - cuda-runtime=12.1.0=0 28 | - cuda-version=12.4=h3060b56_3 29 | - cupy=13.0.0=py310h7aad9d2_3 30 | - cupy-core=13.0.0=py310had4011e_3 31 | - dav1d=1.2.1=hd590300_0 32 | - dbus=1.13.6=h5008d03_3 33 | - expat=2.6.1=h59595ed_0 34 | - fastrlock=0.8.2=py310hc6cd4ac_2 35 | - ffmpeg=6.1.1=gpl_h8007c5b_104 36 | - filelock=3.13.1=pyhd8ed1ab_0 37 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 38 | - font-ttf-inconsolata=3.000=h77eed37_0 39 | - font-ttf-source-code-pro=2.038=h77eed37_0 40 | - font-ttf-ubuntu=0.83=h77eed37_1 41 | - fontconfig=2.14.2=h14ed4e7_0 42 | - fonts-conda-ecosystem=1=0 43 | - fonts-conda-forge=1=0 44 | - freeglut=3.2.2=hac7e632_2 45 | - freetype=2.12.1=h267a509_2 46 | - fribidi=1.0.10=h36c2ea0_0 47 | - gettext=0.21.1=h27087fc_0 48 | - glib=2.80.0=hf2295e7_0 49 | - glib-tools=2.80.0=hde27a5a_0 50 | - gmp=6.3.0=h59595ed_1 51 | - gmpy2=2.1.2=py310h3ec546c_1 52 | - gnutls=3.7.9=hb077bed_0 53 | - graphite2=1.3.13=h58526e2_1001 54 | - gst-plugins-base=1.22.9=h8e1006c_0 55 | - gstreamer=1.22.9=h98fc4e7_0 56 | - harfbuzz=8.3.0=h3d44ed6_0 57 | - hdf5=1.14.3=nompi_h4f84152_100 58 | - icu=73.2=h59595ed_0 59 | - idna=3.6=pyhd8ed1ab_0 60 | - imath=3.1.11=hfc55251_0 61 | - jasper=4.2.2=he6dfbbe_0 62 | - jinja2=3.1.3=pyhd8ed1ab_0 63 | - keyutils=1.6.1=h166bdaf_0 64 | - krb5=1.21.2=h659d440_0 65 | - lame=3.100=h166bdaf_1003 66 | - lcms2=2.16=hb7c19ff_0 67 | - ld_impl_linux-64=2.40=h41732ed_0 68 | - lerc=4.0.0=h27087fc_0 69 | - libabseil=20240116.1=cxx17_h59595ed_2 70 | - libaec=1.1.2=h59595ed_1 71 | - libass=0.17.1=h8fe9dca_1 72 | - libblas=3.9.0=16_linux64_mkl 73 | - libcap=2.69=h0f662aa_0 74 | - libcblas=3.9.0=16_linux64_mkl 75 | - libclang=15.0.7=default_hb11cfb5_4 76 | - libclang13=15.0.7=default_ha2b6cf4_4 77 | - libcublas=12.1.0.26=0 78 | - libcufft=11.0.2.4=0 79 | - libcufile=1.9.0.20=0 80 | - libcups=2.3.3=h4637d8d_4 81 | - libcurand=10.3.5.119=0 82 | - libcurl=8.5.0=hca28451_0 83 | - libcusolver=11.4.4.55=0 84 | - libcusparse=12.0.2.55=0 85 | - libdeflate=1.19=hd590300_0 86 | - libdrm=2.4.120=hd590300_0 87 | - libedit=3.1.20191231=he28a2e2_2 88 | - libev=4.33=hd590300_2 89 | - libevent=2.1.12=hf998b51_1 90 | - libexpat=2.6.1=h59595ed_0 91 | - libffi=3.4.2=h7f98852_5 92 | - libflac=1.4.3=h59595ed_0 93 | - libgcc-ng=13.2.0=h807b86a_5 94 | - libgcrypt=1.10.3=hd590300_0 95 | - libgfortran-ng=13.2.0=h69a702a_5 96 | - libgfortran5=13.2.0=ha4646dd_5 97 | - libglib=2.80.0=hf2295e7_0 98 | - libglu=9.0.0=hac7e632_1003 99 | - libgpg-error=1.48=h71f35ed_0 100 | - libhwloc=2.9.3=default_h554bfaf_1009 101 | - libiconv=1.17=hd590300_2 102 | - libidn2=2.3.7=hd590300_0 103 | - libjpeg-turbo=3.0.0=hd590300_1 104 | - liblapack=3.9.0=16_linux64_mkl 105 | - liblapacke=3.9.0=16_linux64_mkl 106 | - libllvm14=14.0.6=hcd5def8_4 107 | - libllvm15=15.0.7=hb3ce162_4 108 | - libnghttp2=1.58.0=h47da74e_1 109 | - libnpp=12.0.2.50=0 110 | - libnsl=2.0.1=hd590300_0 111 | - libnvjitlink=12.1.105=0 112 | - libnvjpeg=12.1.1.14=0 113 | - libogg=1.3.4=h7f98852_1 114 | - libopencv=4.9.0=py310hfbccb02_9 115 | - libopenvino=2023.3.0=h2e90f83_2 116 | - libopenvino-auto-batch-plugin=2023.3.0=hd5fc58b_2 117 | - libopenvino-auto-plugin=2023.3.0=hd5fc58b_2 118 | - libopenvino-hetero-plugin=2023.3.0=h3ecfda7_2 119 | - libopenvino-intel-cpu-plugin=2023.3.0=h2e90f83_2 120 | - libopenvino-intel-gpu-plugin=2023.3.0=h2e90f83_2 121 | - libopenvino-ir-frontend=2023.3.0=h3ecfda7_2 122 | - libopenvino-onnx-frontend=2023.3.0=h469e5c9_2 123 | - libopenvino-paddle-frontend=2023.3.0=h469e5c9_2 124 | - libopenvino-pytorch-frontend=2023.3.0=h59595ed_2 125 | - libopenvino-tensorflow-frontend=2023.3.0=he1e0747_2 126 | - libopenvino-tensorflow-lite-frontend=2023.3.0=h59595ed_2 127 | - libopus=1.3.1=h7f98852_1 128 | - libpciaccess=0.18=hd590300_0 129 | - libpng=1.6.43=h2797004_0 130 | - libpq=16.2=h33b98f1_0 131 | - libprotobuf=4.25.2=h08a7969_1 132 | - libsndfile=1.2.2=hc60ed4a_1 133 | - libsqlite=3.45.2=h2797004_0 134 | - libssh2=1.11.0=h0841786_0 135 | - libstdcxx-ng=13.2.0=h7e041cc_5 136 | - libsystemd0=255=h3516f8a_1 137 | - libtasn1=4.19.0=h166bdaf_0 138 | - libtiff=4.6.0=ha9c0a0a_2 139 | - libunistring=0.9.10=h7f98852_0 140 | - libuuid=2.38.1=h0b41bf4_0 141 | - libva=2.21.0=hd590300_0 142 | - libvorbis=1.3.7=h9c3ff4c_0 143 | - libvpx=1.13.1=h59595ed_0 144 | - libwebp-base=1.3.2=hd590300_0 145 | - libxcb=1.15=h0b41bf4_0 146 | - libxcrypt=4.4.36=hd590300_1 147 | - libxkbcommon=1.6.0=hd429924_1 148 | - libxml2=2.12.5=h232c23b_0 149 | - libzlib=1.2.13=hd590300_5 150 | - llvm-openmp=15.0.7=h0cdce71_0 151 | - llvmlite=0.42.0=py310h1b8f574_1 152 | - lz4-c=1.9.4=hcb278e6_0 153 | - markupsafe=2.1.5=py310h2372a71_0 154 | - mkl=2022.1.0=h84fe81f_915 155 | - mkl-devel=2022.1.0=ha770c72_916 156 | - mkl-include=2022.1.0=h84fe81f_915 157 | - mpc=1.3.1=hfe3b2da_0 158 | - mpfr=4.2.1=h9458935_0 159 | - mpg123=1.32.4=h59595ed_0 160 | - mpmath=1.3.0=pyhd8ed1ab_0 161 | - mysql-common=8.0.33=hf1915f5_6 162 | - mysql-libs=8.0.33=hca2cd23_6 163 | - ncurses=6.4=h59595ed_2 164 | - nettle=3.9.1=h7ab15ed_0 165 | - networkx=3.2.1=pyhd8ed1ab_0 166 | - nspr=4.35=h27087fc_0 167 | - nss=3.98=h1d7d5a4_0 168 | - numba=0.59.0=py310h7dc5dd1_1 169 | - numpy=1.26.4=py310hb13e2d6_0 170 | - ocl-icd=2.3.2=hd590300_0 171 | - opencv=4.9.0=py310h949e142_9 172 | - openexr=3.2.2=haf962dd_1 173 | - openh264=2.4.1=h59595ed_0 174 | - openjpeg=2.5.2=h488ebb8_0 175 | - openssl=3.2.1=hd590300_0 176 | - p11-kit=0.24.1=hc5aa10d_0 177 | - pcre2=10.43=hcad00b1_0 178 | - pillow=10.2.0=py310h01dd4db_0 179 | - pip=24.0=pyhd8ed1ab_0 180 | - pixman=0.43.2=h59595ed_0 181 | - pkg-config=0.29.2=h36c2ea0_1008 182 | - pthread-stubs=0.4=h36c2ea0_1001 183 | - pugixml=1.14=h59595ed_0 184 | - pulseaudio-client=16.1=hb77b528_5 185 | - py-opencv=4.9.0=py310hadec9d8_9 186 | - pysocks=1.7.1=pyha2e5f31_6 187 | - python=3.10.13=hd12c33a_1_cpython 188 | - python_abi=3.10=4_cp310 189 | - pytorch=2.2.1=py3.10_cuda12.1_cudnn8.9.2_0 190 | - pytorch-cuda=12.1=ha16c6d3_5 191 | - pytorch-mutex=1.0=cuda 192 | - pyyaml=6.0.1=py310h2372a71_1 193 | - qt-main=5.15.8=h5810be5_19 194 | - readline=8.2=h8228510_1 195 | - requests=2.31.0=pyhd8ed1ab_0 196 | - setuptools=69.1.1=pyhd8ed1ab_0 197 | - snappy=1.1.10=h9fff704_0 198 | - svt-av1=1.8.0=h59595ed_0 199 | - sympy=1.12=pypyh9d50eac_103 200 | - tbb=2021.11.0=h00ab1b0_1 201 | - tk=8.6.13=noxft_h4845f30_101 202 | - torchtriton=2.2.0=py310 203 | - torchvision=0.17.1=py310_cu121 204 | - typing_extensions=4.10.0=pyha770c72_0 205 | - tzdata=2024a=h0c530f3_0 206 | - urllib3=2.2.1=pyhd8ed1ab_0 207 | - wheel=0.42.0=pyhd8ed1ab_0 208 | - x264=1!164.3095=h166bdaf_2 209 | - x265=3.5=h924138e_3 210 | - xcb-util=0.4.0=hd590300_1 211 | - xcb-util-image=0.4.0=h8ee46fc_1 212 | - xcb-util-keysyms=0.4.0=h8ee46fc_1 213 | - xcb-util-renderutil=0.3.9=hd590300_1 214 | - xcb-util-wm=0.4.1=h8ee46fc_1 215 | - xkeyboard-config=2.41=hd590300_0 216 | - xorg-fixesproto=5.0=h7f98852_1002 217 | - xorg-inputproto=2.3.2=h7f98852_1002 218 | - xorg-kbproto=1.0.7=h7f98852_1002 219 | - xorg-libice=1.1.1=hd590300_0 220 | - xorg-libsm=1.2.4=h7391055_0 221 | - xorg-libx11=1.8.7=h8ee46fc_0 222 | - xorg-libxau=1.0.11=hd590300_0 223 | - xorg-libxdmcp=1.1.3=h7f98852_0 224 | - xorg-libxext=1.3.4=h0b41bf4_2 225 | - xorg-libxfixes=5.0.3=h7f98852_1004 226 | - xorg-libxi=1.7.10=h7f98852_0 227 | - xorg-libxrender=0.9.11=hd590300_0 228 | - xorg-renderproto=0.11.1=h7f98852_1002 229 | - xorg-xextproto=7.3.0=h0b41bf4_1003 230 | - xorg-xf86vidmodeproto=2.3.1=h7f98852_1002 231 | - xorg-xproto=7.0.31=h7f98852_1007 232 | - xz=5.2.6=h166bdaf_0 233 | - yaml=0.2.5=h7f98852_2 234 | - zlib=1.2.13=hd590300_5 235 | - zstd=1.5.5=hfc55251_0 236 | 237 | -------------------------------------------------------------------------------- /layers/backbone.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import timm 3 | 4 | @gin.configurable 5 | def create_backbone(name='resnet50', **kwargs): 6 | return timm.create_model(name, num_classes=0, **kwargs) -------------------------------------------------------------------------------- /layers/build_model.py: -------------------------------------------------------------------------------- 1 | import gin 2 | from timm.layers import convert_sync_batchnorm 3 | from model.mae import * 4 | 5 | @gin.configurable() 6 | def build_model(args, model_fn=mae_vit_base_patch16, **kwargs): 7 | model = model_fn(**kwargs) 8 | model = convert_sync_batchnorm(model) 9 | return model 10 | -------------------------------------------------------------------------------- /layers/operation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.distributed import group, ReduceOp, is_initialized 7 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 | from torch import Tensor 9 | 10 | 11 | # utils 12 | @torch.no_grad() 13 | def concat_all_gather(tensor): 14 | """ 15 | Performs all_gather operation on the provided tensors. 16 | *** Warning ***: torch.distributed.all_gather has no gradient. 17 | """ 18 | if not is_initialized(): 19 | return tensor 20 | tensors_gather = [torch.ones_like(tensor) 21 | for _ in range(torch.distributed.get_world_size())] 22 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 23 | 24 | output = torch.cat(tensors_gather, dim=0) 25 | return output 26 | 27 | def contrastive_loss(q, k,temperature=0.1): 28 | # NT-Xent (the normalized temperature-scaled cross entropy loss), applied in [Improved Deep Metric Learning with Multi-class N-pair Loss Objective] 29 | # normalize 30 | q = nn.functional.normalize(q, dim=1) 31 | k = nn.functional.normalize(k, dim=1) 32 | # gather all targets 33 | k = concat_all_gather(k) 34 | # Einstein sum is more intuitive 35 | logits = torch.einsum('nc,mc->nm', [q, k]) / temperature 36 | N = logits.shape[0] # batch size per GPU 37 | rank = torch.distributed.get_rank() if is_initialized() else 0 38 | labels = (torch.arange(N, dtype=torch.long) + N * rank).to(logits.device) 39 | return nn.CrossEntropyLoss()(logits, labels) 40 | 41 | 42 | class AllGatherGrad(torch.autograd.Function): 43 | @staticmethod 44 | def forward( 45 | ctx: Any, 46 | tensor: Tensor, 47 | group: Optional["torch.distributed.ProcessGroup"] = group.WORLD, 48 | ) -> Tensor: 49 | ctx.group = group 50 | 51 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] 52 | 53 | torch.distributed.all_gather(gathered_tensor, tensor, group=group) 54 | gathered_tensor = torch.stack(gathered_tensor, dim=0) 55 | 56 | return gathered_tensor 57 | 58 | @staticmethod 59 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]: 60 | # print("backward------------->") 61 | # print(grad_output) 62 | grad_output = torch.cat(grad_output) 63 | 64 | torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) 65 | 66 | return grad_output[torch.distributed.get_rank()], None 67 | 68 | def concat_all_gather_grad(tensor): 69 | """ 70 | Performs all_gather operation on the provided tensors. 71 | *** Warning ***: torch.distributed.all_gather has no gradient. 72 | """ 73 | if not is_initialized(): 74 | return tensor 75 | return AllGatherGrad.apply(tensor).flatten(0,1) 76 | 77 | def build_head(num_layers, input_dim, mlp_dim, output_dim, hidden_bn=True,activation=nn.ReLU, 78 | last_norm='bn',): 79 | mlp = [] 80 | for l in range(num_layers): 81 | dim1 = input_dim if l == 0 else mlp_dim 82 | dim2 = output_dim if l == num_layers - 1 else mlp_dim 83 | 84 | if l == num_layers-1: 85 | mlp.append(nn.Linear(dim1, dim2, bias=False)) 86 | else: 87 | mlp.append(nn.Linear(dim1, dim2, bias=True)) 88 | 89 | if l < num_layers - 1: 90 | if hidden_bn: 91 | mlp.append(nn.BatchNorm1d(dim2)) 92 | mlp.append(activation()) 93 | else: 94 | if last_norm=='bn': 95 | # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 96 | # for simplicity, we further removed gamma in BN 97 | mlp.append(nn.BatchNorm1d(dim2, affine=False)) 98 | elif last_norm=='ln': 99 | mlp.append(nn.LayerNorm(dim2)) 100 | elif last_norm=='none': 101 | pass 102 | else: 103 | raise NotImplementedError(f"last_norm={last_norm} not implemented") 104 | 105 | return nn.Sequential(*mlp) 106 | 107 | 108 | def patchify(imgs,p=16): 109 | """ 110 | imgs: (N, 3, H, W) 111 | x: (N, L, patch_size**2 *3) 112 | """ 113 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 114 | 115 | h = w = imgs.shape[2] // p 116 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 117 | x = torch.einsum('nchpwq->nhwpqc', x) 118 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 119 | return x 120 | 121 | def unpatchify(x,p=16): 122 | """ 123 | x: (N, L, patch_size**2 *3) 124 | imgs: (N, 3, H, W) 125 | """ 126 | h = w = int(x.shape[1]**.5) 127 | assert h * w == x.shape[1] 128 | 129 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 130 | x = torch.einsum('nhwpqc->nchpwq', x) 131 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 132 | return imgs 133 | 134 | def pixel_norm(target): 135 | """ 136 | target: (N, L, C) 137 | """ 138 | mean = target.mean(dim=-1, keepdim=True) 139 | var = target.var(dim=-1, keepdim=True) 140 | target = (target - mean) / (var + 1.e-6)**.5 141 | return target 142 | 143 | -------------------------------------------------------------------------------- /layers/target.py: -------------------------------------------------------------------------------- 1 | import math 2 | import einops 3 | import gin 4 | import torch 5 | from torch import nn 6 | from .operation import patchify, unpatchify 7 | 8 | 9 | @gin.configurable 10 | class TargetMSE(nn.Module): 11 | def __init__(self,norm_pix_loss=True,patch_size=16,ignore_mask=False): 12 | super().__init__() 13 | self.norm_pix_loss = norm_pix_loss 14 | self.patch_size = patch_size 15 | self.ignore_mask = ignore_mask 16 | 17 | def forward(self, imgs, pred,mask=None): 18 | """ 19 | imgs: [N, 3, H, W] 20 | pred: [N, L, p*p*3] 21 | mask: [N, L], 0 is keep, 1 is remove, 22 | """ 23 | if len(pred.shape)==4: 24 | pred = patchify(pred,self.patch_size) 25 | 26 | target = patchify(imgs,self.patch_size) 27 | if self.norm_pix_loss: 28 | mean = target.mean(dim=-1, keepdim=True) 29 | var = target.var(dim=-1, keepdim=True) 30 | target = (target - mean) / (var + 1.e-6)**.5 31 | 32 | loss = (pred - target) ** 2 33 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 34 | 35 | if mask is None or self.ignore_mask: 36 | loss = loss.mean() 37 | else: 38 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 39 | return loss 40 | 41 | @gin.configurable 42 | class TargetSSIM(nn.Module): 43 | def __init__(self,patch_size=16,ignore_mask=False): 44 | super().__init__() 45 | self.patch_size = patch_size 46 | self.ignore_mask = ignore_mask 47 | from torchmetrics.image import ssim 48 | self.ssim_loss = ssim.StructuralSimilarityIndexMeasure() 49 | 50 | def forward(self, imgs, pred,mask=None): 51 | """ 52 | imgs: [N, 3, H, W] 53 | pred: [N, 3, H, W] 54 | mask: [N, L], 0 is keep, 1 is remove, 55 | """ 56 | if len(pred.shape)==3: 57 | pred = unpatchify(pred,self.patch_size) 58 | 59 | target = imgs 60 | if mask is None or self.ignore_mask: 61 | loss = self.ssim_loss(pred,target) 62 | else: 63 | mask = mask.unsqueeze(-1).expand(-1,-1,3 * self.patch_size**2) 64 | mask = unpatchify(mask,self.patch_size) 65 | loss = self.ssim_loss(pred*mask,target*mask) 66 | return 1 - loss 67 | 68 | 69 | def get_gkern(kernlen, std): 70 | """Returns a 2D Gaussian kernel array.""" 71 | 72 | def _gaussian_fn(kernlen, std): 73 | n = torch.arange(0, kernlen).float() 74 | n -= n.mean() 75 | n /= std 76 | w = torch.exp(-0.5 * n**2) 77 | return w 78 | 79 | gkern1d = _gaussian_fn(kernlen, std) 80 | gkern2d = torch.outer(gkern1d, gkern1d) 81 | return gkern2d / gkern2d.sum() 82 | 83 | class HOGLayerC(nn.Module): 84 | # copy from https://github.com/facebookresearch/SlowFast/blob/2efb99faa254075b4e28d3d4f313052b51da05bc/slowfast/models/operators.py#L66 85 | def __init__(self, nbins=9, pool=7, gaussian_window=0): 86 | super(HOGLayerC, self).__init__() 87 | self.nbins = nbins 88 | self.pool = pool 89 | self.pi = math.pi 90 | weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 91 | weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1) 92 | weight_y = weight_x.transpose(2, 3) 93 | self.register_buffer("weight_x", weight_x) 94 | self.register_buffer("weight_y", weight_y) 95 | 96 | self.gaussian_window = gaussian_window 97 | if gaussian_window: 98 | gkern = get_gkern(gaussian_window, gaussian_window // 2) 99 | self.register_buffer("gkern", gkern) 100 | 101 | @torch.no_grad() 102 | def forward(self, x): 103 | # input is RGB image with shape [B 3 H W] 104 | x = F.pad(x, pad=(1, 1, 1, 1), mode="reflect") 105 | gx_rgb = F.conv2d( 106 | x, self.weight_x, bias=None, stride=1, padding=0, groups=3 107 | ) 108 | gy_rgb = F.conv2d( 109 | x, self.weight_y, bias=None, stride=1, padding=0, groups=3 110 | ) 111 | norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1) 112 | phase = torch.atan2(gx_rgb, gy_rgb) 113 | phase = phase / self.pi * self.nbins # [-9, 9] 114 | 115 | b, c, h, w = norm_rgb.shape 116 | out = torch.zeros( 117 | (b, c, self.nbins, h, w), dtype=torch.float, device=x.device 118 | ) 119 | phase = phase.view(b, c, 1, h, w) 120 | norm_rgb = norm_rgb.view(b, c, 1, h, w) 121 | if self.gaussian_window: 122 | if h != self.gaussian_window: 123 | assert h % self.gaussian_window == 0, "h {} gw {}".format( 124 | h, self.gaussian_window 125 | ) 126 | repeat_rate = h // self.gaussian_window 127 | temp_gkern = self.gkern.repeat([repeat_rate, repeat_rate]) 128 | else: 129 | temp_gkern = self.gkern 130 | norm_rgb *= temp_gkern 131 | 132 | out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb) 133 | 134 | out = out.unfold(3, self.pool, self.pool) 135 | out = out.unfold(4, self.pool, self.pool) 136 | out = out.sum(dim=[-1, -2]) 137 | 138 | out = torch.nn.functional.normalize(out, p=2, dim=2) 139 | 140 | return out # B 3 nbins H W 141 | @gin.configurable 142 | class TargetHOG(nn.Module): 143 | def __init__(self, patch_size=16,ignore_mask=False,pool=4): 144 | super().__init__() 145 | self.patch_size = patch_size 146 | self.ignore_mask = ignore_mask 147 | self.hog = HOGLayerC(pool=pool) 148 | 149 | self.feat_size = (patch_size//pool) 150 | 151 | def forward(self, imgs, pred,mask=None): 152 | """ 153 | imgs: [N, 3, H, W] 154 | pred: [N, L, p*p*3] 155 | mask: [N, L], 0 is keep, 1 is remove, 156 | """ 157 | if len(pred.shape)==4: 158 | pred = patchify(pred,self.patch_size) 159 | 160 | hog_feat = self.hog(imgs) # [N, 3, Orientation, H, W] 161 | hog_feat = einops.rearrange(hog_feat,'n l c (h p1) (w p2) -> n (h w) (l c p1 p2)',p1=self.feat_size,p2=self.feat_size) # [N,L, C] 162 | 163 | loss = (pred - hog_feat) ** 2 164 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 165 | 166 | if mask is None or self.ignore_mask: 167 | loss = loss.mean() 168 | else: 169 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 170 | return loss 171 | 172 | @gin.configurable 173 | def build_target(*args,target_fn=TargetMSE,**kwargs): 174 | return target_fn(*args,**kwargs) 175 | -------------------------------------------------------------------------------- /main_pretrain_ema.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import sys 4 | import main_pretrain 5 | from main_pretrain import main, get_args_parser 6 | import torch 7 | import util.misc as misc 8 | import util.lr_sched as lr_sched 9 | 10 | def train_one_epoch(model: torch.nn.Module,online_prob, 11 | data_loader, optimizer: torch.optim.Optimizer, 12 | device: torch.device, epoch: int, loss_scaler, 13 | log_writer=None, 14 | args=None): 15 | model.train(True) 16 | metric_logger = misc.MetricLogger(delimiter=" ") 17 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=20, fmt='{value:.6f}')) 18 | metric_logger.add_meter('m', misc.SmoothedValue(window_size=20, fmt='{value:.6f}')) 19 | header = 'Epoch: [{}]'.format(epoch) 20 | print_freq = 20 21 | 22 | accum_iter = args.accum_iter 23 | 24 | optimizer.zero_grad() 25 | 26 | if log_writer is not None: 27 | print('log_dir: {}'.format(log_writer.log_dir)) 28 | 29 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 30 | if args.data_set == "ffcv": 31 | samples = data[:-1] 32 | targets = data[-1] 33 | else: 34 | samples, targets = data 35 | 36 | if isinstance(samples,list) or isinstance(samples,tuple): 37 | samples = [i.to(device, non_blocking=True) for i in samples] 38 | if len(samples)==1: 39 | samples = samples[0] 40 | else: 41 | samples = samples.to(device, non_blocking=True) 42 | targets = targets.to(device, non_blocking=True).flatten() 43 | 44 | # we use a per iteration (instead of per epoch) lr scheduler 45 | if data_iter_step % accum_iter == 0: 46 | epoch_i = data_iter_step / len(data_loader) + epoch 47 | lr_sched.adjust_learning_rate(optimizer, epoch_i, args) 48 | m = lr_sched.adjust_moco_momentum(epoch_i, args) 49 | model.module.update(m) 50 | with torch.amp.autocast('cuda',dtype=torch.float16): 51 | loss, log = model(samples,targets=targets, epoch=epoch) 52 | 53 | loss_value = loss.item() 54 | 55 | if not math.isfinite(loss_value): 56 | print("Loss is {}, stopping training".format(loss_value)) 57 | torch.save(model.module, "nan_model.pt") 58 | sys.exit(1) 59 | 60 | loss /= accum_iter 61 | loss_scaler(loss, optimizer, parameters=model.parameters(), 62 | update_grad=(data_iter_step + 1) % accum_iter == 0) 63 | if (data_iter_step + 1) % accum_iter == 0: 64 | optimizer.zero_grad() 65 | 66 | if online_prob: 67 | log.update(online_prob.step(samples,targets)) 68 | 69 | torch.cuda.synchronize() 70 | 71 | metric_logger.update(loss=loss_value) 72 | 73 | lr = optimizer.param_groups[-1]["lr"] 74 | metric_logger.update(lr=lr) 75 | metric_logger.update(m=m) 76 | for k,v in log.items(): 77 | metric_logger.update(**{k:v}) 78 | 79 | 80 | loss_value_reduce = misc.all_reduce_mean(loss_value) 81 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 82 | """ We use epoch_1000x as the x-axis in tensorboard. 83 | This calibrates different curves when batch size changes. 84 | """ 85 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 86 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 87 | log_writer.add_scalar('epoch_1000x',epoch_1000x) 88 | log_writer.add_scalar('lr', lr, epoch_1000x) 89 | for k,v in log.items(): 90 | log_writer.add_scalar(f'{k}', v, epoch_1000x) 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | 97 | 98 | if __name__ == '__main__': 99 | # replace with the new train function with momentum 100 | from util.helper import aug_parse 101 | main_pretrain.train_one_epoch = train_one_epoch 102 | parser = get_args_parser() 103 | parser.add_argument("-m",type=float, default=0.996) 104 | args = aug_parse(parser) 105 | main(args) 106 | 107 | """ 108 | When EMA helps? 109 | - On the Pros and Cons of Momentum Encoder: https://arxiv.org/pdf/2208.05744 110 | 111 | 112 | """ -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import model.simclr 2 | import model.mae 3 | # import model.asymmae 4 | import model.moco 5 | import model.simsiam 6 | import model.dino 7 | import model.aim 8 | import model.vcl 9 | import model.mcl -------------------------------------------------------------------------------- /model/aim.py: -------------------------------------------------------------------------------- 1 | """AIM: Autoregressive Image Models 2 | Reference: https://github.com/apple/ml-aim 3 | 4 | # Note: 5 | 6 | ## key points: 7 | - Prefix causal attention: 8 | """ 9 | 10 | import math 11 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 12 | 13 | import gin 14 | from huggingface_hub import PyTorchModelHubMixin 15 | 16 | import torch 17 | from torch import nn 18 | 19 | import layers.aim_vit as layers 20 | from layers.operation import patchify 21 | 22 | __all__ = [ 23 | "Transformer", 24 | "AIMPretrain", 25 | "AIMForImageClassification", 26 | "aim_600M", 27 | "aim_1B", 28 | "aim_3B", 29 | "aim_7B", 30 | ] 31 | 32 | ArrayLike = Any 33 | Module = Callable[..., Any] 34 | 35 | 36 | class AIMMixin: 37 | preprocessor: Module 38 | trunk: Module 39 | head: Module 40 | 41 | def forward( 42 | self, 43 | x: ArrayLike, 44 | mask: Optional[ArrayLike] = None, 45 | max_block_id: Optional[int] = -1, 46 | ) -> ArrayLike: 47 | x = self.preprocessor(x) 48 | x, _ = self.trunk(x, mask=mask, max_block_id=max_block_id) 49 | logits = self.head(x) 50 | return logits 51 | 52 | def extract_features( 53 | self, 54 | x: ArrayLike, 55 | mask: Optional[ArrayLike] = None, 56 | max_block_id: Optional[int] = -1, 57 | ) -> List[ArrayLike]: 58 | x = self.preprocessor(x, mask=mask) 59 | feats = self.trunk( 60 | x, mask=mask, max_block_id=max_block_id, return_features=True 61 | ) 62 | return feats 63 | 64 | class Transformer(nn.Module): 65 | def __init__( 66 | self, 67 | attn_target: Callable[[bool], nn.Module], 68 | embed_dim: int, 69 | num_blocks: int, 70 | ffn_target: Callable[..., nn.Module] = layers.MLP, 71 | post_transformer_layer: Optional[nn.Module] = None, 72 | norm_layer: Callable[[int], nn.Module] = layers.LayerNorm, 73 | mlp_ratio: int = 4, 74 | mlp_hidden_dim: Optional[int] = None, 75 | ffn_dropout_rate: float = 0.0, 76 | use_bias: bool = False, 77 | post_trunk_norm: bool = True, 78 | ): 79 | super().__init__() 80 | if mlp_hidden_dim is None: 81 | mlp_hidden_dim = int(mlp_ratio * embed_dim) 82 | 83 | self.blocks = nn.ModuleList( 84 | [ 85 | layers.Block( 86 | dim=embed_dim, 87 | attn_target=attn_target, 88 | ffn_target=ffn_target, 89 | mlp_hidden_dim=mlp_hidden_dim, 90 | norm_layer=norm_layer, 91 | ffn_dropout_rate=ffn_dropout_rate, 92 | use_bias=use_bias, 93 | ) 94 | for _ in range(num_blocks) 95 | ] 96 | ) 97 | self.post_trunk_norm = norm_layer(embed_dim) if post_trunk_norm else None 98 | self.post_transformer_layer = post_transformer_layer 99 | 100 | def forward( 101 | self, 102 | tokens: torch.Tensor, 103 | mask: Optional[torch.Tensor] = None, 104 | max_block_id: Optional[int] = -1, 105 | return_features: bool = False, 106 | ) -> Union[Tuple[torch.Tensor, List[torch.Tensor]], List[torch.Tensor]]: 107 | # only evaluate up to the max block id 108 | if max_block_id is None: 109 | assert ( 110 | self.post_transformer_layer is not None 111 | ), "Unable to determine the max block id." 112 | max_block_id = self.post_transformer_layer.max_block_id 113 | 114 | features = [] 115 | for blk_id, blk in enumerate(self.blocks): 116 | tokens = blk(tokens, mask=mask) 117 | features.append(tokens) 118 | 119 | if blk_id == max_block_id: 120 | break 121 | 122 | if return_features: 123 | return features 124 | 125 | if self.post_trunk_norm is not None: 126 | tokens = self.post_trunk_norm(tokens) 127 | 128 | if self.post_transformer_layer is not None: 129 | tokens = self.post_transformer_layer(tokens, layer_features=features) 130 | 131 | return tokens, features 132 | 133 | 134 | 135 | @gin.configurable() 136 | class AIMPretrain(nn.Module): 137 | def __init__(self, preprocessor: nn.Module, trunk: nn.Module, head: layers.ReconstructionHead, 138 | norm_pix_loss=True, prefix_len=16, 139 | ): 140 | super().__init__() 141 | self.preprocessor = preprocessor 142 | self.trunk = trunk 143 | self.head = head 144 | self.norm_pix_loss = norm_pix_loss 145 | self.prefix_len = prefix_len 146 | 147 | 148 | def forward( 149 | self, 150 | imgs: ArrayLike, 151 | mask: Optional[ArrayLike] = None, 152 | max_block_id: Optional[int] = -1, 153 | **kwargs: Any, 154 | ) -> ArrayLike: 155 | 156 | x = self.preprocessor(imgs) 157 | L = x.shape[1] 158 | if mask is None: 159 | mask = torch.ones(L, L, dtype=torch.bool,device=x.device).tril(diagonal=0) 160 | mask[:, :self.prefix_len] = True # set the first prefix_len tokens to be visible 161 | 162 | x, _ = self.trunk(x, mask=mask, max_block_id=max_block_id) 163 | pred = self.head(x) 164 | 165 | # compute loss 166 | target = patchify(imgs,self.head.patch_size) 167 | if self.norm_pix_loss: 168 | mean = target.mean(dim=-1, keepdim=True) 169 | var = target.var(dim=-1, keepdim=True) 170 | target = (target - mean) / (var + 1.e-6)**.5 171 | 172 | loss = (pred[:,self.prefix_len:-1] - target[:,self.prefix_len+1:]) ** 2 # next token prediction 173 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 174 | loss = loss.mean() 175 | return loss, {} 176 | 177 | 178 | def extract_features( 179 | self, 180 | x: ArrayLike, 181 | mask: Optional[ArrayLike] = None, 182 | max_block_id: Optional[int] = -1, 183 | ) -> List[ArrayLike]: 184 | x = self.preprocessor(x, mask=mask) 185 | feats = self.trunk( 186 | x, mask=mask, max_block_id=max_block_id, return_features=True 187 | ) 188 | return feats 189 | 190 | @torch.no_grad() 191 | def representation(self, x: ArrayLike, **kwargs) -> ArrayLike: 192 | feats = self.extract_features(x, mask=None, max_block_id=-1) 193 | z = feats[-1] 194 | return dict(latent=z[:,-1]) 195 | 196 | class AIMForImageClassification(AIMMixin, PyTorchModelHubMixin, nn.Module): 197 | def __init__(self, config: Dict[str, Any]): 198 | super().__init__() 199 | self.preprocessor, self.trunk, self.head = aim_config(**config) 200 | 201 | 202 | 203 | 204 | def _get_attention_target(dim: int, num_heads: int) -> Callable[[bool], nn.Module]: 205 | def callback(use_bias: bool) -> nn.Module: 206 | return layers.Attention(dim=dim, num_heads=num_heads, use_bias=use_bias) 207 | 208 | return callback 209 | 210 | @gin.configurable() 211 | def aim_config( 212 | img_size: Union[int, Tuple[int, int]], 213 | patch_size: Union[int, Tuple[int, int]], 214 | embed_dim: int, 215 | num_blocks: int, 216 | num_heads: int, 217 | num_channels: int = 3, 218 | probe_layers: Union[int, Tuple[int, ...]] = 6, 219 | num_classes: int = 1000, 220 | mode = 'reconstruction', 221 | **kwargs: Any, 222 | ) -> Tuple[nn.Module, nn.Module, nn.Module]: 223 | # common 224 | norm_layer = layers.LayerNorm 225 | 226 | # preprocessor 227 | patchifier = layers.PatchEmbed( 228 | img_size=img_size, 229 | patch_size=patch_size, 230 | in_chans=num_channels, 231 | embed_dim=embed_dim, 232 | norm_layer=norm_layer, 233 | ) 234 | preprocessor = layers.ViTPreprocessor( 235 | patchifier, drop_patches=False, cls_token=False 236 | ) 237 | 238 | # trunk 239 | if isinstance(probe_layers, int): 240 | probe_layers = tuple(range(num_blocks - probe_layers, num_blocks)) 241 | assert all(layer >= 0 for layer in probe_layers), probe_layers 242 | 243 | attn_target = _get_attention_target(dim=embed_dim, num_heads=num_heads) 244 | post_transform_layer = layers.AverageLayers(probe_layers, reduce=False) 245 | trunk = Transformer( 246 | attn_target, 247 | embed_dim=embed_dim, 248 | num_blocks=num_blocks, 249 | norm_layer=norm_layer, 250 | post_transformer_layer=post_transform_layer, 251 | **kwargs, 252 | ) 253 | 254 | # head 255 | if mode == 'classification': 256 | head = layers.AttentionPoolingClassifier( 257 | dim=embed_dim, 258 | out_features=num_classes, 259 | num_heads=num_heads, 260 | qkv_bias=False, 261 | num_queries=1, 262 | ) 263 | elif mode == 'reconstruction': 264 | head = layers.ReconstructionHead( 265 | dim = embed_dim, 266 | patch_size = patch_size, 267 | ) 268 | return preprocessor, trunk, head 269 | 270 | @gin.configurable() 271 | def aim_tiny(img_size: Union[int, Tuple[int, int]] = 32, **kwargs: Any) -> AIMPretrain: 272 | preprocessor, trunk, head = aim_config( 273 | img_size=img_size, 274 | patch_size=4, 275 | embed_dim=192, 276 | num_blocks=12, 277 | num_heads=12, 278 | **kwargs, 279 | ) 280 | return AIMPretrain(preprocessor, trunk, head) 281 | 282 | @gin.configurable() 283 | def aim_600M(img_size: Union[int, Tuple[int, int]] = 224, **kwargs: Any) -> AIMPretrain: 284 | preprocessor, trunk, head = aim_config( 285 | img_size=img_size, 286 | patch_size=14, 287 | embed_dim=1536, 288 | num_blocks=24, 289 | num_heads=12, 290 | **kwargs, 291 | ) 292 | return AIMPretrain(preprocessor, trunk, head) 293 | 294 | 295 | def aim_1B(img_size: Union[int, Tuple[int, int]] = 224, **kwargs: Any) -> AIMPretrain: 296 | preprocessor, trunk, head = aim_config( 297 | img_size=img_size, 298 | patch_size=14, 299 | embed_dim=2048, 300 | num_blocks=24, 301 | num_heads=16, 302 | **kwargs, 303 | ) 304 | return AIMPretrain(preprocessor, trunk, head) 305 | 306 | 307 | def aim_3B( 308 | img_size: Union[int, Tuple[int, int]] = 224, patch_size: int = 14, **kwargs: Any 309 | ) -> AIMPretrain: 310 | preprocessor, trunk, head = aim_config( 311 | img_size=img_size, 312 | patch_size=patch_size, 313 | embed_dim=3072, 314 | num_blocks=24, 315 | num_heads=24, 316 | **kwargs, 317 | ) 318 | return AIMPretrain(preprocessor, trunk, head) 319 | 320 | 321 | def aim_7B( 322 | img_size: Union[int, Tuple[int, int]] = 224, patch_size: int = 14, **kwargs: Any 323 | ) -> AIMPretrain: 324 | preprocessor, trunk, head = aim_config( 325 | img_size=img_size, 326 | patch_size=patch_size, 327 | embed_dim=4096, 328 | num_blocks=32, 329 | num_heads=32, 330 | **kwargs, 331 | ) 332 | return AIMPretrain(preprocessor, trunk, head) 333 | 334 | 335 | 336 | if __name__ == '__main__': 337 | import torch.nn.functional as F 338 | model = aim_tiny(num_classes=10) 339 | params = sum(p.numel() for p in model.parameters()) 340 | x = torch.randn(10, 3, 32, 32) 341 | out = model(x) 342 | print(f"AIM: Autoregressive Image Models, size = {params:_}") 343 | # print(out.shape) 344 | -------------------------------------------------------------------------------- /model/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class BaseModel(torch.nn.Module): 4 | """ the template of creating a custom model""" 5 | def update(self): 6 | raise NotImplementedError 7 | 8 | def representation(self, x): 9 | raise NotImplementedError 10 | 11 | def forward(self, imgs, **kwargs): 12 | self.log = {} 13 | loss = 0 14 | return loss,self.log -------------------------------------------------------------------------------- /model/dino.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/facebookresearch/dino 3 | 4 | # Keypoints 5 | - Center: the key to avoiding collapse. 6 | - DINO head: WeightNorm is applied at the end and the weight_g (magnitude) is fixed. Therefore, it only optimizes the direction, which equals to L2-normalization. In addition, BN is removed. L2-normalization bottleneck stabilizes the training of DINO with deep projection head. 7 | - freeze last layer: the .last_layer in 8 | - output dimension: large output dimensionality improves the performance. 65536 is the best. 9 | 10 | 11 | # Result: 12 | 13 | 14 | """ 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.distributed as dist 20 | import gin 21 | from timm.layers import trunc_normal_ 22 | 23 | from layers.backbone import create_backbone 24 | 25 | @gin.configurable 26 | class DINOHead(nn.Module): 27 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 28 | super().__init__() 29 | nlayers = max(nlayers, 1) 30 | if nlayers == 1: 31 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 32 | else: 33 | layers = [nn.Linear(in_dim, hidden_dim)] 34 | if use_bn: 35 | layers.append(nn.BatchNorm1d(hidden_dim)) 36 | layers.append(nn.GELU()) 37 | for _ in range(nlayers - 2): 38 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 39 | if use_bn: 40 | layers.append(nn.BatchNorm1d(hidden_dim)) 41 | layers.append(nn.GELU()) 42 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 43 | self.mlp = nn.Sequential(*layers) 44 | self.apply(self._init_weights) 45 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 46 | self.last_layer.weight_g.data.fill_(1) 47 | if norm_last_layer: 48 | self.last_layer.weight_g.requires_grad = False 49 | 50 | def _init_weights(self, m): 51 | if isinstance(m, nn.Linear): 52 | trunc_normal_(m.weight, std=.02) 53 | if isinstance(m, nn.Linear) and m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | 56 | def forward(self, x): 57 | x = self.mlp(x) 58 | x = nn.functional.normalize(x, dim=-1, p=2) 59 | x = self.last_layer(x) 60 | return x 61 | 62 | 63 | @gin.configurable 64 | class DINO(nn.Module): 65 | def __init__(self, 66 | embed_dim = 2048, 67 | norm_last_layer=True, 68 | out_dim=60000, 69 | teacher_temp=0.05, student_temp=0.1, 70 | center_momentum=0.9): 71 | """ 72 | dim: feature dimension (default: 60000) 73 | teacher_temp: softmax temperature for teacher. Final value (after linear warmup) of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend starting with the default value of 0.04 and increase this slightly if needed. 74 | student_temp: 75 | """ 76 | super().__init__() 77 | self.out_dim = out_dim 78 | self.center_momentum = center_momentum 79 | self.register_buffer("center", torch.zeros(1, out_dim)) 80 | self.student_temp = student_temp 81 | self.teacher_temp = teacher_temp # TODO: adjust the temperature dynamically for the teacher 82 | 83 | # build encoders 84 | backbone = create_backbone() 85 | projector = DINOHead(embed_dim, out_dim, norm_last_layer=norm_last_layer) 86 | 87 | self.embed_dim = embed_dim 88 | self.student = nn.Sequential(backbone,projector) 89 | 90 | _teacher = nn.Sequential(create_backbone(), 91 | DINOHead(embed_dim, out_dim,)) 92 | _teacher.requires_grad_(False) 93 | self._teacher = _teacher 94 | self.update(0) 95 | 96 | @torch.no_grad() 97 | def teacher(self,x): 98 | return self._teacher(x) 99 | 100 | @torch.no_grad() 101 | def update_center(self, teacher_output): 102 | """ 103 | Update center used for teacher output. 104 | """ 105 | batch_center = torch.mean(teacher_output, dim=0, keepdim=True) 106 | if dist.is_initialized(): 107 | dist.all_reduce(batch_center) 108 | batch_center = batch_center / dist.get_world_size() 109 | 110 | # ema update 111 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 112 | 113 | @torch.no_grad() 114 | def update(self,m): 115 | for ema_v, model_v in zip(self._teacher.state_dict().values(), self.student.state_dict().values()): 116 | ema_v.data.mul_(m).add_((1 - m) * model_v.detach().data) 117 | 118 | def representation(self, x): 119 | if isinstance(x, list) or isinstance(x, tuple): 120 | x = x[0] 121 | latent = self._teacher[0](x) 122 | proj = self._teacher[1](latent) 123 | return dict(latent=latent,proj=proj) 124 | 125 | def forward(self, imgs, **kwargs): 126 | """ 127 | Input: 128 | x1: first views of images 129 | x2: second views of images 130 | m: moco momentum 131 | Output: 132 | loss 133 | """ 134 | self.log = {} 135 | x1, x2 = imgs[:2] 136 | local_x = imgs[2:] 137 | 138 | # predict the distribution of clusters 139 | logit1 = self.student(x1) 140 | logit2 = self.student(x2) 141 | 142 | with torch.no_grad(): 143 | t_output1 = self.teacher(x1) 144 | t_output2 = self.teacher(x2) 145 | # teacher centering and sharpening 146 | q1 = F.softmax((t_output1 - self.center)/self.teacher_temp,dim=-1) 147 | q2 = F.softmax((t_output2 - self.center)/self.teacher_temp,dim=-1) 148 | t_output = (t_output1 + t_output2)/2 149 | self.update_center(t_output) 150 | 151 | loss = ( 152 | torch.sum(- q1 * F.log_softmax(logit2/self.student_temp,dim=-1),-1) + 153 | torch.sum(- q2 * F.log_softmax(logit1/self.student_temp,dim=-1),-1) 154 | ).mean()/2 155 | 156 | loss_local = 0 157 | for lx in local_x: 158 | lz = self.student(lx) 159 | 160 | loss_local += ( 161 | torch.sum(- q1 * F.log_softmax(lz/self.student_temp),-1) + 162 | torch.sum(- q2 * F.log_softmax(lz/self.student_temp),-1) 163 | ).mean()/2 164 | 165 | 166 | p1 = F.softmax(logit1/self.student_temp,dim=-1) 167 | p2 = F.softmax(logit2/self.student_temp,dim=-1) 168 | K = (p2.shape[1]) 169 | 170 | # CE = q(z) * log(q(z|x)) = H(z) - KL(q(z)||q(z|x)), KL(q(z)||q(z|x)) > H(q(z|x)) 171 | # minimizing CE causes collapse H(Z) -> 0 and reduce the MI(Z,X) 172 | self.log['z_ce'] = - (p1 * torch.log(p2)).sum(-1).mean().item() 173 | self.log['H_zcx'] = - (p1 * torch.log(p1)).sum(-1).mean().item() 174 | # self.log['MI'] = - (q1 * torch.log(p1)).sum(-1).mean().item() 175 | 176 | return loss,self.log 177 | -------------------------------------------------------------------------------- /model/mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from timm.models.vision_transformer import PatchEmbed, Block 19 | 20 | from timm.layers.pos_embed import resample_abs_pos_embed 21 | from util.pos_embed import get_2d_sincos_pos_embed 22 | import gin 23 | from torchvision.transforms import GaussianBlur 24 | from layers.target import build_target 25 | # recipe https://github.com/facebookresearch/mae/blob/main/PRETRAIN.md 26 | @gin.configurable() 27 | class MaskedAutoencoderViT(nn.Module): 28 | """ Masked Autoencoder with VisionTransformer backbone 29 | """ 30 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 31 | mask_ratio=0.75, 32 | embed_dim=1024, depth=24, num_heads=16, 33 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, decoder_feature_size=None, 34 | mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6)): 35 | super().__init__() 36 | self.mask_ratio = mask_ratio 37 | self.embed_dim = embed_dim 38 | self.img_size=img_size 39 | # -------------------------------------------------------------------------- 40 | # MAE encoder specifics 41 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim, strict_img_size=False) 42 | num_patches = self.patch_embed.num_patches 43 | 44 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 45 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 46 | 47 | self.blocks = nn.ModuleList([ 48 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 49 | for i in range(depth)]) 50 | self.norm = norm_layer(embed_dim) 51 | # -------------------------------------------------------------------------- 52 | 53 | # -------------------------------------------------------------------------- 54 | # MAE decoder specifics 55 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 56 | 57 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 58 | 59 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 60 | 61 | self.decoder_blocks = nn.ModuleList([ 62 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 63 | for i in range(decoder_depth)]) 64 | 65 | self.decoder_norm = norm_layer(decoder_embed_dim) 66 | if decoder_feature_size is None: 67 | decoder_feature_size = patch_size**2 * in_chans 68 | self.decoder_pred = nn.Linear(decoder_embed_dim, decoder_feature_size, bias=True) # decoder to patch 69 | # -------------------------------------------------------------------------- 70 | 71 | self.target_loss = build_target(patch_size=patch_size) 72 | 73 | self.initialize_weights() 74 | 75 | def initialize_weights(self): 76 | # initialization 77 | # initialize (and freeze) pos_embed by sin-cos embedding 78 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 79 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 80 | 81 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 82 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 83 | 84 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 85 | w = self.patch_embed.proj.weight.data 86 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 87 | 88 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 89 | torch.nn.init.normal_(self.cls_token, std=.02) 90 | torch.nn.init.normal_(self.mask_token, std=.02) 91 | 92 | # initialize nn.Linear and nn.LayerNorm 93 | self.apply(self._init_weights) 94 | 95 | def _init_weights(self, m): 96 | if isinstance(m, nn.Linear): 97 | # we use xavier_uniform following official JAX ViT: 98 | torch.nn.init.xavier_uniform_(m.weight) 99 | if isinstance(m, nn.Linear) and m.bias is not None: 100 | nn.init.constant_(m.bias, 0) 101 | elif isinstance(m, nn.LayerNorm): 102 | nn.init.constant_(m.bias, 0) 103 | nn.init.constant_(m.weight, 1.0) 104 | 105 | def patchify(self, imgs): 106 | """ 107 | imgs: (N, 3, H, W) 108 | x: (N, L, patch_size**2 *3) 109 | """ 110 | p = self.patch_embed.patch_size[0] 111 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 112 | 113 | h = w = imgs.shape[2] // p 114 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 115 | x = torch.einsum('nchpwq->nhwpqc', x) 116 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 117 | return x 118 | 119 | def unpatchify(self, x): 120 | """ 121 | x: (N, L, patch_size**2 *3) 122 | imgs: (N, 3, H, W) 123 | """ 124 | p = self.patch_embed.patch_size[0] 125 | h = w = int(x.shape[1]**.5) 126 | assert h * w == x.shape[1] 127 | 128 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 129 | x = torch.einsum('nhwpqc->nchpwq', x) 130 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 131 | return imgs 132 | 133 | def random_masking(self, x, mask_ratio): 134 | """ 135 | Perform per-sample random masking by per-sample shuffling. 136 | Per-sample shuffling is done by argsort random noise. 137 | x: [N, L, D], sequence 138 | """ 139 | N, L, D = x.shape # batch, length, dim 140 | len_keep = int(L * (1 - mask_ratio)) 141 | 142 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 143 | 144 | # sort noise for each sample 145 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 146 | ids_restore = torch.argsort(ids_shuffle, dim=1) 147 | 148 | # keep the first subset 149 | ids_keep = ids_shuffle[:, :len_keep] 150 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 151 | 152 | # generate the binary mask: 0 is keep, 1 is remove 153 | mask = torch.ones([N, L], device=x.device) 154 | mask[:, :len_keep] = 0 155 | # unshuffle to get the binary mask 156 | mask = torch.gather(mask, dim=1, index=ids_restore) 157 | 158 | return x_masked, mask, ids_restore 159 | 160 | def representation(self, x, pos_embed=None): 161 | B, C, H, W = x.shape 162 | ## dynamic pos embed 163 | pos_embed = resample_abs_pos_embed( 164 | self.pos_embed, 165 | (H//self.patch_embed.patch_size[0], W//self.patch_embed.patch_size[1]), 166 | ) 167 | # embed patches 168 | x = self.patch_embed(x) 169 | 170 | # add pos embed w/o cls token 171 | x = x + pos_embed[:, 1:, :] 172 | 173 | 174 | # append cls token 175 | cls_token = self.cls_token + pos_embed[:, :1, :] 176 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 177 | x = torch.cat((cls_tokens, x), dim=1) 178 | 179 | # apply Transformer blocks 180 | for blk in self.blocks: 181 | x = blk(x) 182 | # x = self.norm(x) remove normalization 183 | 184 | return x[:,1:].mean(1) 185 | 186 | def forward_encoder(self, x, mask_ratio:float,pos_embed=None): 187 | B, C, H, W = x.shape 188 | if pos_embed is None: 189 | pos_embed = self.pos_embed 190 | # embed patches 191 | x = self.patch_embed(x) 192 | 193 | # add pos embed w/o cls token 194 | x = x + pos_embed[:, 1:, :] 195 | 196 | # masking: length -> length * mask_ratio 197 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 198 | 199 | # append cls token 200 | cls_token = self.cls_token + pos_embed[:, :1, :] 201 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 202 | x = torch.cat((cls_tokens, x), dim=1) 203 | 204 | # apply Transformer blocks 205 | for blk in self.blocks: 206 | x = blk(x) 207 | x = self.norm(x) 208 | 209 | return x, mask, ids_restore 210 | 211 | def forward_decoder(self, x, ids_restore,pos_embed=None): 212 | if pos_embed is None: 213 | pos_embed = self.decoder_pos_embed 214 | # embed tokens 215 | x = self.decoder_embed(x) 216 | 217 | # append mask tokens to sequence 218 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 219 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 220 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 221 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 222 | 223 | # add pos embed 224 | x = x + pos_embed 225 | 226 | # apply Transformer blocks 227 | for blk in self.decoder_blocks: 228 | x = blk(x) 229 | x = self.decoder_norm(x) 230 | 231 | # predictor projection 232 | x = self.decoder_pred(x) 233 | 234 | # remove cls token 235 | x = x[:, 1:, :] 236 | 237 | return x 238 | 239 | 240 | 241 | def forward(self, imgs,**kwargs): 242 | self.log = {} 243 | if isinstance(imgs, list) or isinstance(imgs, tuple): 244 | imgs = imgs[0] 245 | ## dynamic pos embed 246 | B, C, H, W = imgs.shape 247 | pos_embed = resample_abs_pos_embed( 248 | self.pos_embed, 249 | (H//self.patch_embed.patch_size[0], W//self.patch_embed.patch_size[1]), 250 | ) 251 | decoder_pos_embed = resample_abs_pos_embed( 252 | self.decoder_pos_embed, 253 | (H//self.patch_embed.patch_size[0], W//self.patch_embed.patch_size[1]), 254 | ) 255 | # pos_embed, decoder_pos_embed = self.pos_embed, self.decoder_embed 256 | ## dynamic pos embed 257 | 258 | mask_ratio= self.mask_ratio 259 | 260 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio,pos_embed) 261 | pred = self.forward_decoder(latent, ids_restore,decoder_pos_embed) # [N, L, p*p*3] 262 | 263 | loss = self.target_loss(imgs, pred, mask) 264 | return loss, self.log 265 | 266 | @gin.configurable() 267 | def mae_tiny(**kwargs): 268 | default_cfg = dict( 269 | patch_size=16,embed_dim=192,depth=12,num_heads=12, 270 | decoder_embed_dim=96,decoder_depth=1,decoder_num_heads=3, 271 | mlp_ratio=4, 272 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 273 | ) 274 | default_cfg.update(kwargs) 275 | 276 | model = MaskedAutoencoderViT(**default_cfg) 277 | return model 278 | 279 | @gin.configurable() 280 | def mae_vit_small_patch16_dec512d8b(**kwargs): 281 | default_cfg = dict( 282 | patch_size=16, embed_dim=384, depth=12, num_heads=6, 283 | decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16, 284 | mlp_ratio=4, 285 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 286 | default_cfg.update(kwargs) 287 | 288 | model = MaskedAutoencoderViT(**default_cfg) 289 | return model 290 | 291 | @gin.configurable() 292 | def mae_vit_base_patch16_dec512d8b(**kwargs): 293 | default_cfg = dict( 294 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 295 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 296 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 297 | default_cfg.update(kwargs) 298 | 299 | model = MaskedAutoencoderViT(**default_cfg) 300 | return model 301 | 302 | @gin.configurable() 303 | def mae_vit_large_patch16_dec512d8b(**kwargs): 304 | default_cfg = dict( 305 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 306 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 307 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 308 | default_cfg.update(kwargs) 309 | 310 | model = MaskedAutoencoderViT(**default_cfg) 311 | return model 312 | 313 | @gin.configurable() 314 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 315 | default_cfg = dict( 316 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 317 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 318 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6)) 319 | default_cfg.update(kwargs) 320 | 321 | model = MaskedAutoencoderViT(**default_cfg) 322 | return model 323 | 324 | # set recommended archs 325 | mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b 326 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 327 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 328 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 329 | -------------------------------------------------------------------------------- /model/mcl.py: -------------------------------------------------------------------------------- 1 | """SimDINO: Simplifying DINO via Coding Rate Regularization. 2 | Reference: https://github.com/RobinWu218/SimDINO/tree/main 3 | 4 | torchrun main_pretrain_ema.py --data_set cifar10 --data_path ../data/torch_data/ --batch_size 512 --epochs=200 --warmup_epochs=10 --ckpt_freq 100 --cfgs configs/cifar.gin configs/vitt.gin --gin build_model.model_fn=@SimDINO SimDINO.embed_dim=192 SimDINO.out_dim=128 MCRLoss.coeff=100 -m 0.996 --output_dir outputs/simdino_cifar10 5 | """ 6 | import torch 7 | from torch import nn 8 | import torch.distributed as dist 9 | import torch.distributed.nn as dist_nn 10 | import torch.nn.functional as F 11 | import gin 12 | 13 | from layers.backbone import create_backbone 14 | from model.dino import DINOHead 15 | 16 | 17 | @gin.configurable 18 | class MCRLoss(nn.Module): 19 | def __init__(self, ncrops=2, reduce_cov=0, expa_type=1, eps=0.5, coeff=1.0): 20 | """ 21 | Args: 22 | ncrops (int, optional): _description_. Defaults to 2. 23 | reduce_cov (int, optional): Whether or not all_reduce covariance matrices across gpus. Defaults to 0. 24 | expa_type (int, optional): Whether or not apply smoothing in expansion_term. Defaults to 1. 25 | eps (float, optional): eps for TCR. Defaults to 0.5. 26 | coeff (float, optional): coefficient of cosine similarity. Defaults to 1.0. 27 | """ 28 | super().__init__() 29 | self.ncrops = ncrops 30 | self.eps = eps 31 | self.coeff = coeff 32 | self.reduce_cov = reduce_cov 33 | self.expa_type = expa_type 34 | 35 | def forward(self, student_feat, teacher_feat): 36 | """ 37 | Expansion Loss and Compression Loss between features of the teacher and student networks. 38 | """ 39 | student_feat = student_feat.view(self.ncrops, -1, student_feat.shape[-1]) 40 | teacher_feat = teacher_feat.view(2, -1, teacher_feat.shape[-1]) 41 | if student_feat.isnan().any(): 42 | print("Warning: NaN student_feat") 43 | raise ValueError("NaN loss") 44 | 45 | comp_loss = self.calc_compression(student_feat, teacher_feat) 46 | if self.expa_type == 0: # only compute expansion on global views 47 | expa_loss = self.calc_expansion(student_feat[:len(teacher_feat)]) 48 | elif self.expa_type == 1: 49 | expa_loss = self.calc_expansion((student_feat[:2]+teacher_feat)/2) 50 | loss = - self.coeff * comp_loss - expa_loss 51 | return loss, comp_loss.detach(), expa_loss.detach() 52 | 53 | def calc_compression(self, student_feat_list, teacher_feat_list): 54 | """ 55 | Compute compression loss between student and teacher features. 56 | The average cosine similarity between the student and teacher features. This should be high. 57 | """ 58 | # Convert lists of tensors to a single tensor for vectorized operations 59 | 60 | sim = F.cosine_similarity(teacher_feat_list.unsqueeze(1), student_feat_list.unsqueeze(0), dim=-1) 61 | sim.view(-1, sim.shape[-1])[:: (len(student_feat_list) + 1), :].fill_(0) # Trick to fill diagonal 62 | 63 | n_loss_terms = len(teacher_feat_list)* len(student_feat_list) - min(len(teacher_feat_list), len(student_feat_list)) 64 | # Sum the cosine similarities 65 | comp_loss = sim.mean(2).sum()/n_loss_terms 66 | # global_comp_loss = (sim[:, :len(teacher_feat_list)].mean(2).sum()).detach_().div_(len(teacher_feat_list)) 67 | 68 | if torch.isnan(comp_loss): 69 | print("Warning: NaN comp_loss") 70 | raise ValueError("NaN loss") 71 | return comp_loss 72 | 73 | def calc_expansion(self, feat_list) -> torch.Tensor: 74 | """ 75 | Compute expansion loss using Coding Rate estimation. 76 | This denotes the information content of the features. This should be high. 77 | """ 78 | cov_list = [] 79 | num_views = len(feat_list) 80 | m, p = feat_list[0].shape 81 | 82 | cov_list = [W.T.matmul(W) for W in feat_list] 83 | cov_list = torch.stack(cov_list) 84 | N=1 85 | if dist.is_initialized(): 86 | N = dist.get_world_size() 87 | if self.reduce_cov == 1: 88 | cov_list = dist_nn.all_reduce(cov_list) 89 | scalar = p / (m * N * self.eps) 90 | I = torch.eye(p, device=cov_list[0].device) 91 | loss:torch.Tensor = 0 92 | for i in range(num_views): 93 | lossi = torch.linalg.cholesky_ex(I + scalar * cov_list[i])[0].diagonal().log().sum() 94 | if torch.isnan(lossi): 95 | print("Warning: NaN comp_loss") 96 | torch.save(feat_list, "z.pt") 97 | raise ValueError("NaN loss") 98 | loss += lossi 99 | loss /= num_views 100 | # loss *= (p+N*m)/(p*N*m) # the balancing factor gamma, you can also use the next line. This is ultimately a heuristic, so feel free to experiment. 101 | # loss *= ((self.eps * N * m) ** 0.5 / p) 102 | return loss 103 | 104 | def rate_distortion(z,eps=1e-3): 105 | """ 106 | Compute the rate distortion of a given tensor. 107 | """ 108 | m, p = z.shape 109 | cov = z.T.matmul(z) 110 | I = torch.eye(p, device=cov.device) 111 | scalar = p / (m * eps) 112 | return torch.linalg.cholesky_ex(I + scalar * cov)[0].diagonal().log().sum() 113 | 114 | @gin.configurable 115 | class SimDINO(nn.Module): 116 | def __init__(self, 117 | embed_dim = 2048, 118 | out_dim=60000, 119 | norm_last_layer = True, 120 | ): 121 | """ 122 | dim: feature dimension (default: 60000) 123 | teacher_temp: softmax temperature for teacher. Final value (after linear warmup) of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend starting with the default value of 0.04 and increase this slightly if needed. 124 | student_temp: 125 | """ 126 | super().__init__() 127 | self.out_dim = out_dim 128 | 129 | # build encoders 130 | self.loss_fn = MCRLoss() 131 | 132 | self.embed_dim = embed_dim 133 | self.student = nn.Sequential(create_backbone(), 134 | DINOHead(embed_dim, out_dim,norm_last_layer=norm_last_layer)) 135 | 136 | _teacher = nn.Sequential(create_backbone(), 137 | DINOHead(embed_dim, out_dim)) 138 | _teacher.requires_grad_(False) 139 | self._teacher = _teacher 140 | self.update(0) 141 | 142 | @torch.no_grad() 143 | def teacher(self,x): 144 | return self._teacher(x) 145 | 146 | @torch.no_grad() 147 | def update(self,m): 148 | for ema_v, model_v in zip(self._teacher.state_dict().values(), self.student.state_dict().values()): 149 | ema_v.data.mul_(m).add_((1 - m) * model_v.detach().data) 150 | 151 | def representation(self, x): 152 | if isinstance(x, list) or isinstance(x, tuple): 153 | x = x[0] 154 | latent = self._teacher[0](x) 155 | proj = self._teacher[1](latent) 156 | return dict(latent=latent,proj=proj) 157 | 158 | def forward(self, imgs, **kwargs): 159 | """ 160 | Input: 161 | x1: first views of images 162 | x2: second views of images 163 | m: moco momentum 164 | Output: 165 | loss 166 | """ 167 | self.log = {} 168 | x1, x2 = imgs[:2] 169 | local_x = imgs[2:] 170 | 171 | # predict the distribution of clusters 172 | student_output= torch.cat([self.student(x1),self.student(x2)]) 173 | with torch.no_grad(): 174 | teacher_output = torch.cat([self.teacher(x1),self.teacher(x2)]) 175 | 176 | loss, comp_loss, expa_loss = self.loss_fn(student_output, teacher_output) 177 | 178 | self.log['comp_loss'] = comp_loss.item() 179 | self.log['expa_loss'] = expa_loss.item() 180 | 181 | return loss,self.log 182 | 183 | if __name__ == '__main__': 184 | model = SimDINO(out_dim=32) 185 | x = torch.rand(10,3,32,32) 186 | print(model([x,x])) 187 | -------------------------------------------------------------------------------- /model/moco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/facebookresearch/moco-v3 3 | 4 | # Note 5 | - Projector: a MLP ending with BN. 2048-4096-256 6 | - Predictor: a MLP ending without BN. 256-4096-256 7 | 8 | ## [train](https://github.com/facebookresearch/moco-v3/blob/main/CONFIG.md) 9 | 10 | ResNet50 11 | `torchrun --nproc_per_node=8 main_pretrain_ema.py --batch_size=128 --opt LARS --blr=5e-2 --weight_decay=1.5e-6 --epochs=100 --warmup_epochs=10 --ckpt_freq=100 --data_path $train_path --prob_path $val_path --gin build_dataset.transform_fn=@MultiviewPipeline MultiviewPipeline.scale="(0.2,1)" build_model.model_fn=@MoCo MoCo.T=1 -m 0.99 ` 12 | 13 | 14 | ## network 15 | MoCo v3 consists of backbone $f(*)$, projector $f_q(*)$, predictor $f_k(*)$, and their momentum version, where projector is updated by EMA, predictor is learnable. 16 | 17 | q = f_q(f_k(f(x))), k =g_k(g(x)). 18 | 19 | The projector is crutial for the performance improvement refer to SimCLR. 20 | 21 | The predictor has no BN at the end. 22 | 23 | SyncBN is also beneficial. 24 | 25 | ## momentum encoder 26 | 27 | 28 | # Result: 29 | | Model | pretrain epochs | pretrain crops | linear acc | | 30 | |----------|:---------------:|:--------------:|:----------:|:-:| 31 | | resnet50 | 100 | 2x224 | 68.9 | | 32 | | resnet50 | 300 | 2x224 | 72.8 | | 33 | | resnet50 | 1000 | 2x224 | 74.6 | | 34 | | | | | | | 35 | """ 36 | 37 | from copy import deepcopy 38 | import torch 39 | import torch.nn as nn 40 | import torch.nn.functional as F 41 | import gin 42 | 43 | from layers.backbone import create_backbone 44 | from layers.operation import build_head, contrastive_loss 45 | 46 | @gin.configurable 47 | class MoCo(nn.Module): 48 | """ 49 | Build a MoCo model with a base encoder, a momentum encoder, and two MLPs 50 | https://arxiv.org/abs/1911.05722 51 | """ 52 | def __init__(self, 53 | embed_dim = 2048, 54 | out_dim=256, 55 | mlp_dim=4096, 56 | T=1.0,): 57 | """ 58 | dim: feature dimension (default: 256) 59 | mlp_dim: hidden dimension in MLPs (default: 4096) 60 | T: softmax temperature (default: 1.0) 61 | """ 62 | super(MoCo, self).__init__() 63 | 64 | self.T = T 65 | self.m = 0 66 | 67 | # build encoders 68 | backbone = create_backbone() 69 | 70 | self.embed_dim = embed_dim 71 | projector = build_head(2,embed_dim,mlp_dim,out_dim) 72 | self.student = nn.Sequential(backbone, 73 | projector) 74 | 75 | self._teacher = deepcopy(self.student) 76 | self._teacher.requires_grad_(False) 77 | self.update(0) 78 | 79 | self.predictor = build_head(2,out_dim,mlp_dim,out_dim,False) 80 | 81 | @torch.no_grad() 82 | def teacher(self,x): 83 | return self._teacher(x) 84 | 85 | @torch.no_grad() 86 | def update(self,m): 87 | for param_b, param_m in zip(self.student.parameters(), self._teacher.parameters()): 88 | param_m.data = param_m.data * m + param_b.data * (1. - m) 89 | 90 | @torch.no_grad() 91 | def representation(self, x): 92 | if isinstance(x, list) or isinstance(x, tuple): 93 | x = x[0] 94 | latent = self._teacher[0](x) 95 | proj = self._teacher[1](latent) 96 | s_latent = self.student[0](x) 97 | s_proj = self.student[1](s_latent) 98 | rep = dict(latent=latent,proj=proj,s_latent=s_latent,s_proj=s_proj) 99 | return rep 100 | 101 | def forward(self, imgs, **kwargs): 102 | """ 103 | Input: 104 | x1: first views of images 105 | x2: second views of images 106 | m: moco momentum 107 | Output: 108 | loss 109 | """ 110 | x1, x2 = imgs[:2] 111 | local_x = imgs[2:] 112 | self.log = {} 113 | 114 | # compute features 115 | z1 = self.student(x1) 116 | q1 = self.predictor(z1) 117 | q2 = self.predictor(self.student(x2)) 118 | k1 = self.teacher(x1) 119 | k2 = self.teacher(x2) 120 | loss = ( contrastive_loss(q1, k2, self.T) + 121 | contrastive_loss(q2, k1, self.T))/2 122 | 123 | loss_local = 0 124 | for lx in local_x: 125 | lz = self.student(lx) 126 | lp = self.predictor(lz) 127 | 128 | loss_local += ( 129 | contrastive_loss(q1,lp, self.T) + 130 | contrastive_loss(lp,k1, self.T) + 131 | contrastive_loss(q2,lp, self.T) + 132 | contrastive_loss(lp,k2, self.T) 133 | )/4 134 | 135 | self.log['qk@sim'] = F.cosine_similarity(q1,k1).mean().item() 136 | self.log['z@sim'] = F.cosine_similarity(q1,k2).mean().item() 137 | return loss,self.log 138 | 139 | 140 | class ConvStem(nn.Module): 141 | """ 142 | ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 143 | """ 144 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 145 | super().__init__() 146 | 147 | assert patch_size == 16, 'ConvStem only supports patch size of 16' 148 | assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' 149 | 150 | img_size = (img_size, img_size) 151 | patch_size = (patch_size, patch_size) 152 | self.img_size = img_size 153 | self.patch_size = patch_size 154 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 155 | self.num_patches = self.grid_size[0] * self.grid_size[1] 156 | self.flatten = flatten 157 | 158 | # build stem, similar to the design in https://arxiv.org/abs/2106.14881 159 | stem = [] 160 | input_dim, output_dim = 3, embed_dim // 8 161 | for l in range(4): 162 | stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False)) 163 | stem.append(nn.BatchNorm2d(output_dim)) 164 | stem.append(nn.ReLU(inplace=True)) 165 | input_dim = output_dim 166 | output_dim *= 2 167 | stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1)) 168 | self.proj = nn.Sequential(*stem) 169 | 170 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 171 | 172 | def forward(self, x): 173 | B, C, H, W = x.shape 174 | assert H == self.img_size[0] and W == self.img_size[1], \ 175 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 176 | x = self.proj(x) 177 | if self.flatten: 178 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 179 | x = self.norm(x) 180 | return x 181 | x = self.norm(x) 182 | return x -------------------------------------------------------------------------------- /model/msmae.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from timm.models.vision_transformer import PatchEmbed, Block 8 | 9 | from timm.layers.pos_embed import resample_abs_pos_embed 10 | from util.pos_embed import get_2d_sincos_pos_embed 11 | import gin 12 | from .target import build_target 13 | from model.models_mae import MaskedAutoencoderViT 14 | 15 | @gin.configurable() 16 | class MultiScaleMAE(MaskedAutoencoderViT): 17 | """ Masked Autoencoder with VisionTransformer backbone 18 | """ 19 | def forward_decoder(self, x,pos_embed=None): 20 | if pos_embed is None: 21 | pos_embed = self.decoder_pos_embed 22 | # embed tokens 23 | x = self.decoder_embed(x) 24 | # exclude cls token 25 | pos_embed = pos_embed[:,1:] 26 | # append mask tokens to sequence 27 | x_ = self.mask_token + pos_embed.expand(x.shape[0], -1, -1) 28 | x = torch.cat([x, x_], dim=1) # append cls token 29 | len_pred = x_.shape[1] 30 | 31 | # apply Transformer blocks 32 | for blk in self.decoder_blocks: 33 | x = blk(x) 34 | x = self.decoder_norm(x) 35 | 36 | # predictor projection 37 | x = self.decoder_pred(x) 38 | 39 | # remove cls token 40 | x = x[:, -len_pred:, :] 41 | 42 | return x 43 | 44 | def forward(self, imgs, return_variables=False,**kwargs): 45 | if isinstance(imgs, list) or isinstance(imgs, tuple): 46 | imgs = imgs[0] 47 | ## dynamic pos embed 48 | self.log = {} 49 | B, C, H, W = imgs.shape 50 | lx = imgs 51 | sx = F.interpolate(lx, scale_factor=0.5, mode='bilinear', align_corners=False) 52 | 53 | patch_size = self.patch_embed.patch_size[0] 54 | 55 | lpos_embed = resample_abs_pos_embed( 56 | self.pos_embed, 57 | (H//patch_size, W//patch_size), 58 | ) 59 | ldecoder_pos_embed = resample_abs_pos_embed( 60 | self.decoder_pos_embed, 61 | (H//patch_size, W//patch_size), 62 | ) 63 | 64 | spos_embed = resample_abs_pos_embed( 65 | self.pos_embed, 66 | (H//2//patch_size, W//2//patch_size), 67 | ) 68 | sdecoder_pos_embed = resample_abs_pos_embed( 69 | self.decoder_pos_embed, 70 | (H//2//patch_size, W//2//patch_size), 71 | ) 72 | # pos_embed, decoder_pos_embed = self.pos_embed, self.decoder_embed 73 | ## dynamic pos embed 74 | 75 | mask_ratio= self.mask_ratio 76 | 77 | llatent, lmask, lids_restore = self.forward_encoder(lx, mask_ratio,lpos_embed) 78 | slatent, smask, sids_restore = self.forward_encoder(sx, 0, spos_embed) 79 | 80 | # cross prediction 81 | spred = self.forward_decoder(llatent, sdecoder_pos_embed) 82 | lpred = self.forward_decoder(slatent, ldecoder_pos_embed) 83 | 84 | sloss = self.target_loss(sx, spred) 85 | lloss = self.target_loss(lx, lpred) 86 | 87 | loss = (sloss + lloss)/2 88 | self.log.update({ 89 | 'rec': sloss.item(), 90 | 'sr': lloss.item(), 91 | }) 92 | if return_variables: 93 | return loss, llatent, slatent, lmask, smask 94 | else: 95 | return loss,self.log 96 | 97 | @gin.configurable() 98 | def msmae_tiny_patch16(**kwargs): 99 | model = MultiScaleMAE( 100 | patch_size=16,embed_dim=196,depth=12,num_heads=12, 101 | decoder_embed_dim=96,decoder_depth=1,decoder_num_heads=3, 102 | mlp_ratio=4, 103 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 104 | return model 105 | 106 | @gin.configurable() 107 | def msmae_small_patch16(**kwargs): 108 | model = MultiScaleMAE( 109 | patch_size=16, embed_dim=384, depth=12, num_heads=6, 110 | decoder_embed_dim=512, decoder_depth=4, decoder_num_heads=16, 111 | mlp_ratio=4, 112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 113 | return model 114 | 115 | @gin.configurable() 116 | def msmae_base_patch16(**kwargs): 117 | model = MultiScaleMAE( 118 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 119 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 120 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 121 | return model 122 | 123 | 124 | @gin.configurable() 125 | def msmae_large_patch16(**kwargs): 126 | model = MultiScaleMAE( 127 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 128 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 129 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 130 | return model 131 | 132 | @gin.configurable() 133 | def msmae_huge_patch14(**kwargs): 134 | model = MultiScaleMAE( 135 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 136 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 137 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 138 | return model 139 | -------------------------------------------------------------------------------- /model/simclr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/google-research/simclr 3 | 4 | # Note 5 | 6 | ## Keypoints in SimCLR 7 | 8 | - data augmentation: random cropping and random color distortion stand out. 9 | - Global BN (SyncBN): This operation aggregates BN mean and variance over all devices during the training. 10 | - Projector: a MLP with BN. By leveraging the nonlinear transformation g(·), more information can be formed and maintained in h. 2048-2048-256 11 | - Batch size: it is crucial for improving performance. BS=4096 achieves good results. 12 | - Epoch: Contrastive learning benefits (more) from larger batch sizes and longer training. At least 400 epochs. 13 | # Result 14 | 15 | 16 | """ 17 | """ 18 | Reference: https://github.com/google-research/simclr 19 | 20 | # Note 21 | 22 | ## Keypoints in SimCLR 23 | 24 | - data augmentation: random cropping and random color distortion stand out. 25 | - Global BN (SyncBN): This operation aggregates BN mean and variance over all devices during the training. 26 | - Projector: a MLP with BN. By leveraging the nonlinear transformation g(·), more information can be formed and maintained in h. 2048-2048-256 27 | - Batch size: it is crucial for improving performance. BS=4096 achieves good results. 28 | - Epoch: Contrastive learning benefits (more) from larger batch sizes and longer training. At least 400 epochs. 29 | # Result 30 | 31 | 32 | """ 33 | import torch 34 | from torch import nn 35 | import torchvision.transforms as transforms 36 | import gin 37 | import timm 38 | from layers.operation import * 39 | from layers.backbone import create_backbone 40 | 41 | @gin.configurable 42 | class SimCLR(nn.Module): 43 | def __init__(self, 44 | out_dim=256, 45 | embed_dim=2048, 46 | mlp_dim=2048, 47 | temperature=0.5): 48 | super(SimCLR, self).__init__() 49 | self.temperature = temperature 50 | self.out_dim = out_dim 51 | backbone = create_backbone() 52 | 53 | self.embed_dim = embed_dim 54 | self.backbone = backbone 55 | self.projector = build_head(2,embed_dim,mlp_dim,out_dim, False) 56 | 57 | @torch.no_grad() 58 | def representation(self, x): 59 | if isinstance(x, list) or isinstance(x, tuple): 60 | x = x[0] 61 | latent = self.backbone(x) 62 | proj = self.projector(latent) 63 | rep = dict(latent=latent,proj=proj) 64 | return rep 65 | 66 | def forward(self, samples, **kwargs): 67 | self.log = {} 68 | x1,x2 = samples[:2] 69 | local_x = samples[2:] 70 | 71 | z1 = self.projector(self.backbone(x1)) 72 | z2 = self.projector(self.backbone(x2)) 73 | 74 | loss = (contrastive_loss(z1,z2,self.temperature) + 75 | contrastive_loss(z2,z1,self.temperature))/2 76 | 77 | loss_local = 0 78 | for lx in local_x: 79 | lz = self.backbone(lx) 80 | lp = self.projector(lz) 81 | 82 | loss_local += ( 83 | contrastive_loss(z1,lp,self.temperature) + 84 | contrastive_loss(lp,z1,self.temperature) + 85 | contrastive_loss(z2,lp,self.temperature) + 86 | contrastive_loss(lp,z2,self.temperature) 87 | )/4 88 | if loss_local>0: 89 | loss = loss + loss_local 90 | self.log['loss_local'] = loss_local.item() 91 | self.log['z@sim'] = F.cosine_similarity(z1,z2).mean().item() 92 | 93 | return loss, self.log 94 | 95 | 96 | 97 | def dynamic_contrastive_loss(hidden1, hidden2, index=None, gamma=0.9, distributed=True): 98 | """ 99 | paper: Provable stochastic optimization for global contrastive learning: Small batch does not harm performance 100 | reference: https://github.com/Optimization-AI/SogCLR/blob/PyTorch/sogclr/builder.py#L66 101 | """ 102 | # Get (normalized) hidden1 and hidden2. 103 | hidden1, hidden2 = F.normalize(hidden1, p=2, dim=1), F.normalize(hidden2, p=2, dim=1) 104 | batch_size = hidden1.shape[0] 105 | 106 | # Gather hidden1/hidden2 across replicas and create local labels. 107 | if distributed: 108 | hidden1_large = torch.cat(all_gather_layer.apply(hidden1), dim=0) # why concat_all_gather() 109 | hidden2_large = torch.cat(all_gather_layer.apply(hidden2), dim=0) 110 | enlarged_batch_size = hidden1_large.shape[0] 111 | 112 | labels_idx = (torch.arange(batch_size, dtype=torch.long) + batch_size * torch.distributed.get_rank()).to(self.device) 113 | labels = F.one_hot(labels_idx, enlarged_batch_size*2).to(self.device) 114 | masks = F.one_hot(labels_idx, enlarged_batch_size).to(self.device) 115 | batch_size = enlarged_batch_size 116 | else: 117 | hidden1_large = hidden1 118 | hidden2_large = hidden2 119 | labels = F.one_hot(torch.arange(batch_size, dtype=torch.long), batch_size * 2).to(self.device) 120 | masks = F.one_hot(torch.arange(batch_size, dtype=torch.long), batch_size).to(self.device) 121 | 122 | logits_aa = torch.matmul(hidden1, hidden1_large.T) 123 | logits_aa = logits_aa - masks * self.LARGE_NUM 124 | logits_bb = torch.matmul(hidden2, hidden2_large.T) 125 | logits_bb = logits_bb - masks * self.LARGE_NUM 126 | logits_ab = torch.matmul(hidden1, hidden2_large.T) 127 | logits_ba = torch.matmul(hidden2, hidden1_large.T) 128 | 129 | # SogCLR 130 | neg_mask = 1-labels 131 | logits_ab_aa = torch.cat([logits_ab, logits_aa ], 1) 132 | logits_ba_bb = torch.cat([logits_ba, logits_bb ], 1) 133 | 134 | neg_logits1 = torch.exp(logits_ab_aa /self.T)*neg_mask #(B, 2B) 135 | neg_logits2 = torch.exp(logits_ba_bb /self.T)*neg_mask 136 | 137 | # u init 138 | if self.u[index.cpu()].sum() == 0: 139 | gamma = 1 140 | 141 | u1 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1)) 142 | u2 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1)) 143 | 144 | # this sync on all devices (since "hidden" are gathering from all devices) 145 | if distributed: 146 | u1_large = concat_all_gather(u1) 147 | u2_large = concat_all_gather(u2) 148 | index_large = concat_all_gather(index) 149 | self.u[index_large.cpu()] = (u1_large.detach().cpu() + u2_large.detach().cpu())/2 150 | else: 151 | self.u[index.cpu()] = (u1.detach().cpu() + u2.detach().cpu())/2 152 | 153 | p_neg_weights1 = (neg_logits1/u1).detach() 154 | p_neg_weights2 = (neg_logits2/u2).detach() 155 | 156 | def softmax_cross_entropy_with_logits(labels, logits, weights): 157 | expsum_neg_logits = torch.sum(weights*logits, dim=1, keepdim=True)/(2*(batch_size-1)) 158 | normalized_logits = logits - expsum_neg_logits 159 | return -torch.sum(labels * normalized_logits, dim=1) 160 | 161 | loss_a = softmax_cross_entropy_with_logits(labels, logits_ab_aa, p_neg_weights1) 162 | loss_b = softmax_cross_entropy_with_logits(labels, logits_ba_bb, p_neg_weights2) 163 | loss = (loss_a + loss_b).mean() 164 | 165 | return loss -------------------------------------------------------------------------------- /model/simsiam.py: -------------------------------------------------------------------------------- 1 | """ 2 | reference: https://github.com/facebookresearch/simsiam 3 | # Note 4 | 5 | We apply cosine decay for lr to predictor, whereas, removing it gains improvement. 6 | The paper uses SGD to optimize the model. 7 | 8 | Warning: Not working with ViT. 9 | 10 | # keypoints 11 | - break symmetry 12 | - Stop-gradient: the key to avoid collapse. 13 | - Predictor: the key to converge. The loss remains high if removing the predictor. 2048-512-2048, no BN. The bottleneck instead of inverse bottleneck is vital for simsiam. The bottleneck prevents the predictor learning identity map, avoiding collapse. 14 | - projector: Use BN at the end. 2048-2048-2048. 15 | - BN: adding BN to the hidden layers is vital to the success of learning semantic representation. However, adding BN to the output of predictor will cause unstable training and the loss oscillates. 16 | - Hypothesis: The presence of stop-gradient is the consequence of introducing the extra set of variables. 17 | 18 | 19 | # Result: 20 | 21 | | Model | Note | pre-train epochs | batch size | linprob (top-1)| 22 | |----------|-----------|:----------------:|:----------:|:----------:| 23 | | resnet50 | official | 100 | 512 | 68.1 | 24 | | resnet50 | official | 100 | 256 | 68.3 | 25 | | resnet50 | our impl. | 100 | 256 | | 26 | """ 27 | import torch 28 | from torch import nn 29 | import gin 30 | import torch.nn.functional as F 31 | from layers.backbone import create_backbone 32 | 33 | 34 | 35 | @gin.configurable 36 | class SimSiam(nn.Module): 37 | """warning: the model is not stable. The loss oscillates. """ 38 | def __init__(self, 39 | proj_dim = 2048, 40 | embed_dim=2048, 41 | mlp_dim=512): 42 | super(SimSiam, self).__init__() 43 | backbone = create_backbone() 44 | self.embed_dim = embed_dim 45 | self.backbone = backbone 46 | 47 | self.projector = nn.Sequential( 48 | nn.Linear(embed_dim, embed_dim, bias=False), 49 | nn.BatchNorm1d(embed_dim), 50 | nn.ReLU(inplace=True), # first layer 51 | nn.Linear(embed_dim, embed_dim, bias=False), 52 | nn.BatchNorm1d(embed_dim), 53 | nn.ReLU(inplace=True), # second layer 54 | nn.Linear(embed_dim, proj_dim,bias=False), # output layer 55 | nn.BatchNorm1d(proj_dim, affine=False), 56 | ) 57 | 58 | 59 | 60 | self.predictor = nn.Sequential( 61 | nn.Linear(proj_dim, mlp_dim, bias=False), 62 | nn.BatchNorm1d(mlp_dim), 63 | nn.ReLU(), 64 | # nn.Softmax(), # simulate class centroid selection 65 | nn.Linear(mlp_dim, proj_dim)) # output layer 66 | self.criterion = nn.CosineSimilarity(dim=1) 67 | 68 | @torch.no_grad() 69 | def representation(self, x): 70 | if isinstance(x, list) or isinstance(x, tuple): 71 | x = x[0] 72 | latent = self.backbone(x) 73 | proj = self.projector(latent) 74 | pred = self.predictor(proj) 75 | return dict(latent=latent,proj=proj,pred=pred) 76 | 77 | 78 | def forward(self, samples, **kwargs): 79 | self.log = {} 80 | x1,x2 = samples[:2] 81 | local_x = samples[2:] 82 | 83 | z1 = self.projector(self.backbone(x1)) 84 | z2 = self.projector(self.backbone(x2)) 85 | 86 | p1 = self.predictor(z1) 87 | p2 = self.predictor(z2) 88 | 89 | loss = 1 - (self.criterion(p1, z2.detach()).mean() + 90 | self.criterion(p2, z1.detach()).mean()) * 0.5 91 | 92 | loss_local = 0 93 | for lx in local_x: 94 | lz = self.backbone(lx) 95 | lp = self.predictor(lz) 96 | 97 | loss_local += 1 - ( 98 | self.criterion(lp,z2.detach()) + 99 | self.criterion(lp,z1.detach()) 100 | ).mean()/2 101 | 102 | self.log["loss"] = loss.item() 103 | self.log['qk@sim'] = self.criterion(p1.detach(), z1.detach()).mean().item() 104 | with torch.no_grad(): 105 | 106 | self.log['qk@sim'] = F.cosine_similarity(p1,z1).mean().item() 107 | self.log['z@sim'] = F.cosine_similarity(p1,z2).mean().item() 108 | return loss, self.log 109 | -------------------------------------------------------------------------------- /model/sit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/Sara-Ahmed/SiT 3 | 4 | # keypoints: 5 | - reconstruction for patch tokens 6 | - contrastive learning for CLS token 7 | """ 8 | from functools import partial 9 | 10 | import timm.utils 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from timm.layers.pos_embed import resample_abs_pos_embed 16 | from model.operation import build_mlp, contrastive_loss 17 | from util.pos_embed import get_2d_sincos_pos_embed 18 | import gin 19 | import timm 20 | from model.models_mae import MaskedAutoencoderViT,build_mae_backbone 21 | 22 | 23 | 24 | 25 | @gin.configurable() 26 | class SiT(nn.Module): 27 | """ Masked Autoencoder with VisionTransformer backbone 28 | """ 29 | def __init__(self, 30 | model_name = 'vit_small', 31 | out_dim = 256, 32 | mlp_dim=1024, 33 | momentum=0.996, 34 | lambd=1, 35 | T = 0.2, 36 | ): 37 | super().__init__() 38 | self.lambd = lambd 39 | self.T = T 40 | self.student = build_mae_backbone(model_name) 41 | self.embed_dim = self.student.embed_dim 42 | self._teacher = timm.utils.ModelEmaV2(self.student,momentum) 43 | self._teacher.module.mask_ratio=0 44 | self._teacher.requires_grad_(False) 45 | # projector 46 | self.projector = build_mlp(2, self.embed_dim,mlp_dim,out_dim) 47 | teacher_projector = timm.utils.ModelEmaV2(self.projector,momentum) 48 | teacher_projector.requires_grad_(False) 49 | self.teacher_projector = teacher_projector 50 | # predictor 51 | self.predictor = build_mlp(2,out_dim,mlp_dim, out_dim,False) 52 | 53 | @torch.no_grad() 54 | def teacher(self,x): 55 | latent, mask, ids_restore = self._teacher.module.forward_encoder(x,0) 56 | return self.teacher_projector.module(latent[:,0]) 57 | 58 | def update(self): 59 | self._teacher.update(self.student) 60 | self.teacher_projector.update(self.projector) 61 | 62 | def group_masking(self, x, mask_ratio): 63 | """ 64 | Perform per-sample random masking by per-sample shuffling. 65 | Per-sample shuffling is done by argsort random noise. 66 | x: [N, L, D], sequence 67 | """ 68 | # TODO 69 | N, L, D = x.shape # batch, length, dim 70 | len_keep = int(L * (1 - mask_ratio)) 71 | 72 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 73 | 74 | # sort noise for each sample 75 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 76 | ids_restore = torch.argsort(ids_shuffle, dim=1) 77 | 78 | # keep the first subset 79 | ids_keep = ids_shuffle[:, :len_keep] 80 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 81 | 82 | # generate the binary mask: 0 is keep, 1 is remove 83 | mask = torch.ones([N, L], device=x.device) 84 | mask[:, :len_keep] = 0 85 | # unshuffle to get the binary mask 86 | mask = torch.gather(mask, dim=1, index=ids_restore) 87 | return x_masked, mask, ids_restore 88 | 89 | def representation(self, x, pos_embed=None): 90 | if isinstance(x,list): 91 | x= x[0] 92 | return self._teacher.module.representation(x,pos_embed) 93 | 94 | def forward(self, imgs, **kwargs): 95 | self.log = {} 96 | x1, x2 = imgs[:2] 97 | local_x = imgs[2:] 98 | 99 | recon1,z1 = self.student(x1,return_variables=True) 100 | recon2,z2 = self.student(x2,return_variables=True) 101 | q1 = self.predictor(self.projector(z1[:,0])) 102 | q2 = self.predictor(self.projector(z2[:,0])) 103 | k1 = self.teacher(x1) 104 | k2 = self.teacher(x2) 105 | 106 | loss_rec = (recon1+recon2)/2 107 | loss_cl = ( contrastive_loss(q1, k2, self.T) + 108 | contrastive_loss(q2, k1, self.T))/2 109 | self.log.update({'loss_cl':loss_cl.item(),'loss_rec':loss_rec.item()}) 110 | loss = self.lambd * loss_cl + loss_rec 111 | return loss,self.log 112 | -------------------------------------------------------------------------------- /model/vcl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/google-research/simclr 3 | FFCV_DEFAULT_CACHE_PROCESS=1 WANDB_NAME=VCL torchrun --nproc_per_node=8 main_pretrain_ffcv.py --data_path $train100_path --gin build_model.model_fn=@VCL MultiviewPipeline.img_size=112 MultiviewPipeline.local_crops_number=0 --multiview --batch_size 256 --epochs=400 --ckpt_freq 50 --online_prob --weight_decay=1e-4 --output_dir outputs/vcl 4 | # Note 5 | """ 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | import gin 10 | from layers.backbone import create_backbone 11 | from layers.operation import build_head, concat_all_gather 12 | import torch.nn.functional as F 13 | 14 | def kl_normal_loss(mean, logvar, mean_dim=None): 15 | """ 16 | Calculates the KL divergence between a normal distribution 17 | with diagonal covariance and a unit normal distribution. 18 | 19 | Parameters 20 | ---------- 21 | mean : torch.Tensor 22 | Mean of the normal distribution. Shape (batch_size, num_latent) where 23 | D is dimension of distribution. 24 | 25 | logvar : torch.Tensor 26 | Diagonal log variance of the normal distribution. Shape (batch_size, 27 | num_latent) 28 | """ 29 | if mean_dim is None: 30 | mean_dim = [0] 31 | latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=mean_dim) 32 | return latent_kl 33 | 34 | class CenterHead(nn.Module): 35 | def __init__(self, dim, num_clusters): 36 | super(CenterHead, self).__init__() 37 | self.centroid = nn.Parameter(torch.randn(dim)) 38 | self.centroid = nn.utils.weight_norm(nn.Linear(dim, num_clusters, bias=False)) 39 | self.centroid.weight_g.data.fill_(1) 40 | 41 | def forward(self, x): 42 | x = F.normalize(x, p=2, dim=-1) 43 | logits = torch.einsum('nc,mc->nm', [x, self.centroid]) 44 | return logits 45 | 46 | @gin.configurable 47 | class VCL(nn.Module): 48 | def __init__(self, 49 | beta=0, 50 | embed_dim=2048, 51 | out_dim=256, 52 | mlp_dim=2048,): 53 | super(VCL, self).__init__() 54 | self.beta=beta 55 | self.embed_dim = embed_dim 56 | 57 | backbone = create_backbone() 58 | self.backbone = backbone 59 | self.projector = build_head(2,embed_dim,mlp_dim,out_dim, last_norm='none') 60 | self.predictor = build_head(2,out_dim,mlp_dim,out_dim*2, last_norm='none') 61 | 62 | @torch.no_grad() 63 | def representation(self, x): 64 | if isinstance(x, list) or isinstance(x, tuple): 65 | x = x[0] 66 | x = self.backbone(x) 67 | return dict(latent=x) 68 | 69 | def forward(self, samples, **kwargs): 70 | self.log = {} 71 | x1,x2 = samples[:2] 72 | local_x = samples[2:] 73 | 74 | 75 | z1 = self.projector(self.backbone(x1)) 76 | z2 = self.projector(self.backbone(x2)) 77 | mu, logvar = self.predictor(z1).chunk(2, dim=-1) 78 | # fix variance 79 | logvar = torch.zeros_like(logvar) 80 | 81 | # q(yz) 82 | rank = torch.distributed.get_rank() 83 | B, D = mu.shape 84 | 85 | # avoid the centroids collapse 86 | # q(z|y) = H(z,y) - H(y) = q(zy) / (sum_z q(zy)): B x (B N) 87 | tz = concat_all_gather(z2) 88 | log_zy = -0.5 * (logvar.unsqueeze(1) + (tz.unsqueeze(0) - mu.unsqueeze(1)).pow(2) / logvar.exp().unsqueeze(1)).sum(dim=-1) 89 | 90 | log_zcy = torch.diagonal(log_zy, offset=rank*B) 91 | h_zcy = (torch.logsumexp(log_zy, dim=1) - log_zcy).mean() 92 | 93 | 94 | # minimize the distance between features and centroids 95 | # q(y|z) = q(zy)/ (sum_y q(zy)): B x (B N) 96 | tmu = concat_all_gather(mu.contiguous()) 97 | tlogvar = concat_all_gather(logvar.contiguous()) 98 | log_yz = -0.5 * (tlogvar.unsqueeze(0) + (tmu.unsqueeze(0) - z2.unsqueeze(1)).pow(2) / tlogvar.exp().unsqueeze(0)).sum(dim=-1) 99 | 100 | log_ycz = torch.diagonal(log_yz, offset=rank*B) 101 | h_ycz = (torch.logsumexp(log_yz, dim=1) - log_ycz).mean() 102 | 103 | 104 | kl_loss = kl_normal_loss(mu, logvar).sum() 105 | loss = h_zcy + h_ycz + kl_loss * self.beta 106 | self.log['h_zcy'] = h_zcy.item() 107 | self.log['h_ycz'] = h_ycz.item() 108 | self.log['kl_loss'] = kl_loss.item() 109 | return loss, self.log -------------------------------------------------------------------------------- /pics/n01518878_10165.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erow/FastSSL/67f28b6eeb8b940468562df7228124ea72a8d929/pics/n01518878_10165.JPEG -------------------------------------------------------------------------------- /profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | from PIL import Image # a trick to solve loading lib problem 12 | 13 | import argparse 14 | import datetime 15 | import json 16 | import numpy as np 17 | import os 18 | import time 19 | import math 20 | import sys 21 | import gin 22 | from pathlib import Path 23 | 24 | from os import getpid 25 | from psutil import Process, net_io_counters 26 | import torch 27 | import torch.backends.cudnn as cudnn 28 | from torch.utils.tensorboard import SummaryWriter 29 | import torchvision.transforms as transforms 30 | import torchvision.datasets as datasets 31 | from tqdm import tqdm 32 | 33 | from torch.profiler import profile, record_function, ProfilerActivity, schedule 34 | 35 | 36 | class ramqdm(tqdm): 37 | """tqdm progress bar that reports RAM usage with each update""" 38 | _empty_desc = "using ? GB RAM; ? CPU ? IO" 39 | _desc = "{:.2f} GB RAM; {:.2f} % CPU {:.2f} MB IO" 40 | _GB = 10**9 41 | """""" 42 | def __init__(self, *args, **kwargs): 43 | """Override desc and get reference to current process""" 44 | if "desc" in kwargs: 45 | # prepend desc to the reporter mask: 46 | self._empty_desc = kwargs["desc"] + " " + self._empty_desc 47 | self._desc = kwargs["desc"] + " " + self._desc 48 | del kwargs["desc"] 49 | else: 50 | # nothing to prepend, reporter mask is at start of sentence: 51 | self._empty_desc = self._empty_desc.capitalize() 52 | self._desc = self._desc.capitalize() 53 | super().__init__(*args, desc=self._empty_desc, **kwargs) 54 | self._process = Process(getpid()) 55 | self.metrics = [] 56 | """""" 57 | def update(self, n=1): 58 | """Calculate RAM usage and update progress bar""" 59 | rss = self._process.memory_info().rss 60 | ps = self._process.cpu_percent() 61 | io_counters = self._process.io_counters().read_bytes 62 | # net_io = net_io_counters().bytes_recv 63 | # io_counters += net_io 64 | 65 | current_desc = self._desc.format(rss/self._GB, ps, io_counters/1e6) + f" pid {getpid()} " 66 | self.set_description(current_desc) 67 | self.metrics.append({'mem':rss/self._GB, 'cpu':ps, 'io':io_counters/1e6}) 68 | super().update(n) 69 | 70 | def summary(self): 71 | res = {} 72 | for key in self.metrics[0].keys(): 73 | res[key] = np.mean([i[key] for i in self.metrics]) 74 | return res 75 | 76 | def backward_hook_wrapper(module, details=None): 77 | 78 | # define register_full_backward_pre_hook function 79 | def bwd_pre_hook_print(self, output): 80 | message = f'before backward of {module.__class__.__qualname__}' 81 | if details: 82 | message = f'{message}: {details}' 83 | with torch.profiler.record_function(message): 84 | return output 85 | 86 | # define register_full_backward_hook function 87 | def bwd_hook_print(self, input, output): 88 | message = f'after backward of {module.__class__.__qualname__}' 89 | if details: 90 | message = f'{message}: {details}' 91 | with torch.profiler.record_function(message): 92 | return input 93 | print(f"Register backward hook for {module.__class__.__qualname__}") 94 | # register hooks 95 | module.register_full_backward_pre_hook(bwd_pre_hook_print) 96 | module.register_full_backward_hook(bwd_hook_print) 97 | return module 98 | 99 | def main(args): 100 | misc.init_distributed_mode(args) 101 | 102 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 103 | print("{}".format(args).replace(', ', ',\n')) 104 | 105 | device = torch.device(args.device) 106 | 107 | # fix the seed for reproducibility 108 | seed = args.seed + misc.get_rank() 109 | torch.manual_seed(seed) 110 | np.random.seed(seed) 111 | 112 | cudnn.benchmark = True 113 | 114 | dataset_train = build_dataset(args) 115 | 116 | num_tasks = misc.get_world_size() 117 | global_rank = misc.get_rank() 118 | if args.data_set != "ffcv": 119 | sampler_train = torch.utils.data.DistributedSampler( 120 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 121 | ) 122 | data_loader_train = torch.utils.data.DataLoader( 123 | dataset_train, sampler=sampler_train, 124 | batch_size=args.batch_size, 125 | num_workers=args.num_workers, 126 | pin_memory=args.pin_mem, 127 | drop_last=True, 128 | ) 129 | else: 130 | data_loader_train = dataset_train 131 | print("Memory Manager = %s" % str(data_loader_train.memory_manager)) 132 | print("data set : ", dataset_train) 133 | # define the model 134 | model = build_model(args) 135 | model.to(device) 136 | 137 | torch.compile(model) 138 | 139 | model_without_ddp = model 140 | 141 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 142 | 143 | if args.lr is None: # only base_lr is specified 144 | args.lr = args.blr * eff_batch_size / 256 145 | 146 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 147 | print("actual lr: %.2e" % args.lr) 148 | 149 | print("accumulate grad iterations: %d" % args.accum_iter) 150 | print("effective batch size: %d" % eff_batch_size) 151 | 152 | if args.distributed: 153 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 154 | model_without_ddp = model.module 155 | 156 | # following timm: set wd as 0 for bias and norm layers 157 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) 158 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 159 | loss_scaler = NativeScaler() 160 | 161 | print("Preload data") 162 | 163 | for _ in ramqdm(data_loader_train): 164 | pass 165 | model.train() 166 | print(f"Start profiling for {args.num_samples} samples.") 167 | 168 | scaler = torch.cuda.amp.GradScaler() 169 | 170 | 171 | ## Profiling 172 | if args.no_profile: 173 | 174 | for _ in range(3): 175 | print("Start training one epoch.") 176 | l = ramqdm(data_loader_train) 177 | start = time.time() 178 | num_samples = 0 179 | for data_iter_step, data in enumerate(l): 180 | samples,y = data 181 | num_samples+=len(samples) 182 | with torch.cuda.amp.autocast(): 183 | loss = model(samples,epoch=0) 184 | 185 | scaler.scale(loss).backward() 186 | scaler.unscale_(optimizer) 187 | scaler.step(optimizer) 188 | scaler.update() 189 | optimizer.zero_grad() 190 | torch.cuda.synchronize() 191 | 192 | end = time.time() 193 | res = l.summary() 194 | res.update(args.__dict__) 195 | res['runtime'] = end-start 196 | res['throughput'] = float(num_samples)/(end-start) 197 | 198 | print(f"throughput : {res['throughput']} ") 199 | with open(os.path.join(args.output_dir, f"train_one_epoch-{global_rank}.json"), "a+") as file: 200 | file.write(json.dumps(res)+"\n") 201 | else: 202 | my_schedule = schedule( 203 | skip_first=100, 204 | wait=5, 205 | warmup=5, 206 | active=10) 207 | print_freq = 10 208 | optimizer.zero_grad() 209 | n_samples = 0 210 | 211 | print("Start profiling.") 212 | with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 213 | profile_memory=True,use_cuda=True,schedule=my_schedule,with_stack=True) as prof: 214 | metric_logger = misc.MetricLogger(delimiter=" ") 215 | for data_iter_step, data in enumerate(metric_logger.log_every(data_loader_train, print_freq, "")): 216 | with record_function('forward'): 217 | if args.data_set == "ffcv": 218 | samples = data[0] 219 | else: 220 | (samples, _) = data 221 | samples = samples.cuda(non_blocking=True) 222 | 223 | with torch.cuda.amp.autocast(): 224 | loss = model(samples,epoch=0) 225 | with record_function('backward'): 226 | scaler.scale(loss).backward() 227 | 228 | with record_function('opt'): 229 | scaler.unscale_(optimizer) 230 | scaler.step(optimizer) 231 | scaler.update() 232 | optimizer.zero_grad() 233 | torch.cuda.synchronize() 234 | 235 | n_samples +=len(samples) 236 | if n_samples >=args.num_samples: 237 | prof.step() 238 | n_samples = 0 239 | 240 | if prof.step_num >= 120: break 241 | 242 | print(prof.key_averages(group_by_stack_n=3).table(sort_by="self_cuda_time_total", row_limit=10)) 243 | 244 | prof.export_chrome_trace(os.path.join(args.output_dir, f"profile-{global_rank}.json")) 245 | 246 | if __name__ == '__main__': 247 | from util.helper import aug_parse 248 | parser = get_args_parser() 249 | parser.add_argument("-n", "--num_samples", type=int, default=512, help="number of samples to record one step for profile.") 250 | parser.add_argument("--no_profile",default=False,action="store_true",help="whether to profile the model.") 251 | args = aug_parse(parser) 252 | assert args.num_samples > 0, "num_samples should be larger than 0." 253 | assert args.num_samples % args.batch_size == 0, "num_samples should be divisible by batch_size." 254 | 255 | if args.output_dir: 256 | output_dir=Path(args.output_dir) 257 | output_dir.mkdir(parents=True, exist_ok=True) 258 | main(args) 259 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchmetrics>=1.2.1 4 | wandb 5 | gin-config 6 | # numba 7 | einops 8 | tensorboard==2.15.1 9 | timm>=0.9.12 10 | scipy==1.12.0 11 | # submitit==1.5.1 12 | # ffcv@git+https://github.com/erow/ffcv.git -------------------------------------------------------------------------------- /submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import submitit 16 | import importlib 17 | def get_args_parser(): 18 | # trainer_parser = trainer.get_args_parser() 19 | parser = argparse.ArgumentParser("Submitit for evaluation") 20 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 21 | parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request") 22 | parser.add_argument("-t", "--timeout", default=1440, type=int, help="Duration of the job") 23 | parser.add_argument("--module", default='main_pretrain_ffcv', type=str) 24 | parser.add_argument("--mem", default=400, type=float, help="Memory to request") 25 | 26 | parser.add_argument("-p", "--partition", default="big", type=str, help="Partition where to submit") 27 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 28 | parser.add_argument( "--job_dir", default='',type=str,) 29 | 30 | return parser 31 | 32 | 33 | def get_shared_folder(root) -> Path: 34 | root = root.replace("%j", "shared") 35 | p = Path(root) 36 | os.makedirs(str(p), exist_ok=True) 37 | if Path(root).is_dir(): 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(root): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder(root)), exist_ok=True) 45 | init_file = get_shared_folder(root) / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args, module_params): 53 | self.args = args 54 | self.module_params = module_params 55 | self.module = importlib.import_module(args.module) 56 | 57 | ## reassing args 58 | parser = self.module.get_args_parser() 59 | module_args = parser.parse_args(module_params) 60 | module_args.output_dir = args.job_dir 61 | module_args.dist_url = args.dist_url 62 | self.module_args = module_args 63 | 64 | def __call__(self): 65 | self._setup_gpu_args() 66 | 67 | self.module.main(self.module_args) 68 | 69 | def checkpoint(self): 70 | print("Checkpointing") 71 | import os 72 | import submitit 73 | job_env = submitit.JobEnvironment() 74 | print("Requeuing ", self.args, self.module_args) 75 | 76 | output_dir = self.module_args.output_dir 77 | 78 | checkpoint_file = os.path.join(output_dir, "checkpoint.pth") 79 | self.args.dist_url = get_init_file(output_dir).as_uri() 80 | empty_trainer = type(self)(self.args,self.module_params) 81 | if os.path.exists(checkpoint_file): 82 | empty_trainer.module_args.resume = checkpoint_file 83 | 84 | print("Requeueing with ", empty_trainer.module_args) 85 | return submitit.helpers.DelayedSubmission(empty_trainer) 86 | 87 | def _setup_gpu_args(self): 88 | import submitit 89 | module_args = self.module_args 90 | job_env = submitit.JobEnvironment() 91 | output_dir = str(self.args.job_dir).replace("%j", str(job_env.job_id)) 92 | module_args.output_dir = output_dir 93 | 94 | module_args.gpu = job_env.local_rank 95 | module_args.rank = job_env.global_rank 96 | module_args.world_size = job_env.num_tasks 97 | 98 | module_args.comment = f"Job {job_env.job_id} on {job_env.num_tasks} GPUs" 99 | 100 | print("Setting up GPU args", module_args) 101 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 102 | 103 | 104 | def main(): 105 | parser = get_args_parser() 106 | args, module_params = parser.parse_known_args() 107 | print("args:", args) 108 | print("module_params:", module_params) 109 | if args.job_dir=='': 110 | args.job_dir = f"outputs/experiments/%j" 111 | args.job_dir = os.path.abspath(args.job_dir) 112 | # Note that the folder will depend on the job_id, to easily track experiments 113 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 114 | 115 | num_gpus_per_node = args.ngpus 116 | nodes = args.nodes 117 | timeout_min = args.timeout 118 | 119 | partition = args.partition 120 | kwargs = {} 121 | 122 | if args.comment: 123 | kwargs['slurm_comment'] = args.comment 124 | 125 | executor.update_parameters( 126 | mem_gb=args.mem, 127 | gpus_per_node=num_gpus_per_node, 128 | tasks_per_node=num_gpus_per_node, # one task per GPU 129 | cpus_per_task=10, 130 | nodes=nodes, 131 | timeout_min=timeout_min, # max is 60 * 72 132 | # Below are cluster dependent parameters 133 | slurm_partition=partition, 134 | slurm_signal_delay_s=120, 135 | **kwargs 136 | ) 137 | executor.update_parameters(name="pretrain") 138 | args.dist_url = get_init_file(args.job_dir).as_uri() 139 | 140 | trainer = Trainer(args, module_params) 141 | job = executor.submit(trainer) 142 | 143 | print("Submitted job_id:", job.job_id) 144 | 145 | 146 | if __name__ == "__main__": 147 | main() 148 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erow/FastSSL/67f28b6eeb8b940468562df7228124ea72a8d929/util/__init__.py -------------------------------------------------------------------------------- /util/clustering.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from sklearn.metrics import normalized_mutual_info_score 4 | import torch 5 | 6 | def evaluate_clustering(assignments, true_labels): 7 | """ 8 | Evaluates the clustering performance using Normalized Mutual Information (NMI). 9 | 10 | Parameters: 11 | assignments (np.ndarray): The cluster assignments for each data point, shape (num_points,). 12 | true_labels (np.ndarray): The true labels for each data point, shape (num_points,). 13 | 14 | Returns: 15 | nmi (float): The Normalized Mutual Information score. 16 | """ 17 | if isinstance(assignments, torch.Tensor): 18 | assignments = assignments.cpu().numpy() 19 | if isinstance(true_labels, torch.Tensor): 20 | true_labels = true_labels.cpu().numpy() 21 | 22 | nmi = normalized_mutual_info_score(true_labels, assignments) 23 | return nmi 24 | 25 | class StreamingKMeans: 26 | def __init__(self, k, lr=0.01): 27 | """ 28 | Initialize Streaming K-Means. 29 | 30 | Args: 31 | k (int): Number of clusters. 32 | dim (int): Dimensionality of the data. 33 | lr (float): Learning rate for updating cluster centers. 34 | """ 35 | self.k = k 36 | self.lr = lr 37 | self.cluster_centers = None # Initialize cluster centers randomly 38 | 39 | def update_predict(self, data): 40 | """ 41 | Update cluster centers with new data. 42 | 43 | Args: 44 | data (torch.Tensor): A tensor of shape (batch_size, dim) containing the new data points. 45 | """ 46 | batch_size = data.size(0) 47 | 48 | 49 | if self.cluster_centers is None: 50 | # Initialize cluster centers with the first batch of data 51 | self.dim = data.size(1) 52 | self.cluster_centers = torch.randn(self.k, self.dim).to(data.device) 53 | 54 | # Compute distances between data points and cluster centers 55 | distances = torch.cdist(data, self.cluster_centers) # Shape: (batch_size, k) 56 | 57 | # Assign each data point to the nearest cluster 58 | nearest_cluster_indices = torch.argmin(distances, dim=1) # Shape: (batch_size,) 59 | 60 | # Update cluster centers 61 | for i in range(self.k): 62 | # Get the data points assigned to the i-th cluster 63 | cluster_data = data[nearest_cluster_indices == i] 64 | 65 | if cluster_data.size(0) > 0: 66 | # Compute the mean of the data points in the cluster 67 | cluster_mean = torch.mean(cluster_data, dim=0) 68 | 69 | # Update the cluster center using the learning rate 70 | self.cluster_centers[i] = (1 - self.lr) * self.cluster_centers[i] + self.lr * cluster_mean 71 | return nearest_cluster_indices 72 | 73 | def predict(self, data): 74 | """ 75 | Predict the nearest cluster for each data point. 76 | 77 | Args: 78 | data (torch.Tensor): A tensor of shape (batch_size, dim) containing the data points. 79 | 80 | Returns: 81 | torch.Tensor: A tensor of shape (batch_size,) containing the cluster indices. 82 | """ 83 | distances = torch.cdist(data, self.cluster_centers) 84 | return torch.argmin(distances, dim=1) 85 | from layers.operation import concat_all_gather 86 | class KmeansProb(): 87 | def __init__(self, representation_fn, num_clusters=100): 88 | self.representation_fn = representation_fn 89 | self.num_clusters = num_clusters 90 | self.kmeans_dict = {} 91 | 92 | def step(self,x,y): 93 | if isinstance(x,list): 94 | x = x[0] 95 | log = self.step_train(x.cuda(), y.cuda()) 96 | return log 97 | 98 | def step_train(self, x,y): 99 | log = {} 100 | if isinstance(x, list) or isinstance(x, tuple): 101 | x = x[0] 102 | 103 | y = concat_all_gather(y.contiguous()) 104 | 105 | for name, z in self.representation_fn(x).items(): 106 | z = concat_all_gather(z.contiguous()) 107 | if name not in self.kmeans_dict: 108 | self.kmeans_dict[name] = StreamingKMeans(self.num_clusters) 109 | 110 | assignments = self.kmeans_dict[name].update_predict(z) 111 | nmi = evaluate_clustering(assignments, y) 112 | log[name+'@nmi'] = nmi 113 | return log 114 | 115 | # Example usage 116 | if __name__ == "__main__": 117 | # Generate some random data 118 | np.random.seed(42) 119 | num_samples = 20000 120 | noise = np.random.rand(num_samples, 128).astype(np.float32) 121 | true_labels = np.random.randint(0, 10, num_samples) # Assuming we have 10 true clusters 122 | true_centroids = np.random.rand(10, 20).astype(np.float32) 123 | W = np.random.rand(128, 20).astype(np.float32)/4 124 | 125 | # for k in np.linspace(0,2,10): 126 | # data = noise*k + true_centroids[true_labels].dot(W.T) 127 | 128 | # # Run k-means clustering 129 | # num_clusters = 10 130 | # centroids, assignments = run_kmeans(data, num_clusters) 131 | 132 | # # Evaluate clustering performance 133 | # nmi = evaluate_clustering(assignments, true_labels) 134 | # print(": Normalized Mutual Information (NMI):", nmi, "k:", k) 135 | 136 | skm = StreamingKmeans(num_clusters=100) 137 | data = noise*0.1 + true_centroids[true_labels].dot(W.T) 138 | for chunk,labels in zip(np.array_split(data, 10), np.array_split(true_labels, 10)): 139 | assignments = skm.push(chunk) 140 | nmi = evaluate_clustering(assignments, labels) if assignments is not None else None 141 | print("Streaming Kmeans: Normalized Mutual Information (NMI):", nmi) 142 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/dres.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import gin 3 | 4 | 5 | DEFAULT_SCHEME ={ 6 | 1: [ 7 | dict(res=160,mask_ratio=0.5,), 8 | dict(res=192,mask_ratio=0.66), 9 | dict(res=224,mask_ratio=0.75), 10 | ], 11 | 2: [ 12 | dict(res=160,mask_ratio=0.75,), 13 | dict(res=192,mask_ratio=0.75), 14 | dict(res=224,mask_ratio=0.75), 15 | ], 16 | 3: [ 17 | dict(res=224,mask_ratio=0.75), 18 | dict(res=192,mask_ratio=0.75), 19 | dict(res=160,mask_ratio=0.75), 20 | ], 21 | 4: [ 22 | dict(res=160,mask_ratio=0.75,), 23 | dict(res=192,mask_ratio=0.80), 24 | dict(res=224,mask_ratio=0.85), 25 | ], 26 | 5: [ 27 | dict(res=224,mask_ratio=0.75), 28 | dict(res=192,mask_ratio=0.75), 29 | dict(res=224,mask_ratio=0.75), 30 | ], 31 | 6: [ 32 | dict(res=224,mask_ratio=0.75), 33 | dict(res=192,mask_ratio=0.75), 34 | dict(res=224,mask_ratio=0.85), 35 | ], 36 | } 37 | 38 | @gin.configurable 39 | class DynamicMasking: 40 | def __init__(self, start_ramp=gin.REQUIRED, end_ramp=gin.REQUIRED, 41 | scheme = 2): 42 | if isinstance(scheme, int): 43 | scheme = DEFAULT_SCHEME[scheme] 44 | else: 45 | assert isinstance(scheme, list) 46 | self.scheme = scheme 47 | self.start_ramp = start_ramp 48 | self.end_ramp = end_ramp 49 | 50 | def get_config(self, epoch): 51 | if epoch <= self.start_ramp: 52 | return self.scheme[0] 53 | elif epoch>=self.end_ramp: 54 | return self.scheme[-1] 55 | else: 56 | i = (epoch-self.start_ramp) * (len(self.scheme)-1) // (self.end_ramp-self.start_ramp) 57 | return self.scheme[i] 58 | 59 | def __call__(self, model, loader, epoch,is_ffcv=False): 60 | config = self.get_config(epoch) 61 | print(", ".join([f"{k}={v}" for k,v in config.items()])) 62 | img_size = config['res'] 63 | mask_ratio = config['mask_ratio'] 64 | 65 | assert hasattr(model,"mask_ratio") 66 | model.mask_ratio = mask_ratio 67 | if is_ffcv: 68 | pipeline=loader.pipeline_specs['image'] 69 | if pipeline.decoder.output_size[0] != img_size: 70 | pipeline.decoder.output_size = (img_size,img_size) 71 | loader.generate_code() 72 | else: 73 | print(loader.dataset.transforms) 74 | augmentation = loader.dataset.transforms.transform 75 | augmentation.change_resolution(img_size) 76 | -------------------------------------------------------------------------------- /util/helper.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | from pathlib import Path 3 | import gin 4 | 5 | def aug_parse(parser: argparse.ArgumentParser): 6 | import yaml 7 | parser.add_argument('--no_resume',default=False,action='store_true',help="") 8 | parser.add_argument('--cfgs', nargs='+', default=[], 9 | help=' Config files *.gin.', required=False) 10 | parser.add_argument('--gin', nargs='+', 11 | help='Overrides config values. e.g. --gin "section.option=value"') 12 | 13 | args = parser.parse_args() 14 | 15 | if args.output_dir: 16 | output_dir=Path(args.output_dir) 17 | output_dir.mkdir(parents=True, exist_ok=True) 18 | gin.parse_config_files_and_bindings(args.cfgs,args.gin) 19 | 20 | if args.output_dir: 21 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 22 | with open(os.path.join(args.output_dir,'config.yml'), 'w') as f: 23 | yaml.dump(vars(args), f) 24 | 25 | open(output_dir/"config.gin",'w').write(gin.config_str(),) 26 | 27 | return args -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | import numpy as np 15 | 16 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 17 | warmup_schedule = np.array([]) 18 | warmup_iters = warmup_epochs * niter_per_ep 19 | if warmup_epochs > 0: 20 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 21 | 22 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 23 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 24 | 25 | schedule = np.concatenate((warmup_schedule, schedule)) 26 | assert len(schedule) == epochs * niter_per_ep 27 | return schedule 28 | 29 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 30 | """ 31 | Parameter groups for layer-wise lr decay 32 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 33 | """ 34 | param_group_names = {} 35 | param_groups = {} 36 | 37 | num_layers = len(model.blocks) + 1 38 | 39 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 40 | 41 | for n, p in model.named_parameters(): 42 | if not p.requires_grad: 43 | continue 44 | 45 | # no decay: all 1D parameters and model specific ones 46 | if p.ndim == 1 or n in no_weight_decay_list: 47 | g_decay = "no_decay" 48 | this_decay = 0. 49 | else: 50 | g_decay = "decay" 51 | this_decay = weight_decay 52 | 53 | layer_id = get_layer_id_for_vit(n, num_layers) 54 | group_name = "layer_%d_%s" % (layer_id, g_decay) 55 | 56 | if group_name not in param_group_names: 57 | this_scale = layer_scales[layer_id] 58 | 59 | param_group_names[group_name] = { 60 | "lr_scale": this_scale, 61 | "weight_decay": this_decay, 62 | "params": [], 63 | } 64 | param_groups[group_name] = { 65 | "lr_scale": this_scale, 66 | "weight_decay": this_decay, 67 | "params": [], 68 | } 69 | 70 | param_group_names[group_name]["params"].append(n) 71 | param_groups[group_name]["params"].append(p) 72 | 73 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 74 | 75 | return list(param_groups.values()) 76 | 77 | 78 | def get_layer_id_for_vit(name, num_layers): 79 | """ 80 | Assign a parameter with its layer id 81 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 82 | """ 83 | if name in ['cls_token', 'pos_embed']: 84 | return 0 85 | elif name.startswith('patch_embed'): 86 | return 0 87 | elif name.startswith('blocks'): 88 | return int(name.split('.')[1]) + 1 89 | else: 90 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | 23 | def adjust_moco_momentum(epoch, args): 24 | """Adjust moco momentum based on current epoch""" 25 | m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * (1. - args.m) 26 | return m -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import math 16 | import time 17 | from collections import defaultdict, deque 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.distributed as dist 22 | from torch import inf, nn 23 | import numpy as np 24 | 25 | 26 | 27 | class SmoothedValue(object): 28 | """Track a series of values and provide access to smoothed values over a 29 | window or the global series average. 30 | """ 31 | 32 | def __init__(self, window_size=20, fmt=None): 33 | if fmt is None: 34 | fmt = "{median:.4f} ({global_avg:.4f})" 35 | self.deque = deque(maxlen=window_size) 36 | self.total = 0.0 37 | self.count = 0 38 | self.fmt = fmt 39 | 40 | def update(self, value, n=1): 41 | self.deque.append(value) 42 | self.count += n 43 | self.total += value * n 44 | 45 | def synchronize_between_processes(self): 46 | """ 47 | Warning: does not synchronize the deque! 48 | """ 49 | if not is_dist_avail_and_initialized(): 50 | return 51 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 52 | dist.barrier() 53 | dist.all_reduce(t) 54 | t = t.tolist() 55 | self.count = int(t[0]) 56 | self.total = t[1] 57 | 58 | @property 59 | def median(self): 60 | d = torch.tensor(list(self.deque)) 61 | return d.median().item() 62 | 63 | @property 64 | def avg(self): 65 | d = torch.tensor(list(self.deque), dtype=torch.float32) 66 | return d.mean().item() 67 | 68 | @property 69 | def global_avg(self): 70 | return self.total / self.count 71 | 72 | @property 73 | def max(self): 74 | return max(self.deque) 75 | 76 | @property 77 | def value(self): 78 | return self.deque[-1] 79 | 80 | def __str__(self): 81 | return self.fmt.format( 82 | median=self.median, 83 | avg=self.avg, 84 | global_avg=self.global_avg, 85 | max=self.max, 86 | value=self.value) 87 | 88 | 89 | class MetricLogger(object): 90 | def __init__(self, delimiter="\t"): 91 | self.meters = defaultdict(SmoothedValue) 92 | self.delimiter = delimiter 93 | 94 | def update(self, **kwargs): 95 | for k, v in kwargs.items(): 96 | if v is None: 97 | continue 98 | if isinstance(v, torch.Tensor): 99 | v = v.item() 100 | assert isinstance(v, (float, int)) 101 | self.meters[k].update(v) 102 | 103 | def __getattr__(self, attr): 104 | if attr in self.meters: 105 | return self.meters[attr] 106 | if attr in self.__dict__: 107 | return self.__dict__[attr] 108 | raise AttributeError("'{}' object has no attribute '{}'".format( 109 | type(self).__name__, attr)) 110 | 111 | def __str__(self): 112 | loss_str = [] 113 | for name, meter in self.meters.items(): 114 | loss_str.append( 115 | "{}: {}".format(name, str(meter)) 116 | ) 117 | return self.delimiter.join(loss_str) 118 | 119 | def synchronize_between_processes(self): 120 | for meter in self.meters.values(): 121 | meter.synchronize_between_processes() 122 | 123 | def add_meter(self, name, meter): 124 | self.meters[name] = meter 125 | 126 | def log_every(self, iterable, print_freq, header=None): 127 | i = 0 128 | if not header: 129 | header = '' 130 | start_time = time.time() 131 | end = time.time() 132 | iter_time = SmoothedValue(fmt='{avg:.4f}') 133 | data_time = SmoothedValue(fmt='{avg:.4f}') 134 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 135 | log_msg = [ 136 | header, 137 | '[{0' + space_fmt + '}/{1}]', 138 | 'eta: {eta}', 139 | '{meters}', 140 | 'time: {time}', 141 | 'data: {data}' 142 | ] 143 | if torch.cuda.is_available(): 144 | log_msg.append('max mem: {memory:.0f}') 145 | log_msg = self.delimiter.join(log_msg) 146 | MB = 1024.0 * 1024.0 147 | for obj in iterable: 148 | data_time.update(time.time() - end) 149 | yield obj 150 | iter_time.update(time.time() - end) 151 | if i % print_freq == 0 or i == len(iterable) - 1: 152 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 153 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 154 | if torch.cuda.is_available(): 155 | print(log_msg.format( 156 | i, len(iterable), eta=eta_string, 157 | meters=str(self), 158 | time=str(iter_time), data=str(data_time), 159 | memory=torch.cuda.max_memory_allocated() / MB)) 160 | else: 161 | print(log_msg.format( 162 | i, len(iterable), eta=eta_string, 163 | meters=str(self), 164 | time=str(iter_time), data=str(data_time))) 165 | i += 1 166 | end = time.time() 167 | total_time = time.time() - start_time 168 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 169 | print('{} Total time: {} ({:.4f} s / it)'.format( 170 | header, total_time_str, total_time / len(iterable))) 171 | 172 | 173 | def setup_for_distributed(is_master): 174 | """ 175 | This function disables printing when not in master process 176 | """ 177 | builtin_print = builtins.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | force = force or (get_world_size() > 8) 182 | if is_master or force: 183 | now = datetime.datetime.now().time() 184 | builtin_print('[{}] '.format(now), *args, **kwargs) # print with time stamp 185 | 186 | builtins.print = print 187 | 188 | 189 | def is_dist_avail_and_initialized(): 190 | if not dist.is_available(): 191 | return False 192 | if not dist.is_initialized(): 193 | return False 194 | return True 195 | 196 | 197 | def get_world_size(): 198 | if not is_dist_avail_and_initialized(): 199 | return 1 200 | return dist.get_world_size() 201 | 202 | 203 | def get_rank(): 204 | if not is_dist_avail_and_initialized(): 205 | return 0 206 | return dist.get_rank() 207 | 208 | 209 | def is_main_process(): 210 | return get_rank() == 0 211 | 212 | 213 | def save_on_master(*args, **kwargs): 214 | if is_main_process(): 215 | torch.save(*args, **kwargs) 216 | 217 | 218 | def init_distributed_mode(args): 219 | if hasattr(args,'dist_on_itp') and args.dist_on_itp: 220 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 221 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 222 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 223 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 224 | os.environ['LOCAL_RANK'] = str(args.gpu) 225 | os.environ['RANK'] = str(args.rank) 226 | os.environ['WORLD_SIZE'] = str(args.world_size) 227 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 228 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 229 | args.rank = int(os.environ["RANK"]) 230 | args.world_size = int(os.environ['WORLD_SIZE']) 231 | args.gpu = int(os.environ['LOCAL_RANK']) 232 | elif 'SLURM_PROCID' in os.environ: 233 | args.rank = int(os.environ['SLURM_PROCID']) 234 | args.gpu = args.rank % torch.cuda.device_count() 235 | else: 236 | print('Not using distributed mode') 237 | setup_for_distributed(is_master=True) # hack 238 | args.distributed = False 239 | return 240 | 241 | args.distributed = True 242 | 243 | if torch.cuda.is_available(): 244 | torch.cuda.set_device(args.gpu) 245 | args.dist_backend = 'nccl' 246 | print('| distributed init (rank {}): {}, gpu {}'.format( 247 | args.rank, args.dist_url, args.gpu), flush=True) 248 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 249 | world_size=args.world_size, rank=args.rank) 250 | torch.distributed.barrier() 251 | setup_for_distributed(args.rank == 0) 252 | elif torch.backends.mps.is_available(): 253 | args.device = 'mps' # MPS uses a single GPU device 254 | args.distributed = False 255 | else: 256 | print('No CUDA or MPS device found') 257 | args.distributed = False 258 | 259 | 260 | class NativeScalerWithGradNormCount: 261 | state_dict_key = "amp_scaler" 262 | 263 | def __init__(self): 264 | self._scaler = torch.cuda.amp.GradScaler() 265 | 266 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 267 | self._scaler.scale(loss).backward(create_graph=create_graph) 268 | if update_grad: 269 | if clip_grad is not None: 270 | assert parameters is not None 271 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 272 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 273 | else: 274 | self._scaler.unscale_(optimizer) 275 | norm = get_grad_norm_(parameters) 276 | self._scaler.step(optimizer) 277 | self._scaler.update() 278 | else: 279 | norm = None 280 | return norm 281 | 282 | def state_dict(self): 283 | return self._scaler.state_dict() 284 | 285 | def load_state_dict(self, state_dict): 286 | self._scaler.load_state_dict(state_dict) 287 | 288 | 289 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 290 | if isinstance(parameters, torch.Tensor): 291 | parameters = [parameters] 292 | parameters = [p for p in parameters if p.grad is not None] 293 | norm_type = float(norm_type) 294 | if len(parameters) == 0: 295 | return torch.tensor(0.) 296 | device = parameters[0].grad.device 297 | if norm_type == inf: 298 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 299 | else: 300 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 301 | return total_norm 302 | 303 | 304 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, bac=True): 305 | output_dir = Path(args.output_dir) 306 | epoch_name = str(epoch) 307 | if loss_scaler is not None: 308 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 309 | if bac: 310 | checkpoint_paths.append(output_dir / ('checkpoint-%s.pth' % epoch_name)) 311 | for checkpoint_path in checkpoint_paths: 312 | to_save = { 313 | 'model': model_without_ddp.state_dict(), 314 | 'optimizer': optimizer.state_dict(), 315 | 'epoch': epoch, 316 | 'scaler': loss_scaler.state_dict(), 317 | 'args': args, 318 | } 319 | 320 | save_on_master(to_save, checkpoint_path) 321 | else: 322 | client_state = {'epoch': epoch} 323 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 324 | 325 | 326 | def interpolate_pos_encoding(pos_embed, w, h, patch_size=16): 327 | """interpolate pos embedding to w, h 328 | 329 | Args: 330 | pos_embed (Tensor): pos embedding 331 | w (int): the width of target image size 332 | h (int): the height of target image size 333 | 334 | Returns: 335 | Tensor: pos embedding 336 | """ 337 | N = pos_embed.shape[1] - 1 338 | class_pos_embed = pos_embed[:, 0] 339 | patch_pos_embed = pos_embed[:, 1:] 340 | dim = pos_embed.shape[-1] 341 | w0 = w // patch_size 342 | h0 = h // patch_size 343 | w0, h0 = w0 + 0.1, h0 + 0.1 344 | patch_pos_embed = nn.functional.interpolate( 345 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 346 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 347 | mode='bicubic', 348 | ) 349 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 350 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 351 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 352 | 353 | def load_pretrained_weights(model, weight_path): 354 | state_dict = torch.load(weight_path, map_location='cpu') 355 | pos_embed_keys = [k for k in state_dict.keys() if 'pos_embed' in k] # find pos embedding 356 | for k in pos_embed_keys: 357 | del state_dict[k] 358 | print("Loading pretrained weights from %s" % weight_path) 359 | print(model.load_state_dict(state_dict, strict=False)) 360 | 361 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 362 | if args.resume: 363 | if args.resume.startswith('https'): 364 | checkpoint = torch.hub.load_state_dict_from_url( 365 | args.resume, map_location='cpu', check_hash=True) 366 | else: 367 | checkpoint = torch.load(args.resume, map_location='cpu') 368 | model_without_ddp.load_state_dict(checkpoint['model']) 369 | print("Resume checkpoint %s" % args.resume) 370 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 371 | optimizer.load_state_dict(checkpoint['optimizer']) 372 | args.start_epoch = checkpoint['epoch'] + 1 373 | if 'scaler' in checkpoint: 374 | loss_scaler.load_state_dict(checkpoint['scaler']) 375 | print("With optim & sched!") 376 | 377 | 378 | def all_reduce_mean(x): 379 | world_size = get_world_size() 380 | if world_size > 1: 381 | x_reduce = torch.tensor(x).cuda() 382 | dist.all_reduce(x_reduce) 383 | x_reduce /= world_size 384 | return x_reduce.item() 385 | else: 386 | return x 387 | 388 | def init_seed(seed): 389 | torch.manual_seed(seed) 390 | torch.cuda.manual_seed_all(seed) 391 | np.random.seed(seed) -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float32) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /util/prob.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.distributed 4 | import gin 5 | import torch.nn.functional as F 6 | 7 | 8 | class LinearProb(nn.Module): 9 | def __init__(self,data_path, names, representations_fn,num_classes=1): 10 | super().__init__() 11 | self.names = names 12 | self.representations_fn = representations_fn 13 | self.heads = nn.ModuleDict({ 14 | name: nn.LazyLinear(num_classes) for name in names 15 | }) 16 | self.regression = num_classes==1 17 | 18 | distributed = torch.distributed.is_initialized() 19 | self.optimizer = torch.optim.Adam(self.heads.parameters(), lr=1e-3) 20 | if num_classes>1: 21 | self.criterion = nn.CrossEntropyLoss() 22 | else: 23 | self.criterion = nn.MSELoss() 24 | 25 | if ".ffcv" in data_path: 26 | from ffcv.loader import Loader, OrderOption 27 | from dataset.ffcv_transform import ValPipeline 28 | self.dl = Loader(data_path, batch_size=64, order=OrderOption.RANDOM, num_workers=10,drop_last=True, pipelines=ValPipeline(),distributed=distributed) 29 | else: 30 | from torchvision import datasets, transforms 31 | data = datasets.ImageFolder(data_path, 32 | transform=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(), 33 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])) 34 | self.dl = torch.utils.data.DataLoader(data, batch_size=64, shuffle=True, num_workers=10,drop_last=True) 35 | self.next_batch = iter(self.dl) 36 | 37 | def step(self,x,y): 38 | if isinstance(x,list): 39 | x = x[0] 40 | self.step_train(x.cuda(), y.cuda()) 41 | log = self.step_val() 42 | return log 43 | 44 | @torch.no_grad() 45 | def step_val(self): 46 | try: 47 | x, y = next(self.next_batch) 48 | except StopIteration: 49 | self.next_batch = iter(self.dl) 50 | x, y = next(self.next_batch) 51 | log = {} 52 | for name, z in zip(self.names, self.representations_fn(x,y)): 53 | pred = self.heads[name](z.detach()) 54 | if self.regression: 55 | loss = self.criterion(pred.flatten(), y.float()) 56 | log[name]=loss.item() 57 | else: 58 | acc = (pred.argmax(1) == y).float().mean() 59 | log[name] = acc.item() 60 | return log 61 | 62 | def step_train(self, x, y): 63 | self.optimizer.zero_grad() 64 | 65 | loss = 0 66 | for name, z in zip(self.names, self.representations_fn(x,y)): 67 | pred = self.heads[name](z.detach()) 68 | if self.regression: 69 | loss = loss + self.criterion(pred.flatten(), y.float()) 70 | 71 | else: 72 | loss = loss + self.criterion(pred, y) 73 | 74 | 75 | loss.backward() 76 | self.optimizer.step() 77 | 78 | 79 | 80 | @gin.configurable(denylist=['model']) 81 | def build_representations(model): 82 | if hasattr(model,'projector'): 83 | @torch.no_grad() 84 | def representations_fn(x,y): 85 | latent = model.representation(x) 86 | proj = model.projector(latent) 87 | return latent, proj 88 | names = "latent","proj" 89 | else: 90 | @torch.no_grad() 91 | def representations_fn(x,y): 92 | latent = model.representation(x) 93 | return latent, 94 | names = "latent", 95 | 96 | return names, representations_fn 97 | 98 | @gin.configurable(denylist=['model']) 99 | def build_representations_fn(model,fn=build_representations): 100 | return fn(model) 101 | 102 | --------------------------------------------------------------------------------