├── README.md ├── checkpoints └── placeholder.rtf ├── data_utils ├── __init__.py ├── data_stats.py ├── dataloader.py └── dataset_to_beton.py ├── explore.ipynb ├── finetune.py ├── image.png ├── models ├── __init__.py └── networks.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── config.py ├── download.py ├── get_compute.py ├── metrics.py ├── optimizer.py └── parsers.py /README.md: -------------------------------------------------------------------------------- 1 | # Scaling MLPs :fire: 2 | ​ 3 | ![](https://user-images.githubusercontent.com/38691167/274573298-0aaf4b37-e9b3-4d67-9c0f-ef27fe17089d.png) 4 | ## Overview 5 | This repository contains the code accompanying our [paper](https://arxiv.org/abs/2306.13575) *Scaling MLPs: A Tale of Inductive Bias*. In this work we explore the limits of the *multi-layer perceptron*, or short MLP, when subjected to higher amounts of compute. More precisely, we study architectures with the following block form: 6 | ![](https://lh3.googleusercontent.com/pw/AIL4fc_3gvNmHfrvhN38zgU2OMTHqG-4w0zMY6of3S7Gi0EoV498btfYB2H7NnYUlpm8d0Va7COQAigFYZ9BCEI93qIqkV4_CKLKtdED6VQ8p-uJrKb6zD0yRfoe2yaMRdFFZeyPXaiFGWkJEurH-wvNGMY1=w1426-h154-s-no?authuser=0) 7 | ​ 8 | **Why?** We argue that such an MLP has minimal inductive bias (compared to convolutional networks, vision transformers, MLPMixers etc.) and thus offers an interesting test bed to explore whether simply scaling compute can make even the simplest models work (to some degree). The importance of inductive bias has recently been questioned due to vision transformers and MLPMixers eclipsing the more structured convolutional models on standard benchmarks. 9 | ​ 10 | Moreover, MLPs still remain to be the main protagonists in ML theory works but surprisingly, very little is known about their empirical performance at scale! We aim to close this gap here and provide the community with very performant MLPs to analyse! 11 | ​ 12 | ## Explore 13 | You can easily explore our pre-trained and fine-tuned models by specifying the checkpooint flag. For instance, to load a BottleneckMLP with 12 blocks of width 1024, pre-trained on Imagenet21k, simply run 14 | >`model = get_model(architecture='B_12-Wi_1024', resolution=64, num_classes=11230, 15 | checkpoint='in21k')` 16 | 17 | If you need an already fine-tuned model, you can specify 18 | >`model = get_model(architecture='B_12-Wi_1024', resolution=64, num_classes=10, 19 | checkpoint='in21k_cifar10')` 20 | 21 | Check-out the Juypter notebook *explore.ipynb* to play around with the models. 22 | 23 | ## Pretrained Models 24 | 25 | We further publish our models pre-trained on ImageNet21k for various number of epochs at an image resolution of $64\times 64$ [here](https://drive.google.com/drive/folders/17pbKnQgftxkGW5zZGuUvN1C---DesqOW?usp=sharing). Fine-tuning the $800$ epochs models for $100$ epochs should give you roughly the following down-stream performances (check *Fine-tuning* section for hyper-parameter details) 26 | 27 | | | #Params | CIFAR10 | CIFAR100 | STL10 | TinyImageNet | ImageNet | ImageNetReal 28 | | ---------------- | ------- | ------- | -------- | ----- | ------------ | ---------- | ------------ 29 | | **B_6-Wi_512** | 24M | 88.5% | 71.2% | 79.9% | 53.2% | 33.3% | 38.2 30 | | **B_12-Wi_512** | 37M | 91.4% | 75.1% | 84.4% | 60.0% | 38.0% | 42.8 31 | | **B_6-Wi_1024** | 74M | 92.5% | 77.1% | 86.5% | 64.3% | 40.0% | 47.0% 32 | | **B_12-Wi_1024** | 124M | 94.2% | 80.0% | 89.9% | 69.9% | 43.2% | 48.6% 33 | | **B_12-Wi_1024 + TTA** | 124M | 95.5% | 82.6% | 92.2% | 73.1% | 51.4% | 57.9% 34 | 35 | 36 | Make sure that you also download the config.txt file and place in together in the same folder as the corresponding checkpoint. 37 | ## Environment Setup 38 | ​ 39 | For installing the *FFCV* dataloading framework, we refer to the original [repository](https://github.com/libffcv/ffcv). To install the remaining packages, activate the FFCV environment and run 40 | >`pip install -r requirements.txt` 41 | ​ 42 | ## Creating .beton Files 43 | In order to use the efficiency of MLPs to the fullest, we need a more optimised data loading framework than the standard one provided by *torch*. This is because the data transfer from CPU to GPU otherwise becomes the bottleneck of training, not the gradient computation!! 44 | To ensure a faster data transfer, we use the *FFCV* framework, which requires converting your dataset first to the **beton** format. This can be achieved by specifying your dataset as a torchvision.dataset object. 45 | 46 | If your dataset is implemented in the torchvision.datasets library, simply add the corresponding lines of code to the `get_dataset` function in `dataset_to_beton.py`. We provide implementations for *CIFAR10* and *CIFAR100*. 47 | 48 | If you have your dataset in the standard hierarchical subfolder structure, i.e. your dataset consists of subfolders each corresponding to a separate class, you can simply specify the `dataset_path` argument in `create_beton` in order to obtain the *.beton* file. 49 | ​ 50 | 51 | Conversion to *.beton* accepts a resolution parameter `res`, specifying the resolution of the images. We recommend using `-- res 64` for very large datasets such as *ImageNet21k* in order to keep the computational requirements manageable for users with less resources. 52 | ​ 53 | 54 | Downloading and converting the trainset of CIFAR10 to the *.beton* format can for instance be achieved by running 55 | >`python3 data_utils/dataset_to_beton.py --dataset_name cifar10 --mode train --res 32` 56 | 57 | Converting a subfolder-structured dataset can be converted to the *.beton* format at resolution 64 by running 58 | >`python3 data_utils/dataset_to_beton.py --data_path path/to/folders --mode train --res 64` 59 | 60 | ## Pre-training 61 | ​ 62 | **ImageNet21k.** Due to legal reasons, we cannot provide the *ImageNet21k* in the .beton format directly. We recommend applying [here](https://www.image-net.org/download.php) to download it but in case you cannot get access, you can use the torrent [here](https://academictorrents.com/details/8ec0d8df0fbb507594557bce993920442f4f6477). Similarly for *ImageNet1k*. Once you have downloaded the dataset, we recommend pre-processing it as detailed in this [repository](https://arxiv.org/abs/2104.10972) to remove faulty images and classes with only very little examples. Then produce the *.beton* as outlined above. 63 | ​ 64 | ​ 65 | 66 | **Pre-training.** For pre-training the `B_12-Wi_1024` *BottleneckMLP* on *ImageNet21k* at resolution $64 \times 64$, you can use the following command: 67 | >`python3 train.py --dataset imagenet21 --model BottleneckMLP --architecture B_12-Wi_1024 --batch_size 16384 --resolution 64` 68 | ​ 69 | 70 | For more specific configurations, we encourage the user to check out all available flags in `train.py`. In case you run into memory issues, try to reduce the batch-size. We remark however that smaller batch sizes tend to lead worse results, check-out our paper where we highlight this effect. During training, the parameters will automatically be saved to the `checkpoints`folder. 71 | ## Fine-tuning 72 | ​ 73 | You can fine-tune our pre-trained checkpoints or your own using the script `finetune.py`. For instance, the following command fine-tunes a pre-trained B_12-Wi_1024 model on CIFAR10, provided you have converted the CIFAR10 dataset to the *.beton* format: 74 | > `python3 finetune.py --architecture B_12-Wi_1024 --checkpoint res_64_in21k --dataset cifar10 --data_resolution 32 --batch_size 2048 --epochs 100 --lr 0.01 --weight_decay 0.0001 --data_path /local/home/stuff/ --crop_scale 0.4 1. --crop_ratio 1. 1. --optimizer sgd --augment --mode finetune --smooth 0.3` 75 | ​ 76 | 77 | You can also train a linear layer on top by specifying the flag `--mode linear`instead. 78 | 79 | 80 | ## Own Dataset 81 | If you want to add your own dataset, convert it to the FFCV format as detailed above and make sure to fill in the values provided in data_utils/data_stats for the other datasets, such as number of classes, number of samples etc. -------------------------------------------------------------------------------- /checkpoints/placeholder.rtf: -------------------------------------------------------------------------------- 1 | {\rtf1\ansi\ansicpg1252\cocoartf2634 2 | \cocoatextscaling0\cocoaplatform0{\fonttbl\f0\fswiss\fcharset0 Helvetica;} 3 | {\colortbl;\red255\green255\blue255;} 4 | {\*\expandedcolortbl;;} 5 | \paperw11900\paperh16840\margl1440\margr1440\vieww11520\viewh8400\viewkind0 6 | \pard\tx566\tx1133\tx1700\tx2267\tx2834\tx3401\tx3968\tx4535\tx5102\tx5669\tx6236\tx6803\pardirnatural\partightenfactor0 7 | 8 | \f0\fs24 \cf0 Empty file as placeholder} -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregorbachmann/scaling_mlps/50cb3a90fdac8fad5b9e4d3733f29464d4344277/data_utils/__init__.py -------------------------------------------------------------------------------- /data_utils/data_stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Define all the relevant stats for the datasets to look up 4 | 5 | # Number of samples 6 | SAMPLE_DICT = { 7 | "imagenet21": 11801680, 8 | "imagenet": 1281167, 9 | "imagenet_real": 1281167, 10 | "tinyimagenet": 100000, 11 | "cifar10": 50000, 12 | "cifar100": 50000, 13 | "stl10": 5000, 14 | } 15 | 16 | # Number of classes 17 | CLASS_DICT = { 18 | "imagenet21": 11230, 19 | "in21k": 11230, # Need the short name here too for loading models 20 | "imagenet": 1000, 21 | "in1k": 1000, # Need the short name here too for loading models 22 | 'imagenet_real': 1000, 23 | "tinyimagenet": 200, 24 | "cifar10": 10, 25 | "cifar100": 100, 26 | "stl10": 10, 27 | } 28 | 29 | # Image resolutions 30 | DEFAULT_RES_DICT = { 31 | "imagenet21": 64, 32 | "imagenet": 64, 33 | "imagenet_real": 64, 34 | "tinyimagenet": 64, 35 | "cifar10": 32, 36 | "cifar100": 32, 37 | "stl10": 64, 38 | } 39 | 40 | 41 | # Parent directory name 42 | DATA_DICT = { 43 | "imagenet21": "imagenet21", 44 | "imagenet": "imagenet", 45 | "tinyimagenet": "tiny-imagenet-200", 46 | "imagenet_real": "imagenet", 47 | "cifar10": "cifar10", 48 | "cifar100": "cifar100", 49 | "stl10": "stl10", 50 | } 51 | 52 | MODE_DICT = { 53 | "imagenet21": "test", 54 | "imagenet": "val", 55 | 'imagenet_real': "val", 56 | "tinyimagenet": "val", 57 | "cifar10": "val", 58 | "cifar100": "test", 59 | "stl10": "val", 60 | } 61 | 62 | # Standardization statistics 63 | MEAN_DICT = { 64 | "imagenet21": np.array([0.485, 0.456, 0.406]) * 255, 65 | "imagenet": np.array([0.485, 0.456, 0.406]) * 255, 66 | "imagenet_real": np.array([0.485, 0.456, 0.406]) * 255, 67 | "tinyimagenet": np.array([0.485, 0.456, 0.406]) * 255, 68 | "cifar10": np.array([0.49139968, 0.48215827, 0.44653124]) * 255, 69 | "cifar100": np.array([0.49139968, 0.48215827, 0.44653124]) * 255, 70 | "stl10": np.array([0.4914, 0.4822, 0.4465]) * 255, 71 | } 72 | 73 | 74 | STD_DICT = { 75 | "imagenet21": np.array([0.229, 0.224, 0.225]) * 255, 76 | "imagenet": np.array([0.229, 0.224, 0.225]) * 255, 77 | "imagenet_real": np.array([0.229, 0.224, 0.225]) * 255, 78 | "tinyimagenet": np.array([0.229, 0.224, 0.225]) * 255, 79 | "cifar10": np.array([0.24703233, 0.24348505, 0.26158768]) * 255, 80 | "cifar100": np.array([0.24703233, 0.24348505, 0.26158768]) * 255, 81 | "stl10": np.array([0.2471, 0.2435, 0.2616]) * 255, 82 | } 83 | 84 | # Whether dataset can be cached in memory 85 | OS_CACHED_DICT = { 86 | "imagenet21": False, 87 | "imagenet": False, 88 | "imagenet_real": False, 89 | "tinyimagenet": True, 90 | "cifar10": True, 91 | "cifar100": True, 92 | "stl10": True, 93 | "pets": True, 94 | "coyo": False, 95 | } 96 | -------------------------------------------------------------------------------- /data_utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | from typing import List 5 | 6 | import torch 7 | import torchvision 8 | from ffcv.fields.decoders import IntDecoder, NDArrayDecoder 9 | from ffcv.fields.rgb_image import ( 10 | CenterCropRGBImageDecoder, 11 | RandomResizedCropRGBImageDecoder, 12 | ) 13 | from ffcv.loader import Loader, OrderOption 14 | from ffcv.pipeline.operation import Operation 15 | from ffcv.transforms import ( 16 | Convert, 17 | ImageMixup, 18 | LabelMixup, 19 | RandomHorizontalFlip, 20 | ToDevice, 21 | ToTensor, 22 | ToTorchImage, 23 | ) 24 | from ffcv.transforms.common import Squeeze 25 | 26 | from .data_stats import * 27 | 28 | 29 | # Define an ffcv dataloader 30 | def get_loader( 31 | dataset, 32 | bs, 33 | mode, 34 | augment, 35 | dev, 36 | data_resolution=None, 37 | crop_resolution=None, 38 | crop_ratio=(0.75, 1.3333333333333333), 39 | crop_scale=(0.08, 1.0), 40 | num_samples=None, 41 | dtype=torch.float32, 42 | mixup=None, 43 | data_path='./beton', 44 | ): 45 | mode_name = MODE_DICT[dataset] if mode != 'train' else mode 46 | os_cache = OS_CACHED_DICT[dataset] 47 | 48 | if data_resolution is None: 49 | data_resolution = DEFAULT_RES_DICT[dataset] 50 | if crop_resolution is None: 51 | crop_resolution = data_resolution 52 | 53 | real = '' if dataset != 'imagenet_real' or mode == 'train' else 'real_' 54 | sub_sampled = '' if num_samples is None or num_samples == SAMPLE_DICT[dataset] else '_ntrain_' + str(num_samples) 55 | 56 | beton_path = os.path.join( 57 | data_path, 58 | DATA_DICT[dataset], 59 | 'ffcv', 60 | mode_name, 61 | real + f'{mode_name}_{data_resolution}' + sub_sampled + '.beton', 62 | ) 63 | 64 | print(f'Loading {beton_path}') 65 | 66 | mean = MEAN_DICT[dataset] 67 | std = STD_DICT[dataset] 68 | 69 | if dataset == 'imagenet_real' and mode != 'train': 70 | label_pipeline: List[Operation] = [NDArrayDecoder()] 71 | else: 72 | label_pipeline: List[Operation] = [IntDecoder()] 73 | 74 | if augment: 75 | image_pipeline: List[Operation] = [ 76 | RandomResizedCropRGBImageDecoder((crop_resolution, crop_resolution), ratio=crop_ratio, scale=crop_scale), 77 | RandomHorizontalFlip(), 78 | ] 79 | else: 80 | image_pipeline: List[Operation] = [ 81 | CenterCropRGBImageDecoder(output_size=(crop_resolution, crop_resolution), ratio=1) 82 | ] 83 | 84 | # Add image transforms and normalization 85 | if mode == 'train' and augment and mixup > 0: 86 | image_pipeline.extend([ImageMixup(alpha=mixup, same_lambda=True)]) 87 | label_pipeline.extend([LabelMixup(alpha=mixup, same_lambda=True)]) 88 | 89 | label_pipeline.extend([ToTensor(), ToDevice(dev, non_blocking=True), Squeeze()]) 90 | 91 | image_pipeline.extend( 92 | [ 93 | ToTensor(), 94 | ToDevice(dev, non_blocking=True), 95 | ToTorchImage(), 96 | Convert(dtype), 97 | torchvision.transforms.Normalize(mean, std), 98 | ] 99 | ) 100 | 101 | if mode == 'train': 102 | num_samples = SAMPLE_DICT[dataset] if num_samples is None else num_samples 103 | 104 | # Shuffle indices in case the classes are ordered 105 | #indices = list(range(num_samples)) 106 | 107 | #random.seed(0) 108 | #random.shuffle(indices) 109 | indices = None 110 | else: 111 | indices = None 112 | 113 | return Loader( 114 | beton_path, 115 | batch_size=bs, 116 | num_workers=4, 117 | order=OrderOption.QUASI_RANDOM if mode == 'train' else OrderOption.SEQUENTIAL, 118 | drop_last=(mode == 'train'), 119 | pipelines={'image': image_pipeline, 'label': label_pipeline}, 120 | os_cache=os_cache, 121 | indices=indices, 122 | ) 123 | -------------------------------------------------------------------------------- /data_utils/dataset_to_beton.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torchvision 5 | from ffcv.fields import BytesField, IntField, RGBImageField 6 | from ffcv.writer import DatasetWriter 7 | 8 | 9 | def get_dataset(dataset_name, mode, data_path): 10 | if data_path is not None: 11 | return torchvision.datasets.ImageFolder(root=data_path, transform=None) 12 | 13 | if dataset_name == "cifar10": 14 | return torchvision.datasets.CIFAR10( 15 | root="/tmp", train=mode == "train", download=True 16 | ) 17 | elif dataset_name == "cifar100": 18 | return torchvision.datasets.CIFAR100( 19 | root="/tmp", train=mode == "train", download=True 20 | ) 21 | else: 22 | raise NotImplementedError( 23 | f"Dataset {dataset_name} not supported. Please add it here." 24 | ) 25 | 26 | 27 | def create_beton(args): 28 | dataset = get_dataset(args.dataset_name, args.mode, args.data_path) 29 | 30 | write_path = os.path.join( 31 | args.write_path, args.dataset_name, args.mode, f"{args.mode}_{args.res}.beton" 32 | ) 33 | 34 | os.makedirs(os.path.dirname(write_path), exist_ok=True) 35 | 36 | writer = DatasetWriter( 37 | write_path, 38 | { 39 | "image": RGBImageField(write_mode="smart", max_resolution=args.res), 40 | "label": IntField(), 41 | }, 42 | num_workers=args.num_workers, 43 | ) 44 | 45 | writer.from_indexed_dataset(dataset, chunksize=100) 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser(description="Convert dataset to .beton format") 50 | parser.add_argument("--dataset_name", type=str, default=None, help="dataset name") 51 | parser.add_argument( 52 | "--data_path", 53 | type=str, 54 | default=None, 55 | help="path to dataset if data is given in a hierarchical subfolder structure.", 56 | ) 57 | parser.add_argument("--mode", type=str, default="train", help="train or test") 58 | parser.add_argument("--res", type=int, default=32, help="resolution of images") 59 | parser.add_argument( 60 | "--write_path", 61 | type=str, 62 | default="./beton/", 63 | help="path to write .beton file to", 64 | ) 65 | parser.add_argument( 66 | "--num_workers", type=int, default=16, help="number of workers to use" 67 | ) 68 | args = parser.parse_args() 69 | 70 | assert ( 71 | args.dataset_name is not None or args.data_path is not None 72 | ), "Either dataset_name or data_path must be specified." 73 | 74 | create_beton(args) 75 | -------------------------------------------------------------------------------- /explore.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "import torch\n", 12 | "from tqdm import tqdm\n", 13 | "from ffcv.fields import BytesField, IntField, RGBImageField\n", 14 | "from ffcv.writer import DatasetWriter\n", 15 | "\n", 16 | "from data_utils.data_stats import *\n", 17 | "from data_utils.dataloader import get_loader\n", 18 | "from utils.metrics import topk_acc, real_acc, AverageMeter\n", 19 | "from models.networks import get_model\n", 20 | "from data_utils.dataset_to_beton import get_dataset" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "dataset = 'cifar10' # One of cifar10, cifar100, stl10, imagenet or imagenet21\n", 30 | "architecture = 'B_12-Wi_1024'\n", 31 | "data_resolution = 32 # Resolution of data as it is stored\n", 32 | "crop_resolution = 64 # Resolution of fine-tuned model (64 for all models we provide)\n", 33 | "num_classes = CLASS_DICT[dataset]\n", 34 | "data_path = './beton/'\n", 35 | "eval_batch_size = 1024\n", 36 | "checkpoint = 'in21k_cifar10' # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# If you did not yet, produce .beton file for CIFAR10 (check README for how to do that for ImageNet)\n", 46 | "def create_beton(dataset, mode, data_path, res):\n", 47 | " dataset = get_dataset(dataset, mode, data_path, res)\n", 48 | "\n", 49 | " write_path = os.path.join(\n", 50 | " write_path, dataset, mode, f\"{mode}_{res}.beton\"\n", 51 | " )\n", 52 | "\n", 53 | " os.makedirs(os.path.dirname(write_path), exist_ok=True)\n", 54 | "\n", 55 | " writer = DatasetWriter(\n", 56 | " write_path,\n", 57 | " {\n", 58 | " \"image\": RGBImageField(write_mode=\"smart\", max_resolution=res),\n", 59 | " \"label\": IntField(),\n", 60 | " },\n", 61 | " num_workers=0,\n", 62 | " )\n", 63 | "\n", 64 | " writer.from_indexed_dataset(dataset, chunksize=100)\n", 65 | "\n", 66 | "\n", 67 | "create_beton(dataset, 'test', data_path, data_resolution)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "torch.backends.cuda.matmul.allow_tf32 = True\n", 77 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 78 | "\n", 79 | "# Define the model and specify the pre-trained weights\n", 80 | "model = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],\n", 81 | " checkpoint='in21k_cifar10')\n", 82 | "model.cuda()" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "# Get the test loader\n", 92 | "loader = get_loader(\n", 93 | " dataset,\n", 94 | " bs=eval_batch_size,\n", 95 | " mode=\"test\",\n", 96 | " augment=False,\n", 97 | " dev=device,\n", 98 | " mixup=0.0,\n", 99 | " data_path=data_path,\n", 100 | " data_resolution=data_resolution,\n", 101 | " crop_resolution=crop_resolution,\n", 102 | ")" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "# Define a test function that evaluates test accuracy\n", 112 | "@torch.no_grad()\n", 113 | "def test(model, loader):\n", 114 | " model.eval()\n", 115 | " total_acc, total_top5 = AverageMeter(), AverageMeter()\n", 116 | "\n", 117 | " for ims, targs in tqdm(loader, desc=\"Evaluation\"):\n", 118 | " ims = torch.reshape(ims, (ims.shape[0], -1))\n", 119 | " preds = model(ims)\n", 120 | "\n", 121 | " if dataset != 'imagenet_real':\n", 122 | " acc, top5 = topk_acc(preds, targs, k=5, avg=True)\n", 123 | " else:\n", 124 | " acc = real_acc(preds, targs, k=5, avg=True)\n", 125 | " top5 = 0\n", 126 | "\n", 127 | " total_acc.update(acc, ims.shape[0])\n", 128 | " total_top5.update(top5, ims.shape[0])\n", 129 | "\n", 130 | "\n", 131 | " return (\n", 132 | " total_acc.get_avg(percentage=True),\n", 133 | " total_top5.get_avg(percentage=True),\n", 134 | " )" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "test_acc, test_top5 = test(model, loader)\n", 144 | "\n", 145 | "# Print all the stats\n", 146 | "print(\"Test Accuracy \", \"{:.4f}\".format(test_acc))\n", 147 | "print(\"Top 5 Test Accuracy \", \"{:.4f}\".format(test_top5))" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "language_info": { 153 | "name": "python" 154 | }, 155 | "orig_nbformat": 4 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 2 159 | } 160 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from torch.nn import CrossEntropyLoss, Linear 6 | from tqdm import tqdm 7 | 8 | from data_utils.data_stats import * 9 | from data_utils.dataloader import get_loader 10 | from models import get_architecture 11 | from models.networks import get_model 12 | from utils.parsers import get_finetune_parser 13 | from utils.config import config_to_name, model_from_config, model_from_checkpoint 14 | from utils.metrics import topk_acc, real_acc 15 | from utils.optimizer import ( 16 | OPTIMIZERS_DICT, 17 | SCHEDULERS, 18 | get_optimizer, 19 | get_scheduler, 20 | ) 21 | from train import train, test 22 | 23 | 24 | @torch.no_grad() 25 | def test_time_aug(model, loader, num_augs, args): 26 | model.eval() 27 | all_preds = torch.zeros(len(loader.indices), model.linear_out.out_features) 28 | 29 | for _ in tqdm(range(num_augs)): 30 | targets = [] 31 | cnt = 0 32 | 33 | for ims, targs in loader: 34 | ims = torch.reshape(ims, (ims.shape[0], -1)) 35 | preds = model(ims) 36 | 37 | all_preds[cnt:cnt + ims.shape[0]] += torch.nn.functional.softmax(preds.detach().cpu(), dim=-1) 38 | targets.append(targs.detach().cpu()) 39 | 40 | cnt += ims.shape[0] 41 | 42 | all_preds = all_preds / num_augs 43 | targets = torch.cat(targets) 44 | 45 | if args.dataset != 'imagenet_real': 46 | acc, top5 = topk_acc(all_preds, targets, k=5, avg=True) 47 | else: 48 | acc = real_acc(all_preds, targets, k=5, avg=True) 49 | top5 = 0. 50 | 51 | return 100 * acc, 100 * top5 52 | 53 | 54 | def finetune(args): 55 | # Use mixed precision matrix multiplication 56 | torch.backends.cuda.matmul.allow_tf32 = True 57 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 58 | 59 | pretrained, crop_resolution, num_pretrain_classes = model_from_checkpoint(args.checkpoint) 60 | model = get_model(architecture=args.architecture, resolution=crop_resolution, num_classes=num_pretrain_classes, 61 | checkpoint=pretrained) 62 | args.crop_resolution = crop_resolution 63 | 64 | # Get the dataloaders 65 | train_loader = get_loader( 66 | args.dataset, 67 | bs=args.batch_size, 68 | mode='train', 69 | augment=args.augment, 70 | dev=device, 71 | num_samples=args.n_train, 72 | mixup=args.mixup, 73 | data_path=args.data_path, 74 | data_resolution=args.data_resolution, 75 | crop_resolution=args.crop_resolution, 76 | crop_ratio=tuple(args.crop_ratio), 77 | crop_scale=tuple(args.crop_scale) 78 | ) 79 | 80 | test_loader = get_loader( 81 | args.dataset, 82 | bs=args.batch_size, 83 | mode='test', 84 | augment=False, 85 | dev=device, 86 | data_path=args.data_path, 87 | data_resolution=args.data_resolution, 88 | crop_resolution=args.crop_resolution, 89 | ) 90 | 91 | test_loader_aug = get_loader( 92 | args.dataset, 93 | bs=args.batch_size, 94 | mode='test', 95 | augment=True, 96 | dev=device, 97 | data_path=args.data_path, 98 | data_resolution=args.data_resolution, 99 | crop_resolution=args.crop_resolution, 100 | crop_ratio=tuple(args.crop_ratio), 101 | crop_scale=tuple(args.crop_scale) 102 | 103 | ) 104 | 105 | model.linear_out = Linear(model.linear_out.in_features, args.num_classes) 106 | model.cuda() 107 | 108 | param_groups = [ 109 | { 110 | 'params': [v for k, v in model.named_parameters() if 'linear_out' in k], 111 | 'lr': args.lr, 112 | }, 113 | ] 114 | 115 | if args.mode != "linear": 116 | param_groups.append( 117 | { 118 | 'params': [ 119 | v for k, v in model.named_parameters() if 'linear_out' not in k 120 | ], 121 | 'lr': args.lr * args.body_learning_rate_multiplier, 122 | }, 123 | ) 124 | else: 125 | # freeze the body 126 | for name, param in model.named_parameters(): 127 | if 'linear_out' not in name: 128 | param.requires_grad = False 129 | 130 | # Create folder to store the checkpoints 131 | path = os.path.join(args.checkpoint_folder, args.checkpoint + '_' + args.dataset) 132 | if not os.path.exists(path): 133 | os.makedirs(path) 134 | with open(path + '/config.txt', 'w') as f: 135 | json.dump(args.__dict__, f, indent=2) 136 | 137 | 138 | opt = get_optimizer(args.optimizer)(param_groups, lr=args.lr) 139 | 140 | scheduler = get_scheduler(opt, args.scheduler, **args.__dict__) 141 | loss_fn = CrossEntropyLoss(label_smoothing=args.smooth) 142 | 143 | for ep in range(args.epochs): 144 | train_acc, train_top5, train_loss, train_time = train( 145 | model, opt, scheduler, loss_fn, ep, train_loader, args 146 | ) 147 | 148 | if (ep + 1) % args.calculate_stats == 0: 149 | test_acc, test_top5, test_loss, test_time = test( 150 | model, test_loader, loss_fn, args 151 | ) 152 | 153 | # Print all the stats 154 | print('Epoch', ep, ' Time:', train_time) 155 | print('-------------- Training ----------------') 156 | print('Average Training Loss: ', '{:.6f}'.format(train_loss)) 157 | print('Average Training Accuracy: ', '{:.4f}'.format(train_acc)) 158 | print('Top 5 Training Accuracy: ', '{:.4f}'.format(train_top5)) 159 | print('---------------- Test ------------------') 160 | print('Test Accuracy ', '{:.4f}'.format(test_acc)) 161 | print('Top 5 Test Accuracy ', '{:.4f}'.format(test_top5)) 162 | print() 163 | 164 | if ep % args.save_freq == 0 and args.save: 165 | torch.save( 166 | model.state_dict(), 167 | path + "/epoch_" + str(ep), 168 | ) 169 | 170 | print('-------- Test Time Augmentation Evaluation -------') 171 | 172 | num_augs = 100 173 | acc, top5 = test_time_aug(model, test_loader_aug, num_augs, args) 174 | print(num_augs, 'augmentations: Test accuracy:', acc) 175 | print(num_augs, 'augmentations: Test Top5 accuracy:', top5) 176 | 177 | 178 | if __name__ == "__main__": 179 | parser = get_finetune_parser() 180 | args = parser.parse_args() 181 | 182 | args.num_classes = CLASS_DICT[args.dataset] 183 | 184 | if args.n_train is None: 185 | args.n_train = SAMPLE_DICT[args.dataset] 186 | 187 | finetune(args) 188 | -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregorbachmann/scaling_mlps/50cb3a90fdac8fad5b9e4d3733f29464d4344277/image.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import BottleneckMLP, StandardMLP 2 | 3 | 4 | def get_architecture( 5 | architecture="B_6-Wi_1024", 6 | model="BottleneckMLP", 7 | num_channels=3, 8 | crop_resolution=None, 9 | num_classes=None, 10 | normalization='layer', 11 | act='gelu', 12 | drop_rate=None, 13 | **kwargs, 14 | ): 15 | assert model in ["BottleneckMLP", "MLP"], f"Model {model} not supported." 16 | 17 | sep = architecture.split("-") 18 | num_blocks = int(sep[0].split("_")[1]) 19 | thin = int(sep[1].split("_")[1]) 20 | 21 | if len(sep) == 3: 22 | expansion_factor = int(sep[2].split("_")[1]) 23 | else: 24 | expansion_factor = 4 25 | 26 | if model == "BottleneckMLP": 27 | blocks = [[expansion_factor * thin, thin] for _ in range(num_blocks)] 28 | dim_in = crop_resolution**2 * num_channels 29 | return BottleneckMLP( 30 | dim_in=dim_in, 31 | dim_out=num_classes, 32 | block_dims=blocks, 33 | norm=normalization, 34 | act=act, 35 | drop_rate=drop_rate 36 | ) 37 | 38 | elif model == "MLP": 39 | blocks = [thin for _ in range(num_blocks)] 40 | 41 | return StandardMLP( 42 | dim_in=crop_resolution**2 * num_channels, 43 | dim_out=num_classes, 44 | widths=blocks, 45 | ) 46 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from torch import nn 4 | import torch 5 | import numpy as np 6 | 7 | from utils.download import download, default_checkpoints 8 | 9 | 10 | NORMS = { 11 | 'layer': nn.LayerNorm, 12 | 'batch': nn.BatchNorm1d, 13 | 'none': nn.Identity 14 | } 15 | 16 | ACT = { 17 | 'gelu': nn.GELU(), 18 | 'relu': nn.ReLU() 19 | } 20 | 21 | 22 | class StandardMLP(nn.Module): 23 | def __init__(self, dim_in, dim_out, widths): 24 | super(StandardMLP, self).__init__() 25 | self.dim_in = dim_in 26 | self.dim_out = dim_out 27 | self.widths = widths 28 | self.linear_in = nn.Linear(self.dim_in, self.widths[0]) 29 | self.linear_out = nn.Linear(self.widths[-1], self.dim_out) 30 | self.layers = [] 31 | self.layer_norms = [] 32 | for i in range(len(self.widths) - 1): 33 | self.layers.append(nn.Linear(self.widths[i], self.widths[i + 1])) 34 | self.layer_norms.append(nn.LayerNorm(widths[i + 1])) 35 | 36 | self.layers = nn.ModuleList(self.layers) 37 | self.layernorms = nn.ModuleList(self.layer_norms) 38 | 39 | def forward(self, x): 40 | z = self.linear_in(x) 41 | for layer, norm in zip(self.layers, self.layer_norms): 42 | z = norm(z) 43 | z = nn.GELU()(z) 44 | z = layer(z) 45 | 46 | out = self.linear_out(z) 47 | 48 | return out 49 | 50 | 51 | class BottleneckMLP(nn.Module): 52 | def __init__(self, dim_in, dim_out, block_dims, norm='layer', checkpoint=None, name=None): 53 | super(BottleneckMLP, self).__init__() 54 | self.dim_in = dim_in 55 | self.dim_out = dim_out 56 | self.block_dims = block_dims 57 | self.norm = NORMS[norm] 58 | self.checkpoint = checkpoint 59 | 60 | self.name = name 61 | self.linear_in = nn.Linear(self.dim_in, self.block_dims[0][1]) 62 | self.linear_out = nn.Linear(self.block_dims[-1][1], self.dim_out) 63 | blocks = [] 64 | layernorms = [] 65 | 66 | for block_dim in self.block_dims: 67 | wide, thin = block_dim 68 | blocks.append(BottleneckBlock(thin=thin, wide=wide)) 69 | layernorms.append(self.norm(thin)) 70 | 71 | self.blocks = nn.ModuleList(blocks) 72 | self.layernorms = nn.ModuleList(layernorms) 73 | 74 | if self.checkpoint is not None: 75 | self.load(self.checkpoint) 76 | 77 | def forward(self, x): 78 | x = self.linear_in(x) 79 | 80 | for block, norm in zip(self.blocks, self.layernorms): 81 | x = x + block(norm(x)) 82 | 83 | out = self.linear_out(x) 84 | 85 | return out 86 | 87 | def load(self, name, checkpoint_path='./checkpoints/'): 88 | #if name == True: 89 | # This simply assumes Imagenet21 pre-trained weights at the latest epoch available, no fine-tuning 90 | # name = default_checkpoints[self.name] 91 | #elif name in ['cifar10', 'cifar100', 'imagenet']: 92 | # This loads the optimal fine-tuned weights for that dataset 93 | # name = default_checkpoints[self.name + '_' + name] 94 | #else: 95 | # This assumes a full path, e.g. also specifying which epoch etc 96 | # name = self.name + '_' + name 97 | name = self.name + '_' + name 98 | name = default_checkpoints[name] 99 | weight_path, config_path = download(name, checkpoint_path) 100 | 101 | with open(config_path, 'r') as f: 102 | self.config = json.load(f) 103 | 104 | params = { 105 | k: v 106 | for k, v in torch.load(weight_path).items() 107 | } 108 | 109 | # Load pre-trained parameters 110 | print('Load_state output', self.load_state_dict(params, strict=False)) 111 | 112 | 113 | class BottleneckBlock(nn.Module): 114 | def __init__(self, thin, wide, act=nn.GELU()): 115 | super(BottleneckBlock, self).__init__() 116 | 117 | self.block = nn.Sequential( 118 | nn.Linear(thin, wide), act, nn.Linear(wide, thin) 119 | ) 120 | 121 | def forward(self, x): 122 | out = self.block(x) 123 | 124 | return out 125 | 126 | 127 | def B_12_Wi_1024(dim_in, dim_out, checkpoint=None): 128 | block_dims = [[4 * 1024, 1024] for _ in range(12)] 129 | return BottleneckMLP(dim_in=dim_in, dim_out=dim_out, norm='layer', block_dims=block_dims, checkpoint=checkpoint, 130 | name='B_' + str(len(block_dims)) + '-Wi_' + str(block_dims[0][1]) + '_res_' + str(int(np.sqrt(dim_in/3)))) 131 | 132 | 133 | def B_12_Wi_512(dim_in, dim_out, checkpoint=None): 134 | block_dims = [[4 * 512, 512] for _ in range(12)] 135 | return BottleneckMLP(dim_in=dim_in, dim_out=dim_out, norm='layer', block_dims=block_dims, checkpoint=checkpoint, 136 | name='B_' + str(len(block_dims)) + '-Wi_' + str(block_dims[0][1]) + '_res_' + str(int(np.sqrt(dim_in/3)))) 137 | 138 | 139 | def B_6_Wi_1024(dim_in, dim_out, checkpoint=None): 140 | block_dims = [[4 * 1024, 1024] for _ in range(6)] 141 | return BottleneckMLP(dim_in=dim_in, dim_out=dim_out, norm='layer', block_dims=block_dims, checkpoint=checkpoint, 142 | name='B_' + str(len(block_dims)) + '-Wi_' + str(block_dims[0][1]) + '_res_' + str(int(np.sqrt(dim_in/3)))) 143 | 144 | 145 | def B_6_Wi_512(dim_in, dim_out, checkpoint=None): 146 | block_dims = [[4 * 512, 512] for _ in range(6)] 147 | return BottleneckMLP(dim_in=dim_in, dim_out=dim_out, norm='layer', block_dims=block_dims, checkpoint=checkpoint, 148 | name='B_' + str(len(block_dims)) + '-Wi_' + str(block_dims[0][1]) + '_res_' + str(int(np.sqrt(dim_in/3)))) 149 | 150 | 151 | model_list = { 152 | 'B_12-Wi_1024': B_12_Wi_1024, 153 | 'B_12-Wi_512': B_12_Wi_512, 154 | 'B_6-Wi_1024': B_6_Wi_1024, 155 | 'B_6-Wi_512': B_6_Wi_512 156 | } 157 | 158 | 159 | def get_model(architecture, checkpoint, resolution, num_classes): 160 | return model_list[architecture](dim_in=resolution**2 * 3, dim_out=num_classes, checkpoint=checkpoint) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore==0.1.5.post20221221 2 | matplotlib==3.6.3 3 | numba==0.56.4 4 | numpy==1.23.5 5 | Pillow==10.0.0 6 | psutil==5.9.4 7 | torch==1.12.1+cu113 8 | torchvision==0.13.1+cu113 9 | tqdm==4.64.1 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | 5 | import torch 6 | import wandb 7 | from torch.nn import CrossEntropyLoss 8 | from tqdm import tqdm 9 | 10 | from models import get_architecture 11 | from data_utils.data_stats import * 12 | from data_utils.dataloader import get_loader 13 | from utils.config import config_to_name 14 | from utils.get_compute import get_compute 15 | from utils.metrics import topk_acc, real_acc, AverageMeter 16 | from utils.optimizer import get_optimizer, get_scheduler, OPTIMIZERS_DICT, SCHEDULERS 17 | from utils.parsers import get_training_parser 18 | 19 | 20 | def train(model, opt, scheduler, loss_fn, epoch, train_loader, args): 21 | start = time.time() 22 | model.train() 23 | 24 | total_acc, total_top5 = AverageMeter(), AverageMeter() 25 | total_loss = AverageMeter() 26 | 27 | for step, (ims, targs) in enumerate(tqdm(train_loader, desc="Training epoch: " + str(epoch))): 28 | ims = torch.reshape(ims, (ims.shape[0], -1)) 29 | preds = model(ims) 30 | 31 | if args.mixup > 0: 32 | targs_perm = targs[:, 1].long() 33 | weight = targs[0, 2].squeeze() 34 | targs = targs[:, 0].long() 35 | if weight != -1: 36 | loss = loss_fn(preds, targs) * weight + loss_fn(preds, targs_perm) * ( 37 | 1 - weight 38 | ) 39 | else: 40 | loss = loss_fn(preds, targs) 41 | targs_perm = None 42 | else: 43 | loss = loss_fn(preds, targs) 44 | targs_perm = None 45 | 46 | acc, top5 = topk_acc(preds, targs, targs_perm, k=5, avg=True) 47 | total_acc.update(acc, ims.shape[0]) 48 | total_top5.update(top5, ims.shape[0]) 49 | 50 | loss = loss / args.accum_steps 51 | loss.backward() 52 | 53 | if (step + 1) % args.accum_steps == 0 or (step + 1) == len(train_loader): 54 | if args.clip > 0: 55 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 56 | opt.step() 57 | opt.zero_grad() 58 | 59 | total_loss.update(loss.item() * args.accum_steps, ims.shape[0]) 60 | 61 | end = time.time() 62 | 63 | scheduler.step() 64 | 65 | return ( 66 | total_acc.get_avg(percentage=True), 67 | total_top5.get_avg(percentage=True), 68 | total_loss.get_avg(percentage=False), 69 | end - start, 70 | ) 71 | 72 | 73 | @torch.no_grad() 74 | def test(model, loader, loss_fn, args): 75 | start = time.time() 76 | model.eval() 77 | total_acc, total_top5, total_loss = AverageMeter(), AverageMeter(), AverageMeter() 78 | 79 | for ims, targs in tqdm(loader, desc="Evaluation"): 80 | ims = torch.reshape(ims, (ims.shape[0], -1)) 81 | preds = model(ims) 82 | 83 | if args.dataset != 'imagenet_real': 84 | acc, top5 = topk_acc(preds, targs, k=5, avg=True) 85 | loss = loss_fn(preds, targs).item() 86 | else: 87 | acc = real_acc(preds, targs, k=5, avg=True) 88 | top5 = 0 89 | loss = 0 90 | 91 | total_acc.update(acc, ims.shape[0]) 92 | total_top5.update(top5, ims.shape[0]) 93 | total_loss.update(loss) 94 | 95 | end = time.time() 96 | 97 | return ( 98 | total_acc.get_avg(percentage=True), 99 | total_top5.get_avg(percentage=True), 100 | total_loss.get_avg(percentage=False), 101 | end - start, 102 | ) 103 | 104 | 105 | def main(args): 106 | # Use mixed precision matrix multiplication 107 | torch.backends.cuda.matmul.allow_tf32 = True 108 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 109 | 110 | model = get_architecture(**args.__dict__).cuda() 111 | 112 | # Count number of parameters for logging purposes 113 | args.num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 114 | 115 | # Create unique identifier 116 | name = config_to_name(args) 117 | path = os.path.join(args.checkpoint_folder, name) 118 | 119 | # Create folder to store the checkpoints 120 | if not os.path.exists(path): 121 | os.makedirs(path) 122 | with open(path + '/config.txt', 'w') as f: 123 | json.dump(args.__dict__, f, indent=2) 124 | 125 | # Get the dataloaders 126 | local_batch_size = args.batch_size // args.accum_steps 127 | 128 | train_loader = get_loader( 129 | args.dataset, 130 | bs=local_batch_size, 131 | mode="train", 132 | augment=args.augment, 133 | dev=device, 134 | num_samples=args.n_train, 135 | mixup=args.mixup, 136 | data_path=args.data_path, 137 | data_resolution=args.resolution, 138 | crop_resolution=args.crop_resolution 139 | ) 140 | 141 | test_loader = get_loader( 142 | args.dataset, 143 | bs=local_batch_size, 144 | mode="test", 145 | augment=False, 146 | dev=device, 147 | data_path=args.data_path, 148 | data_resolution=args.resolution, 149 | crop_resolution=args.crop_resolution 150 | ) 151 | 152 | start_ep = 1 153 | if args.reload: 154 | try: 155 | params = torch.load(path + "/name_of_checkpoint") 156 | model.load_state_dict(params) 157 | start_ep = 350 158 | except: 159 | print("No pretrained model found, training from scratch") 160 | 161 | opt = get_optimizer(args.optimizer)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 162 | scheduler = get_scheduler(opt, args.scheduler, **args.__dict__) 163 | 164 | loss_fn = CrossEntropyLoss(label_smoothing=args.smooth) 165 | 166 | if args.wandb: 167 | # Add your wandb credentials and project name 168 | wandb.init( 169 | project=args.wandb_project, 170 | entity=args.wandb_entity, 171 | config=args.__dict__, 172 | tags=["pretrain", args.dataset], 173 | ) 174 | wandb.run.name = name 175 | 176 | compute_per_epoch = get_compute(model, args.n_train, args.crop_resolution) 177 | 178 | for ep in range(start_ep, args.epochs): 179 | calc_stats = (ep + 1) % args.calculate_stats == 0 180 | 181 | current_compute = compute_per_epoch * ep 182 | 183 | train_acc, train_top5, train_loss, train_time = train( 184 | model, opt, scheduler, loss_fn, ep, train_loader, args 185 | ) 186 | 187 | if args.wandb: 188 | wandb.log({"Training time": train_time, "Training loss": train_loss}) 189 | 190 | if ep % args.save_freq == 0 and args.save: 191 | torch.save( 192 | model.state_dict(), 193 | path + "/epoch_" + str(ep) + "_compute_" + str(current_compute), 194 | ) 195 | 196 | if calc_stats: 197 | test_acc, test_top5, test_loss, test_time = test( 198 | model, test_loader, loss_fn, args 199 | ) 200 | if args.wandb: 201 | wandb.log( 202 | { 203 | "Training accuracy": train_acc, 204 | "Training Top 5 accuracy": train_top5, 205 | "Test accuracy": test_acc, 206 | "Test Top 5 accuracy": test_top5, 207 | "Test loss": test_loss, 208 | "Inference time": test_time, 209 | } 210 | ) 211 | 212 | # Print all the stats 213 | print("Epoch", ep, " Time:", train_time) 214 | print("-------------- Training ----------------") 215 | print("Average Training Loss: ", "{:.6f}".format(train_loss)) 216 | print("Average Training Accuracy: ", "{:.4f}".format(train_acc)) 217 | print("Top 5 Training Accuracy: ", "{:.4f}".format(train_top5)) 218 | print("---------------- Test ------------------") 219 | print("Test Accuracy ", "{:.4f}".format(test_acc)) 220 | print("Top 5 Test Accuracy ", "{:.4f}".format(test_top5)) 221 | print() 222 | 223 | 224 | 225 | if __name__ == "__main__": 226 | parser = get_training_parser() 227 | args = parser.parse_args() 228 | 229 | args.num_classes = CLASS_DICT[args.dataset] 230 | 231 | if args.n_train is None: 232 | args.n_train = SAMPLE_DICT[args.dataset] 233 | 234 | if args.crop_resolution is None: 235 | args.crop_resolution = args.resolution 236 | if args.wandb_entity is None: 237 | print("No wandb entity provided, Continuing without wandb") 238 | args.wandb = False 239 | 240 | main(args) 241 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregorbachmann/scaling_mlps/50cb3a90fdac8fad5b9e4d3733f29464d4344277/utils/__init__.py -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from data_utils.data_stats import CLASS_DICT 5 | 6 | 7 | def config_to_name(args): 8 | return os.path.join( 9 | str(args.dataset) + '_res_' + str(args.crop_resolution), 10 | f"{args.model}_{args.architecture}_norm_{args.normalization}", 11 | f"batchsize_{args.batch_size}", 12 | f"{args.optimizer}_lr_{args.lr}_smooth_{args.smooth}" 13 | + f"_decay_{args.weight_decay}_augment_{args.augment}_mixup_{args.mixup}_droprate_{args.drop_rate}", 14 | f"ntrain_{args.n_train}", 15 | ) 16 | 17 | 18 | def model_from_config(path): 19 | """Return model class from checkpoint path.""" 20 | path = os.path.dirname(path) 21 | with open(path + '/config.txt', 'r') as f: 22 | config = json.load(f) 23 | model = config["model"] 24 | architecture = config["architecture"] 25 | norm = config["normalization"] 26 | crop_resolution = int(config["crop_resolution"]) 27 | 28 | return model, architecture, crop_resolution, norm 29 | 30 | 31 | def model_from_checkpoint(checkpoint): 32 | res = int(checkpoint.split('_')[1]) 33 | num_classes = CLASS_DICT[checkpoint.split('_')[-1]] 34 | if len(checkpoint.split('_')) == 4: 35 | pretrained_finetuned = checkpoint.split('_')[-2] + '_' +checkpoint.split('_')[-1] 36 | else: 37 | pretrained_finetuned = checkpoint.split('_')[-1] 38 | 39 | return pretrained_finetuned, res, num_classes 40 | -------------------------------------------------------------------------------- /utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from urllib.request import urlretrieve 3 | import progressbar 4 | 5 | 6 | default_checkpoints = { 7 | 'B_12-Wi_1024_res_64_in21k': 'B-12_Wi-1024_res_64_imagenet21_epochs_800', 8 | 'B_12-Wi_512_res_64_in21k': 'B-12_Wi-512_res_64_imagenet21_epochs_600', 9 | 'B_6-Wi_1024_res_64_in21k': 'B-6_Wi-1024_res_64_imagenet21_epochs_800', 10 | 'B_6-Wi_512_res_64_in21k': 'B-6_Wi-512_res_64_imagenet21_epochs_800', 11 | 'B_12-Wi_1024_res_64_in21k_cifar10': 'B-12_Wi-1024_res_64_cifar10_epochs_20', 12 | 'B_12-Wi_1024_res_64_in21k_cifar100': 'B-12_Wi-1024_res_64_cifar100_epochs_40', 13 | 'B_12-Wi_1024_res_64_in21k_imagenet': 'B-12_Wi-1024_res_64_imagenet_epochs_50', 14 | 'B_12-Wi_512_res_64_in21k_cifar10': 'B-12_Wi-512_res_64_cifar10_epochs_20', 15 | 'B_12-Wi_512_res_64_in21k_cifar100': 'B-12_Wi-512_res_64_cifar100_epochs_20', 16 | 'B_12-Wi_512_res_64_in21k_imagenet': 'B-12_Wi-512_res_64_imagenet_epochs_20', 17 | 'B_6-Wi_512_res_64_in21k_cifar10': 'B-6_Wi-512_res_64_cifar10_epochs_20', 18 | 'B_6-Wi_512_res_64_in21k_cifar100': 'B-6_Wi-512_res_64_cifar100_epochs_20', 19 | 'B_6-Wi_512_res_64_in21k_imagenet': 'B-6_Wi-512_res_64_imagenet_epochs_20', 20 | 'B_6-Wi_1024_res_64_in21k_cifar10': 'B-6_Wi-1024_res_64_cifar10_epochs_20', 21 | 'B_6-Wi_1024_res_64_in21k_cifar100': 'B-6_Wi-1024_res_64_cifar100_epochs_20', 22 | 'B_6-Wi_1024_res_64_in21k_imagenet': 'B-6_Wi-1024_res_64_imagenet_epochs_20' 23 | } 24 | 25 | weight_urls = { 26 | 'B-12_Wi-1024_res_64_imagenet21_epochs_800': 27 | 'https://drive.usercontent.google.com/download?id=1rcV8RXij_kW9X2zSLNyNOTO_bKUjE0cJ&export=download&authuser=0&confirm=t&uuid=72ba7ef7-5c0e-43a8-8538-c78a7b6ae34c&at=APZUnTVYImDEDtOncjUjlRW2Fa-v%3A1718049334362', 28 | 'B-12_Wi-512_res_64_imagenet21_epochs_600': 29 | 'https://drive.usercontent.google.com/download?id=1sL9j_4FFeBTWTzuRFbHLfwrmPVLXhUtW&export=download&authuser=0&confirm=t&uuid=91299093-c2dc-4538-93fa-d34be798cedc&at=APZUnTUEplQUhKAe6zbjUuFWUGiV%3A1718049319992', 30 | 'B-6_Wi-1024_res_64_imagenet21_epochs_800': 31 | 'https://drive.usercontent.google.com/download?id=1cmO3QSz_hfHtyzkUOZnPmXPxIE2YzEbf&export=download&authuser=0&confirm=t&uuid=c102f0ec-18b9-496a-b615-819513501d65&at=APZUnTX6iqbWKmcVQzv4nf04efor%3A1718049304706', 32 | 'B-6_Wi-512_res_64_imagenet21_epochs_800': 33 | 'https://drive.usercontent.google.com/download?id=1QV3a99UT8llfh9zDDuNKWDH_5c6S_YT5&export=download&authuser=0&confirm=t&uuid=fa3e3e51-9eae-4f4c-9c88-f9882258160c&at=APZUnTWSzjI5fY70cc3I1t_E3nv1%3A1718049288621', 34 | 'B-12_Wi-1024_res_64_cifar10_epochs_20': 35 | 'https://drive.usercontent.google.com/download?id=1GyxuewoOzMRhzEOyUrIBLQzquc-QEYNV&export=download&authuser=0&confirm=t&uuid=02337a36-362b-41bc-8b66-c1e8737c6729&at=APZUnTV0pOpn9aeIkKng_OtiRw0l%3A1718049274446', 36 | 'B-12_Wi-1024_res_64_cifar100_epochs_40': 37 | 'https://drive.usercontent.google.com/download?id=1LNqC58cSwtuDr-C4bk1O3GA_vAWls-UH&export=download&authuser=0&confirm=t&uuid=37ee7032-ec34-4414-ac3b-fcb8a3f5e17d&at=APZUnTXsPCgnt2__IQ7fScHpXmcX%3A1718049262372', 38 | 'B-12_Wi-1024_res_64_imagenet_epochs_50': 39 | 'https://drive.usercontent.google.com/download?id=1MVebvnSGL02k_ql1gUCjh4quGqM9RM4F&export=download&authuser=0&confirm=t&uuid=4118f07e-ffdd-4b74-9b74-2508ffcc2b02&at=APZUnTWAAiNwrzrTzDm3Sl3MtzMF%3A1718049247748', 40 | 'B-12_Wi-512_res_64_cifar10_epochs_20': 41 | 'https://drive.usercontent.google.com/download?id=1F1NvoOsYCgsn1GOZcwsoToOtsw-9Aw1v&export=download&authuser=0&confirm=t&uuid=899cb74b-2bce-4b51-81ab-2df63af3dcbe&at=APZUnTWAC-eENGH6rRWchnMHsSBm%3A1718049232656', 42 | 'B-12_Wi-512_res_64_cifar100_epochs_20': 43 | 'https://drive.usercontent.google.com/download?id=1KIULehrqOyxIkZj0HiqowNmBy4Ye1EQ2&export=download&authuser=0&confirm=t&uuid=bf208699-bbf3-4ad3-9bb6-d61eec237265&at=APZUnTXwvPbxLngt1wCVCVNriiXA%3A1718049215282', 44 | 'B-12_Wi-512_res_64_imagenet_epochs_20': 45 | 'https://drive.usercontent.google.com/download?id=1f0ZYzB_XujX8hDcEn_J6iWvw1meJ4Cbg&export=download&authuser=0&confirm=t&uuid=90d9b1d0-fd5f-468e-b637-53afa56d3f22&at=APZUnTXFslB7n2WqcFxsmxZzNqoB%3A1718049045796', 46 | 'B-6_Wi-1024_res_64_cifar10_epochs_20': 47 | 'https://drive.usercontent.google.com/download?id=1Tyd5CkROPCMQybnrZ_o1wiAW7VoK6AfJ&export=download&authuser=0&confirm=t&uuid=7534d430-40ab-4475-a862-9413499b0f79&at=APZUnTXX2ioldX_5JCXDt0nP3pCu%3A1718049195583', 48 | 'B-6_Wi-1024_res_64_cifar100_epochs_20': 49 | 'https://drive.usercontent.google.com/download?id=1FrRb78bjun6QGbbH-pCWDaaE_8LWW785&export=download&authuser=0&confirm=t&uuid=b2c5459c-ede5-4ba4-97dc-e7a247cfba6a&at=APZUnTWa7Uha96h-6FxJosR1b2F0%3A1718048945010', 50 | 'B-6_Wi-1024_res_64_imagenet_epochs_20': 51 | 'https://drive.google.com/uc?id=115Lks211vx1at2dWn3JtQ57EZ7eNAVP4E&export=download&confirm=t&uuid', 52 | 'B-6_Wi-512_res_64_cifar10_epochs_20': 53 | 'https://drive.usercontent.google.com/download?id=1VjHgjheSm_w7xPtheEmY5kV_KE4-38zQ&export=download&authuser=0&confirm=t&uuid=71fbd376-79f4-43db-bd4a-381255571319&at=APZUnTUJN8LEutgI-L0oVjbG3df3%3A1718048695392', 54 | 'B-6_Wi-512_res_64_cifar100_epochs_20': 55 | 'https://drive.usercontent.google.com/download?id=1iK3t20-GS_Vs-_Q3ZexSiCfjGJ3IaPC2&export=download&authuser=0&confirm=t&uuid=0196232e-f83d-4c9d-921e-857db8848725&at=APZUnTV3aw0EOkJS4SEIo4XToVT4%3A1718048904050', 56 | 'B-6_Wi-512_res_64_imagenet_epochs_20': 57 | 'https://drive.usercontent.google.com/download?id=1iK3t20-GS_Vs-_Q3ZexSiCfjGJ3IaPC2&export=download&authuser=0&confirm=t&uuid=d8f548aa-ccd6-4f0a-be49-356e9ee2e243&at=APZUnTXb7Ss81nGKgrixYS0binTs%3A1718048598551' 58 | } 59 | 60 | config_urls = { 61 | 'B-12_Wi-1024_res_64_imagenet21_epochs_800': 62 | 'https://drive.google.com/uc?id=1envpLKUa9LhUlp2dLIL8Jb8447wwpXF0&export=download&confirm=t&uuid', 63 | 'B-12_Wi-512_res_64_imagenet21_epochs_600': 64 | 'https://drive.google.com/uc?id=14GKtQ1iYwOqYpy4RcrWz2Ue3AGG7eGLz&export=download&confirm=t&uuid', 65 | 'B-6_Wi-1024_res_64_imagenet21_epochs_800': 66 | 'https://drive.google.com/uc?id=11zFGFiKKxxrZOGk5oyk3AzBDnIY7KN3s&export=download&confirm=t&uuid', 67 | 'B-6_Wi-512_res_64_imagenet21_epochs_800': 68 | 'https://drive.google.com/uc?id=1Fjf4RA_yUXHgHncb9GIlf9zBNAJ-8giv&export=download&confirm=t&uuid', 69 | 'B-12_Wi-1024_res_64_cifar10_epochs_20': 70 | 'https://drive.google.com/uc?id=1envpLKUa9LhUlp2dLIL8Jb8447wwpXF0&export=download&confirm=t&uuid', 71 | 'B-12_Wi-1024_res_64_cifar100_epochs_40': 72 | 'https://drive.google.com/uc?id=1envpLKUa9LhUlp2dLIL8Jb8447wwpXF0&export=download&confirm=t&uuid', 73 | 'B-12_Wi-1024_res_64_imagenet_epochs_50': 74 | 'https://drive.google.com/uc?id=1envpLKUa9LhUlp2dLIL8Jb8447wwpXF0&export=download&confirm=t&uuid', 75 | 'B-6_Wi-1024_res_64_cifar10_epochs_20': 76 | 'https://drive.google.com/uc?id=11zFGFiKKxxrZOGk5oyk3AzBDnIY7KN3s&export=download&confirm=t&uuid', 77 | 'B-6_Wi-1024_res_64_cifar100_epochs_40': 78 | 'https://drive.google.com/uc?id=11zFGFiKKxxrZOGk5oyk3AzBDnIY7KN3s&export=download&confirm=t&uuid', 79 | 'B-6_Wi-1024_res_64_cifar100_epochs_20': 80 | 'https://drive.google.com/uc?id=11zFGFiKKxxrZOGk5oyk3AzBDnIY7KN3s&export=download&confirm=t&uuid', 81 | 'B-6_Wi-1024_res_64_imagenet_epochs_50': 82 | 'https://drive.google.com/uc?id=11zFGFiKKxxrZOGk5oyk3AzBDnIY7KN3s&export=download&confirm=t&uuid', 83 | 'B-12_Wi-512_res_64_cifar10_epochs_20': 84 | 'https://drive.google.com/uc?id=14GKtQ1iYwOqYpy4RcrWz2Ue3AGG7eGLz&export=download&confirm=t&uuid', 85 | 'B-12_Wi-512_res_64_cifar100_epochs_20': 86 | 'https://drive.google.com/uc?id=14GKtQ1iYwOqYpy4RcrWz2Ue3AGG7eGLz&export=download&confirm=t&uuid', 87 | 'B-12_Wi-512_res_64_cifar100_epochs_40': 88 | 'https://drive.google.com/uc?id=14GKtQ1iYwOqYpy4RcrWz2Ue3AGG7eGLz&export=download&confirm=t&uuid', 89 | 'B-12_Wi-512_res_64_imagenet_epochs_50': 90 | 'https://drive.google.com/uc?id=14GKtQ1iYwOqYpy4RcrWz2Ue3AGG7eGLz&export=download&confirm=t&uuid', 91 | 'B-6_Wi-512_res_64_cifar10_epochs_20': 92 | 'https://drive.google.com/uc?id=1Fjf4RA_yUXHgHncb9GIlf9zBNAJ-8giv&export=download&confirm=t&uuid', 93 | 'B-6_Wi-512_res_64_cifar100_epochs_40': 94 | 'https://drive.google.com/uc?id=1Fjf4RA_yUXHgHncb9GIlf9zBNAJ-8giv&export=download&confirm=t&uuid', 95 | 'B-6_Wi-512_res_64_cifar100_epochs_20': 96 | 'https://drive.google.com/uc?id=1Fjf4RA_yUXHgHncb9GIlf9zBNAJ-8giv&export=download&confirm=t&uuid', 97 | 'B-6_Wi-512_res_64_imagenet_epochs_50': 98 | 'https://drive.google.com/uc?id=1Fjf4RA_yUXHgHncb9GIlf9zBNAJ-8giv&export=download&confirm=t&uuid', 99 | 100 | } 101 | 102 | 103 | def download(name, checkpoint_path): 104 | weight_url = weight_urls[name] 105 | config_url = config_urls[name] 106 | 107 | weight_path = checkpoint_path + name + '_weights' 108 | config_path = checkpoint_path + name + '_config' 109 | weight_exists = os.path.isfile(weight_path) 110 | config_exists = os.path.isfile(config_path) 111 | 112 | if not weight_exists: 113 | print('Downloading weights...') 114 | urlretrieve(weight_url, weight_path, show_progress) 115 | else: 116 | print('Weights already downloaded') 117 | if not config_exists: 118 | urlretrieve(config_url, config_path) 119 | 120 | return weight_path, config_path 121 | 122 | 123 | pbar = None 124 | 125 | 126 | def show_progress(block_num, block_size, total_size): 127 | global pbar 128 | if pbar is None: 129 | pbar = progressbar.ProgressBar(maxval=total_size) 130 | pbar.start() 131 | 132 | downloaded = block_num * block_size 133 | if downloaded < total_size: 134 | pbar.update(downloaded) 135 | else: 136 | pbar.finish() 137 | pbar = None 138 | -------------------------------------------------------------------------------- /utils/get_compute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fvcore.nn import FlopCountAnalysis 3 | 4 | 5 | def get_compute(model, dataset_size, res): 6 | input = torch.randn(1, 3 * res * res).cuda() 7 | flops = FlopCountAnalysis(model, input) 8 | 9 | return flops.total() * 3 * dataset_size 10 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from torch import topk, any, sum 2 | import torch 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | 22 | def get_avg(self, percentage=False): 23 | return self.sum / self.count if not percentage else self.sum * 100 / self.count 24 | 25 | 26 | def topk_acc(preds, targs, targs_perm=None, k=5, avg=False): 27 | if avg: 28 | num = preds.shape[0] 29 | else: 30 | num = 1 31 | _, top_k_inds = topk(preds, k) 32 | top_5 = 1 / num * sum(any(top_k_inds == targs.unsqueeze(dim=1), dim=1), dim=0) 33 | acc = 1 / num * sum(top_k_inds[:, 0].eq(targs), dim=0) 34 | 35 | if targs_perm is not None: 36 | top_5_perm = ( 37 | 1 / num * sum(any(top_k_inds == targs_perm.unsqueeze(dim=1), dim=1), dim=0) 38 | ) 39 | acc_perm = 1 / num * sum(top_k_inds[:, 0].eq(targs_perm), dim=0) 40 | 41 | return torch.maximum(acc, acc_perm), torch.maximum(top_5, top_5_perm) 42 | 43 | return acc.item(), top_5.item() 44 | 45 | 46 | def real_acc(preds, targs, k, avg=False): 47 | if avg: 48 | num = preds.shape[0] 49 | else: 50 | num = 1 51 | _, top_k_inds = topk(preds, k) 52 | top_1_inds = top_k_inds[:, 0] 53 | acc_real = 1 / num * sum(any(top_1_inds.unsqueeze(dim=1).eq(targs), dim=1), dim=0) 54 | 55 | return acc_real -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google Research. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """PyTorch implementation of the Lion optimizer.""" 16 | import torch 17 | from torch.optim.optimizer import Optimizer 18 | import functools 19 | 20 | 21 | class Lion(Optimizer): 22 | r"""Implements Lion algorithm.""" 23 | 24 | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): 25 | """Initialize the hyperparameters. 26 | Args: 27 | params (iterable): iterable of parameters to optimize or dicts defining 28 | parameter groups 29 | lr (float, optional): learning rate (default: 1e-4) 30 | betas (Tuple[float, float], optional): coefficients used for computing 31 | running averages of gradient and its square (default: (0.9, 0.99)) 32 | weight_decay (float, optional): weight decay coefficient (default: 0) 33 | """ 34 | 35 | if not 0.0 <= lr: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if not 0.0 <= betas[0] < 1.0: 38 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 39 | if not 0.0 <= betas[1] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 41 | defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) 42 | super().__init__(params, defaults) 43 | 44 | @torch.no_grad() 45 | def step(self, closure=None): 46 | """Performs a single optimization step. 47 | Args: 48 | closure (callable, optional): A closure that reevaluates the model 49 | and returns the loss. 50 | Returns: 51 | the loss. 52 | """ 53 | loss = None 54 | if closure is not None: 55 | with torch.enable_grad(): 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group["params"]: 60 | if p.grad is None: 61 | continue 62 | 63 | # Perform stepweight decay 64 | p.data.mul_(1 - group["lr"] * group["weight_decay"]) 65 | 66 | grad = p.grad 67 | state = self.state[p] 68 | # State initialization 69 | if len(state) == 0: 70 | # Exponential moving average of gradient values 71 | state["exp_avg"] = torch.zeros_like(p) 72 | 73 | exp_avg = state["exp_avg"] 74 | beta1, beta2 = group["betas"] 75 | 76 | # Weight update 77 | update = exp_avg * beta1 + grad * (1 - beta1) 78 | p.add_(torch.sign(update), alpha=-group["lr"]) 79 | # Decay the momentum running average coefficient 80 | exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) 81 | 82 | return loss 83 | 84 | 85 | class SAM(torch.optim.Optimizer): 86 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 87 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 88 | 89 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 90 | super(SAM, self).__init__(params, defaults) 91 | 92 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 93 | self.param_groups = self.base_optimizer.param_groups 94 | self.defaults.update(self.base_optimizer.defaults) 95 | 96 | @torch.no_grad() 97 | def first_step(self, zero_grad=False): 98 | grad_norm = self._grad_norm() 99 | for group in self.param_groups: 100 | scale = group["rho"] / (grad_norm + 1e-12) 101 | 102 | for p in group["params"]: 103 | if p.grad is None: 104 | continue 105 | self.state[p]["old_p"] = p.data.clone() 106 | e_w = ( 107 | (torch.pow(p, 2) if group["adaptive"] else 1.0) 108 | * p.grad 109 | * scale.to(p) 110 | ) 111 | p.add_(e_w) # climb to the local maximum "w + e(w)" 112 | 113 | if zero_grad: 114 | self.zero_grad() 115 | 116 | @torch.no_grad() 117 | def second_step(self, zero_grad=False): 118 | for group in self.param_groups: 119 | for p in group["params"]: 120 | if p.grad is None: 121 | continue 122 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 123 | 124 | self.base_optimizer.step() # do the actual "sharpness-aware" update 125 | 126 | if zero_grad: 127 | self.zero_grad() 128 | 129 | @torch.no_grad() 130 | def step(self, closure=None): 131 | assert ( 132 | closure is not None 133 | ), "Sharpness Aware Minimization requires closure, but it was not provided" 134 | closure = torch.enable_grad()( 135 | closure 136 | ) # the closure should do a full forward-backward pass 137 | 138 | self.first_step(zero_grad=True) 139 | closure() 140 | self.second_step() 141 | 142 | def _grad_norm(self): 143 | shared_device = self.param_groups[0]["params"][ 144 | 0 145 | ].device # put everything on the same device, in case of model parallelism 146 | norm = torch.norm( 147 | torch.stack( 148 | [ 149 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad) 150 | .norm(p=2) 151 | .to(shared_device) 152 | for group in self.param_groups 153 | for p in group["params"] 154 | if p.grad is not None 155 | ] 156 | ), 157 | p=2, 158 | ) 159 | return norm 160 | 161 | def load_state_dict(self, state_dict): 162 | super().load_state_dict(state_dict) 163 | self.base_optimizer.param_groups = self.param_groups 164 | 165 | 166 | MomentumSGD = functools.partial(torch.optim.SGD, momentum=0.9) 167 | 168 | 169 | OPTIMIZERS_DICT = { 170 | "sgd": MomentumSGD, 171 | "adamw": torch.optim.AdamW, 172 | "lion": Lion, 173 | } 174 | 175 | SCHEDULERS = ["cosine", "none"] 176 | 177 | 178 | def get_optimizer(opt_name): 179 | """Return optimizer class.""" 180 | opt_name = opt_name.lower() 181 | 182 | if opt_name not in OPTIMIZERS_DICT: 183 | raise ValueError(f"Optimizer {opt_name} not supported.") 184 | 185 | return OPTIMIZERS_DICT[opt_name] 186 | 187 | 188 | def get_scheduler(opt, scheduler_name, **kwargs): 189 | """Return scheduler class.""" 190 | scheduler_name = scheduler_name.lower() 191 | 192 | if scheduler_name not in SCHEDULERS: 193 | raise ValueError(f"Scheduler {scheduler_name} not supported.") 194 | 195 | if scheduler_name == "cosine": 196 | return torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=kwargs["epochs"]) 197 | elif scheduler_name == "none": 198 | return torch.optim.lr_scheduler.StepLR( 199 | opt, step_size=kwargs["epochs"], gamma=1.0 200 | ) 201 | else: 202 | raise ValueError(f"Scheduler {scheduler_name} not supported.") 203 | -------------------------------------------------------------------------------- /utils/parsers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.optimizer import OPTIMIZERS_DICT, SCHEDULERS 3 | 4 | 5 | def get_training_parser(): 6 | parser = argparse.ArgumentParser(description="Scaling MLPs") 7 | 8 | # Data 9 | parser.add_argument( 10 | "--data_path", 11 | default="./beton", 12 | type=str, 13 | help="Path to data directory" 14 | ) 15 | parser.add_argument( 16 | "--dataset", 17 | default="imagenet21", 18 | type=str, 19 | help="Dataset" 20 | ) 21 | parser.add_argument( 22 | "--resolution", 23 | default=64, 24 | type=int, 25 | help="Image Resolution" 26 | ) 27 | parser.add_argument( 28 | "--crop_resolution", 29 | default=None, 30 | type=int, 31 | help="Crop Resolution" 32 | ) 33 | parser.add_argument( 34 | "--n_train", 35 | default=None, 36 | type=int, 37 | help="Number of samples. None for all" 38 | ) 39 | 40 | # Model 41 | parser.add_argument( 42 | "--model", 43 | default="BottleneckMLP", 44 | type=str, 45 | help="Type of model" 46 | ) 47 | parser.add_argument( 48 | "--architecture", 49 | default="B_6-Wi_1024", 50 | type=str, 51 | help="Architecture type" 52 | ) 53 | parser.add_argument( 54 | "--normalization", 55 | default="layer", 56 | type=str, 57 | help="Normalization type" 58 | ) 59 | parser.add_argument( 60 | "--act", 61 | default="gelu", 62 | type=str, 63 | help="Normalization type" 64 | ) 65 | parser.add_argument( 66 | "--drop_rate", 67 | default=None, 68 | type=float, 69 | help="Drop rate for dropout" 70 | ) 71 | 72 | # Training 73 | parser.add_argument( 74 | "--optimizer", 75 | default="lion", 76 | type=str, 77 | help="Choice of optimizer", 78 | choices=OPTIMIZERS_DICT.keys(), 79 | ) 80 | parser.add_argument( 81 | "--batch_size", 82 | default=4096, 83 | type=int, 84 | help="Batch size" 85 | ) 86 | parser.add_argument( 87 | "--accum_steps", 88 | default=1, 89 | type=int, 90 | help="Number of accumulation steps" 91 | ) 92 | parser.add_argument( 93 | "--lr", 94 | default=0.00005, 95 | type=float, 96 | help="Learning rate" 97 | ) 98 | parser.add_argument( 99 | "--scheduler", 100 | type=str, 101 | default="none", 102 | choices=SCHEDULERS, 103 | help="Scheduler" 104 | ) 105 | parser.add_argument( 106 | "--weight_decay", 107 | default=0.0, 108 | type=float, 109 | help="Weight decay" 110 | ) 111 | parser.add_argument( 112 | "--epochs", 113 | default=500, 114 | type=int, 115 | help="Epochs" 116 | ) 117 | parser.add_argument( 118 | "--smooth", 119 | default=0.3, 120 | type=float, 121 | help="Amount of label smoothing" 122 | ) 123 | parser.add_argument( 124 | "--clip", 125 | default=0., 126 | type=float, 127 | help="Gradient clipping" 128 | ) 129 | parser.add_argument( 130 | "--reload", 131 | action=argparse.BooleanOptionalAction, 132 | default=False, 133 | help="Reinitialize from checkpoint", 134 | ) 135 | parser.add_argument( 136 | "--augment", 137 | action=argparse.BooleanOptionalAction, 138 | default=True, 139 | help="Whether to augment data", 140 | ) 141 | parser.add_argument( 142 | "--mixup", 143 | default=0.8, 144 | type=float, 145 | help="Strength of mixup" 146 | ) 147 | 148 | # Logging 149 | parser.add_argument( 150 | "--calculate_stats", 151 | type=int, 152 | default=1, 153 | help="Frequence of calculating stats", 154 | ) 155 | parser.add_argument( 156 | "--checkpoint_folder", 157 | type=str, 158 | default="./checkpoints", 159 | help="Path to checkpoint directory", 160 | ) 161 | parser.add_argument( 162 | "--save_freq", 163 | type=int, 164 | default=50, 165 | help="Save frequency" 166 | ) 167 | parser.add_argument( 168 | "--save", 169 | action=argparse.BooleanOptionalAction, 170 | default=True, 171 | help="Whether to save checkpoints", 172 | ) 173 | parser.add_argument( 174 | "--wandb", 175 | default=True, 176 | action=argparse.BooleanOptionalAction, 177 | help="Whether to log with wandb", 178 | ) 179 | parser.add_argument( 180 | "--wandb_project", 181 | default="mlps", 182 | type=str, 183 | help="Wandb project name" 184 | ) 185 | parser.add_argument( 186 | "--wandb_entity", 187 | default=None, 188 | type=str, 189 | help="Wandb entity name" 190 | ) 191 | 192 | return parser 193 | 194 | 195 | def get_finetune_parser(): 196 | parser = argparse.ArgumentParser(description="Scaling MLPs") 197 | # Data 198 | parser.add_argument( 199 | "--data_path", default="./beton", type=str, help="Path to data directory" 200 | ) 201 | parser.add_argument( 202 | "--architecture", default="B_12-Wi_1024", type=str, help="Path to data directory" 203 | ) 204 | parser.add_argument("--dataset", default="cifar100", type=str, help="Dataset") 205 | parser.add_argument("--data_resolution", default=64, type=int, help="Image Resolution") 206 | parser.add_argument( 207 | "--n_train", default=None, type=int, help="Number of samples. None for all" 208 | ) 209 | parser.add_argument( 210 | "--augment", 211 | action=argparse.BooleanOptionalAction, 212 | default=True, 213 | help="Whether to augment data", 214 | ) 215 | parser.add_argument("--mixup", default=0., type=float, help="Strength of mixup") 216 | parser.add_argument('--crop_scale', nargs='+', type=float, default=[0.08, 1.], help="Scale for crop at test time") 217 | parser.add_argument('--crop_ratio', nargs='+', type=float, default=[0.08, 1.], help="Ratio for crop at test time") 218 | parser.add_argument( 219 | "--drop_rate", 220 | default=None, 221 | type=float, 222 | help="Drop rate for dropout" 223 | ) 224 | 225 | # Training 226 | parser.add_argument( 227 | "--optimizer", 228 | default="sgd", 229 | type=str, 230 | help="Choice of optimizer", 231 | choices=OPTIMIZERS_DICT.keys(), 232 | ) 233 | parser.add_argument("--batch_size", default=4096, type=int, help="Batch size") 234 | parser.add_argument("--accum_steps", default=1, type=int, help="Number of accumulation steps") 235 | parser.add_argument("--lr", default=0.01, type=float, help="Learning rate") 236 | parser.add_argument( 237 | "--scheduler", type=str, default="none", choices=SCHEDULERS, help="Scheduler" 238 | ) 239 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay") 240 | parser.add_argument("--epochs", default=500, type=int, help="Epochs") 241 | parser.add_argument( 242 | "--smooth", default=0.3, type=float, help="Amount of label smoothing" 243 | ) 244 | parser.add_argument("--clip", default=1.0, type=float, help="Gradient clipping") 245 | 246 | # Misc 247 | parser.add_argument( 248 | "--mode", 249 | default="linear", 250 | type=str, 251 | help="Mode", 252 | choices=["linear", "finetune"], 253 | ) 254 | parser.add_argument( 255 | "--checkpoint_folder", 256 | default="./checkpoints_finetune", 257 | type=str, 258 | help="Folder to store checkpoints", 259 | ) 260 | parser.add_argument( 261 | "--checkpoint_path", default=None, type=str, help="Checkpoint", required=False 262 | ) 263 | parser.add_argument( 264 | "--checkpoint", default='checkpoints_finetune', type=str, help="Checkpoint", required=False 265 | ) 266 | parser.add_argument( 267 | "--body_learning_rate_multiplier", 268 | default=0.1, 269 | type=float, 270 | help="Percentage of learning rate for the body", 271 | ) 272 | parser.add_argument( 273 | "--calculate_stats", 274 | type=int, 275 | default=1, 276 | help="Frequency of calculating stats", 277 | ) 278 | parser.add_argument( 279 | "--save_freq", 280 | type=int, 281 | default=20, 282 | help="Save frequency" 283 | ) 284 | parser.add_argument( 285 | "--save", 286 | action=argparse.BooleanOptionalAction, 287 | default=True, 288 | help="Whether to save checkpoints", 289 | ) 290 | 291 | return parser --------------------------------------------------------------------------------