├── .gitignore ├── LICENSE ├── README.md ├── configs ├── dino.yaml └── esvit.yaml ├── datasets ├── __init__.py ├── imagenet.py └── transforms.py ├── models ├── __init__.py ├── classifier.py ├── dino.py ├── esvit.py ├── vit.py └── xcit.py ├── tools ├── train.py ├── val_knn.py ├── val_linear.py └── visualize_attention.py └── utils ├── __init__.py ├── loss.py ├── metrics.py ├── schedulers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- 2 | *.jpg 3 | *.jpeg 4 | *.png 5 | *.bmp 6 | *.tif 7 | *.tiff 8 | *.heic 9 | *.JPG 10 | *.JPEG 11 | *.PNG 12 | *.BMP 13 | *.TIF 14 | *.TIFF 15 | *.HEIC 16 | *.mp4 17 | *.mov 18 | *.MOV 19 | *.avi 20 | *.data 21 | *.json 22 | 23 | *.cfg 24 | !cfg/yolov3*.cfg 25 | 26 | storage.googleapis.com 27 | test_imgs/ 28 | runs/* 29 | data/* 30 | !data/images/zidane.jpg 31 | !data/images/bus.jpg 32 | !data/coco.names 33 | !data/coco_paper.names 34 | !data/coco.data 35 | !data/coco_*.data 36 | !data/coco_*.txt 37 | !data/trainvalno5k.shapes 38 | !data/*.sh 39 | 40 | pycocotools/* 41 | results*.txt 42 | gcp_test*.sh 43 | 44 | checkpoints/ 45 | output/ 46 | 47 | # Datasets ------------------------------------------------------------------------------------------------------------- 48 | coco/ 49 | coco128/ 50 | VOC/ 51 | 52 | # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- 53 | *.m~ 54 | *.mat 55 | !targets*.mat 56 | 57 | # Neural Network weights ----------------------------------------------------------------------------------------------- 58 | *.weights 59 | *.pt 60 | *.onnx 61 | *.mlmodel 62 | *.torchscript 63 | darknet53.conv.74 64 | yolov3-tiny.conv.15 65 | 66 | # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- 67 | # Byte-compiled / optimized / DLL files 68 | __pycache__/ 69 | *.py[cod] 70 | *$py.class 71 | 72 | # C extensions 73 | *.so 74 | 75 | # Distribution / packaging 76 | .Python 77 | env/ 78 | build/ 79 | develop-eggs/ 80 | dist/ 81 | downloads/ 82 | eggs/ 83 | .eggs/ 84 | lib/ 85 | lib64/ 86 | parts/ 87 | sdist/ 88 | var/ 89 | wheels/ 90 | *.egg-info/ 91 | wandb/ 92 | .installed.cfg 93 | *.egg 94 | 95 | 96 | # PyInstaller 97 | # Usually these files are written by a python script from a template 98 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 99 | *.manifest 100 | *.spec 101 | 102 | # Installer logs 103 | pip-log.txt 104 | pip-delete-this-directory.txt 105 | 106 | # Unit test / coverage reports 107 | htmlcov/ 108 | .tox/ 109 | .coverage 110 | .coverage.* 111 | .cache 112 | nosetests.xml 113 | coverage.xml 114 | *.cover 115 | .hypothesis/ 116 | 117 | # Translations 118 | *.mo 119 | *.pot 120 | 121 | # Django stuff: 122 | *.log 123 | local_settings.py 124 | 125 | # Flask stuff: 126 | instance/ 127 | .webassets-cache 128 | 129 | # Scrapy stuff: 130 | .scrapy 131 | 132 | # Sphinx documentation 133 | docs/_build/ 134 | 135 | # PyBuilder 136 | target/ 137 | 138 | # Jupyter Notebook 139 | .ipynb_checkpoints 140 | 141 | # pyenv 142 | .python-version 143 | 144 | # celery beat schedule file 145 | celerybeat-schedule 146 | 147 | # SageMath parsed files 148 | *.sage.py 149 | 150 | # dotenv 151 | .env 152 | 153 | # virtualenv 154 | .venv* 155 | venv*/ 156 | ENV*/ 157 | 158 | # Spyder project settings 159 | .spyderproject 160 | .spyproject 161 | 162 | # Rope project settings 163 | .ropeproject 164 | 165 | # mkdocs documentation 166 | /site 167 | 168 | # mypy 169 | .mypy_cache/ 170 | 171 | 172 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- 173 | 174 | # General 175 | .DS_Store 176 | .AppleDouble 177 | .LSOverride 178 | 179 | # Icon must end with two \r 180 | Icon 181 | Icon? 182 | 183 | # Thumbnails 184 | ._* 185 | 186 | # Files that might appear in the root of a volume 187 | .DocumentRevisions-V100 188 | .fseventsd 189 | .Spotlight-V100 190 | .TemporaryItems 191 | .Trashes 192 | .VolumeIcon.icns 193 | .com.apple.timemachine.donotpresent 194 | 195 | # Directories potentially created on remote AFP share 196 | .AppleDB 197 | .AppleDesktop 198 | Network Trash Folder 199 | Temporary Items 200 | .apdisk 201 | 202 | 203 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 204 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 205 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 206 | 207 | # User-specific stuff: 208 | .idea/* 209 | .idea/**/workspace.xml 210 | .idea/**/tasks.xml 211 | .idea/dictionaries 212 | .html # Bokeh Plots 213 | .pg # TensorFlow Frozen Graphs 214 | .avi # videos 215 | 216 | # Sensitive or high-churn files: 217 | .idea/**/dataSources/ 218 | .idea/**/dataSources.ids 219 | .idea/**/dataSources.local.xml 220 | .idea/**/sqlDataSources.xml 221 | .idea/**/dynamic.xml 222 | .idea/**/uiDesigner.xml 223 | 224 | # Gradle: 225 | .idea/**/gradle.xml 226 | .idea/**/libraries 227 | 228 | # CMake 229 | cmake-build-debug/ 230 | cmake-build-release/ 231 | 232 | # Mongo Explorer plugin: 233 | .idea/**/mongoSettings.xml 234 | 235 | ## File-based project format: 236 | *.iws 237 | 238 | ## Plugin-specific files: 239 | 240 | # IntelliJ 241 | out/ 242 | 243 | # mpeltonen/sbt-idea plugin 244 | .idea_modules/ 245 | 246 | # JIRA plugin 247 | atlassian-ide-plugin.xml 248 | 249 | # Cursive Clojure plugin 250 | .idea/replstate.xml 251 | 252 | # Crashlytics plugin (for Android Studio and IntelliJ) 253 | com_crashlytics_export_strings.xml 254 | crashlytics.properties 255 | crashlytics-build.properties 256 | fabric.properties 257 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 sithu3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self Supervised Image Classification 2 | 3 | ## Introduction 4 | 5 | Read a blog post from FAIR >> [Self-supervised learning: The dark matter of intelligence](https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/). 6 | 7 | 18 | 19 | Models 20 | 21 | * [EsViT](https://arxiv.org/abs/2106.09785) 22 | * [XCiT](https://arxiv.org/abs/2106.09681v2) 23 | * [DINO](https://arxiv.org/abs/2104.14294v2) 24 | * [MoCov3](https//arxiv.org/abs/2104.02057) 25 | 26 | ## Model Zoo 27 | 28 | Method | Model | ImageNet Top1 Acc (Linear) | ImageNet Top1 Acc (k-NN) | Params (M) | Weights 29 | --- | --- | --- | --- | --- | --- 30 | EsViT | Swin-B/W=14 | 81.3 | 79.3 | 87 | N/A 31 | EsViT | Swin-S/W=14 | 80.8 | 79.1 | 49 | N/A 32 | EsViT | Swin-T/W=14 | 78.7 | 77.0 | 28 | N/A 33 | DINO | XCiT-M24/8 | 80.3 | 77.9 | 84 | [model](https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth)/[checkpoint](https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain_full_checkpoint.pth) 34 | DINO | XCiT-S12/8 | 79.2 | 77.1 | 26 | [model](https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth)/[checkpoint](https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain_full_checkpoint.pth) 35 | DINO | ViT-B/8 | 80.1 | 77.4 | 85 | [model](https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth)/[checkpoint](https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain_full_checkpoint.pth) 36 | DINO | ViT-S/8 | 79.7 | 78.3 | 21 | [model](https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth)/[checkpoint](https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain_full_checkpoint.pth) 37 | MoCov3 | ViT-B | 76.7 | - | - | N/A 38 | MoCov3 | ViT-S | 73.2 | - | - | N/A 39 | 40 | ## Configuration 41 | 42 | Create a configuration file in `configs`. Sample configuration for ImageNet dataset with DINO can be found [here](configs/dino.yaml). Then edit the fields you think if it is needed. This configuration file is needed for all of training, evaluation and prediction scripts. 43 | 44 | ## Training 45 | 46 | ### Single GPU 47 | ```bash 48 | $ python tools/train.py --cfg configs/CONFIG_FILE_NAME.yaml 49 | ``` 50 | 51 | ### Multiple GPUs 52 | 53 | Traing with 2 GPUs: 54 | 55 | ```bash 56 | $ python -m torch.distributed.launch --nproc_per_node=2 --use_env tools/train.py --cfg configs/CONFIG_FILE_NAME.yaml 57 | ``` 58 | 59 | ## Evaluation 60 | 61 | Make sure to set `MODEL_PATH` of the configuration file to your trained model directory. 62 | 63 | ### Linear Classification 64 | 65 | This will train a supervised linear classifier on top of trained weights and evaluate the result. 66 | 67 | ```bash 68 | $ python -m torch.distributed.launch --nproc_per_node=2 --use_env tools/val_linear.py --cfg configs/CONFIG_FILE_NAME.yaml 69 | ``` 70 | 71 | ### k-NN Classification 72 | 73 | ```bash 74 | $ python -m torch.distributed.launch --nproc_per_node=1 --use_env tools/val_knn.py --cfg configs/CONFIG_FILE_NAME.yaml 75 | ``` 76 | 77 | 78 | ## Attention Visualization 79 | 80 | Make sure to set `MODEL_PATH` of the configuration file to model's weights. 81 | 82 | ```bash 83 | $ python tools/visualize_attention.py --cfg configs/CONFIG_FILE_NAME.yaml 84 | ``` -------------------------------------------------------------------------------- /configs/dino.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | SAVE_DIR: './output' # output folder name used for saving the trained model and logs 3 | MODEL_PATH: 'checkpoints/xcit/dino_xcit_small_12_p8_pretrain.pth' # trained model path (used for evaluation, inference and optimization) 4 | 5 | METHOD: 'dino' # name of the method you are using 6 | 7 | MODEL: 8 | NAME: 'xcit' 9 | VARIANT: 'S12' # sub name of the model you are using 10 | 11 | DATASET: 12 | NAME: imagenet # dataset name 13 | ROOT: '../datasets/imagenet' # dataset root path 14 | 15 | TRAIN: 16 | IMAGE_SIZE: [224, 224] # image size used in training the model 17 | EPOCHS: 100 # number of epochs to train 18 | BATCH_SIZE: 8 # batch size used to train 19 | WORKERS: 8 # number of workers used in training dataloader 20 | LR: 0.01 # initial learning rate used in optimizer 21 | DECAY: 0.0005 # decay rate use in optimizer 22 | LOSS: dinoloss # loss function name (vanilla, label_smooth, soft_target) 23 | DINO: 24 | CROP_SCALE: 0.4 25 | LOCAL_CROPS: 10 26 | HEAD_DIM: 65536 27 | TEACHER_TEMP: 0.07 28 | WARMUP_TEACHER_TEMP: 0.04 29 | WARMUP_TEACHER_EPOCHS: 30 30 | TEACHER_MOMENTUM: 0.996 31 | SCHEDULER: 32 | NAME: steplr 33 | PARAMS: (30, 0.1) 34 | EVAL_INTERVAL: 20 # interval to evaluate the model during training 35 | SEED: 123 # random seed number 36 | AMP: false # use Automatic Mixed Precision training or not 37 | DDP: false 38 | 39 | EVAL: 40 | IMAGE_SIZE: [224, 224] # evaluation image size 41 | BATCH_SIZE: 8 # evaluation batch size 42 | WORKERS: 4 # number of workers used in evalaution dataloader 43 | NUM_CLASSES: 1000 44 | KNN: 45 | NB_KNN: [10, 20, 100, 200] # number of NN to use, 20 is usually the best 46 | TEMP: 0.07 # temperature used in voting coefficient 47 | 48 | TEST: 49 | MODE: image # inference mode (image) 50 | FILE: 'test_imgs' # filename or foldername (image mode) 51 | IMAGE_SIZE: [480, 480] # inference image size -------------------------------------------------------------------------------- /configs/esvit.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | SAVE_DIR: './output' # output folder name used for saving the trained model and logs 3 | MODEL_PATH: '' # trained model path (used for evaluation, inference and optimization) 4 | 5 | METHOD: 'esvit' # name of the method you are using 6 | 7 | MODEL: 8 | NAME: 'vit' 9 | VARIANT: 'B' # sub name of the model you are using 10 | 11 | DATASET: 12 | NAME: imagenet # dataset name 13 | ROOT: '../datasets/imagenet' # dataset root path 14 | 15 | TRAIN: 16 | IMAGE_SIZE: [224, 224] # image size used in training the model 17 | EPOCHS: 100 # number of epochs to train 18 | BATCH_SIZE: 8 # batch size used to train 19 | WORKERS: 8 # number of workers used in training dataloader 20 | LR: 0.01 # initial learning rate used in optimizer 21 | DECAY: 0.0005 # decay rate use in optimizer 22 | LOSS: ddinoloss # loss function name (vanilla, label_smooth, soft_target) 23 | DINO: 24 | CROP_SCALE: 0.4 25 | LOCAL_CROPS: 10 26 | HEAD_DIM: 65536 27 | TEACHER_TEMP: 0.07 28 | WARMUP_TEACHER_TEMP: 0.04 29 | WARMUP_TEACHER_EPOCHS: 30 30 | TEACHER_MOMENTUM: 0.996 31 | SCHEDULER: 32 | NAME: steplr 33 | PARAMS: (30, 0.1) 34 | EVAL_INTERVAL: 20 # interval to evaluate the model during training 35 | SEED: 123 # random seed number 36 | AMP: false # use Automatic Mixed Precision training or not 37 | DDP: false 38 | 39 | EVAL: 40 | IMAGE_SIZE: [224, 224] # evaluation image size 41 | BATCH_SIZE: 8 # evaluation batch size 42 | WORKERS: 4 # number of workers used in evalaution dataloader 43 | NUM_CLASSES: 1000 44 | KNN: 45 | NB_KNN: [10, 20, 100, 200] # number of NN to use, 20 is usually the best 46 | TEMP: 0.07 # temperature used in voting coefficient 47 | 48 | TEST: 49 | MODE: image # inference mode (image) 50 | FILE: 'test_imgs' # filename or foldername (image mode) 51 | IMAGE_SIZE: [480, 480] # inference image size -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/self-supervised-learning/490f9dd4dc932ccd666caf85ea38ecce5221dd57/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | from typing import Optional, Callable 3 | from pathlib import Path 4 | 5 | CLASSES = ['tench, Tinca tinca', 'goldfish, Carassius auratus', 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias', 'tiger shark, Galeocerdo cuvieri', 'hammerhead, hammerhead shark', 'electric ray, crampfish, numbfish, torpedo', 'stingray', 'cock', 'hen', 'ostrich, Struthio camelus', 'brambling, Fringilla montifringilla', 'goldfinch, Carduelis carduelis', 'house finch, linnet, Carpodacus mexicanus', 'junco, snowbird', 'indigo bunting, indigo finch, indigo bird, Passerina cyanea', 'robin, American robin, Turdus migratorius', 'bulbul', 'jay', 'magpie', 'chickadee', 'water ouzel, dipper', 'kite', 'bald eagle, American eagle, Haliaeetus leucocephalus', 'vulture', 'great grey owl, great gray owl, Strix nebulosa', 'European fire salamander, Salamandra salamandra', 'common newt, Triturus vulgaris', 'eft', 'spotted salamander, Ambystoma maculatum', 'axolotl, mud puppy, Ambystoma mexicanum', 'bullfrog, Rana catesbeiana', 'tree frog, tree-frog', 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui', 'loggerhead, loggerhead turtle, Caretta caretta', 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea', 'mud turtle', 'terrapin', 'box turtle, box tortoise', 'banded gecko', 'common iguana, iguana, Iguana iguana', 'American chameleon, anole, Anolis carolinensis', 'whiptail, whiptail lizard', 'agama', 'frilled lizard, Chlamydosaurus kingi', 'alligator lizard', 'Gila monster, Heloderma suspectum', 'green lizard, Lacerta viridis', 'African chameleon, Chamaeleo chamaeleon', 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis', 'African crocodile, Nile crocodile, Crocodylus niloticus', 'American alligator, Alligator mississipiensis', 'triceratops', 'thunder snake, worm snake, Carphophis amoenus', 'ringneck snake, ring-necked snake, ring snake', 'hognose snake, puff adder, sand viper', 'green snake, grass snake', 'king snake, kingsnake', 'garter snake, grass snake', 'water snake', 'vine snake', 'night snake, Hypsiglena torquata', 'boa constrictor, Constrictor constrictor', 'rock python, rock snake, Python sebae', 'Indian cobra, Naja naja', 'green mamba', 'sea snake', 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus', 'diamondback, diamondback rattlesnake, Crotalus adamanteus', 'sidewinder, horned rattlesnake, Crotalus cerastes', 'trilobite', 'harvestman, daddy longlegs, Phalangium opilio', 'scorpion', 'black and gold garden spider, Argiope aurantia', 'barn spider, Araneus cavaticus', 'garden spider, Aranea diademata', 'black widow, Latrodectus mactans', 'tarantula', 'wolf spider, hunting spider', 'tick', 'centipede', 'black grouse', 'ptarmigan', 'ruffed grouse, partridge, Bonasa umbellus', 'prairie chicken, prairie grouse, prairie fowl', 'peacock', 'quail', 'partridge', 'African grey, African gray, Psittacus erithacus', 'macaw', 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita', 'lorikeet', 'coucal', 'bee eater', 'hornbill', 'hummingbird', 'jacamar', 'toucan', 'drake', 'red-breasted merganser, Mergus serrator', 'goose', 'black swan, Cygnus atratus', 'tusker', 'echidna, spiny anteater, anteater', 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus', 'wallaby, brush kangaroo', 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus', 'wombat', 'jellyfish', 'sea anemone, anemone', 'brain coral', 'flatworm, platyhelminth', 'nematode, nematode worm, roundworm', 'conch', 'snail', 'slug', 'sea slug, nudibranch', 'chiton, coat-of-mail shell, sea cradle, polyplacophore', 'chambered nautilus, pearly nautilus, nautilus', 'Dungeness crab, Cancer magister', 'rock crab, Cancer irroratus', 'fiddler crab', 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica', 'American lobster, Northern lobster, Maine lobster, Homarus americanus', 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish', 'crayfish, crawfish, crawdad, crawdaddy', 'hermit crab', 'isopod', 'white stork, Ciconia ciconia', 'black stork, Ciconia nigra', 'spoonbill', 'flamingo', 'little blue heron, Egretta caerulea', 'American egret, great white heron, Egretta albus', 'bittern', 'crane', 'limpkin, Aramus pictus', 'European gallinule, Porphyrio porphyrio', 'American coot, marsh hen, mud hen, water hen, Fulica americana', 'bustard', 'ruddy turnstone, Arenaria interpres', 'red-backed sandpiper, dunlin, Erolia alpina', 'redshank, Tringa totanus', 'dowitcher', 'oystercatcher, oyster catcher', 'pelican', 'king penguin, Aptenodytes patagonica', 'albatross, mollymawk', 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus', 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca', 'dugong, Dugong dugon', 'sea lion', 'Chihuahua', 'Japanese spaniel', 'Maltese dog, Maltese terrier, Maltese', 'Pekinese, Pekingese, Peke', 'Shih-Tzu', 'Blenheim spaniel', 'papillon', 'toy terrier', 'Rhodesian ridgeback', 'Afghan hound, Afghan', 'basset, basset hound', 'beagle', 'bloodhound, sleuthhound', 'bluetick', 'black-and-tan coonhound', 'Walker hound, Walker foxhound', 'English foxhound', 'redbone', 'borzoi, Russian wolfhound', 'Irish wolfhound', 'Italian greyhound', 'whippet', 'Ibizan hound, Ibizan Podenco', 'Norwegian elkhound, elkhound', 'otterhound, otter hound', 'Saluki, gazelle hound', 'Scottish deerhound, deerhound', 'Weimaraner', 'Staffordshire bullterrier, Staffordshire bull terrier', 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier', 'Bedlington terrier', 'Border terrier', 'Kerry blue terrier', 'Irish terrier', 'Norfolk terrier', 'Norwich terrier', 'Yorkshire terrier', 'wire-haired fox terrier', 'Lakeland terrier', 'Sealyham terrier, Sealyham', 'Airedale, Airedale terrier', 'cairn, cairn terrier', 'Australian terrier', 'Dandie Dinmont, Dandie Dinmont terrier', 'Boston bull, Boston terrier', 'miniature schnauzer', 'giant schnauzer', 'standard schnauzer', 'Scotch terrier, Scottish terrier, Scottie', 'Tibetan terrier, chrysanthemum dog', 'silky terrier, Sydney silky', 'soft-coated wheaten terrier', 'West Highland white terrier', 'Lhasa, Lhasa apso', 'flat-coated retriever', 'curly-coated retriever', 'golden retriever', 'Labrador retriever', 'Chesapeake Bay retriever', 'German short-haired pointer', 'vizsla, Hungarian pointer', 'English setter', 'Irish setter, red setter', 'Gordon setter', 'Brittany spaniel', 'clumber, clumber spaniel', 'English springer, English springer spaniel', 'Welsh springer spaniel', 'cocker spaniel, English cocker spaniel, cocker', 'Sussex spaniel', 'Irish water spaniel', 'kuvasz', 'schipperke', 'groenendael', 'malinois', 'briard', 'kelpie', 'komondor', 'Old English sheepdog, bobtail', 'Shetland sheepdog, Shetland sheep dog, Shetland', 'collie', 'Border collie', 'Bouvier des Flandres, Bouviers des Flandres', 'Rottweiler', 'German shepherd, German shepherd dog, German police dog, alsatian', 'Doberman, Doberman pinscher', 'miniature pinscher', 'Greater Swiss Mountain dog', 'Bernese mountain dog', 'Appenzeller', 'EntleBucher', 'boxer', 'bull mastiff', 'Tibetan mastiff', 'French bulldog', 'Great Dane', 'Saint Bernard, St Bernard', 'Eskimo dog, husky', 'malamute, malemute, Alaskan malamute', 'Siberian husky', 'dalmatian, coach dog, carriage dog', 'affenpinscher, monkey pinscher, monkey dog', 'basenji', 'pug, pug-dog', 'Leonberg', 'Newfoundland, Newfoundland dog', 'Great Pyrenees', 'Samoyed, Samoyede', 'Pomeranian', 'chow, chow chow', 'keeshond', 'Brabancon griffon', 'Pembroke, Pembroke Welsh corgi', 'Cardigan, Cardigan Welsh corgi', 'toy poodle', 'miniature poodle', 'standard poodle', 'Mexican hairless', 'timber wolf, grey wolf, gray wolf, Canis lupus', 'white wolf, Arctic wolf, Canis lupus tundrarum', 'red wolf, maned wolf, Canis rufus, Canis niger', 'coyote, prairie wolf, brush wolf, Canis latrans', 'dingo, warrigal, warragal, Canis dingo', 'dhole, Cuon alpinus', 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus', 'hyena, hyaena', 'red fox, Vulpes vulpes', 'kit fox, Vulpes macrotis', 'Arctic fox, white fox, Alopex lagopus', 'grey fox, gray fox, Urocyon cinereoargenteus', 'tabby, tabby cat', 'tiger cat', 'Persian cat', 'Siamese cat, Siamese', 'Egyptian cat', 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor', 'lynx, catamount', 'leopard, Panthera pardus', 'snow leopard, ounce, Panthera uncia', 'jaguar, panther, Panthera onca, Felis onca', 'lion, king of beasts, Panthera leo', 'tiger, Panthera tigris', 'cheetah, chetah, Acinonyx jubatus', 'brown bear, bruin, Ursus arctos', 'American black bear, black bear, Ursus americanus, Euarctos americanus', 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus', 'sloth bear, Melursus ursinus, Ursus ursinus', 'mongoose', 'meerkat, mierkat', 'tiger beetle', 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle', 'ground beetle, carabid beetle', 'long-horned beetle, longicorn, longicorn beetle', 'leaf beetle, chrysomelid', 'dung beetle', 'rhinoceros beetle', 'weevil', 'fly', 'bee', 'ant, emmet, pismire', 'grasshopper, hopper', 'cricket', 'walking stick, walkingstick, stick insect', 'cockroach, roach', 'mantis, mantid', 'cicada, cicala', 'leafhopper', 'lacewing, lacewing fly', "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", 'damselfly', 'admiral', 'ringlet, ringlet butterfly', 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', 'cabbage butterfly', 'sulphur butterfly, sulfur butterfly', 'lycaenid, lycaenid butterfly', 'starfish, sea star', 'sea urchin', 'sea cucumber, holothurian', 'wood rabbit, cottontail, cottontail rabbit', 'hare', 'Angora, Angora rabbit', 'hamster', 'porcupine, hedgehog', 'fox squirrel, eastern fox squirrel, Sciurus niger', 'marmot', 'beaver', 'guinea pig, Cavia cobaya', 'sorrel', 'zebra', 'hog, pig, grunter, squealer, Sus scrofa', 'wild boar, boar, Sus scrofa', 'warthog', 'hippopotamus, hippo, river horse, Hippopotamus amphibius', 'ox', 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis', 'bison', 'ram, tup', 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis', 'ibex, Capra ibex', 'hartebeest', 'impala, Aepyceros melampus', 'gazelle', 'Arabian camel, dromedary, Camelus dromedarius', 'llama', 'weasel', 'mink', 'polecat, fitch, foulmart, foumart, Mustela putorius', 'black-footed ferret, ferret, Mustela nigripes', 'otter', 'skunk, polecat, wood pussy', 'badger', 'armadillo', 'three-toed sloth, ai, Bradypus tridactylus', 'orangutan, orang, orangutang, Pongo pygmaeus', 'gorilla, Gorilla gorilla', 'chimpanzee, chimp, Pan troglodytes', 'gibbon, Hylobates lar', 'siamang, Hylobates syndactylus, Symphalangus syndactylus', 'guenon, guenon monkey', 'patas, hussar monkey, Erythrocebus patas', 'baboon', 'macaque', 'langur', 'colobus, colobus monkey', 'proboscis monkey, Nasalis larvatus', 'marmoset', 'capuchin, ringtail, Cebus capucinus', 'howler monkey, howler', 'titi, titi monkey', 'spider monkey, Ateles geoffroyi', 'squirrel monkey, Saimiri sciureus', 'Madagascar cat, ring-tailed lemur, Lemur catta', 'indri, indris, Indri indri, Indri brevicaudatus', 'Indian elephant, Elephas maximus', 'African elephant, Loxodonta africana', 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens', 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca', 'barracouta, snoek', 'eel', 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch', 'rock beauty, Holocanthus tricolor', 'anemone fish', 'sturgeon', 'gar, garfish, garpike, billfish, Lepisosteus osseus', 'lionfish', 'puffer, pufferfish, blowfish, globefish', 'abacus', 'abaya', "academic gown, academic robe, judge's robe", 'accordion, piano accordion, squeeze box', 'acoustic guitar', 'aircraft carrier, carrier, flattop, attack aircraft carrier', 'airliner', 'airship, dirigible', 'altar', 'ambulance', 'amphibian, amphibious vehicle', 'analog clock', 'apiary, bee house', 'apron', 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', 'assault rifle, assault gun', 'backpack, back pack, knapsack, packsack, rucksack, haversack', 'bakery, bakeshop, bakehouse', 'balance beam, beam', 'balloon', 'ballpoint, ballpoint pen, ballpen, Biro', 'Band Aid', 'banjo', 'bannister, banister, balustrade, balusters, handrail', 'barbell', 'barber chair', 'barbershop', 'barn', 'barometer', 'barrel, cask', 'barrow, garden cart, lawn cart, wheelbarrow', 'baseball', 'basketball', 'bassinet', 'bassoon', 'bathing cap, swimming cap', 'bath towel', 'bathtub, bathing tub, bath, tub', 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon', 'beacon, lighthouse, beacon light, pharos', 'beaker', 'bearskin, busby, shako', 'beer bottle', 'beer glass', 'bell cote, bell cot', 'bib', 'bicycle-built-for-two, tandem bicycle, tandem', 'bikini, two-piece', 'binder, ring-binder', 'binoculars, field glasses, opera glasses', 'birdhouse', 'boathouse', 'bobsled, bobsleigh, bob', 'bolo tie, bolo, bola tie, bola', 'bonnet, poke bonnet', 'bookcase', 'bookshop, bookstore, bookstall', 'bottlecap', 'bow', 'bow tie, bow-tie, bowtie', 'brass, memorial tablet, plaque', 'brassiere, bra, bandeau', 'breakwater, groin, groyne, mole, bulwark, seawall, jetty', 'breastplate, aegis, egis', 'broom', 'bucket, pail', 'buckle', 'bulletproof vest', 'bullet train, bullet', 'butcher shop, meat market', 'cab, hack, taxi, taxicab', 'caldron, cauldron', 'candle, taper, wax light', 'cannon', 'canoe', 'can opener, tin opener', 'cardigan', 'car mirror', 'carousel, carrousel, merry-go-round, roundabout, whirligig', "carpenter's kit, tool kit", 'carton', 'car wheel', 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM', 'cassette', 'cassette player', 'castle', 'catamaran', 'CD player', 'cello, violoncello', 'cellular telephone, cellular phone, cellphone, cell, mobile phone', 'chain', 'chainlink fence', 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour', 'chain saw, chainsaw', 'chest', 'chiffonier, commode', 'chime, bell, gong', 'china cabinet, china closet', 'Christmas stocking', 'church, church building', 'cinema, movie theater, movie theatre, movie house, picture palace', 'cleaver, meat cleaver, chopper', 'cliff dwelling', 'cloak', 'clog, geta, patten, sabot', 'cocktail shaker', 'coffee mug', 'coffeepot', 'coil, spiral, volute, whorl, helix', 'combination lock', 'computer keyboard, keypad', 'confectionery, confectionary, candy store', 'container ship, containership, container vessel', 'convertible', 'corkscrew, bottle screw', 'cornet, horn, trumpet, trump', 'cowboy boot', 'cowboy hat, ten-gallon hat', 'cradle', 'crane', 'crash helmet', 'crate', 'crib, cot', 'Crock Pot', 'croquet ball', 'crutch', 'cuirass', 'dam, dike, dyke', 'desk', 'desktop computer', 'dial telephone, dial phone', 'diaper, nappy, napkin', 'digital clock', 'digital watch', 'dining table, board', 'dishrag, dishcloth', 'dishwasher, dish washer, dishwashing machine', 'disk brake, disc brake', 'dock, dockage, docking facility', 'dogsled, dog sled, dog sleigh', 'dome', 'doormat, welcome mat', 'drilling platform, offshore rig', 'drum, membranophone, tympan', 'drumstick', 'dumbbell', 'Dutch oven', 'electric fan, blower', 'electric guitar', 'electric locomotive', 'entertainment center', 'envelope', 'espresso maker', 'face powder', 'feather boa, boa', 'file, file cabinet, filing cabinet', 'fireboat', 'fire engine, fire truck', 'fire screen, fireguard', 'flagpole, flagstaff', 'flute, transverse flute', 'folding chair', 'football helmet', 'forklift', 'fountain', 'fountain pen', 'four-poster', 'freight car', 'French horn, horn', 'frying pan, frypan, skillet', 'fur coat', 'garbage truck, dustcart', 'gasmask, respirator, gas helmet', 'gas pump, gasoline pump, petrol pump, island dispenser', 'goblet', 'go-kart', 'golf ball', 'golfcart, golf cart', 'gondola', 'gong, tam-tam', 'gown', 'grand piano, grand', 'greenhouse, nursery, glasshouse', 'grille, radiator grille', 'grocery store, grocery, food market, market', 'guillotine', 'hair slide', 'hair spray', 'half track', 'hammer', 'hamper', 'hand blower, blow dryer, blow drier, hair dryer, hair drier', 'hand-held computer, hand-held microcomputer', 'handkerchief, hankie, hanky, hankey', 'hard disc, hard disk, fixed disk', 'harmonica, mouth organ, harp, mouth harp', 'harp', 'harvester, reaper', 'hatchet', 'holster', 'home theater, home theatre', 'honeycomb', 'hook, claw', 'hoopskirt, crinoline', 'horizontal bar, high bar', 'horse cart, horse-cart', 'hourglass', 'iPod', 'iron, smoothing iron', "jack-o'-lantern", 'jean, blue jean, denim', 'jeep, landrover', 'jersey, T-shirt, tee shirt', 'jigsaw puzzle', 'jinrikisha, ricksha, rickshaw', 'joystick', 'kimono', 'knee pad', 'knot', 'lab coat, laboratory coat', 'ladle', 'lampshade, lamp shade', 'laptop, laptop computer', 'lawn mower, mower', 'lens cap, lens cover', 'letter opener, paper knife, paperknife', 'library', 'lifeboat', 'lighter, light, igniter, ignitor', 'limousine, limo', 'liner, ocean liner', 'lipstick, lip rouge', 'Loafer', 'lotion', 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', "loupe, jeweler's loupe", 'lumbermill, sawmill', 'magnetic compass', 'mailbag, postbag', 'mailbox, letter box', 'maillot', 'maillot, tank suit', 'manhole cover', 'maraca', 'marimba, xylophone', 'mask', 'matchstick', 'maypole', 'maze, labyrinth', 'measuring cup', 'medicine chest, medicine cabinet', 'megalith, megalithic structure', 'microphone, mike', 'microwave, microwave oven', 'military uniform', 'milk can', 'minibus', 'miniskirt, mini', 'minivan', 'missile', 'mitten', 'mixing bowl', 'mobile home, manufactured home', 'Model T', 'modem', 'monastery', 'monitor', 'moped', 'mortar', 'mortarboard', 'mosque', 'mosquito net', 'motor scooter, scooter', 'mountain bike, all-terrain bike, off-roader', 'mountain tent', 'mouse, computer mouse', 'mousetrap', 'moving van', 'muzzle', 'nail', 'neck brace', 'necklace', 'nipple', 'notebook, notebook computer', 'obelisk', 'oboe, hautboy, hautbois', 'ocarina, sweet potato', 'odometer, hodometer, mileometer, milometer', 'oil filter', 'organ, pipe organ', 'oscilloscope, scope, cathode-ray oscilloscope, CRO', 'overskirt', 'oxcart', 'oxygen mask', 'packet', 'paddle, boat paddle', 'paddlewheel, paddle wheel', 'padlock', 'paintbrush', "pajama, pyjama, pj's, jammies", 'palace', 'panpipe, pandean pipe, syrinx', 'paper towel', 'parachute, chute', 'parallel bars, bars', 'park bench', 'parking meter', 'passenger car, coach, carriage', 'patio, terrace', 'pay-phone, pay-station', 'pedestal, plinth, footstall', 'pencil box, pencil case', 'pencil sharpener', 'perfume, essence', 'Petri dish', 'photocopier', 'pick, plectrum, plectron', 'pickelhaube', 'picket fence, paling', 'pickup, pickup truck', 'pier', 'piggy bank, penny bank', 'pill bottle', 'pillow', 'ping-pong ball', 'pinwheel', 'pirate, pirate ship', 'pitcher, ewer', "plane, carpenter's plane, woodworking plane", 'planetarium', 'plastic bag', 'plate rack', 'plow, plough', "plunger, plumber's helper", 'Polaroid camera, Polaroid Land camera', 'pole', 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria', 'poncho', 'pool table, billiard table, snooker table', 'pop bottle, soda bottle', 'pot, flowerpot', "potter's wheel", 'power drill', 'prayer rug, prayer mat', 'printer', 'prison, prison house', 'projectile, missile', 'projector', 'puck, hockey puck', 'punching bag, punch bag, punching ball, punchball', 'purse', 'quill, quill pen', 'quilt, comforter, comfort, puff', 'racer, race car, racing car', 'racket, racquet', 'radiator', 'radio, wireless', 'radio telescope, radio reflector', 'rain barrel', 'recreational vehicle, RV, R.V.', 'reel', 'reflex camera', 'refrigerator, icebox', 'remote control, remote', 'restaurant, eating house, eating place, eatery', 'revolver, six-gun, six-shooter', 'rifle', 'rocking chair, rocker', 'rotisserie', 'rubber eraser, rubber, pencil eraser', 'rugby ball', 'rule, ruler', 'running shoe', 'safe', 'safety pin', 'saltshaker, salt shaker', 'sandal', 'sarong', 'sax, saxophone', 'scabbard', 'scale, weighing machine', 'school bus', 'schooner', 'scoreboard', 'screen, CRT screen', 'screw', 'screwdriver', 'seat belt, seatbelt', 'sewing machine', 'shield, buckler', 'shoe shop, shoe-shop, shoe store', 'shoji', 'shopping basket', 'shopping cart', 'shovel', 'shower cap', 'shower curtain', 'ski', 'ski mask', 'sleeping bag', 'slide rule, slipstick', 'sliding door', 'slot, one-armed bandit', 'snorkel', 'snowmobile', 'snowplow, snowplough', 'soap dispenser', 'soccer ball', 'sock', 'solar dish, solar collector, solar furnace', 'sombrero', 'soup bowl', 'space bar', 'space heater', 'space shuttle', 'spatula', 'speedboat', "spider web, spider's web", 'spindle', 'sports car, sport car', 'spotlight, spot', 'stage', 'steam locomotive', 'steel arch bridge', 'steel drum', 'stethoscope', 'stole', 'stone wall', 'stopwatch, stop watch', 'stove', 'strainer', 'streetcar, tram, tramcar, trolley, trolley car', 'stretcher', 'studio couch, day bed', 'stupa, tope', 'submarine, pigboat, sub, U-boat', 'suit, suit of clothes', 'sundial', 'sunglass', 'sunglasses, dark glasses, shades', 'sunscreen, sunblock, sun blocker', 'suspension bridge', 'swab, swob, mop', 'sweatshirt', 'swimming trunks, bathing trunks', 'swing', 'switch, electric switch, electrical switch', 'syringe', 'table lamp', 'tank, army tank, armored combat vehicle, armoured combat vehicle', 'tape player', 'teapot', 'teddy, teddy bear', 'television, television system', 'tennis ball', 'thatch, thatched roof', 'theater curtain, theatre curtain', 'thimble', 'thresher, thrasher, threshing machine', 'throne', 'tile roof', 'toaster', 'tobacco shop, tobacconist shop, tobacconist', 'toilet seat', 'torch', 'totem pole', 'tow truck, tow car, wrecker', 'toyshop', 'tractor', 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi', 'tray', 'trench coat', 'tricycle, trike, velocipede', 'trimaran', 'tripod', 'triumphal arch', 'trolleybus, trolley coach, trackless trolley', 'trombone', 'tub, vat', 'turnstile', 'typewriter keyboard', 'umbrella', 'unicycle, monocycle', 'upright, upright piano', 'vacuum, vacuum cleaner', 'vase', 'vault', 'velvet', 'vending machine', 'vestment', 'viaduct', 'violin, fiddle', 'volleyball', 'waffle iron', 'wall clock', 'wallet, billfold, notecase, pocketbook', 'wardrobe, closet, press', 'warplane, military plane', 'washbasin, handbasin, washbowl, lavabo, wash-hand basin', 'washer, automatic washer, washing machine', 'water bottle', 'water jug', 'water tower', 'whiskey jug', 'whistle', 'wig', 'window screen', 'window shade', 'Windsor tie', 'wine bottle', 'wing', 'wok', 'wooden spoon', 'wool, woolen, woollen', 'worm fence, snake fence, snake-rail fence, Virginia fence', 'wreck', 'yawl', 'yurt', 'web site, website, internet site, site', 'comic book', 'crossword puzzle, crossword', 'street sign', 'traffic light, traffic signal, stoplight', 'book jacket, dust cover, dust jacket, dust wrapper', 'menu', 'plate', 'guacamole', 'consomme', 'hot pot, hotpot', 'trifle', 'ice cream, icecream', 'ice lolly, lolly, lollipop, popsicle', 'French loaf', 'bagel, beigel', 'pretzel', 'cheeseburger', 'hotdog, hot dog, red hot', 'mashed potato', 'head cabbage', 'broccoli', 'cauliflower', 'zucchini, courgette', 'spaghetti squash', 'acorn squash', 'butternut squash', 'cucumber, cuke', 'artichoke, globe artichoke', 'bell pepper', 'cardoon', 'mushroom', 'Granny Smith', 'strawberry', 'orange', 'lemon', 'fig', 'pineapple, ananas', 'banana', 'jackfruit, jak, jack', 'custard apple', 'pomegranate', 'hay', 'carbonara', 'chocolate sauce, chocolate syrup', 'dough', 'meat loaf, meatloaf', 'pizza, pizza pie', 'potpie', 'burrito', 'red wine', 'espresso', 'cup', 'eggnog', 'alp', 'bubble', 'cliff, drop, drop-off', 'coral reef', 'geyser', 'lakeside, lakeshore', 'promontory, headland, head, foreland', 'sandbar, sand bar', 'seashore, coast, seacoast, sea-coast', 'valley, vale', 'volcano', 'ballplayer, baseball player', 'groom, bridegroom', 'scuba diver', 'rapeseed', 'daisy', "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", 'corn', 'acorn', 'hip, rose hip, rosehip', 'buckeye, horse chestnut, conker', 'coral fungus', 'agaric', 'gyromitra', 'stinkhorn, carrion fungus', 'earthstar', 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa', 'bolete', 'ear, spike, capitulum', 'toilet tissue, toilet paper, bathroom tissue'] 6 | 7 | 8 | class ImageNet(ImageFolder): 9 | def __init__(self, root: str, split: str = 'train', transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): 10 | assert split in ('train', 'val') 11 | split_folder = Path(root) / split 12 | super().__init__(split_folder, transform=transform, target_transform=target_transform) 13 | 14 | 15 | if __name__ == '__main__': 16 | from torch.utils.data import DataLoader 17 | imagenet = ImageNet('C:\\Users\\sithu\\Documents\\Datasets\\imagenet-mini', split='val') 18 | dataloader = DataLoader(imagenet, batch_size=4) 19 | print(len(imagenet)) 20 | print(len(dataloader)) -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | """torchvision builtin transforms 2 | # shape transform 3 | CenterCrop(size) 4 | Resize(size) 5 | RandomCrop(size, padding=None, pad_if_needed=False, fill=0) 6 | RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.33)) 7 | RandomRotation(degrees) 8 | Pad(padding, fill=0) 9 | 10 | # spatial transform 11 | ColorJitter(brightness=0, contrast=0, saturation=0, hue=0) 12 | GaussianBlur(kernel_size, sigma=(0.1, 2.0)) 13 | RandomAffine(degrees, translate=None, scale=None, shear=None) 14 | RandomGrayscale(p=0.1) 15 | RandomHorizontalFlip(p=0.5) 16 | RandomVerticalFlip(p=0.5) 17 | RandomPerspective(distortion_scale=0.5, p=0.5) 18 | RandomInvert(p=0.5) 19 | RandomPosterize(bits, p=0.5) 20 | RandomSolarize(threshold, p=0.5) 21 | RandomAdjustSharpness(sharpness_factor, p=0.5) 22 | RandomAutocontrast(p=0.5) 23 | 24 | # auto-augment 25 | AutoAugment(policy=T.AutoAugmentPolicy.IMAGENET) 26 | 27 | # others 28 | RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0) 29 | RandomApply(transforms, p=0.5) # apply randomly a list of transformations with a given probability 30 | """ 31 | import random 32 | from PIL import Image, ImageFilter, ImageOps 33 | from torchvision import transforms as T 34 | 35 | 36 | 37 | class GaussianBlur: 38 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.) -> None: 39 | self.p = p 40 | self.radius = random.uniform(radius_min, radius_max) 41 | 42 | def __call__(self, img): 43 | if random.random() < self.p: 44 | return img.filter(ImageFilter.GaussianBlur(self.radius)) 45 | return img 46 | 47 | 48 | class Solarization: 49 | def __init__(self, p=0.2) -> None: 50 | self.p = p 51 | 52 | def __call__(self, img): 53 | if random.random() < self.p: 54 | return ImageOps.solarize(img) 55 | return img 56 | 57 | 58 | class DINOAug: 59 | def __init__(self, img_size, crop_scale, local_crops_number) -> None: 60 | flip_color = T.Compose([ 61 | T.RandomHorizontalFlip(p=0.5), 62 | T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), 63 | T.RandomGrayscale(p=0.2) 64 | ]) 65 | normalize = T.Compose([ 66 | T.ToTensor(), 67 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 68 | ]) 69 | 70 | self.global_transform1 = T.Compose([ 71 | T.RandomResizedCrop(img_size, (crop_scale, 1.0), interpolation=Image.BICUBIC), 72 | flip_color, 73 | GaussianBlur(1.0), 74 | normalize 75 | ]) 76 | 77 | self.global_transform2 = T.Compose([ 78 | T.RandomResizedCrop(img_size, (crop_scale, 1.0), interpolation=Image.BICUBIC), 79 | flip_color, 80 | GaussianBlur(0.1), 81 | Solarization(0.2), 82 | normalize 83 | ]) 84 | 85 | self.local_crops_number = local_crops_number 86 | self.local_transform = T.Compose([ 87 | T.RandomResizedCrop((img_size[0]//2, img_size[1]//2), (0.05, crop_scale), interpolation=Image.BICUBIC), 88 | flip_color, 89 | GaussianBlur(0.5), 90 | normalize 91 | ]) 92 | 93 | def __call__(self, img): 94 | crops = [] 95 | crops.append(self.global_transform1(img)) 96 | crops.append(self.global_transform2(img)) 97 | 98 | for _ in range(self.local_crops_number): 99 | crops.append(self.local_transform(img)) 100 | 101 | return crops -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vit import ViT 2 | from .xcit import XciT 3 | from .dino import DINO 4 | from .esvit import EsViT 5 | 6 | methods = { 7 | 'dino': DINO, 8 | 'esvit': EsViT 9 | } 10 | 11 | __all__ = { 12 | 'vit': ViT, 13 | 'xcit': XciT 14 | } 15 | 16 | def get_model(model: str, variant: str, img_size): 17 | assert model in __all__.keys() 18 | return __all__[model](variant, image_size=img_size) 19 | 20 | def get_method(method: str, model: str, variant: str, img_size, head_dim): 21 | assert method in __all__.keys() 22 | backbone = get_model(model, variant, img_size) 23 | return methods[method](backbone, head_dim) -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class LinearClassifier(nn.Module): 6 | def __init__(self, dim, num_classes=1000): 7 | super().__init__() 8 | self.linear = nn.Linear(dim, num_classes) 9 | 10 | def forward(self, x: Tensor) -> Tensor: 11 | return self.linear(x.flatten(1)) 12 | 13 | 14 | if __name__ == '__main__': 15 | model = LinearClassifier(384) 16 | x = torch.randn(1, 384) 17 | y = model(x) 18 | print(y.shape) -------------------------------------------------------------------------------- /models/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class DINOHead(nn.Module): 7 | def __init__(self, c1, c2): 8 | super().__init__() 9 | self.mlp = nn.Sequential(*[ 10 | nn.Linear(c1, 2048), 11 | nn.GELU(), 12 | nn.Linear(2048, 2048), 13 | nn.GELU(), 14 | nn.Linear(2048, 256) 15 | ]) 16 | 17 | self.last_layer = nn.utils.weight_norm(nn.Linear(256, c2, bias=False)) 18 | self.last_layer.weight_g.data.fill_(1) 19 | self.last_layer.weight_g.requires_grad = False 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | x = self.mlp(x) 23 | x = F.normalize(x, p=2, dim=-1) 24 | x = self.last_layer(x) 25 | return x 26 | 27 | 28 | class DINO(nn.Module): 29 | def __init__(self, backbone: nn.Module, head_dim: int = 65536): 30 | super().__init__() 31 | self.backbone = backbone 32 | self.head = DINOHead(self.backbone.embed_dim, head_dim) 33 | 34 | def forward(self, x) -> Tensor: 35 | if not isinstance(x, list): 36 | x = [x] 37 | 38 | idx_crops = torch.cumsum(torch.unique_consecutive(torch.tensor([inp.shape[-1] for inp in x]), return_counts=True)[1], dim=0) 39 | start_idx, output = 0, torch.empty(0).to(x[0].device) 40 | 41 | for end_idx in idx_crops: 42 | out = self.backbone(torch.cat(x[start_idx:end_idx])) 43 | output = torch.cat((output, out)) 44 | start_idx = end_idx 45 | 46 | return self.head(output) 47 | 48 | 49 | if __name__ == '__main__': 50 | from xcit import XciT 51 | backbone = XciT('') 52 | model = DINO(backbone) 53 | x = torch.randn(1, 3, 224, 224) 54 | y = model(x) 55 | print(y.shape) -------------------------------------------------------------------------------- /models/esvit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class DINOHead(nn.Module): 7 | def __init__(self, c1, c2): 8 | super().__init__() 9 | self.mlp = nn.Sequential(*[ 10 | nn.Linear(c1, 2048), 11 | nn.GELU(), 12 | nn.Linear(2048, 2048), 13 | nn.GELU(), 14 | nn.Linear(2048, 256) 15 | ]) 16 | 17 | self.last_layer = nn.utils.weight_norm(nn.Linear(256, c2, bias=False)) 18 | self.last_layer.weight_g.data.fill_(1) 19 | self.last_layer.weight_g.requires_grad = False 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | x = self.mlp(x) 23 | x = F.normalize(x, p=2, dim=-1) 24 | x = self.last_layer(x) 25 | return x 26 | 27 | 28 | class EsViT(nn.Module): 29 | def __init__(self, backbone: nn.Module, head_dim: int = 65536): 30 | super().__init__() 31 | self.backbone = backbone 32 | self.head = DINOHead(self.backbone.embed_dim, head_dim) 33 | self.head_dense = DINOHead(self.backbone.embed_dim, head_dim) 34 | 35 | def forward(self, x) -> Tensor: 36 | if not isinstance(x, list): 37 | x = [x] 38 | 39 | idx_crops = torch.cumsum(torch.unique_consecutive(torch.tensor([inp.shape[-1] for inp in x]), return_counts=True)[1], dim=0) 40 | start_idx = 0 41 | npatch = [] 42 | output_cls = torch.empty(0).to(x[0].device) 43 | output_feats = torch.empty(0).to(x[0].device) 44 | 45 | for end_idx in idx_crops: 46 | out_cls, out_feats = self.backbone(torch.cat(x[start_idx:end_idx]), return_dense=True) 47 | B, N, C = out_feats.shape 48 | 49 | npatch.append(N) 50 | 51 | output_cls = torch.cat((output_cls, out_cls)) 52 | output_feats = torch.cat((output_feats, out_feats.view(-1, C))) 53 | start_idx = end_idx 54 | 55 | return self.head(output_cls), self.head_dense(output_feats), output_feats, npatch 56 | 57 | 58 | if __name__ == '__main__': 59 | from xcit import XciT 60 | backbone = XciT() 61 | model = EsViT(backbone) 62 | x = torch.randn(1, 3, 224, 224) 63 | y_cls, y_dense, y_feats, patches = model(x) 64 | print(y_cls.shape) 65 | print(y_dense.shape) 66 | print(y_feats.shape) 67 | print(patches) -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from torch import nn, Tensor 5 | 6 | 7 | 8 | class DropPath(nn.Module): 9 | def __init__(self, drop_prob: float = 0.0): 10 | super().__init__() 11 | self.drop_prob = drop_prob 12 | 13 | def forward(self, x: Tensor) -> Tensor: 14 | if self.drop_prob == 0. or not self.training: 15 | return x 16 | 17 | keep_prob = 1 - self.drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 19 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 20 | random_tensor.floor_() 21 | output = x.div(keep_prob) * random_tensor 22 | return output 23 | 24 | 25 | class MLP(nn.Module): 26 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 27 | super().__init__() 28 | out_dim = out_dim or dim 29 | self.fc1 = nn.Linear(dim, hidden_dim) 30 | self.act = nn.GELU() 31 | self.fc2 = nn.Linear(hidden_dim, out_dim) 32 | 33 | def forward(self, x: Tensor) -> Tensor: 34 | return self.fc2(self.act(self.fc1(x))) 35 | 36 | 37 | class PatchEmbedding(nn.Module): 38 | """Image to Patch Embedding 39 | """ 40 | def __init__(self, img_size=224, patch_size=16, embed_dim=768): 41 | super().__init__() 42 | assert img_size % patch_size == 0, 'Image size must be divisible by patch size' 43 | 44 | img_size = (img_size, img_size) if isinstance(img_size, int) else img_size 45 | 46 | self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size) 47 | self.num_patches = self.grid_size[0] * self.grid_size[1] 48 | self.proj = nn.Conv2d(3, embed_dim, patch_size, patch_size) 49 | 50 | def forward(self, x: torch.Tensor): 51 | x = self.proj(x) # b x hidden_dim x 14 x 14 52 | x = x.flatten(2).swapaxes(1, 2) # b x (14*14) x hidden_dim 53 | 54 | return x 55 | 56 | 57 | class Attention(nn.Module): 58 | def __init__(self, dim, heads=12): 59 | super().__init__() 60 | self.num_heads = heads 61 | self.scale = (dim // heads) ** -0.5 62 | 63 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 64 | self.proj = nn.Linear(dim, dim) 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | B, N, C = x.shape 68 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 69 | q, k, v = qkv[0], qkv[1], qkv[2] 70 | 71 | attn = (q @ k.transpose(-2, -1)) * self.scale 72 | attn = attn.softmax(dim=-1) 73 | 74 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 75 | x = self.proj(x) 76 | 77 | return x, attn 78 | 79 | 80 | class TransformerEncoder(nn.Module): 81 | def __init__(self, dim, heads, drop_path=0.): 82 | super().__init__() 83 | self.norm1 = nn.LayerNorm(dim) 84 | self.attn = Attention(dim, heads) 85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 86 | self.norm2 = nn.LayerNorm(dim) 87 | self.mlp = MLP(dim, int(dim * 4)) 88 | 89 | def forward(self, x: torch.Tensor, return_attention=False): 90 | y, attn = self.attn(self.norm1(x)) 91 | if return_attention: 92 | return attn 93 | x += self.drop_path(y) 94 | x += self.drop_path(self.mlp(self.norm2(x))) 95 | 96 | return x 97 | 98 | 99 | vit_settings = { 100 | 'T': [8, 12, 192, 3, 0.1], #[patch_size, number_of_layers, embed_dim, heads] 101 | 'S': [8, 12, 384, 6, 0.1], 102 | 'B': [8, 12, 768, 12, 0.1] 103 | } 104 | 105 | 106 | class ViT(nn.Module): 107 | def __init__(self, model_name: str = 'S', pretrained: str = None, image_size: int = 224) -> None: 108 | super().__init__() 109 | assert model_name in vit_settings.keys(), f"DeiT model name should be in {list(vit_settings.keys())}" 110 | patch_size, layers, embed_dim, heads, drop_path_rate = vit_settings[model_name] 111 | 112 | self.patch_size = patch_size 113 | self.patch_embed = PatchEmbedding(image_size, patch_size, embed_dim) 114 | self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) 115 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 116 | 117 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers)] 118 | 119 | self.blocks = nn.ModuleList([ 120 | TransformerEncoder(embed_dim, heads, dpr[i]) 121 | for i in range(layers)]) 122 | 123 | self.norm = nn.LayerNorm(embed_dim) 124 | 125 | self.embed_dim = embed_dim 126 | 127 | self._init_weights(pretrained) 128 | 129 | 130 | def _init_weights(self, pretrained: str = None) -> None: 131 | if pretrained: 132 | self.load_state_dict(torch.load(pretrained, map_location='cpu')) 133 | else: 134 | for n, m in self.named_modules(): 135 | if isinstance(m, nn.Linear): 136 | if n.startswith('head'): 137 | nn.init.zeros_(m.weight) 138 | nn.init.zeros_(m.bias) 139 | else: 140 | nn.init.xavier_uniform_(m.weight) 141 | if m.bias is not None: 142 | nn.init.zeros_(m.bias) 143 | elif isinstance(m, nn.LayerNorm): 144 | nn.init.ones_(m.weight) 145 | nn.init.zeros_(m.bias) 146 | elif isinstance(m, nn.Conv2d): 147 | nn.init.xavier_uniform_(m.weight) 148 | if m.bias is not None: 149 | nn.init.zeros_(m.bias) 150 | 151 | def interpolate_pos_encoding(self, x: Tensor, W: int, H: int) -> Tensor: 152 | num_patches = x.shape[1] - 1 153 | N = self.pos_embed.shape[1] - 1 154 | 155 | if num_patches == N and H == W: 156 | return self.pos_embed 157 | 158 | class_pos_embed = self.pos_embed[:, 0] 159 | patch_pos_embed = self.pos_embed[:, 1:] 160 | 161 | dim = x.shape[-1] 162 | w0 = W // self.patch_size 163 | h0 = H // self.patch_size 164 | 165 | w0, h0 = w0 + 0.1, h0 + 0.1 166 | 167 | patch_pos_embed = F.interpolate( 168 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 169 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 170 | mode='bicubic' 171 | ) 172 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 173 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 174 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 175 | 176 | 177 | def forward(self, x: Tensor, return_attention=False, return_dense=False) -> Tensor: 178 | B, C, W, H = x.shape 179 | x = self.patch_embed(x) 180 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 181 | x = torch.cat((cls_token, x), dim=1) 182 | x += self.interpolate_pos_encoding(x, W, H) 183 | 184 | for i, blk in enumerate(self.blocks): 185 | if i + 1 == len(self.blocks): 186 | if return_attention: 187 | return blk(x, return_attention=return_attention)[:, :, 0, :] 188 | x = blk(x) 189 | 190 | x = self.norm(x) 191 | if return_dense: 192 | return x[:, 0], x[:, 1:] 193 | return x[:, 0] 194 | 195 | 196 | if __name__ == '__main__': 197 | model = ViT('S') 198 | # model.load_state_dict(torch.load('checkpoints/vit/dino_vitbase8_pretrain.pth', map_location='cpu')) 199 | x = torch.zeros(1, 3, 224, 224) 200 | y, y_dense = model(x, return_dense=True) 201 | print(y.shape, y_dense.shape) -------------------------------------------------------------------------------- /models/xcit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn.functional as F 4 | from torch import nn, Tensor 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, dim, hidden_dim, out_dim=None) -> None: 8 | super().__init__() 9 | out_dim = out_dim or dim 10 | self.fc1 = nn.Linear(dim, hidden_dim) 11 | self.act = nn.GELU() 12 | self.fc2 = nn.Linear(hidden_dim, out_dim) 13 | 14 | def forward(self, x: Tensor) -> Tensor: 15 | return self.fc2(self.act(self.fc1(x))) 16 | 17 | 18 | class PositionalEncodingFourier(nn.Module): 19 | def __init__(self, dim: int = 768, temp: int = 10000): 20 | super().__init__() 21 | self.hidden_dim = 32 22 | self.token_projection = nn.Conv2d(self.hidden_dim * 2, dim, 1) 23 | self.scale = 2 * math.pi 24 | self.temperature = temp 25 | self.dim = dim 26 | 27 | def forward(self, B, H, W): 28 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) 29 | not_mask = ~mask 30 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 31 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 32 | eps = 1e-6 33 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 34 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 35 | 36 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) 37 | dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / self.hidden_dim) 38 | 39 | pos_x = x_embed[:, :, :, None] / dim_t 40 | pos_y = y_embed[:, :, :, None] / dim_t 41 | 42 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), 43 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), 45 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | pos = self.token_projection(pos) 48 | return pos 49 | 50 | 51 | class Conv3x3(nn.Sequential): 52 | def __init__(self, c1, c2, s=1): 53 | super().__init__( 54 | nn.Conv2d(c1, c2, 3, s, 1, bias=False), 55 | nn.BatchNorm2d(c2) 56 | ) 57 | 58 | 59 | class ConvPatchEmbed(nn.Module): 60 | """Image to Patch Embedding using multiple convolutional layers 61 | """ 62 | def __init__(self, img_size=224, patch_size=8, embed_dim=768): 63 | super().__init__() 64 | img_size = (img_size, img_size) if isinstance(img_size, int) else img_size 65 | self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) 66 | self.proj = nn.Sequential( 67 | Conv3x3(3, embed_dim // 4, 2), 68 | nn.GELU(), 69 | Conv3x3(embed_dim // 4, embed_dim // 2, 2), 70 | nn.GELU(), 71 | Conv3x3(embed_dim // 2, embed_dim, 2), 72 | ) 73 | 74 | def forward(self, x: Tensor): 75 | x = self.proj(x) 76 | Hp, Wp = x.shape[2], x.shape[3] 77 | x = x.flatten(2).transpose(1, 2) 78 | return x, (Hp, Wp) 79 | 80 | 81 | class LPI(nn.Module): 82 | """ 83 | Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows 84 | to augment the implicit communcation performed by the block diagonal scatter attention. 85 | Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d 86 | """ 87 | def __init__(self, dim, out_dim=None): 88 | super().__init__() 89 | out_dim = out_dim or dim 90 | 91 | self.conv1 = nn.Conv2d(dim, out_dim, 3, 1, 1, groups=out_dim) 92 | self.act = nn.GELU() 93 | self.bn = nn.BatchNorm2d(dim) 94 | self.conv2 = nn.Conv2d(dim, out_dim, 3, 1, 1, groups=out_dim) 95 | 96 | def forward(self, x, H, W): 97 | B, N, C = x.shape 98 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 99 | x = self.conv2(self.bn(self.act(self.conv1(x)))) 100 | x = x.reshape(B, C, N).permute(0, 2, 1) 101 | return x 102 | 103 | 104 | class ClassAttention(nn.Module): 105 | """ClassAttention as in CaiT 106 | """ 107 | def __init__(self, dim: int, heads: int): 108 | super().__init__() 109 | self.num_heads = heads 110 | self.scale = (dim // heads) ** -0.5 111 | 112 | self.qkv = nn.Linear(dim, dim * 3) 113 | self.proj = nn.Linear(dim, dim) 114 | 115 | def forward(self, x: Tensor) -> Tensor: 116 | B, N, C = x.shape 117 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 118 | q, k, v = qkv[0], qkv[1], qkv[2] 119 | 120 | qc = q[:, :, 0:1] # CLS token 121 | 122 | attn_cls = (qc * k).sum(dim=-1) * self.scale 123 | attn_cls = attn_cls.softmax(dim=-1) 124 | 125 | cls_token = (attn_cls.unsqueeze(2) @ v).transpose(1, 2).reshape(B, 1, C) 126 | cls_token = self.proj(cls_token) 127 | 128 | x = torch.cat([cls_token, x[:, 1:]], dim=1) 129 | return x, attn_cls 130 | 131 | 132 | class XCA(nn.Module): 133 | """ Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted 134 | sum. The weights are obtained from the (softmax normalized) Cross-covariance 135 | matrix (Q^T K \\in d_h \\times d_h) 136 | """ 137 | def __init__(self, dim: int, heads: int): 138 | super().__init__() 139 | self.num_heads = heads 140 | self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) 141 | 142 | self.qkv = nn.Linear(dim, dim * 3) 143 | self.proj = nn.Linear(dim, dim) 144 | 145 | def forward(self, x: Tensor) -> Tensor: 146 | B, N, C = x.shape 147 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 148 | q, k, v = qkv[0].transpose(-2, -1), qkv[1].transpose(-2, -1), qkv[2].transpose(-2, -1) 149 | q = F.normalize(q, dim=-1) 150 | k = F.normalize(k, dim=-1) 151 | attn = (q @ k.transpose(-2, -1)) * self.temperature 152 | attn = attn.softmax(dim=-1) 153 | 154 | x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) 155 | x = self.proj(x) 156 | 157 | return x 158 | 159 | 160 | class ClassAttentionBlock(nn.Module): 161 | def __init__(self, dim, heads, eta=1e-5): 162 | super().__init__() 163 | self.norm1 = nn.LayerNorm(dim) 164 | self.attn = ClassAttention(dim, heads) 165 | self.norm2 = nn.LayerNorm(dim) 166 | self.mlp = MLP(dim, int(dim * 4)) 167 | 168 | self.gamma1 = nn.Parameter(eta * torch.ones(dim)) 169 | self.gamma2 = nn.Parameter(eta * torch.ones(dim)) 170 | 171 | 172 | def forward(self, x: Tensor, return_attention=False) -> Tensor: 173 | y, attn = self.attn(self.norm1(x)) 174 | if return_attention: return attn 175 | x = x + self.gamma1 * y 176 | x = self.norm2(x) 177 | 178 | x_res = x 179 | cls_token = x[:, 0:1] 180 | cls_token = self.gamma2 * self.mlp(cls_token) 181 | 182 | x = torch.cat([cls_token, x[:, 1:]], dim=1) 183 | x = x_res + x 184 | 185 | return x 186 | 187 | 188 | class XCABlock(nn.Module): 189 | def __init__(self, dim, heads, eta=1e-5): 190 | super().__init__() 191 | self.norm1 = nn.LayerNorm(dim) 192 | self.attn = XCA(dim, heads) 193 | self.norm2 = nn.LayerNorm(dim) 194 | self.mlp = MLP(dim, int(dim * 4)) 195 | self.norm3 = nn.LayerNorm(dim) 196 | self.local_mp = LPI(dim) 197 | 198 | self.gamma1 = nn.Parameter(eta * torch.ones(dim)) 199 | self.gamma2 = nn.Parameter(eta * torch.ones(dim)) 200 | self.gamma3 = nn.Parameter(eta * torch.ones(dim)) 201 | 202 | def forward(self, x: Tensor, H, W) -> Tensor: 203 | x = x + self.gamma1 * self.attn(self.norm1(x)) 204 | x = x + self.gamma3 * self.local_mp(self.norm3(x), H, W) 205 | x = x + self.gamma2 * self.mlp(self.norm2(x)) 206 | return x 207 | 208 | 209 | xcit_settings = { 210 | 'T12': [8, 12, 192, 4], 211 | 'T24': [8, 24, 192, 4], #[patch_size, layers, embed dim, heads] 212 | 'S12': [8, 12, 384, 8], 213 | 'S24': [8, 24, 384, 8], 214 | 'M24': [8, 24, 512, 8], 215 | 'L24': [8, 24, 768, 16] 216 | } 217 | 218 | 219 | class XciT(nn.Module): 220 | def __init__(self, model_name: str = 'S24', image_size: int = 224) -> None: 221 | super().__init__() 222 | assert model_name in xcit_settings.keys(), f"XciT model name should be in {list(xcit_settings.keys())}" 223 | patch_size, layers, embed_dim, heads = xcit_settings[model_name] 224 | 225 | self.patch_embed = ConvPatchEmbed(image_size, patch_size, embed_dim) 226 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 227 | 228 | self.pos_embeder = PositionalEncodingFourier(dim=embed_dim) 229 | 230 | self.blocks = nn.ModuleList([ 231 | XCABlock(embed_dim, heads) 232 | for _ in range(layers)]) 233 | 234 | self.cls_attn_blocks = nn.ModuleList([ 235 | ClassAttentionBlock(embed_dim, heads) 236 | for _ in range(2)]) 237 | 238 | self.norm = nn.LayerNorm(embed_dim) 239 | 240 | self.embed_dim = embed_dim 241 | self.patch_size = patch_size 242 | 243 | 244 | def forward(self, x, return_attention=False, return_dense=False): 245 | B, C, H, W = x.shape 246 | x, (Hp, Wp) = self.patch_embed(x) 247 | pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) 248 | x = x + pos_encoding 249 | 250 | for blk in self.blocks: 251 | x = blk(x, Hp, Wp) 252 | 253 | cls_tokens = self.cls_token.expand(B, -1, -1) 254 | x = torch.cat((cls_tokens, x), dim=1) 255 | 256 | for i, blk in enumerate(self.cls_attn_blocks): 257 | if i + 1 == len(self.cls_attn_blocks): 258 | if return_attention: 259 | return blk(x, return_attention=return_attention) 260 | x = blk(x) 261 | 262 | x = self.norm(x) 263 | if return_dense: 264 | return x[:, 0], x[:, 1:] 265 | return x[:, 0] 266 | 267 | 268 | if __name__ == '__main__': 269 | model = XciT('S12') 270 | model.load_state_dict(torch.load('checkpoints/xcit/dino_xcit_small_12_p8_pretrain.pth', map_location='cpu')) 271 | x = torch.zeros(1, 3, 224, 224) 272 | y, y_dense = model(x, return_dense=True) 273 | print(y.shape, y_dense.shape) 274 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import yaml 4 | import time 5 | from tqdm import tqdm 6 | from torch.utils.data import DistributedSampler 7 | from torch.nn.parallel import DistributedDataParallel as DDP 8 | from pathlib import Path 9 | from torch.optim import SGD 10 | from torch.utils.data import DataLoader 11 | from torch.cuda.amp import GradScaler, autocast 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | import sys 15 | sys.path.insert(0, '.') 16 | from datasets.imagenet import ImageNet 17 | from datasets.transforms import DINOAug 18 | from models import get_method 19 | from utils.utils import fix_seeds, time_synchronized, setup_cudnn, setup_ddp 20 | from utils import get_scheduler, get_loss 21 | 22 | 23 | def main(cfg): 24 | start = time_synchronized() 25 | save_dir = Path(cfg['SAVE_DIR']) 26 | if not save_dir.exists(): save_dir.mkdir() 27 | 28 | device = torch.device(cfg['DEVICE']) 29 | ddp_enable = cfg['TRAIN']['DDP'] 30 | epochs = cfg['TRAIN']['EPOCHS'] 31 | 32 | gpu = setup_ddp() 33 | 34 | # setup augmentation, dataset and dataloader 35 | transform = DINOAug(cfg['TRAIN']['IMAGE_SIZE'], cfg['TRAIN']['DINO']['CROP_SCALE'], cfg['TRAIN']['DINO']['LOCAL_CROPS']) 36 | dataset = ImageNet(cfg['DATASET']['ROOT'], split='train', transform=transform) 37 | sampler = DistributedSampler(dataset, shuffle=True) 38 | dataloader = DataLoader(dataset, batch_size=cfg['TRAIN']['BATCH_SIZE'], num_workers=cfg['TRAIN']['WORKERS'], drop_last=True, pin_memory=True, sampler=sampler) 39 | 40 | # student and teacher networks 41 | student = get_method(cfg['METHOD'], cfg['MODEL']['NAME'], cfg['MODEL']['VARIANT'], cfg['TRAIN']['IMAGE_SIZE'][0], cfg['TRAIN']['DINO']['HEAD_DIM']) 42 | teacher = get_method(cfg['METHOD'], cfg['MODEL']['NAME'], cfg['MODEL']['VARIANT'], cfg['TRAIN']['IMAGE_SIZE'][0], cfg['TRAIN']['DINO']['HEAD_DIM']) 43 | student, teacher = student.to(device), teacher.to(device) 44 | 45 | if ddp_enable: 46 | student = DDP(student, device_ids=[gpu]) 47 | teacher.load_state_dict(student.module.state_dict()) 48 | else: 49 | teacher.load_state_dict(student.state_dict()) 50 | 51 | for p in teacher.parameters(): 52 | p.requires_grad = False 53 | 54 | # loss function, optimizer, scheduler, AMP scaler, tensorboard writer 55 | loss_fn = get_loss(cfg, epochs).to(device) 56 | optimizer = SGD(student.parameters(), lr=cfg['TRAIN']['LR']) 57 | scheduler = get_scheduler(cfg, optimizer) 58 | scaler = GradScaler(enabled=cfg['TRAIN']['AMP']) 59 | writer = SummaryWriter(save_dir / 'logs') 60 | 61 | iters_per_epoch = int(len(dataset)) / cfg['TRAIN']['BATCH_SIZE'] 62 | 63 | for epoch in range(1, epochs+1): 64 | student.train() 65 | 66 | if ddp_enable: 67 | dataloader.sampler.set_epoch(epoch) 68 | 69 | train_loss = 0.0 70 | 71 | pbar = tqdm(enumerate(dataloader), total=iters_per_epoch, desc=f"Epoch: [{epoch}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {cfg['TRAIN']['LR']:.8f} Loss: {0:.8f}") 72 | 73 | for iter, (images, _) in pbar: 74 | images = [image.to(device) for image in images] 75 | 76 | with autocast(enabled=cfg['TRAIN']['AMP']): 77 | teacher_pred = teacher(images[:2]) # only 2 global views pass through the teacher 78 | student_pred = student(images) 79 | loss = loss_fn(student_pred, teacher_pred, epoch) 80 | 81 | # Backpropagation 82 | optimizer.zero_grad() 83 | scaler.scale(loss).backward() 84 | scaler.step(optimizer) 85 | scaler.update() 86 | 87 | # EMA update for the teacher 88 | with torch.no_grad(): 89 | for p, q in zip(student.module.parameters(), teacher.module.parameters()): 90 | q.data.mul_(cfg['TRAIN']['DINO']['TEACHER_MOMENTUM']).add_((1 - cfg['TRAIN']['DINO']['TEACHER_MOMENTUM']) * p.detach().data) 91 | 92 | lr = scheduler.get_last_lr()[0] 93 | train_loss += loss.item() * images[0].shape[0] 94 | 95 | pbar.set_description(f"Epoch: [{epoch}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {loss.item():.8f}") 96 | 97 | train_loss /= len(dataset) 98 | writer.add_scalar('train/loss', train_loss, epoch) 99 | writer.add_scalar('train/lr', lr, epoch) 100 | writer.flush() 101 | 102 | scheduler.step() 103 | torch.cuda.synchronize() 104 | torch.cuda.empty_cache() 105 | 106 | torch.save({ 107 | "student": student.module.state_dict() if ddp_enable else student.state_dict(), 108 | "teacher": teacher.state_dict() 109 | }, f"{cfg['METHOD']}_{cfg['MODEL']['NAME']}_{cfg['MODEL']['VARIANT']}_checkpoint.pth") 110 | torch.save( 111 | teacher.state_dict(), 112 | f"{cfg['METHOD']}_{cfg['MODEL']['NAME']}_{cfg['MODEL']['VARIANT']}.pth" 113 | ) 114 | 115 | writer.close() 116 | pbar.close() 117 | 118 | end = time.gmtime(time_synchronized() - start) 119 | total_time = time.strftime("%H:%M:%S", end) 120 | 121 | print(f"Total Training Time: {total_time}") 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--cfg', type=str, required=True, help='Experiment configuration file name') 127 | args = parser.parse_args() 128 | 129 | with open(args.cfg) as f: 130 | cfg = yaml.load(f, Loader=yaml.FullLoader) 131 | 132 | fix_seeds(cfg['TRAIN']['SEED']) 133 | setup_cudnn() 134 | main(cfg) -------------------------------------------------------------------------------- /tools/val_knn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import yaml 4 | from torch.nn import functional as F 5 | from pathlib import Path 6 | from torch.utils.data import DataLoader, DistributedSampler 7 | from torchvision import transforms as T 8 | from torch import distributed as dist 9 | 10 | import sys 11 | sys.path.insert(0, '.') 12 | from datasets.imagenet import ImageNet 13 | from models import get_model 14 | from utils.utils import fix_seeds, setup_cudnn, setup_ddp 15 | 16 | 17 | def main(cfg): 18 | save_dir = Path(cfg['SAVE_DIR']) 19 | if not save_dir.exists(): save_dir.mkdir() 20 | 21 | device = torch.device(cfg['DEVICE']) 22 | _ = setup_ddp() 23 | 24 | # setup augmentation, dataset and dataloader 25 | transform = T.Compose([ 26 | T.Resize((256, 256)), 27 | T.CenterCrop(cfg['EVAL']['IMAGE_SIZE']), 28 | T.ToTensor(), 29 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 30 | ]) 31 | 32 | trainset = ReturnIndexDataset(cfg['DATASET']['ROOT'], split='train', transform=transform) 33 | valset = ReturnIndexDataset(cfg['DATASET']['ROOT'], split='val', transform=transform) 34 | sampler = DistributedSampler(trainset, shuffle=True) 35 | trainloader = DataLoader(trainset, batch_size=cfg['TRAIN']['BATCH_SIZE'], num_workers=cfg['TRAIN']['WORKERS'], drop_last=True, pin_memory=True, sampler=sampler) 36 | valloader = DataLoader(valset, batch_size=cfg['EVAL']['BATCH_SIZE'], num_workers=cfg['EVAL']['WORKERS'], pin_memory=True) 37 | 38 | # student and teacher networks 39 | model = get_model(cfg['MODEL']['NAME'], cfg['MODEL']['VARIANT'], cfg['EVAL']['IMAGE_SIZE'][0]) 40 | model.load_state_dict(torch.load(cfg['MODEL_PATH'], map_location='cpu')) 41 | model = model.to(device) 42 | model.eval() 43 | 44 | # extract features 45 | train_features = extract_features(model, trainloader, device) 46 | val_features = extract_features(model, valloader, device) 47 | 48 | train_features = F.normalize(train_features, p=2, dim=1) 49 | val_features = F.normalize(val_features, p=2, dim=1) 50 | 51 | train_labels = torch.tensor([s[-1] for s in trainset.samples]).long() 52 | val_labels = torch.tensor([s[-1] for s in valset.samples]).long() 53 | 54 | for k in cfg['EVAL']['KNN']['NB_KNN']: 55 | top1, top5 = knn_classifier(train_features, train_labels, val_features, val_labels, k, cfg['EVAL']['KNN']['TEMP'], cfg['EVAL']['NUM_CLASSES'], device) 56 | print(f"{k}-NN classifier results >> Top1: {top1}, Top5: {top5}") 57 | 58 | 59 | class ReturnIndexDataset(ImageNet): 60 | def __getitem__(self, idx): 61 | img, _ = super().__getitem__(idx) 62 | return img, idx 63 | 64 | 65 | @torch.no_grad() 66 | def extract_features(model, dataloader, device): 67 | features = None 68 | for img, index in dataloader: 69 | img = img.to(device) 70 | index = index.to(device) 71 | 72 | feats = model(img).clone() 73 | 74 | if dist.get_rank() == 0 and features is None: 75 | features = torch.zeros(len(dataloader.dataset), feats.shape[-1]) 76 | features = features.to(device) 77 | 78 | # get indexes from all processes 79 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=device) 80 | y_l = list(y_all.unbind(0)) 81 | y_all_reduce = dist.all_gather(y_l, index, async_op=True) 82 | y_all_reduce.wait() 83 | 84 | # share features between processes 85 | feats_all = torch.empty(dist.get_world_size(), feats.size(0), feats.size(1), dtype=feats.dtype, device=device) 86 | output_l = list(feats_all.unbind(0)) 87 | output_all_reduce = dist.all_gather(output_l, feats, async_op=True) 88 | output_all_reduce.wait() 89 | 90 | if dist.get_rank() == 0: 91 | features.index_copy_(0, torch.cat(y_l), torch.cat(output_l)) 92 | 93 | return features 94 | 95 | 96 | @torch.no_grad() 97 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, temp, num_classes, device): 98 | top1, top5, total = 0.0, 0.0, 0 99 | 100 | train_features = train_features.t() 101 | num_test_images, num_chunks = test_labels.shape[0], 100 102 | imgs_per_chunk = num_test_images // num_chunks 103 | 104 | retrieval_one_hot = torch.zeros(k, num_classes, device=device) 105 | 106 | for idx in range(0, num_test_images, imgs_per_chunk): 107 | features = test_features[idx:min((idx+imgs_per_chunk), num_test_images), :] 108 | targets = test_labels[idx:min((idx+imgs_per_chunk), num_test_images)] 109 | 110 | # calculate dot product and compute top-k neighbors 111 | similarity = torch.mm(features, train_features) 112 | distances, indices = similarity.topk(k) 113 | candidates = train_labels.view(1, -1).expand(targets.shape[0], -1) 114 | retrieved_neighbors = torch.gather(candidates, 1, indices) 115 | 116 | retrieval_one_hot.resize_(targets.shape[0] * k, num_classes).zero_() 117 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 118 | distances_transform = distances.clone().div_(temp).exp_() 119 | 120 | probs = torch.sum(torch.mul( 121 | retrieval_one_hot.view(targets.shape[0], -1, num_classes), 122 | distances_transform.view(targets.shape[0], -1, 1) 123 | ), dim=1) 124 | 125 | _, preds = probs.sort(1, descending=True) 126 | 127 | # find the preds that match the target 128 | correct = preds.eq(targets.data.view(-1, 1)) 129 | top1 += correct.narrow(1, 0, 1).sum().item() 130 | top5 += correct.narrow(1, 0, min(5, k)).sum().item() 131 | total += targets.size(0) 132 | 133 | top1 *= 100.0 / total 134 | top5 *= 100.0 / total 135 | 136 | return top1, top5 137 | 138 | 139 | 140 | if __name__ == '__main__': 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--cfg', type=str, required=True, help='Experiment configuration file name') 143 | args = parser.parse_args() 144 | 145 | with open(args.cfg) as f: 146 | cfg = yaml.load(f, Loader=yaml.FullLoader) 147 | 148 | fix_seeds(cfg['TRAIN']['SEED']) 149 | setup_cudnn() 150 | main(cfg) -------------------------------------------------------------------------------- /tools/val_linear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.nn.modules import linear 4 | import yaml 5 | import time 6 | from tqdm import tqdm 7 | from torch.utils.data import DistributedSampler 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from pathlib import Path 10 | from torch.optim import SGD 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms as T 13 | from torch.cuda.amp import GradScaler, autocast 14 | from torch.nn import CrossEntropyLoss 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | import sys 18 | sys.path.insert(0, '.') 19 | from datasets.imagenet import ImageNet 20 | from datasets.transforms import DINOAug 21 | from models import get_model, get_backbone 22 | from models.classifier import LinearClassifier 23 | from utils.utils import fix_seeds, time_synchronized, setup_cudnn, setup_ddp 24 | from utils import get_scheduler 25 | from utils.metrics import accuracy 26 | 27 | 28 | def main(cfg): 29 | start = time_synchronized() 30 | save_dir = Path(cfg['SAVE_DIR']) 31 | if not save_dir.exists(): save_dir.mkdir() 32 | 33 | device = torch.device(cfg['DEVICE']) 34 | ddp_enable = cfg['TRAIN']['DDP']['ENABLE'] 35 | epochs = cfg['TRAIN']['EPOCHS'] 36 | gpu = setup_ddp() 37 | 38 | # setup augmentation, dataset and dataloader 39 | train_transform = T.Compose([ 40 | T.RandomResizedCrop(cfg['TRAIN']['IMAGE_SIZE']), 41 | T.RandomHorizontalFlip(), 42 | T.ToTensor(), 43 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 44 | ]) 45 | val_transform = T.Compose([ 46 | T.Resize((256, 256)), 47 | T.CenterCrop(cfg['EVAL']['IMAGE_SIZE']), 48 | T.ToTensor(), 49 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 50 | ]) 51 | trainset = ImageNet(cfg['DATASET']['ROOT'], split='train', transform=train_transform) 52 | valset = ImageNet(cfg['DATASET']['ROOT'], split='val', transform=val_transform) 53 | sampler = DistributedSampler(trainset, shuffle=True) 54 | trainloader = DataLoader(trainset, batch_size=cfg['TRAIN']['BATCH_SIZE'], num_workers=cfg['TRAIN']['WORKERS'], drop_last=True, pin_memory=True, sampler=sampler) 55 | valloader = DataLoader(valset, batch_size=cfg['EVAL']['BATCH_SIZE'], num_workers=cfg['EVAL']['WORKERS'], pin_memory=True) 56 | 57 | # load model and classifier 58 | model = get_model(cfg['MODEL']['NAME'], cfg['MODEL']['VARIANT'], cfg['TRAIN']['IMAGE_SIZE'][0]) 59 | model.load_state_dict(torch.load(cfg['MODEL_PATH'], map_location='cpu')) 60 | model = model.to(device) 61 | model.eval() 62 | 63 | linear_classifier = LinearClassifier(model.embed_dim, cfg['EVAL']['NUM_CLASSES']) 64 | linear_classifier = linear_classifier.to(device) 65 | 66 | if ddp_enable: 67 | linear_classifier = DDP(linear_classifier, device_ids=[gpu]) 68 | 69 | # loss function, optimizer, scheduler, AMP scaler, tensorboard writer 70 | loss_fn = CrossEntropyLoss() 71 | optimizer = SGD(linear_classifier.parameters(), lr=cfg['TRAIN']['LR'], momentum=0.9, weight_decay=0) 72 | scheduler = get_scheduler(cfg, optimizer) 73 | scaler = GradScaler(enabled=cfg['TRAIN']['AMP']) 74 | writer = SummaryWriter(save_dir / 'logs') 75 | 76 | iters_per_epoch = int(len(trainset)) / cfg['TRAIN']['BATCH_SIZE'] 77 | 78 | for epoch in range(1, epochs+1): 79 | linear_classifier.train() 80 | 81 | if ddp_enable: 82 | trainloader.sampler.set_epoch(epoch) 83 | 84 | train_loss = 0.0 85 | 86 | pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f"Epoch: [{epoch}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {cfg['TRAIN']['LR']:.8f} Loss: {0:.8f}") 87 | 88 | for iter, (img, target) in pbar: 89 | img = img.to(device) 90 | target = target.to(device) 91 | 92 | with torch.no_grad(): 93 | pred = model(img) 94 | 95 | with autocast(enabled=cfg['TRAIN']['AMP']): 96 | pred = linear_classifier(pred) 97 | loss = loss_fn(pred, target) 98 | 99 | # Backpropagation 100 | optimizer.zero_grad() 101 | scaler.scale(loss).backward() 102 | scaler.step(optimizer) 103 | scaler.update() 104 | 105 | lr = scheduler.get_last_lr()[0] 106 | train_loss += loss.item() * img.shape[0] 107 | 108 | pbar.set_description(f"Epoch: [{epoch}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {loss.item():.8f}") 109 | 110 | train_loss /= len(trainset) 111 | writer.add_scalar('train/loss', train_loss, epoch) 112 | writer.add_scalar('train/lr', lr, epoch) 113 | writer.flush() 114 | 115 | scheduler.step() 116 | torch.cuda.synchronize() 117 | torch.cuda.empty_cache() 118 | 119 | if epoch > cfg['TRAIN']['EVAL_INTERVAL'] and epoch % cfg['TRAIN']['EVAL_INTERVAL'] == 0: 120 | linear_classifier.eval() 121 | val_loss = 0.0 122 | for img, target in valloader: 123 | img = img.to(device) 124 | target = target.to(device) 125 | 126 | with torch.no_grad(): 127 | pred = model(img) 128 | 129 | pred = linear_classifier(pred) 130 | loss = loss_fn(pred, target) 131 | 132 | acc1, acc5 = accuracy(pred, target, topk=(1, 5)) 133 | 134 | val_loss += loss.item() * img.shape[0] 135 | 136 | val_loss /= len(valset) 137 | writer.add_scalar('val/loss', val_loss, epoch) 138 | writer.add_scalar('val/acc1', acc1, epoch) 139 | writer.add_scalar('val/acc5', acc5, epoch) 140 | 141 | 142 | writer.close() 143 | pbar.close() 144 | 145 | end = time.gmtime(time_synchronized() - start) 146 | total_time = time.strftime("%H:%M:%S", end) 147 | 148 | print(f"Total Training Time: {total_time}") 149 | 150 | 151 | if __name__ == '__main__': 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('--cfg', type=str, required=True, help='Experiment configuration file name') 154 | args = parser.parse_args() 155 | 156 | with open(args.cfg) as f: 157 | cfg = yaml.load(f, Loader=yaml.FullLoader) 158 | 159 | fix_seeds(cfg['TRAIN']['SEED']) 160 | setup_cudnn() 161 | main(cfg) -------------------------------------------------------------------------------- /tools/visualize_attention.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import yaml 4 | import requests 5 | import matplotlib.pyplot as plt 6 | from pathlib import Path 7 | from io import BytesIO 8 | from PIL import Image 9 | from torchvision import transforms as T 10 | from torch.nn import functional as F 11 | from torchvision.utils import save_image, make_grid 12 | 13 | import sys 14 | sys.path.insert(0, '.') 15 | from models import get_model 16 | from utils.utils import fix_seeds, setup_cudnn 17 | 18 | 19 | def main(cfg): 20 | save_dir = Path(cfg['SAVE_DIR']) 21 | if not save_dir.exists(): save_dir.mkdir() 22 | 23 | device = torch.device(cfg['DEVICE']) 24 | 25 | # load model and weights 26 | model = get_model(cfg['MODEL']['NAME'], cfg['MODEL']['VARIANT'], cfg['TRAIN']['IMAGE_SIZE'][0]) 27 | model.load_state_dict(torch.load(cfg['MODEL_PATH'], map_location='cpu')) 28 | model = model.to(device) 29 | model.eval() 30 | 31 | response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png") 32 | img = Image.open(BytesIO(response.content)).convert('RGB') 33 | 34 | transform = T.Compose([ 35 | T.Resize(cfg['TEST']['IMAGE_SIZE']), 36 | T.ToTensor(), 37 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 38 | ]) 39 | img = transform(img) 40 | 41 | # make the image divisible by patch size 42 | W, H = img.shape[1] - img.shape[1] % model.patch_size, img.shape[2] - img.shape[2] % model.patch_size 43 | img = img[:, :W, :H].unsqueeze(0) 44 | img = img.to(device) 45 | w_featmap = img.shape[-2] // model.patch_size 46 | h_featmap = img.shape[-1] // model.patch_size 47 | 48 | attentions = model(img, return_attention=True) 49 | 50 | # keep only the output patch attention 51 | attentions = attentions.squeeze()[:, 1:].view(-1, w_featmap, h_featmap) 52 | attentions = F.interpolate(attentions.unsqueeze(0), scale_factor=model.patch_size, mode='nearest')[0].detach().cpu().numpy() 53 | 54 | save_image(make_grid(img, normalize=True, scale_each=True), str(save_dir / "img.png")) 55 | 56 | for i, attn in enumerate(attentions): 57 | fname = save_dir / f"attn-head{i}.png" 58 | plt.imsave(str(fname), attn, format='png') 59 | print(f"{fname} saved.") 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--cfg', type=str, default='configs/dino.yaml', help='Experiment configuration file name') 65 | args = parser.parse_args() 66 | 67 | with open(args.cfg) as f: 68 | cfg = yaml.load(f, Loader=yaml.FullLoader) 69 | 70 | fix_seeds(cfg['TRAIN']['SEED']) 71 | setup_cudnn() 72 | main(cfg) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .schedulers import * 2 | from .loss import * 3 | 4 | schs = { 5 | "steplr": StepLR 6 | } 7 | 8 | losses = { 9 | "dinoloss": DINOLoss, 10 | "ddinoloss": DDINOLoss 11 | } 12 | 13 | def get_loss(cfg, epochs): 14 | loss_fn_name = cfg['TRAIN']['LOSS'] 15 | assert loss_fn_name in losses.keys() 16 | return losses[loss_fn_name](cfg['TRAIN']['DINO']['HEAD_DIM'], cfg['TRAIN']['DINO']['LOCAL_CROPS']+2, cfg['TRAIN']['DINO']['WARMUP_TEACHER_TEMP'], cfg['TRAIN']['DINO']['TEACHER_TEMP'], cfg['TRAIN']['DINO']['WARMUP_TEACHER_EPOCHS'], epochs) 17 | 18 | 19 | def get_scheduler(cfg, optimizer): 20 | scheduler_name = cfg['TRAIN']['SCHEDULER']['NAME'] 21 | assert scheduler_name in schs.keys(), f"Unavailable scheduler name >> {scheduler_name}.\nList of available schedulers: {list(schs.keys())}" 22 | return schs[scheduler_name](optimizer, *cfg['TRAIN']['SCHEDULER']['PARAMS']) -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn, Tensor 4 | from torch.nn import functional as F 5 | from torch import distributed as dist 6 | 7 | 8 | class DINOLoss(nn.Module): 9 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, warmup_teacher_epochs, nepochs, student_temp=0.1, center_momentum=0.9): 10 | super().__init__() 11 | self.student_temp = student_temp 12 | self.ncrops = ncrops 13 | self.center_momentum = center_momentum 14 | self.register_buffer('center', torch.zeros(1, out_dim)) 15 | 16 | self.teacher_temp_schedule = np.concatenate(( 17 | np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_epochs), 18 | np.ones(nepochs - warmup_teacher_epochs) * teacher_temp 19 | )) 20 | 21 | self.softmax = nn.Softmax(dim=-1) 22 | self.logsoftmax = nn.LogSoftmax(dim=-1) 23 | 24 | def forward(self, student_pred: Tensor, teacher_pred: Tensor, epoch: int): 25 | student_pred = student_pred / self.student_temp 26 | student_pred = student_pred.chunk(self.ncrops) 27 | 28 | # teacher centering and sharpening 29 | temp = self.teacher_temp_schedule[epoch] 30 | teacher_pred = self.softmax((teacher_pred - self.center) / temp) 31 | teacher_pred = teacher_pred.detach().chunk(2) 32 | 33 | total_loss, n_loss_terms = 0, 0 34 | 35 | for i, q in enumerate(teacher_pred): 36 | for j, v in enumerate(student_pred): 37 | if j == i: 38 | continue 39 | 40 | loss = torch.sum(-q * self.logsoftmax(v), dim=-1) 41 | total_loss += loss.mean() 42 | n_loss_terms += 1 43 | 44 | total_loss /= n_loss_terms 45 | self.update_center(teacher_pred) 46 | return total_loss 47 | 48 | @torch.no_grad() 49 | def update_center(self, teacher_pred): 50 | batch_center = torch.sum(teacher_pred, dim=0, keepdim=True) 51 | dist.all_reduce(batch_center) 52 | batch_center /= len(teacher_pred) * dist.get_world_size() 53 | 54 | # ema update 55 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 56 | 57 | 58 | class DDINOLoss(nn.Module): 59 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, warmup_teacher_epochs, nepochs, student_temp=0.1, center_momentum=0.9): 60 | super().__init__() 61 | self.student_temp = student_temp 62 | self.ncrops = ncrops 63 | self.center_momentum = center_momentum 64 | self.register_buffer('center', torch.zeros(1, out_dim)) 65 | self.register_buffer('center_grid', torch.zeros(1, out_dim)) 66 | 67 | self.teacher_temp_schedule = np.concatenate(( 68 | np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_epochs), 69 | np.ones(nepochs - warmup_teacher_epochs) * teacher_temp 70 | )) 71 | 72 | self.softmax = nn.Softmax(dim=-1) 73 | self.logsoftmax = nn.LogSoftmax(dim=-1) 74 | 75 | def forward(self, student_outputs: Tensor, teacher_outputs: Tensor, epoch: int): 76 | student_cls_pred, student_region_pred, student_feats, student_npatch = student_outputs 77 | teacher_cls_pred, teacher_region_pred, teacher_feats, teacher_npatch = teacher_outputs 78 | 79 | # teacher centering and sharpening 80 | temp = self.teacher_temp_schedule[epoch] 81 | teacher_cls = self.softmax((teacher_cls_pred - self.center) / temp) 82 | teacher_cls = teacher_cls.detach().chunk(2) 83 | 84 | teacher_region = self.softmax((teacher_region_pred - self.center_grid) / temp) 85 | teacher_region = teacher_region.detach().chunk(2) 86 | 87 | teacher_feats = teacher_feats.chunk(2) 88 | 89 | N = teacher_npatch[0] # number of patches in the first view 90 | B = teacher_region[0].shape[0] // N 91 | 92 | # student sharpening 93 | student_cls = student_cls_pred / self.student_temp 94 | student_cls = student_cls.chunk(self.ncrops) 95 | 96 | student_region = student_region_pred / self.student_temp 97 | student_split_size = [student_npatch[0]] * 2 + [student_npatch[1]] * (self.ncrops - 2) 98 | student_split_size_bs = [i * B for i in student_split_size] 99 | student_region = torch.split(student_region, student_split_size_bs, dim=0) 100 | 101 | student_feats = torch.split(student_feats, student_split_size_bs, dim=0) 102 | 103 | total_loss, n_loss_terms = 0, 0 104 | 105 | for i, q in enumerate(teacher_cls): 106 | for j, v in enumerate(student_cls): 107 | if j == i: 108 | continue 109 | 110 | # view level prediction loss 111 | loss = 0.5 * torch.sum(-q * self.logsoftmax(v), dim=-1) 112 | 113 | # region level prediction loss 114 | s_region_cur = student_region[j].view(B, student_split_size[j], -1) 115 | s_fea_cur = student_feats[j].view(B, student_split_size[j], -1) 116 | 117 | t_region_cur = teacher_region[i].view(B, N, -1) 118 | t_fea_cur = teacher_feats[i].view(B, N, -1) 119 | 120 | # similarity matrix between two sets of region features 121 | region_sim_matrix = torch.matmul(F.normalize(s_fea_cur, p=2, dim=-1), F.normalize(t_fea_cur, p=2, dim=-1).permute(0, 2, 1)) 122 | region_sim_ind = region_sim_matrix.max(dim=2)[1] 123 | 124 | t_indexed_region = torch.gather(t_region_cur, 1, region_sim_ind.unsqueeze(2).expand(-1, -1, t_region_cur.shape[2])) 125 | 126 | loss += 0.5 * torch.sum(-t_indexed_region * self.logsoftmax(s_region_cur), dim=-1).mean(-1) 127 | 128 | total_loss += loss.mean() 129 | n_loss_terms += 1 130 | 131 | total_loss /= n_loss_terms 132 | self.update_center(teacher_cls_pred, teacher_region_pred) 133 | return total_loss 134 | 135 | @torch.no_grad() 136 | def update_center(self, teacher_pred, teacher_grid_pred): 137 | # view level center update 138 | batch_center = torch.sum(teacher_pred, dim=0, keepdim=True) 139 | dist.all_reduce(batch_center) 140 | batch_center /= len(teacher_pred) * dist.get_world_size() 141 | 142 | # region level center update 143 | batch_grid_center = torch.sum(teacher_grid_pred, dim=0, keepdim=True) 144 | dist.all_reduce(batch_grid_center) 145 | batch_grid_center /= len(teacher_grid_pred) * dist.get_world_size() 146 | 147 | # ema update 148 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 149 | self.center_grid = self.center_grid * self.center_momentum + batch_grid_center * (1 - self.center_momentum) 150 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(pred: torch.Tensor, target:torch.Tensor, topk: tuple = (1,)): 5 | maxk = max(topk) 6 | batch_size = target.shape[0] 7 | pred = pred.topk(maxk, 1)[-1] 8 | pred = pred.t() 9 | correct = pred == target.view(1, -1).expand_as(pred) 10 | 11 | return [correct[:k].reshape(-1).float().sum(0)*100. / batch_size for k in topk] -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import StepLR -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | import os 6 | from pathlib import Path 7 | from torch.backends import cudnn 8 | from torch import nn 9 | from torch.autograd import profiler 10 | from typing import Union 11 | from torch import distributed as dist 12 | 13 | 14 | def fix_seeds(seed: int = 123) -> None: 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | 20 | def setup_cudnn() -> None: 21 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 22 | cudnn.benchmark = True 23 | cudnn.deterministic = False 24 | 25 | def time_synchronized() -> float: 26 | if torch.cuda.is_available(): 27 | torch.cuda.synchronize() 28 | return time.time() 29 | 30 | def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]): 31 | tmp_model_path = Path('temp.p') 32 | if isinstance(model, torch.jit.ScriptModule): 33 | torch.jit.save(model, tmp_model_path) 34 | else: 35 | torch.save(model.state_dict(), tmp_model_path) 36 | size = tmp_model_path.stat().st_size 37 | os.remove(tmp_model_path) 38 | return size / 1e6 # in MB 39 | 40 | @torch.no_grad() 41 | def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float: 42 | with profiler.profile(use_cuda=use_cuda) as prof: 43 | _ = model(inputs) 44 | return prof.self_cpu_time_total / 1000 # ms 45 | 46 | def setup_ddp() -> None: 47 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 48 | rank = int(os.environ['RANK']) 49 | world_size = int(os.environ['WORLD_SIZE']) 50 | gpu = int(os.environ(['LOCAL_RANK'])) 51 | elif torch.cuda.is_available(): 52 | rank, world_size, gpu = 0, 1, 0 53 | 54 | torch.cuda.set_device(gpu) 55 | dist.init_process_group('nccl', world_size=world_size, rank=rank) 56 | dist.barrier() 57 | 58 | return gpu 59 | 60 | def cleanup_ddp(): 61 | dist.destroy_process_group() --------------------------------------------------------------------------------