├── archs ├── __init__.py ├── xresnet.py ├── vision_transformer.py └── cnns.py ├── docs ├── requirements.txt ├── _toc.yml ├── _config.yml ├── example_extraction.md ├── example_training.md └── index.md ├── requirements.txt ├── utils ├── slurm.py ├── yaml_tfms.py ├── classification_utils.py └── utils.py ├── run_dino.py ├── data_utils ├── cellpainting_dataset.py ├── cell_dataset.py ├── label_dict.py └── file_dataset.py ├── README.md ├── run_get_features.py ├── config.yaml └── main_dino.py /archs/__init__.py: -------------------------------------------------------------------------------- 1 | from .xresnet import * -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter-book 2 | matplotlib 3 | numpy 4 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | # Learn more at https://jupyterbook.org/customize/toc.html 3 | 4 | format: jb-book 5 | root: index 6 | chapters: 7 | - file: example_training 8 | - file: example_extraction 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.19.3 2 | kornia==0.6.8 3 | numpy==1.23.5 4 | oyaml==1.0 5 | pandas==1.5.3 6 | Pillow==9.5.0 7 | PyYAML==6.0 8 | scikit-image==0.19.3 9 | scikit-learn==1.2.2 10 | scipy==1.10.1 11 | torch==1.12.1+cu116 12 | torchvision==0.13.1+cu116 13 | tqdm==4.64.1 14 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | # Learn more at https://jupyterbook.org/customize/config.html 3 | 4 | title: Dino4cells documentation 5 | author: Michael Doron 6 | # logo: logo.png 7 | 8 | # Force re-execution of notebooks on each build. 9 | # See https://jupyterbook.org/content/execute.html 10 | execute: 11 | execute_notebooks: force 12 | 13 | # Define the name of the latex output file for PDF builds 14 | latex: 15 | latex_documents: 16 | targetname: dino4cells_book.tex 17 | 18 | 19 | # Information about where the book exists on the web 20 | repository: 21 | url: https://github.com/broadinstitute/DINO4Cells_code/ 22 | path_to_book: docs # Optional path to your book, relative to the repository root 23 | branch: master # Which branch of the repository should be used when creating links (optional) 24 | 25 | # Add GitHub buttons to your book 26 | # See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository 27 | html: 28 | use_issues_button: true 29 | use_repository_button: true 30 | -------------------------------------------------------------------------------- /utils/slurm.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import os 3 | import signal 4 | import time 5 | 6 | 7 | logger = getLogger() 8 | 9 | 10 | def trigger_job_requeue(): 11 | """ Submit a new job to resume from checkpoint. 12 | Be careful to use only for main process. 13 | """ 14 | if ( 15 | int(os.environ["SLURM_PROCID"]) == 0 16 | and str(os.getpid()) == os.environ["MAIN_PID"] 17 | ): 18 | print("time is up, back to slurm queue", flush=True) 19 | command = "scontrol requeue " + os.environ["SLURM_JOB_ID"] 20 | print(command) 21 | if os.system(command): 22 | raise RuntimeError("requeue failed") 23 | print("New job submitted to the queue", flush=True) 24 | exit(0) 25 | 26 | 27 | def SIGTERMHandler(a, b): 28 | print("received sigterm") 29 | pass 30 | 31 | 32 | def signalHandler(a, b): 33 | print("Signal received", a, time.time(), flush=True) 34 | os.environ["SIGNAL_RECEIVED"] = "True" 35 | return 36 | 37 | 38 | def init_signal_handler(): 39 | """ 40 | Handle signals sent by SLURM for time limit / pre-emption. 41 | """ 42 | os.environ["SIGNAL_RECEIVED"] = "False" 43 | os.environ["MAIN_PID"] = str(os.getpid()) 44 | 45 | signal.signal(signal.SIGUSR1, signalHandler) 46 | signal.signal(signal.SIGTERM, SIGTERMHandler) 47 | print("Signal handler installed.", flush=True) -------------------------------------------------------------------------------- /run_dino.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | 5 | parser = argparse.ArgumentParser("Get embeddings from model") 6 | parser.add_argument("--config", type=str, default=".", help="path to config file") 7 | parser.add_argument( 8 | "--gpus", type=str, default=".", help='Used GPUs, divided by commas (e.g., "1,2,4")' 9 | ) 10 | parser.add_argument( 11 | "--master_port", 12 | type=str, 13 | default="29501", 14 | help='Used GPUs, divided by commas (e.g., "1,2,4")', 15 | ) 16 | 17 | args = parser.parse_args() 18 | config = yaml.safe_load(open(args.config, "r")) 19 | 20 | num_gpus = len(args.gpus.split(",")) 21 | command = f'CUDA_VISIBLE_DEVICES={args.gpus} torchrun --master_port={args.master_port} --nproc_per_node={num_gpus} main_dino.py --arch {config["model"]["arch"]} --output_dir {config["model"]["output_dir"]} --data_path {config["model"]["data_path"]} --saveckp_freq {config["model"]["saveckp_freq"]} --batch_size_per_gpu {config["model"]["batch_size_per_gpu"]} --num_channels {config["model"]["num_channels"]} --patch_size {config["model"]["patch_size"]} --local_crops_scale {config["model"]["local_crops_scale"]} --epochs {config["model"]["epochs"]} --config {args.config} --center_momentum {config["model"]["center_momentum"]} --lr {config["model"]["lr"]} {"--sample_single_cells" if(config["model"]["sample_single_cells"] == True) else ""}' 22 | print(command) 23 | 24 | os.system(command) 25 | -------------------------------------------------------------------------------- /docs/example_extraction.md: -------------------------------------------------------------------------------- 1 | ## Example extraction 2 | 3 | After we trained DINO of the images, we can now use it to extract features. 4 | 5 | For this, we will again look at the `config.yaml` file, this time at the `embeddings` section. 6 | 7 | Inside `config.yaml`: 8 | 9 | ``` 10 | ... 11 | embedding: 12 | pretrained_weights: /scr/mdoron/DINO4Cells_code/output/checkpoint.pth 13 | output_path: /scr/mdoron/DINO4Cells_code/output/features.pth 14 | df_path: /scr/mdoron/DINO4Cells_code/cellpainting_data/sc-metadata.csv 15 | image_size: 224 16 | num_workers: 0 17 | embedding_has_labels: True 18 | target_labels: False 19 | ... 20 | ``` 21 | 22 | These are the parameters we need to set for DINO to extract features. 23 | 24 | `pretrained_weights` points to the checkpoint of the trained DINO model 25 | 26 | `output_path` points to the location where the DINO features will be saved 27 | 28 | `df_path` points to the csv file containing the metadata of the data to be extracted 29 | 30 | `image_size` determines the size of the images DINO expects to find 31 | 32 | `num_workers` determines how many workers will be used in extracting the features 33 | 34 | `embedding_has_labels` determines whether the metadata has labels to be saved along with the features 35 | 36 | `target_labels` False 37 | 38 | 39 | Finally, run this command to extract the features: 40 | 41 | `python run_get_features.py --config config.yaml` 42 | -------------------------------------------------------------------------------- /utils/yaml_tfms.py: -------------------------------------------------------------------------------- 1 | import oyaml as yaml 2 | import argparse 3 | import sys 4 | from torchvision import datasets, transforms 5 | from torchvision.transforms import * 6 | from utils.augmentations import * 7 | 8 | 9 | def get_args_parser(): 10 | parser = argparse.ArgumentParser("DINO", add_help=False) 11 | parser.add_argument("--config", default=".", type=str) 12 | return parser 13 | 14 | 15 | def parse_tfms(config, key): 16 | tfms = config[key] 17 | augs = [] 18 | for i in tfms: 19 | f = globals()[i] 20 | if config[key][i][0] == True: 21 | print(f"adding {key}: {i}") 22 | if isinstance(config[key][i][1], dict): 23 | f = f(**config[key][i][1]) 24 | elif isinstance(config[key][i][1], list): 25 | f = f(*config[key][i][1]) 26 | else: 27 | f = f() 28 | if i in ["ColorJitter", "ColorJitter_for_RGBA"]: 29 | print("found color jitter") 30 | f = RandomApply([f], p=0.8) 31 | augs.append(f) 32 | return augs 33 | 34 | 35 | def tfms_from_config(config): 36 | jitter = parse_tfms(config, "flip_and_color_jitter_transforms") 37 | norm = parse_tfms(config, "normalization") 38 | 39 | testing_tfms = parse_tfms(config, "testing_transfo") 40 | testing_tfms = transforms.Compose(testing_tfms) 41 | glb_tfms_1 = parse_tfms(config, "global_transfo1") 42 | global_tfms_1 = transforms.Compose( 43 | glb_tfms_1 + jitter + parse_tfms(config, "global_aug1") + norm 44 | ) 45 | glb_tfms_2 = parse_tfms(config, "global_transfo2") 46 | global_tfms_2 = transforms.Compose( 47 | glb_tfms_2 + jitter + parse_tfms(config, "global_aug2") + norm 48 | ) 49 | loc_tfms = parse_tfms(config, "local_transfo") 50 | local_tfms = transforms.Compose( 51 | loc_tfms + parse_tfms(config, "local_aug") + jitter + norm 52 | ) # note different order! 53 | 54 | return global_tfms_1, global_tfms_2, local_tfms, testing_tfms 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser("test", parents=[get_args_parser()]) 59 | args = parser.parse_args() 60 | dummy_func(args.config) 61 | -------------------------------------------------------------------------------- /data_utils/cellpainting_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import skimage.io 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import transforms, utils 11 | import torchvision 12 | t = torchvision.transforms.ToTensor() 13 | 14 | # Ignore warnings 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | ######################################################## 20 | ## Re-arrange channels from tape format to stack tensor 21 | ######################################################## 22 | 23 | def fold_channels(image, channel_width, mode="drop"): 24 | # Expected input image shape: (h, w * c) 25 | # Output image shape: (h, w, c) 26 | output = np.reshape(image, (image.shape[0], channel_width, -1), order="F") 27 | 28 | if mode == "ignore": 29 | # Keep all channels 30 | pass 31 | elif mode == "drop": 32 | # Drop mask channel (last) 33 | output = output[:, :, 0:-1] 34 | elif mode == "apply": 35 | # Use last channel as a binary mask 36 | mask = output["image"][:, :, -1:] 37 | output = output[:, :, 0:-1] * mask 38 | 39 | return t(output) 40 | 41 | 42 | ######################################################## 43 | ## Dataset Class 44 | ######################################################## 45 | 46 | dataset = FileList( 47 | args.data_path, 48 | config["model"]["root"], 49 | transform=transform, 50 | loader=chosen_loader, 51 | flist_reader=partial( 52 | file_dataset.pandas_reader_only_file, 53 | sample_single_cells=args.sample_single_cells, 54 | ), 55 | with_labels=False, 56 | balance=False, 57 | sample_single_cells=args.sample_single_cells, 58 | ) 59 | class SingleCellDataset(Dataset): 60 | """Single cell dataset.""" 61 | def __init__(self, csv_file, root, transform=None, loader=None, flist_reader=None, with_labels=None, balance=None, sample_single_cells=None, training=None, target_labels=None): 62 | """ 63 | Args: 64 | csv_file (string): Path to the csv file with metadata. 65 | root (string): Directory with all the images. 66 | transform (callable, optional): Optional transform to be applied 67 | on a sample. 68 | """ 69 | self.metadata = pd.read_csv(csv_file) 70 | self.root = root 71 | self.transform = transform 72 | 73 | def __len__(self): 74 | return len(self.metadata) 75 | 76 | 77 | def __getitem__(self, idx): 78 | if torch.is_tensor(idx): 79 | idx = idx.tolist() 80 | print(idx) 81 | 82 | img_name = os.path.join(self.root, 83 | self.metadata.loc[idx, "Image_Name"]) 84 | channel_width = self.metadata.loc[idx, 'channel_width'] 85 | image = skimage.io.imread(img_name) 86 | image = fold_channels(image, channel_width) 87 | 88 | label = self.metadata.loc[idx, "Target"] 89 | 90 | if self.transform: 91 | image = self.transform(image) 92 | 93 | return image, label 94 | 95 | -------------------------------------------------------------------------------- /data_utils/cell_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | from skimage import io 6 | import numpy as np 7 | from tqdm import tqdm 8 | from skimage import exposure 9 | from torchvision.transforms import ToTensor 10 | import pandas as pd 11 | 12 | import torch.nn.functional as F 13 | 14 | class dino_dataset(torch.utils.data.Dataset): 15 | """ 16 | Class build on top of PyTorch dataset 17 | used to load training, validation and test datasets 18 | 19 | inputs: 20 | - dataframe (required): a pandas dataframe conatining at least 21 | three columns, namely, Path, Class_ID, Batch_ID. 22 | - root_dir: a string to be added as a prefix to the "Path" 23 | - transform (optional): any image transforms. Optional. 24 | - class_dict (optional): a dictionary that converts strings 25 | in Class_ID column to numerical class labels 26 | - class_dict (optional): a dictionary that converts strings 27 | in Batch_ID column to numerical class labels 28 | 29 | returns: 30 | - a three-tuple with the image tensor (N,C,W,H) or (C,W,H), 31 | a numerical class id, and a numerical batch id. 32 | The id's are numerical values inferred from the 33 | class and batch dictionaries. An np.nan is returned 34 | if the keys are unavailable in the dictionaries. 35 | """ 36 | 37 | def __init__(self, dataframe, root_dir='/home/ubuntu/data/CellNet_data/Hirano3D_v2.0/data/', transform=None, 38 | label_dicts=None, RGBmode=False, training=True): 39 | 40 | if not isinstance(dataframe, pd.DataFrame): 41 | dataframe=pd.read_csv(dataframe) 42 | self.dataframe = dataframe 43 | self.root_dir = root_dir 44 | self.transform = transform 45 | self.RGBmode = RGBmode 46 | self.train = training 47 | 48 | # store unique class and batch IDs 49 | self.classes = [cls for cls in self.dataframe.Class_ID.unique()] 50 | self.batches = [bt for bt in self.dataframe.Batch_ID.unique()] 51 | if label_dicts is not None: 52 | print('label_dict received!') 53 | self.class_dict, self.batch_dict = label_dicts['class_dict'], label_dicts['batch_dict'] 54 | else: 55 | self.class_dict = {v: k for k, v in enumerate(self.classes)} 56 | self.batch_dict = {v: k for k, v in enumerate(self.batches)} 57 | self.label_dicts = {'class_dict': self.class_dict, 'batch_dict': self.batch_dict} 58 | 59 | def __len__(self): 60 | return len(self.dataframe) 61 | 62 | def __getitem__(self, index): 63 | row = self.dataframe.iloc[index] 64 | im = io.imread(self.root_dir + row["Path"]) 65 | im = exposure.rescale_intensity(im, in_range='image', out_range=(0, 255)) 66 | im = np.float32(im) 67 | im = np.divide(im, 255) 68 | im = ToTensor()(im) 69 | 70 | if self.RGBmode: im = im[:3,...] 71 | if self.transform is not None: 72 | im = self.transform(im) 73 | 74 | if self.train: 75 | return (im, 76 | torch.tensor(self.class_dict[row["Class_ID"]])) 77 | else: 78 | return (im, 79 | row['Class_ID'], 80 | row['Batch_ID']) 81 | 82 | 83 | # torch.tensor(self.class_dict[row["Class_ID"]] if row["Class_ID"] in 84 | # self.class_dict.keys() else 100) 85 | # , 86 | # torch.tensor(self.batch_dict[row["Batch_ID"]] if row["Batch_ID"] in 87 | # self.batch_dict.keys() else 100), 88 | # ) 89 | -------------------------------------------------------------------------------- /docs/example_training.md: -------------------------------------------------------------------------------- 1 | 2 | ## Example training 3 | 4 | For this example, we assume we have a dataset of 10,000 FOV images of cells. We wish to train an unbiased feature extractror to explore the structure of the data. To do this, we should first prepare a configuration file that will contain the parapeters DINO will use to train: 5 | 6 | Inside `config.yaml`: 7 | 8 | ``` 9 | ... 10 | model: 11 | model_type: DINO 12 | arch: vit_tiny 13 | root: /home/michaeldoron/dino_example/single_cells_data/ 14 | data_path: /home/michaeldoron/dino_example/metadata.csv 15 | output_dir: /home/michaeldoron/dino_example/output/ 16 | datatype: HPA 17 | image_mode: normalized_4_channels 18 | batch_size_per_gpu: 24 19 | num_channels: 4 20 | patch_size: 16 21 | epochs: 10 22 | momentum_teacher: 0.996 23 | center_momentum: 0.9 24 | lr: 0.0005 25 | local_crops_scale: '0.2 0.5' 26 | ... 27 | ``` 28 | 29 | These are the parameters we need to set for DINO to train. 30 | 31 | `arch` can be `vit_tiny`, `vit_small` or `vit_base`, giving us options for a larger feature vector with the price of more memory and computation. 32 | 33 | `root` points to the root directory where the images are stored. 34 | 35 | `data_path` points to the csv file containing the metadata. 36 | 37 | `output_dir` points to the location where the DINO model will be stored. 38 | 39 | `image_mode` described the type of images the data loader should expect to find. In this case, we use normalized 4 channels, as our data will contain 4 channel images. You can create new image modes by changing the code in `data_utils/file_dataset.py`. 40 | 41 | `batch_size_per_gpu` determines the size of the batch per each GPU. 42 | 43 | `num_channels` determines the number of channels the DINO model should expect. 44 | 45 | `patch_size` determines the size of the vision transformer patch. Usual values are 8 or 16 pixels. 46 | 47 | `epochs` determines the number of epochs the DINO model shall train over the entire data. 48 | 49 | `momentum_teacher` determines the momentum used to transfer the weights from the student network to the teacher network. 50 | 51 | `center_momentum` determines the momentum used to center the teacher. 52 | 53 | `lr` determines the learning rate used by DINO. 54 | 55 | `local_crops_scale` determines the size of the crops used by DINO. 56 | 57 | 58 | After the determine these parameters, we can now decide on the augmentations DINO will use. 59 | 60 | As written in the paper, DINO relies on a selection of augmentations that transform the images in ways DINO learns to be invariant to in its feature extraction. Thus, to train DINO well, one should choose augmentations that do not alter the important information found inside the images. These augmentations can be altered in these sections of the DINO `config.yaml` file: `flip_and_color_jitter_transforms`, `normalization`, `global_transfo1`, `global_aug1`, `testing_transfo`, `global_transfo2`, `local_transfo`, `local_aug`. 61 | 62 | Each section is active at the different time, in the global views, local views, or normalization. Inside each section there is a list of augmentations, each with a boolean flag signifying whether it is active or not in the training process, as well as possible additional augmentation parameters. 63 | 64 | For example: 65 | 66 | ``` 67 | global_transfo1: 68 | Warp_cell: 69 | - True 70 | - # no params 71 | Single_cell_centered: 72 | - False 73 | - # no params 74 | remove_channel: 75 | - True 76 | - {p: 0.2} 77 | ``` 78 | 79 | means that the global views has the Warp_cell augmentation active, the single_cell_centering augmentation not active, and the remove_channel augmentation active, with a probability of 20% activation. 80 | 81 | After the config file is set, we can train our DINO model on the data by running `python run_dino.py --config config.yaml --gpus 0,1`, where the `--gpus` argument determines which GPUs are used in training. 82 | -------------------------------------------------------------------------------- /archs/xresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch,math,sys 3 | import torch.utils.model_zoo as model_zoo 4 | import functools 5 | from functools import partial 6 | 7 | __all__ = ['XResNet', 'xresnet18', 'xresnet34', 'xresnet50', 'xresnet101', 'xresnet152', 8 | 'xresnet18_deep', 'xresnet34_deep', 'xresnet50_deep', 'xresnet18_c'] 9 | 10 | class PrePostInitMeta(type): 11 | "A metaclass that calls optional `__pre_init__` and `__post_init__` methods" 12 | def __new__(cls, name, bases, dct): 13 | x = super().__new__(cls, name, bases, dct) 14 | old_init = x.__init__ 15 | def _pass(self): pass 16 | @functools.wraps(old_init) 17 | def _init(self,*args,**kwargs): 18 | self.__pre_init__() 19 | old_init(self, *args,**kwargs) 20 | self.__post_init__() 21 | x.__init__ = _init 22 | if not hasattr(x,'__pre_init__'): x.__pre_init__ = _pass 23 | if not hasattr(x,'__post_init__'): x.__post_init__ = _pass 24 | return x 25 | 26 | class Module(nn.Module, metaclass=PrePostInitMeta): 27 | "Same as `nn.Module`, but no need for subclasses to call `super().__init__`" 28 | def __pre_init__(self): super().__init__() 29 | def __init__(self): pass 30 | 31 | # or: ELU+init (a=0.54; gain=1.55) 32 | act_fn = nn.ReLU(inplace=True) 33 | 34 | class Flatten(Module): 35 | def forward(self, x): 36 | return x.view(x.size(0), -1) 37 | 38 | def init_cnn(m): 39 | if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0) 40 | if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight) 41 | for l in m.children(): init_cnn(l) 42 | 43 | def conv(ni, nf, ks=3, stride=1, bias=False): 44 | return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias) 45 | 46 | def noop(x): return x 47 | 48 | def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True, IN=False): 49 | if IN == 'IN': 50 | insNorm = nn.InstanceNorm2d(nf) 51 | layers = [conv(ni, nf, ks, stride=stride), insNorm] 52 | else: 53 | bn = nn.BatchNorm2d(nf) 54 | nn.init.constant_(bn.weight, 0. if zero_bn else 1.) 55 | layers = [conv(ni, nf, ks, stride=stride), bn] 56 | if act: layers.append(act_fn) 57 | return nn.Sequential(*layers) 58 | 59 | class ResBlock(Module): 60 | def __init__(self, expansion, ni, nh, stride=1, IN=False): 61 | nf,ni = nh*expansion,ni*expansion 62 | layers = [conv_layer(ni, nh, 3, stride=stride, IN=IN), 63 | conv_layer(nh, nf, 3, zero_bn=True, act=False, IN=IN) 64 | ] if expansion == 1 else [ 65 | conv_layer(ni, nh, 1, IN=IN), 66 | conv_layer(nh, nh, 3, stride=stride, IN=IN), 67 | conv_layer(nh, nf, 1, zero_bn=True, act=False, IN=IN) 68 | ] 69 | self.convs = nn.Sequential(*layers) 70 | # TODO: check whether act=True works better 71 | self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False, IN=IN) 72 | self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True) 73 | 74 | def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x))) 75 | 76 | def filt_sz(recep): return min(64, 2**math.floor(math.log2(recep*0.75))) 77 | 78 | class XResNet(nn.Sequential): 79 | def __init__(self, expansion, layers, c_in=3, c_out=1000, IN=False): 80 | stem = [] 81 | sizes = [c_in,32,32,64] 82 | for i in range(3): 83 | stem.append(conv_layer(sizes[i], sizes[i+1], stride=2 if i==0 else 1, IN=IN)) 84 | #nf = filt_sz(c_in*9) 85 | #stem.append(conv_layer(c_in, nf, stride=2 if i==1 else 1)) 86 | #c_in = nf 87 | 88 | block_szs = [64//expansion,64,128,256,512] +[256]*(len(layers)-4) 89 | blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2, IN=IN) 90 | for i,l in enumerate(layers)] 91 | super().__init__( 92 | *stem, 93 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1), 94 | *blocks, 95 | nn.AdaptiveAvgPool2d(1), Flatten(), 96 | nn.Linear(block_szs[-1]*expansion, c_out), 97 | ) 98 | init_cnn(self) 99 | 100 | def _make_layer(self, expansion, ni, nf, blocks, stride, IN=False): 101 | return nn.Sequential( 102 | *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1, IN=IN) 103 | for i in range(blocks)]) 104 | 105 | def xresnet(pretrained, expansion, n_layers, name, c_out=1000, **kwargs): #(!) had to move pretrained to first arg, without default 106 | model = XResNet(expansion, n_layers, c_out=c_out, **kwargs) 107 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls[name])) 108 | return model 109 | 110 | me = sys.modules[__name__] 111 | for n,e,l in [ 112 | [ 18 , 1, [2,2,2 ,2] ], 113 | [ 34 , 1, [3,4,6 ,3] ], 114 | [ 50 , 4, [3,4,6 ,3] ], 115 | [ 101, 4, [3,4,23,3] ], 116 | [ 152, 4, [3,8,36,3] ], 117 | ]: 118 | name = f'xresnet{n}' 119 | setattr(me, name, partial(xresnet, expansion=e, n_layers=l, name=name)) 120 | 121 | xresnet18_deep = partial(xresnet, expansion=1, n_layers=[2, 2, 2, 2,1,1], name='xresnet18_deep') 122 | xresnet34_deep = partial(xresnet, expansion=1, n_layers=[3, 4, 6, 3,1,1], name='xresnet34_deep') 123 | xresnet50_deep = partial(xresnet, expansion=4, n_layers=[3, 4, 6, 3,1,1], name='xresnet50_deep') 124 | 125 | xresnet18_c = partial(xresnet, expansion=1, n_layers=[2, 2, 2 ,2], name='xresnet18_c') #(!) 126 | 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DINO4Cells_code 2 | This repo will contain the code for training DINO models and extracting features, as described in the paper [Unbiased single-cell morphology with self-supervised vision transformers 3 | ](https://www.biorxiv.org/content/10.1101/2023.06.16.545359v1). 4 | 5 | For the code to reproduce the results of the paper, go to https://github.com/broadinstitute/Dino4Cells_analysis. 6 | 7 | A more thorough documentation can be found [here](https://broadinstitute.github.io/DINO4Cells_code/). 8 | 9 | ## Installation 10 | 11 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 12 | 13 | pip install -r requirements.py 14 | 15 | Typical installation time: 10 minutes 16 | 17 | ## Running example 18 | 19 | ### Train dino 20 | python run_dino.py --config config.yaml --gpus 0,1,2,3 21 | Typical running time: 1 day 22 | 23 | ### Extract features 24 | python run_get_features.py --config config.yaml 25 | 26 | ### Train classifier 27 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 run_end_to_end.py --config config.yaml --epochs 100 --balance True --num_classes 35 --train_cell_type True --train_protein False --master_port = 1234 28 | 29 | # Data 30 | 31 | # HPA FOV 32 | 33 | For the HPA FOV data, access [https://zenodo.org/record/8061392](https://zenodo.org/record/8061392) 34 | 35 | ## Model checkpoints 36 | 37 | HPA_FOV_data/DINO_FOV_checkpoint.pth 38 | 39 | HPA_FOV_data/densenet_model_batch.onnx 40 | 41 | HPA_FOV_data/densenet_model.onnx 42 | 43 | ## metadata 44 | 45 | HPA_FOV_data/whole_images.csv 46 | 47 | ## features 48 | 49 | HPA_FOV_data/DINO_features_for_HPA_FOV.pth 50 | 51 | HPA_FOV_data/bestfitting_features_for_HPA_FOV.pth 52 | 53 | HPA_FOV_data/pretrained_DINO_features_for_HPA_FOV.pth 54 | 55 | ## Embeddings 56 | 57 | HPA_FOV_data/DINO_FOV_harmonized_embeddings.csv 58 | 59 | HPA_FOV_data/DINO_FOV_embeddings.csv 60 | 61 | ## Classifiers 62 | 63 | HPA_FOV_data/classifier_cells.pth 64 | 65 | HPA_FOV_data/classifier_proteins.pth 66 | 67 | ## Misc 68 | 69 | ### train / test divisions for protein localizations and cell line classification 70 | 71 | HPA_FOV_data/cells_train_IDs.pth 72 | 73 | HPA_FOV_data/cells_valid_IDs.pth 74 | 75 | HPA_FOV_data/train_IDs.pth 76 | 77 | HPA_FOV_data/valid_IDs.pth 78 | 79 | ### HPA single cell kaggle protein localization competition, download from https://www.kaggle.com/competitions/human-protein-atlas-image-classification/leaderboard 80 | 81 | HPA_FOV_data/human-protein-atlas-image-classification-publicleaderboard.csv 82 | 83 | ### HPA cell line RNASeq data 84 | 85 | HPA_FOV_data/rna_cellline.tsv 86 | 87 | ### HPA FOV color visualization 88 | 89 | HPA_FOV_data/whole_image_cell_color_indices.pth 90 | 91 | HPA_FOV_data/whole_image_protein_color_indices.pth 92 | 93 | # HPA single cells 94 | 95 | For the HPA single cell data, access [https://zenodo.org/record/8061426](https://zenodo.org/record/8061426) 96 | 97 | ## Model checkpoints 98 | 99 | HPA_single_cells_data/DINO_single_cell_checkpoint.pth 100 | 101 | HPA_single_cells_data/HPA_single_cell_model_checkpoint.pth 102 | 103 | HPA_single_cells_data/dualhead_config.json 104 | 105 | HPA_single_cells_data/dualhead_matched_state.pth 106 | 107 | ## metadata 108 | 109 | HPA_single_cells_data/fixed_size_masked_single_cells_for_sc.csv 110 | 111 | ## features 112 | 113 | HPA_single_cells_data/DINO_features_for_HPA_single_cells.pth 114 | 115 | HPA_single_cells_data/dualhead_features_for_HPA_single_cells.pth 116 | 117 | HPA_single_cells_data/pretrained_DINO_features_for_HPA_single_cells.pth 118 | 119 | ## Embeddings 120 | 121 | HPA_single_cells_data/DINO_embedding_average_umap.csv 122 | 123 | HPA_single_cells_data/DINO_harmonized_embedding_average_umap.csv 124 | 125 | ## Classifiers 126 | 127 | HPA_single_cells_data/classifier_cells.pth 128 | 129 | HPA_single_cells_data/classifier_proteins.pth 130 | 131 | ## Misc 132 | 133 | ### HPA single cell kaggle protein localization competition, download from https://www.kaggle.com/competitions/hpa-single-cell-image-classification/leaderboard 134 | 135 | HPA_single_cells_data/hpa-single-cell-image-classification-publicleaderboard.csv 136 | 137 | ### HPA XML data 138 | 139 | HPA_single_cells_data/XML_HPA.csv 140 | 141 | ### UNIPROT interaction dataset 142 | 143 | HPA_single_cells_data/uniport_interactions.tsv 144 | 145 | ### HPA gene heterogeneity annotated by experts 146 | 147 | HPA_single_cells_data/gene_heterogeneity.tsv 148 | 149 | ### single cell metadata with genetic information 150 | 151 | HPA_single_cells_data/Master_scKaggle.csv 152 | 153 | ### HPA single cell color visualization 154 | 155 | HPA_single_cells_data/cell_color_indices.pth 156 | 157 | HPA_single_cells_data/protein_color_indices.pth 158 | 159 | # WTC11 160 | 161 | For the WTC11 data, access [https://zenodo.org/record/8061424](https://zenodo.org/record/8061424) 162 | 163 | ## Model checkpoints 164 | 165 | WTC11_data/DINO_checkpoint.pth 166 | 167 | ## metadata 168 | 169 | WTC11_data/normalized_cell_df.csv 170 | 171 | ## features 172 | 173 | WTC11_data/DINO_features_and_df.pth 174 | 175 | WTC11_data/engineered_features.pth 176 | 177 | WTC11_data/pretrained_features_and_df.pth 178 | 179 | ## Embeddings 180 | 181 | WTC11_data/DINO_trained_embedding.pth 182 | 183 | WTC11_data/DINO_trained_harmonized_embedding.pth 184 | 185 | WTC11_data/pretrained_embedding.pth 186 | 187 | WTC11_data/pretrained_harmonized_embedding.pth 188 | 189 | WTC11_data/engineered_embedding.pth 190 | 191 | ## Classifiers 192 | 193 | WTC11_data/predictions_for_WTC11_trained_model.pth 194 | 195 | WTC11_data/predictions_for_WTC11_pretrained_model.pth 196 | 197 | WTC11_data/predictions_for_WTC11_xgb.pth 198 | 199 | ## Misc 200 | 201 | WTC11_data/train_indices.pth 202 | 203 | WTC11_data/test_indices.pth 204 | 205 | # Cell Painting 206 | 207 | For the Cell Painting data, access [https://zenodo.org/record/8061428](https://zenodo.org/record/8061428) 208 | 209 | ## DINO model checkpoints 210 | 211 | Cell_Painting_data/DINO_cell_painting_base_checkpoint.pth 212 | 213 | Cell_Painting_data/DINO_cell_painting_small_checkpoint.pth 214 | 215 | ## metadata and embeddings 216 | 217 | Cell_Painting_data/LINCS_ViT_Small_Compressed_df_and_UMAP.csv 218 | 219 | Cell_Painting_data/Combined_CP_df_and_UMAP.csv 220 | 221 | ## features 222 | 223 | Code to calculate the PUMA results is in: 224 | [https://github.com/CaicedoLab/2023_Moshkov_NatComm](https://github.com/CaicedoLab/2023_Moshkov_NatComm) 225 | 226 | 227 | ## misc (data partition indices, preprocessed data, etc.) 228 | 229 | Cell_Painting_data/scaffold_median_python_dino.csv 230 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Dino4cells handbook 2 | 3 | Dino4cells is a self-supervised method to extract phenotypic information of single cells from their morphology by training unbiased representation learning models on their microscopy images. This guide describes how to install dino4cells, and how to run it on an example small dataset. After completing this guidebook, you should have a good understanding of how to use dino4cells in your own research. 4 | 5 | Dino4cells is an extention of [DINO](https://github.com/facebookresearch/dino), or self-DIstilation with NO labels, published by Meta-AI. A demonstration of its abilities is presented in the paper [Unbiased single-cell morphology with self-supervised vision transformers](https://www.biorxiv.org/content/10.1101/2023.06.16.545359v1). 6 | 7 | 8 | ## Installation 9 | 10 | To install the code, first please clone the [DINO4cells_code git repo](https://github.com/broadinstitute/Dino4Cells_code). 11 | 12 | Next, install the required dependencies by inputing 13 | 14 | `pip install -r requirements.py` 15 | and 16 | `pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116` 17 | 18 | This should make you set up for running DINO4cells. 19 | 20 | 21 | ## Usage 22 | 23 | ### Train dino 24 | `python run_dino.py --config config.yaml --gpus 0,1,2,3` 25 | 26 | Typical running time: 1 day 27 | 28 | ### Extract features 29 | After the model is trained, you can extract features from the microscopy images by running 30 | 31 | `python run_get_features.py --config config.yaml` 32 | 33 | ### Train classifier 34 | Next, if you want to use the features for, e.g., predicting some quantity of interest, you can run 35 | 36 | `CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 run_end_to_end.py --config config.yaml --epochs 100 --balance True --num_classes 35 --train_cell_type True --train_protein False --master_port = 1234` 37 | 38 | 39 | ## Example training 40 | 41 | For this example, we assume we have a dataset of 15188 images of single cells, that can be downloaded from [here](https://zenodo.org/record/8198252). We wish to train an unbiased feature extractror to explore the structure of the data. To do this, we should first prepare a configuration file that will contain the parapeters DINO will use to train: 42 | 43 | Inside `config.yaml`: 44 | 45 | ``` 46 | ... 47 | model: 48 | model_type: DINO 49 | arch: vit_tiny 50 | root: /home/michaeldoron/dino_example/single_cells_data/ 51 | data_path: /home/michaeldoron/dino_example/metadata.csv 52 | output_dir: /home/michaeldoron/dino_example/output/ 53 | datatype: HPA 54 | image_mode: normalized_4_channels 55 | batch_size_per_gpu: 24 56 | num_channels: 4 57 | patch_size: 16 58 | epochs: 100 59 | momentum_teacher: 0.996 60 | center_momentum: 0.9 61 | lr: 0.0005 62 | local_crops_scale: '0.2 0.5' 63 | ... 64 | ``` 65 | 66 | These are the parameters we need to set for DINO to train. 67 | 68 | `arch` can be `vit_tiny`, `vit_small` or `vit_base`, giving us options for a larger feature vector with the price of more memory and computation. 69 | 70 | `root` points to the root directory where the images are stored. 71 | 72 | `data_path` points to the csv file containing the metadata. 73 | 74 | `output_dir` points to the location where the DINO model will be stored. 75 | 76 | `image_mode` described the type of images the data loader should expect to find. In this case, we use normalized 4 channels, as our data will contain 4 channel images. You can create new image modes by changing the code in `data_utils/file_dataset.py`. 77 | 78 | `batch_size_per_gpu` determines the size of the batch per each GPU. 79 | 80 | `num_channels` determines the number of channels the DINO model should expect. 81 | 82 | `patch_size` determines the size of the vision transformer patch. Usual values are 8 or 16 pixels. 83 | 84 | `epochs` determines the number of epochs the DINO model shall train over the entire data. 85 | 86 | `momentum_teacher` determines the momentum used to transfer the weights from the student network to the teacher network. 87 | 88 | `center_momentum` determines the momentum used to center the teacher. 89 | 90 | `lr` determines the learning rate used by DINO. 91 | 92 | `local_crops_scale` determines the size of the crops used by DINO. 93 | 94 | 95 | After the determine these parameters, we can now decide on the augmentations DINO will use. 96 | 97 | As written in the paper, DINO relies on a selection of augmentations that transform the images in ways DINO learns to be invariant to in its feature extraction. Thus, to train DINO well, one should choose augmentations that do not alter the important information found inside the images. These augmentations can be altered in these sections of the DINO `config.yaml` file: `flip_and_color_jitter_transforms`, `normalization`, `global_transfo1`, `global_aug1`, `testing_transfo`, `global_transfo2`, `local_transfo`, `local_aug`. 98 | 99 | Each section is active at the different time, in the global views, local views, or normalization. Inside each section there is a list of augmentations, each with a boolean flag signifying whether it is active or not in the training process, as well as possible additional augmentation parameters. 100 | 101 | For example: 102 | 103 | ``` 104 | global_transfo1: 105 | Warp_cell: 106 | - True 107 | - # no params 108 | Single_cell_centered: 109 | - False 110 | - # no params 111 | remove_channel: 112 | - True 113 | - {p: 0.2} 114 | ``` 115 | 116 | means that the global views has the Warp_cell augmentation active, the single_cell_centering augmentation not active, and the remove_channel augmentation active, with a probability of 20% activation. 117 | 118 | After the config file is set, we can train our DINO model on the data by running `python run_dino.py --config config.yaml --gpus 0,1`, where the `--gpus` argument determines which GPUs are used in training. 119 | 120 | On a single 3090 GPU, training DINO on 15k images should take about 3.5 hours. 121 | 122 | 123 | ## Example extraction 124 | 125 | After we trained DINO of the images, we can now use it to extract features. 126 | 127 | For this, we will again look at the `config.yaml` file, this time at the `embeddings` section. 128 | 129 | Inside `config.yaml`: 130 | 131 | ``` 132 | ... 133 | embedding: 134 | pretrained_weights: /scr/mdoron/DINO4Cells_code/output/checkpoint.pth 135 | output_path: /scr/mdoron/DINO4Cells_code/output/features.pth 136 | df_path: /scr/mdoron/DINO4Cells_code/cellpainting_data/sc-metadata.csv 137 | image_size: 224 138 | num_workers: 0 139 | embedding_has_labels: True 140 | target_labels: False 141 | ... 142 | ``` 143 | 144 | These are the parameters we need to set for DINO to extract features. 145 | 146 | `pretrained_weights` points to the checkpoint of the trained DINO model 147 | 148 | `output_path` points to the location where the DINO features will be saved 149 | 150 | `df_path` points to the csv file containing the metadata of the data to be extracted 151 | 152 | `image_size` determines the size of the images DINO expects to find 153 | 154 | `num_workers` determines how many workers will be used in extracting the features 155 | 156 | `embedding_has_labels` determines whether the metadata has labels to be saved along with the features 157 | 158 | `target_labels` False 159 | 160 | 161 | Finally, run this command to extract the features: 162 | 163 | `python run_get_features.py --config config.yaml` 164 | 165 | 166 | -------------------------------------------------------------------------------- /run_get_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import random 5 | import colorsys 6 | from io import BytesIO 7 | from sklearn.preprocessing import StandardScaler 8 | from tqdm import tqdm 9 | import skimage.io 10 | import matplotlib.pyplot as plt 11 | from matplotlib import cm 12 | import torch 13 | from torch.nn import DataParallel 14 | import torch.nn as nn 15 | import torchvision 16 | from torchvision import datasets, transforms 17 | from tqdm import tqdm 18 | from pathlib import Path 19 | import yaml 20 | from functools import partial # (!) 21 | from utils.yaml_tfms import tfms_from_config 22 | import utils.utils 23 | import archs.vision_transformer as vits 24 | from archs import xresnet as cell_models # (!) 25 | from archs.vision_transformer import DINOHead 26 | from data_utils import file_dataset 27 | 28 | try: 29 | from get_wair_model import get_wair_model 30 | except: 31 | pass 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser("Get embeddings from model") 35 | parser.add_argument("--config", type=str, default=".", help="path to config file") 36 | parser.add_argument( 37 | "--pretrained_weights", 38 | type=str, 39 | default=None, 40 | help="pretrained weights, if different than config", 41 | ) 42 | parser.add_argument( 43 | "--output_prefix", type=str, default=None, help="path to config file" 44 | ) 45 | parser.add_argument("--gpus", type=str, default=".", help="path to config file") 46 | parser.add_argument("--dataset", type=str, default=None, help="path to config file") 47 | parser.add_argument("--whole", action="store_true", help="path to config file") 48 | 49 | args = parser.parse_args() 50 | config = yaml.safe_load(open(args.config, "r")) 51 | 52 | if args.output_prefix is None: 53 | output_path = f'{config["embedding"]["output_path"]}' 54 | else: 55 | output_path = args.output_prefix 56 | Path(output_path).parent.absolute().mkdir(exist_ok=True) 57 | print(f"output_path is {output_path}") 58 | 59 | # TODO: fix these temp compatibility patches: 60 | if not "HEAD" in list(config["embedding"].keys()): 61 | print( 62 | "Please see line 55 in run_get_embeddings.py for additional arguments that can be used to run the full backbone+HEAD model" 63 | ) 64 | config["embedding"]["HEAD"] = ( 65 | True if "HEAD" in list(config["embedding"].keys()) else False 66 | ) 67 | 68 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 69 | # build model 70 | wair_model_name_list = [ 71 | "DenseNet121_change_avg_512_all_more_train_add_3_v2", 72 | "DenseNet121_change_avg_512_all_more_train_add_3_v3", 73 | "DenseNet121_change_avg_512_all_more_train_add_3_v5", 74 | "DenseNet169_change_avg_512_all_more_train_add_3_v5", 75 | "se_resnext50_32x4d_512_all_more_train_add_3_v5", 76 | "Xception_osmr_512_all_more_train_add_3_v5", 77 | "ibn_densenet121_osmr_512_all_more_train_add_3_v5_2", 78 | ] 79 | 80 | if config["model"]["model_type"] == "DINO": 81 | if config["model"]["arch"] in vits.__dict__.keys(): 82 | # model = vits.__dict__[config['model']['arch']](img_size=[112], patch_size=config['model']['patch_size'], num_classes=0, in_chans=config['model']['num_channels']) 83 | # model = vits.__dict__[config['model']['arch']](img_size=[512], patch_size=config['model']['patch_size'], num_classes=0, in_chans=config['model']['num_channels']) 84 | model = vits.__dict__[config["model"]["arch"]]( 85 | img_size=[224], 86 | patch_size=config["model"]["patch_size"], 87 | num_classes=0, 88 | in_chans=config["model"]["num_channels"], 89 | ) 90 | # model = vits.__dict__[config['model']['arch']](img_size=[224], patch_size=config['model']['patch_size'], num_classes=0, in_chans=config['model']['num_channels']) 91 | embed_dim = model.embed_dim 92 | elif config["model"]["arch"] in cell_models.__dict__.keys(): 93 | model = partial( 94 | cell_models.__dict__[config["model"]["arch"]], 95 | c_in=config["model"]["num_channels"], 96 | )(False) 97 | embed_dim = model[-1].in_features 98 | model[-1] = nn.Identity() 99 | 100 | if config["embedding"]["HEAD"] == True: 101 | model = utils.MultiCropWrapper( 102 | model, 103 | DINOHead( 104 | embed_dim, 105 | config["model"]["out_dim"], 106 | config["model"]["use_bn_in_head"], 107 | ), 108 | ) 109 | for p in model.parameters(): 110 | p.requires_grad = False 111 | model.eval() 112 | model.to(device) 113 | if args.pretrained_weights is None: 114 | pretrained_weights = config["embedding"]["pretrained_weights"] 115 | print(f'loaded {config["embedding"]["pretrained_weights"]}') 116 | else: 117 | pretrained_weights = args.pretrained_weights 118 | print(f"loaded {args.pretrained_weights}") 119 | if os.path.isfile(pretrained_weights): 120 | state_dict = torch.load(pretrained_weights, map_location="cpu") 121 | if "teacher" in state_dict: 122 | teacher = state_dict["teacher"] 123 | if not config["embedding"]["HEAD"] == True: 124 | teacher = {k.replace("module.", ""): v for k, v in teacher.items()} 125 | teacher = { 126 | k.replace("backbone.", ""): v for k, v in teacher.items() 127 | } 128 | msg = model.load_state_dict(teacher, strict=False) 129 | else: 130 | student = state_dict 131 | if not config["embedding"]["HEAD"] == True: 132 | student = {k.replace("module.", ""): v for k, v in student.items()} 133 | student = { 134 | k.replace("backbone.", ""): v for k, v in student.items() 135 | } 136 | student = {k.replace("0.", ""): v for k, v in student.items()} 137 | msg = model.load_state_dict(student, strict=False) 138 | 139 | for p in model.parameters(): 140 | p.requires_grad = False 141 | model = model.cuda() 142 | model = model.eval() 143 | model = DataParallel(model) 144 | # model = DataParallel(model, device_ids=[eval(args.gpus)]) 145 | # model = nn.parallel.DistributedDataParallel(model, device_ids=[eval(args.gpus)]) 146 | print( 147 | "Pretrained weights found at {} and loaded with msg: {}".format( 148 | pretrained_weights, msg 149 | ) 150 | ) 151 | else: 152 | print( 153 | "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." 154 | ) 155 | quit() 156 | elif config["model"]["model_type"] in wair_model_name_list: 157 | model = get_wair_model(config["model"]["model_type"], fold=0) 158 | for p in model.parameters(): 159 | p.requires_grad = False 160 | model.eval() 161 | model.to(device) 162 | 163 | _, _, _, transform = tfms_from_config(config) 164 | 165 | if type(args.dataset) is type(None): 166 | print(f'df_path is {config["embedding"]["df_path"]}') 167 | dataset_path = Path(config["embedding"]["df_path"]) 168 | else: 169 | print(f"df_path is {args.dataset}") 170 | dataset_path = args.dataset 171 | 172 | loader = file_dataset.image_modes[config["model"]["image_mode"]] 173 | 174 | reader = file_dataset.readers[config["embedding"]["embedding_has_labels"]] 175 | 176 | FileList = file_dataset.data_loaders[config["model"]["datatype"]] 177 | 178 | dataset = FileList( 179 | dataset_path, 180 | transform=transform, 181 | flist_reader=reader, 182 | balance=False, 183 | loader=loader, 184 | training=False, 185 | with_labels=config["embedding"]["embedding_has_labels"], 186 | root=config["model"]["root"], 187 | # The target labels are the column names of the protein localizationsm 188 | # used to create the multilabel target matrix 189 | target_labels=config["embedding"]["target_labels"], 190 | ) 191 | 192 | sampler = torch.utils.data.SequentialSampler(dataset) 193 | data_loader = torch.utils.data.DataLoader( 194 | dataset, 195 | sampler=sampler, 196 | batch_size=config["model"]["batch_size_per_gpu"], 197 | num_workers=config["embedding"]["num_workers"], 198 | pin_memory=True, 199 | ) 200 | 201 | labels = None 202 | all_features = None 203 | running_index = 0 204 | 205 | # Main feature extraction loop 206 | for record in tqdm(data_loader): 207 | # Decode record 208 | if labels is None: 209 | labels = [[] for s in record[1:]] 210 | for ind, label in enumerate(record[1:]): 211 | labels[ind].extend(record[1 + ind]) 212 | images = record[0] 213 | # Run model 214 | if isinstance(images, list): 215 | # Compatibility for crops and multi-views 216 | with torch.no_grad(): 217 | f = torch.stack([model(img.to(device)) for img in images]) 218 | f = torch.transpose(f, 0, 1) 219 | features = torch.reshape(f, (f.shape[0], f.shape[1] * f.shape[2])) 220 | del f 221 | else: 222 | # Single image 223 | with torch.no_grad(): 224 | features = model(images.to(device)) 225 | # Append features 226 | if all_features == None: 227 | all_features = torch.zeros(len(dataset), features.shape[1]) 228 | all_features[ 229 | running_index : running_index + len(features), : 230 | ] = features.detach().cpu() 231 | running_index += len(features) 232 | del images, record, features 233 | 234 | # Save results 235 | result = [all_features] 236 | for l in labels: 237 | result.append(l) 238 | torch.save(result, output_path) 239 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | model: 3 | model_type: DINO 4 | arch: vit_tiny 5 | root: /home/michaeldoron/dino_example/single_cells_data/ 6 | data_path: /home/michaeldoron/dino_example/metadata.csv 7 | output_dir: /home/michaeldoron/dino_example/output/ 8 | datatype: HPA 9 | image_mode: normalized_4_channels 10 | saveckp_freq: 50 11 | batch_size_per_gpu: 24 12 | num_channels: 4 13 | patch_size: 16 14 | epochs: 10 15 | momentum_teacher: 0.996 16 | center_momentum: 0.9 17 | sample_single_cells: False 18 | lr: 0.0005 19 | local_crops_scale: '0.2 0.5' 20 | 21 | embedding: 22 | pretrained_weights: /home/michaeldoron/dino_example/output/checkpoint.pth 23 | output_path: /home/michaeldoron/dino_example/features.pth 24 | df_path: /home/michaeldoron/dino_example/metadata.csv 25 | image_size: 224 26 | num_workers: 0 27 | embedding_has_labels: True 28 | target_labels: ['actin filaments,focal adhesion sites', 'aggresome', 29 | 'centrosome,centriolar satellite', 'cytosol', 'endoplasmic reticulum', 30 | 'golgi apparatus', 'intermediate filaments', 'microtubules', 31 | 'mitochondria', 'mitotic spindle', 'no staining', 'nuclear bodies', 32 | 'nuclear membrane', 'nuclear speckles', 'nucleoli', 33 | 'nucleoli fibrillar center', 'nucleoplasm', 34 | 'plasma membrane,cell junctions', 35 | 'vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies',] 36 | 37 | flip_and_color_jitter_transforms: 38 | RandomRotation: 39 | - False 40 | - {degrees: 90, expand: False} 41 | RandomHorizontalFlip: 42 | - True 43 | - {p: 0.5} 44 | RandomVerticalFlip: 45 | - True 46 | - {p: 0.5} 47 | Change_brightness: 48 | - True 49 | - {p: 0.5} 50 | Change_contrast: 51 | - True 52 | - {p: 0.5} 53 | GaussianBlur: 54 | - False 55 | - {p: 1.0} 56 | ColorJitter: 57 | - False 58 | - {brightness: 0.4, contrast: 0.4, saturation: 0.2, hue: 0.1} 59 | ColorJitter_for_RGBA: 60 | - False 61 | - {brightness: 0.4, contrast: 0.4, saturation: 0.2, hue: 0.1} 62 | normalization: 63 | Get_specific_channel: # nucleus_only 64 | - False 65 | - {c: 0} 66 | Get_specific_channel: # protein_only 67 | - False 68 | - {c: 1} 69 | Get_specific_channel: # cyto_only 70 | - False 71 | - {c: 2} 72 | Get_specific_channel: # ER_only 73 | - False 74 | - {c: 3} 75 | ToTensor: 76 | - True 77 | - # no params 78 | Normalize: 79 | - False 80 | - {mean: [0.1450534, 0.11360057, 0.1231717, 0.14919987], std: [0.18122554, 0.14004277, 0.18840286, 0.17790672]} 81 | self_normalize: 82 | - True 83 | - # no params 84 | # --- Global crops 1 ---: 85 | global_transfo1: 86 | Warp_cell: 87 | - True 88 | - # no params 89 | Single_cell_centered: 90 | - False 91 | - # no params 92 | Single_cell_random_resize: 93 | - False 94 | - # no params 95 | FA_resize: 96 | - False 97 | - {size: 512} 98 | Single_cell_Resize: 99 | - False 100 | - # no params 101 | Single_cell_Mirror: 102 | - False 103 | - # no params 104 | remove_channel: 105 | - True 106 | - {p: 0.2} 107 | rescale_protein: 108 | - True 109 | - {p: 0.2} 110 | RandomResizedCrop: 111 | - True 112 | - {size: 224, scale: [0.4, 1]} 113 | Threshold_protein: 114 | - False 115 | - {p: 0.8, interpolation: Image.BICUBIC} 116 | RandomResizedCenterCrop: 117 | - False 118 | - {size: 224, scale: [0.5, 1], depth: 1e6, s: 0.7} 119 | global_aug1: 120 | GaussianBlur: 121 | - False 122 | - {p: 1.0} 123 | Solarization: 124 | - False 125 | - {p: 0.2} 126 | Solarization_for_RGBA: 127 | - False 128 | - {p: 0.2} 129 | rnd_dihedral: 130 | - False 131 | - # no params 132 | testing_transfo: 133 | Single_cell_centered: 134 | - False 135 | - # no params 136 | Single_cell_random_resize: 137 | - False 138 | - # no params 139 | FA_resize: 140 | - False 141 | - {size: 512} 142 | Single_cell_Resize: 143 | - False 144 | - # no params 145 | Single_cell_Mirror: 146 | - False 147 | - # no params 148 | Get_specific_channel: 149 | - False 150 | - {c: 0} 151 | Get_specific_channel: 152 | - False 153 | - {c: 1} 154 | Get_specific_channel: 155 | - False 156 | - {c: 2} 157 | Get_specific_channel: 158 | - False 159 | - {c: 3} 160 | ToTensor: 161 | - True 162 | - # no params 163 | Normalize: 164 | - False 165 | - {mean: [0.1450534, 0.11360057, 0.1231717, 0.14919987], std: [0.18122554, 0.14004277, 0.18840286, 0.17790672]} 166 | self_normalize: 167 | - True 168 | - # no params 169 | # --- Global crops 2 ---: 170 | global_transfo2: 171 | Warp_cell: 172 | - True 173 | - # no params 174 | Single_cell_centered: 175 | - False 176 | - # no params 177 | Single_cell_random_resize: 178 | - False 179 | - # no params 180 | FA_resize: 181 | - False 182 | - {size: 512} 183 | Single_cell_Resize: 184 | - False 185 | - # no params 186 | Single_cell_Mirror: 187 | - False 188 | - # no params 189 | remove_channel: 190 | - True 191 | - {p: 0.2} 192 | rescale_protein: 193 | - True 194 | - {p: 0.2} 195 | RandomResizedCrop: 196 | - True 197 | - {size: 224, scale: [0.4, 1]} 198 | Threshold_protein: 199 | - False 200 | - {p: 0.8, interpolation: Image.BICUBIC} 201 | RandomResizedCenterCrop: 202 | - False 203 | - {size: 224, scale: [0.4, 1], depth: 1e6, s: 0.7} 204 | global_aug2: 205 | GaussianBlur: 206 | - False 207 | - {p: 1.0} 208 | Solarization: 209 | - False 210 | - {p: 0.2} 211 | Solarization_for_RGBA: 212 | - False 213 | - {p: 0.2} 214 | rnd_dihedral: 215 | - False 216 | - # no params 217 | # --- Local crops ---: 218 | local_crops_number: 8 219 | local_transfo: 220 | Warp_cell: 221 | - True 222 | - # no params 223 | Single_cell_centered: 224 | - False 225 | - # no params 226 | Single_cell_random_resize: 227 | - False 228 | - # no params 229 | FA_resize: 230 | - False 231 | - {size: 512} 232 | Single_cell_Resize: 233 | - False 234 | - # no params 235 | Single_cell_Mirror: 236 | - False 237 | - # no params 238 | remove_channel: 239 | - True 240 | - {p: 0.2} 241 | rescale_protein: 242 | - True 243 | - {p: 0.2} 244 | RandomResizedCrop: 245 | - True 246 | - {size: 96, scale: [0.05, 0.4]} 247 | Threshold_protein: 248 | - False 249 | - {p: 0.8, interpolation: Image.BICUBIC} 250 | RandomResizedCenterCrop: 251 | - False 252 | - {size: 96, scale: [0.2, 0.5], depth: 1e6, s: 0.7} 253 | local_aug: 254 | GaussianBlur: 255 | - False 256 | - {p: 1.0} 257 | rnd_dihedral: 258 | - False 259 | - # no params 260 | 261 | 262 | # --- Global crops 1 ---: 263 | global_transfo1: 264 | Warp_cell: 265 | - True 266 | - # no params 267 | Single_cell_centered: 268 | - False 269 | - # no params 270 | Single_cell_random_resize: 271 | - False 272 | - # no params 273 | FA_resize: 274 | - False 275 | - {size: 512} 276 | Single_cell_Resize: 277 | - False 278 | - # no params 279 | Single_cell_Mirror: 280 | - False 281 | - # no params 282 | remove_channel: 283 | - False 284 | - {p: 0.2} 285 | RandomResizedCrop: 286 | - True 287 | - {size: 224, scale: [0.4, 1]} 288 | RandomResizedCenterCrop: 289 | - False 290 | - {size: 224, scale: [0.5, 1], depth: 1e6, s: 0.7} 291 | 292 | global_aug1: 293 | GaussianBlur: 294 | - False 295 | - {p: 1.0} 296 | Solarization: 297 | - False 298 | - {p: 0.2} 299 | Solarization_for_RGBA: 300 | - False 301 | - {p: 0.2} 302 | rnd_dihedral: 303 | - False 304 | - # no params 305 | 306 | testing_transfo: 307 | Single_cell_centered: 308 | - False 309 | - # no params 310 | Single_cell_random_resize: 311 | - False 312 | - # no params 313 | FA_resize: 314 | - False 315 | - {size: 512} 316 | Single_cell_Resize: 317 | - False 318 | - # no params 319 | Single_cell_Mirror: 320 | - False 321 | - # no params 322 | Get_specific_channel: 323 | - False 324 | - {c: 0} 325 | Get_specific_channel: 326 | - False 327 | - {c: 1} 328 | Get_specific_channel: 329 | - False 330 | - {c: 2} 331 | Get_specific_channel: 332 | - False 333 | - {c: 3} 334 | ToTensor: 335 | - True 336 | - # no params 337 | Normalize: 338 | - False 339 | - {mean: [0.1450534, 0.11360057, 0.1231717, 0.14919987], std: [0.18122554, 0.14004277, 0.18840286, 0.17790672]} 340 | self_normalize: 341 | - True 342 | - # no params 343 | 344 | # --- Global crops 2 ---: 345 | global_transfo2: 346 | Warp_cell: 347 | - False 348 | - # no params 349 | Single_cell_centered: 350 | - False 351 | - # no params 352 | Single_cell_random_resize: 353 | - False 354 | - # no params 355 | FA_resize: 356 | - False 357 | - {size: 512} 358 | Single_cell_Resize: 359 | - False 360 | - # no params 361 | Single_cell_Mirror: 362 | - False 363 | - # no params 364 | remove_channel: 365 | - True 366 | - {p: 0.2} 367 | RandomResizedCrop: 368 | - True 369 | - {size: 224, scale: [0.4, 1]} 370 | RandomResizedCenterCrop: 371 | - False 372 | - {size: 224, scale: [0.4, 1], depth: 1e6, s: 0.7} 373 | 374 | global_aug2: 375 | GaussianBlur: 376 | - False 377 | - {p: 1.0} 378 | Solarization: 379 | - False 380 | - {p: 0.2} 381 | Solarization_for_RGBA: 382 | - False 383 | - {p: 0.2} 384 | rnd_dihedral: 385 | - False 386 | - # no params 387 | 388 | 389 | # --- Local crops ---: 390 | local_crops_number: 8 391 | local_transfo: 392 | Warp_cell: 393 | - False 394 | - # no params 395 | Single_cell_centered: 396 | - False 397 | - # no params 398 | Single_cell_random_resize: 399 | - False 400 | - # no params 401 | FA_resize: 402 | - False 403 | - {size: 512} 404 | Single_cell_Resize: 405 | - False 406 | - # no params 407 | Single_cell_Mirror: 408 | - False 409 | - # no params 410 | remove_channel: 411 | - True 412 | - {p: 0.2} 413 | RandomResizedCrop: 414 | - True 415 | - {size: 96, scale: [0.05, 0.4]} 416 | RandomResizedCenterCrop: 417 | - False 418 | - {size: 96, scale: [0.2, 0.5], depth: 1e6, s: 0.7} 419 | 420 | local_aug: 421 | GaussianBlur: 422 | - False 423 | - {p: 1.0} 424 | rnd_dihedral: 425 | - False 426 | - # no params 427 | -------------------------------------------------------------------------------- /data_utils/label_dict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | 4 | num_to_protein_full = {} 5 | num_to_protein_full[0] = "nucleoplasm" 6 | num_to_protein_full[1] = "nuclear membrane" 7 | num_to_protein_full[2] = "nucleoli" 8 | num_to_protein_full[3] = "nucleoli fibrillar center" 9 | num_to_protein_full[4] = "nuclear speckles" 10 | num_to_protein_full[5] = "nuclear bodies" 11 | num_to_protein_full[6] = "endoplasmic reticulum" 12 | num_to_protein_full[7] = "golgi apparatus" 13 | num_to_protein_full[8] = "peroxisomes" 14 | num_to_protein_full[9] = "endosomes" 15 | num_to_protein_full[10] = "lysosomes" 16 | num_to_protein_full[11] = "intermediate filaments" 17 | num_to_protein_full[12] = "actin filaments" 18 | num_to_protein_full[13] = "focal adhesion sites" 19 | num_to_protein_full[14] = "microtubules" 20 | num_to_protein_full[15] = "microtubule ends" 21 | num_to_protein_full[16] = "cytokinetic bridge" 22 | num_to_protein_full[17] = "mitotic spindle" 23 | num_to_protein_full[18] = "microtubule organizing center" 24 | num_to_protein_full[19] = "centrosome" 25 | num_to_protein_full[20] = "lipid droplets" 26 | num_to_protein_full[21] = "plasma membrane" 27 | num_to_protein_full[22] = "cell junctions" 28 | num_to_protein_full[23] = "mitochondria" 29 | num_to_protein_full[24] = "aggresome" 30 | num_to_protein_full[25] = "cytosol" 31 | num_to_protein_full[26] = "cytoplasmic bodies" 32 | num_to_protein_full[27] = "rods & rings" 33 | 34 | num_to_protein_single_cells = {} 35 | num_to_protein_single_cells[0] = "nucleoplasm" 36 | num_to_protein_single_cells[1] = "nuclear membrane" 37 | num_to_protein_single_cells[2] = "nucleoli" 38 | num_to_protein_single_cells[3] = "nucleoli fibrillar center" 39 | num_to_protein_single_cells[4] = "nuclear speckles" 40 | num_to_protein_single_cells[5] = "nuclear bodies" 41 | num_to_protein_single_cells[6] = "endoplasmic reticulum" 42 | num_to_protein_single_cells[7] = "golgi apparatus" 43 | num_to_protein_single_cells[8] = "intermediate filaments" 44 | num_to_protein_single_cells[9] = "actin filaments,focal adhesion sites" 45 | num_to_protein_single_cells[10] = "microtubules" 46 | num_to_protein_single_cells[11] = "mitotic spindle" 47 | num_to_protein_single_cells[12] = "centrosome,centriolar satellite" 48 | num_to_protein_single_cells[13] = "plasma membrane,cell junctions" 49 | num_to_protein_single_cells[14] = "mitochondria" 50 | num_to_protein_single_cells[15] = "aggresome" 51 | num_to_protein_single_cells[16] = "cytosol" 52 | num_to_protein_single_cells[ 53 | 17 54 | ] = "vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies" 55 | num_to_protein_single_cells[18] = "no staining" 56 | 57 | whole2single = { 58 | "nucleoplasm": 0, 59 | "nuclear membrane": 1, 60 | "nucleoli": 2, 61 | "nucleoli fibrillar center": 3, 62 | "nuclear speckles": 4, 63 | "nuclear bodies": 5, 64 | "endoplasmic reticulum": 6, 65 | "golgi apparatus": 7, 66 | "intermediate filaments": 8, 67 | "actin filaments": 9, 68 | "focal adhesion sites": 9, 69 | "microtubules": 10, 70 | "microtubule ends": 10, 71 | # "microtubule organizing center": 10, 72 | # "cytokinetic bridge": 10, 73 | "mitotic spindle": 11, 74 | "mitotic chromosome": 11, 75 | "centrosome": 12, 76 | "centriolar satellite": 12, 77 | "plasma membrane": 13, 78 | "cell junctions": 13, 79 | "mitochondria": 14, 80 | "aggresome": 15, 81 | "cytosol": 16, 82 | "vesicles": 17, 83 | "peroxisomes": 17, 84 | "endosomes": 17, 85 | "lysosomes": 17, 86 | "lipid droplets": 17, 87 | "cytoplasmic bodies": 17, 88 | "no staining": 18, 89 | "rods & rings": 18, 90 | "nucleoli rim": 18, 91 | "kinetochore": 18, 92 | } 93 | 94 | protein_to_num_full = {v.lower(): k for (k, v) in num_to_protein_full.items()} 95 | protein_to_num_single_cells = { 96 | v.lower(): k for (k, v) in num_to_protein_single_cells.items() 97 | } 98 | 99 | num_to_cell_full = {} 100 | num_to_cell_full[0] = "BJ" 101 | num_to_cell_full[1] = "LHCN-M2" 102 | num_to_cell_full[2] = "RH-30" 103 | num_to_cell_full[3] = "SH-SY5Y" 104 | num_to_cell_full[4] = "SiHa" 105 | num_to_cell_full[5] = "U-2 OS" 106 | num_to_cell_full[6] = "ASC TERT1" 107 | num_to_cell_full[7] = "HaCaT" 108 | num_to_cell_full[8] = "A-431" 109 | num_to_cell_full[9] = "U-251 MG" 110 | num_to_cell_full[10] = "HEK 293" 111 | num_to_cell_full[11] = "A549" 112 | num_to_cell_full[12] = "RT4" 113 | num_to_cell_full[13] = "HeLa" 114 | num_to_cell_full[14] = "MCF7" 115 | num_to_cell_full[15] = "PC-3" 116 | num_to_cell_full[16] = "hTERT-RPE1" 117 | num_to_cell_full[17] = "SK-MEL-30" 118 | num_to_cell_full[18] = "EFO-21" 119 | num_to_cell_full[19] = "AF22" 120 | num_to_cell_full[20] = "HEL" 121 | num_to_cell_full[21] = "Hep G2" 122 | num_to_cell_full[22] = "HUVEC TERT2" 123 | num_to_cell_full[23] = "THP-1" 124 | num_to_cell_full[24] = "CACO-2" 125 | num_to_cell_full[25] = "JURKAT" 126 | num_to_cell_full[26] = "RPTEC TERT1" 127 | num_to_cell_full[27] = "SuSa" 128 | num_to_cell_full[28] = "REH" 129 | num_to_cell_full[29] = "HDLM-2" 130 | num_to_cell_full[30] = "K-562" 131 | num_to_cell_full[31] = "hTCEpi" 132 | num_to_cell_full[32] = "NB-4" 133 | num_to_cell_full[33] = "HAP1" 134 | num_to_cell_full[34] = "OE19" 135 | 136 | cell_to_num_full = {v.lower(): k for (k, v) in num_to_cell_full.items()} 137 | 138 | num_to_cell_single_cells = {} 139 | num_to_cell_single_cells[0] = "A-431" 140 | num_to_cell_single_cells[1] = "A549" 141 | num_to_cell_single_cells[2] = "AF22" 142 | num_to_cell_single_cells[3] = "ASC TERT1" 143 | num_to_cell_single_cells[4] = "BJ" 144 | num_to_cell_single_cells[5] = "CACO-2" 145 | num_to_cell_single_cells[6] = "EFO-21" 146 | num_to_cell_single_cells[7] = "HAP1" 147 | num_to_cell_single_cells[8] = "HDLM-2" 148 | num_to_cell_single_cells[9] = "HEK 293" 149 | num_to_cell_single_cells[10] = "HEL" 150 | num_to_cell_single_cells[11] = "HUVEC TERT2" 151 | num_to_cell_single_cells[12] = "HaCaT" 152 | num_to_cell_single_cells[13] = "HeLa" 153 | num_to_cell_single_cells[14] = "Hep G2" 154 | num_to_cell_single_cells[15] = "JURKAT" 155 | num_to_cell_single_cells[16] = "K-562" 156 | num_to_cell_single_cells[17] = "MCF7" 157 | num_to_cell_single_cells[18] = "PC-3" 158 | num_to_cell_single_cells[19] = "REH" 159 | num_to_cell_single_cells[20] = "RH-30" 160 | num_to_cell_single_cells[21] = "RPTEC TERT1" 161 | num_to_cell_single_cells[22] = "RT4" 162 | num_to_cell_single_cells[23] = "SH-SY5Y" 163 | num_to_cell_single_cells[24] = "SK-MEL-30" 164 | num_to_cell_single_cells[25] = "SiHa" 165 | num_to_cell_single_cells[26] = "U-2 OS" 166 | num_to_cell_single_cells[27] = "U-251 MG" 167 | num_to_cell_single_cells[28] = "hTCEpi" 168 | cell_to_num_single_cells = {v.lower(): k for (k, v) in num_to_cell_single_cells.items()} 169 | 170 | num_whole2single = { 171 | protein_to_num_full[k]: v 172 | for k, v in whole2single.items() 173 | if k in protein_to_num_full 174 | } 175 | 176 | num_single2whole = { 177 | v: protein_to_num_full[k] 178 | for k, v in whole2single.items() 179 | if k in protein_to_num_full 180 | } 181 | 182 | protein_to_num_ref = copy.deepcopy(protein_to_num_full) 183 | protein_to_num_ref["vesicles"] = 28 184 | protein_to_num_ref["unknown"] = 29 185 | protein_to_num_ref["nucleoli rim"] = 30 186 | protein_to_num_ref["mitotic chromosome"] = 31 187 | protein_to_num_ref["kinetochore"] = 32 188 | 189 | num_to_protein_4k = {} 190 | num_to_protein_4k[0] = "nucleoplasm" 191 | num_to_protein_4k[1] = "plasma membrane" 192 | num_to_protein_4k[2] = "mitochondria" 193 | num_to_protein_4k[3] = "cytosol" 194 | 195 | protein_to_num_4k = {v.lower(): k for (k, v) in num_to_protein_4k.items()} 196 | 197 | num_to_protein_5k = {} 198 | num_to_protein_5k[0] = "nucleoplasm" 199 | num_to_protein_5k[1] = "plasma membrane" 200 | num_to_protein_5k[2] = "mitochondria" 201 | num_to_protein_5k[3] = "cytosol" 202 | num_to_protein_5k[4] = "vesicles" 203 | 204 | protein_to_num_5k = {v.lower(): k for (k, v) in num_to_protein_5k.items()} 205 | 206 | num_to_other_protein_5k = {} 207 | num_to_other_protein_5k[0] = "golgi apparatus" 208 | num_to_other_protein_5k[1] = "nuclear speckles" 209 | num_to_other_protein_5k[2] = "nuclear bodies" 210 | num_to_other_protein_5k[3] = "nucleoli" 211 | num_to_other_protein_5k[4] = "endoplasmic reticulum" 212 | 213 | other_protein_to_num_5k = {v.lower(): k for (k, v) in num_to_other_protein_5k.items()} 214 | 215 | from collections import defaultdict 216 | 217 | hierarchical_organization_whole_image_high_level = [ 218 | [ 219 | "nucleoplasm", 220 | "nuclear bodies", 221 | "nuclear speckles", 222 | "nucleoli", 223 | "nucleoli fibrillar center", 224 | "nuclear membrane", 225 | ], 226 | [ 227 | "cytosol", 228 | "aggresome", 229 | "mitochondria", 230 | "intermediate filaments", 231 | "microtubule ends", 232 | "microtubules", 233 | "actin filaments", 234 | "cytokinetic bridge", 235 | "microtubule organizing center", 236 | "centrosome", 237 | "endoplasmic reticulum", 238 | "golgi apparatus", 239 | "vesicles", 240 | "cell junctions", 241 | "focal adhesion sites", 242 | "plasma membrane", 243 | "rods & rings", 244 | "peroxisomes", 245 | "endosomes", 246 | "lysosomes", 247 | "mitotic spindle", 248 | "lipid droplets", 249 | "cytoplasmic bodies", 250 | ], 251 | ] 252 | hierarchical_organization_whole_image_low_level = [ 253 | ["nucleoplasm"], 254 | ["nuclear bodies", "nuclear speckles"], 255 | ["nucleoli", "nucleoli fibrillar center"], 256 | ["nuclear membrane"], 257 | ["cytosol", "aggresome", "mitochondria"], 258 | [ 259 | "intermediate filaments", 260 | "microtubule ends", 261 | "microtubules", 262 | "actin filaments", 263 | "cytokinetic bridge", 264 | ], 265 | ["microtubule organizing center", "centrosome"], 266 | ["endoplasmic reticulum", "golgi apparatus", "vesicles"], 267 | ["cell junctions", "focal adhesion sites", "plasma membrane"], 268 | [ 269 | "rods & rings", 270 | "peroxisomes", 271 | "endosomes", 272 | "lysosomes", 273 | "mitotic spindle", 274 | "lipid droplets", 275 | "cytoplasmic bodies", 276 | ], 277 | ] 278 | 279 | hierarchical_organization_single_cell_high_level = [ 280 | [ 281 | "nucleoplasm", 282 | "nuclear bodies", 283 | "nuclear speckles", 284 | "nucleoli", 285 | "nucleoli fibrillar center", 286 | "nuclear membrane", 287 | ], 288 | [ 289 | "cytosol", 290 | "aggresome", 291 | "mitochondria", 292 | "intermediate filaments", 293 | "microtubule ends", 294 | "microtubules", 295 | "actin filaments,focal adhesion sites", 296 | "cytokinetic bridge", 297 | "microtubule organizing center", 298 | "centrosome,centriolar satellite", 299 | "endoplasmic reticulum", 300 | "golgi apparatus", 301 | "vesicles", 302 | "focal adhesion sites", 303 | "plasma membrane,cell junctions", 304 | "rods & rings", 305 | "peroxisomes", 306 | "endosomes", 307 | "lysosomes", 308 | "mitotic spindle", 309 | "lipid droplets", 310 | "vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies", 311 | "no staining", 312 | ], 313 | ] 314 | 315 | hierarchical_organization_single_cell_low_level = [ 316 | ["nucleoplasm"], 317 | ["nuclear bodies", "nuclear speckles"], 318 | ["nucleoli", "nucleoli fibrillar center"], 319 | ["nuclear membrane"], 320 | ["cytosol", "aggresome", "mitochondria"], 321 | [ 322 | "intermediate filaments", 323 | "microtubule ends", 324 | "microtubules", 325 | "actin filaments,focal adhesion sites", 326 | "cytokinetic bridge", 327 | ], 328 | [ 329 | "microtubule organizing center", 330 | "centrosome,centriolar satellite", 331 | ], 332 | ["endoplasmic reticulum", "golgi apparatus", "vesicles"], 333 | ["focal adhesion sites", "plasma membrane,cell junctions"], 334 | [ 335 | "rods & rings", 336 | "peroxisomes", 337 | "endosomes", 338 | "lysosomes", 339 | "mitotic spindle", 340 | "lipid droplets", 341 | "vesicles,peroxisomes,endosomes,lysosomes,lipid droplets,cytoplasmic bodies", 342 | "no staining", 343 | ], 344 | ] 345 | -------------------------------------------------------------------------------- /archs/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from utils.utils import trunc_normal_ 25 | 26 | 27 | def drop_path(x, drop_prob: float = 0., training: bool = False): 28 | if drop_prob == 0. or not training: 29 | return x 30 | keep_prob = 1 - drop_prob 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | output = x.div(keep_prob) * random_tensor 35 | return output 36 | 37 | 38 | class DropPath(nn.Module): 39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 40 | """ 41 | def __init__(self, drop_prob=None): 42 | super(DropPath, self).__init__() 43 | self.drop_prob = drop_prob 44 | 45 | def forward(self, x): 46 | return drop_path(x, self.drop_prob, self.training) 47 | 48 | 49 | class Mlp(nn.Module): 50 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 51 | super().__init__() 52 | out_features = out_features or in_features 53 | hidden_features = hidden_features or in_features 54 | self.fc1 = nn.Linear(in_features, hidden_features) 55 | self.act = act_layer() 56 | self.fc2 = nn.Linear(hidden_features, out_features) 57 | self.drop = nn.Dropout(drop) 58 | 59 | def forward(self, x): 60 | x = self.fc1(x) 61 | x = self.act(x) 62 | x = self.drop(x) 63 | x = self.fc2(x) 64 | x = self.drop(x) 65 | return x 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 70 | super().__init__() 71 | self.num_heads = num_heads 72 | head_dim = dim // num_heads 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 76 | self.attn_drop = nn.Dropout(attn_drop) 77 | self.proj = nn.Linear(dim, dim) 78 | self.proj_drop = nn.Dropout(proj_drop) 79 | 80 | def forward(self, x): 81 | B, N, C = x.shape 82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 83 | q, k, v = qkv[0], qkv[1], qkv[2] 84 | 85 | attn = (q @ k.transpose(-2, -1)) * self.scale 86 | attn = attn.softmax(dim=-1) 87 | attn = self.attn_drop(attn) 88 | 89 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x, attn 93 | 94 | 95 | class Block(nn.Module): 96 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 97 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 98 | super().__init__() 99 | self.norm1 = norm_layer(dim) 100 | self.attn = Attention( 101 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 106 | 107 | def forward(self, x, return_attention=False): 108 | y, attn = self.attn(self.norm1(x)) 109 | if return_attention: 110 | return attn 111 | x = x + self.drop_path(y) 112 | x = x + self.drop_path(self.mlp(self.norm2(x))) 113 | return x 114 | 115 | 116 | class PatchEmbed(nn.Module): 117 | """ Image to Patch Embedding 118 | """ 119 | def __init__(self, img_size=224, patch_size=16, in_chans=4, embed_dim=768): 120 | # def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 121 | super().__init__() 122 | num_patches = (img_size // patch_size) * (img_size // patch_size) 123 | self.img_size = img_size 124 | self.patch_size = patch_size 125 | self.num_patches = num_patches 126 | 127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 128 | 129 | def forward(self, x): 130 | B, C, H, W = x.shape 131 | x = self.proj(x).flatten(2).transpose(1, 2) 132 | return x 133 | 134 | 135 | class VisionTransformer(nn.Module): 136 | """ Vision Transformer """ 137 | def __init__(self, img_size=[224], patch_size=16, in_chans=4, num_classes=0, embed_dim=768, depth=12, 138 | # def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 139 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 140 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 141 | super().__init__() 142 | self.num_features = self.embed_dim = embed_dim 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | num_patches = self.patch_embed.num_patches 147 | 148 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 149 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 150 | self.pos_drop = nn.Dropout(p=drop_rate) 151 | 152 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 153 | self.blocks = nn.ModuleList([ 154 | Block( 155 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 156 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 157 | for i in range(depth)]) 158 | self.norm = norm_layer(embed_dim) 159 | 160 | # Classifier head 161 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | def interpolate_pos_encoding(self, x, w, h): 177 | npatch = x.shape[1] - 1 178 | N = self.pos_embed.shape[1] - 1 179 | if npatch == N and w == h: 180 | return self.pos_embed 181 | class_pos_embed = self.pos_embed[:, 0] 182 | patch_pos_embed = self.pos_embed[:, 1:] 183 | dim = x.shape[-1] 184 | w0 = w // self.patch_embed.patch_size 185 | h0 = h // self.patch_embed.patch_size 186 | # we add a small number to avoid floating point error in the interpolation 187 | # see discussion at https://github.com/facebookresearch/dino/issues/8 188 | w0, h0 = w0 + 0.1, h0 + 0.1 189 | patch_pos_embed = nn.functional.interpolate( 190 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 191 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 192 | mode='bicubic', 193 | ) 194 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 195 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 196 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 197 | 198 | def prepare_tokens(self, x): 199 | B, nc, w, h = x.shape 200 | x = self.patch_embed(x) # patch linear embedding 201 | 202 | # add the [CLS] token to the embed patch tokens 203 | cls_tokens = self.cls_token.expand(B, -1, -1) 204 | x = torch.cat((cls_tokens, x), dim=1) 205 | 206 | # add positional encoding to each token 207 | x = x + self.interpolate_pos_encoding(x, w, h) 208 | 209 | return self.pos_drop(x) 210 | 211 | def forward(self, x): 212 | x = self.prepare_tokens(x) 213 | for blk in self.blocks: 214 | x = blk(x) 215 | x = self.norm(x) 216 | return x[:, 0] 217 | 218 | def get_last_selfattention(self, x): 219 | x = self.prepare_tokens(x) 220 | for i, blk in enumerate(self.blocks): 221 | if i < len(self.blocks) - 1: 222 | x = blk(x) 223 | else: 224 | # return attention of the last block 225 | return blk(x, return_attention=True) 226 | 227 | def get_intermediate_layers(self, x, n=1): 228 | x = self.prepare_tokens(x) 229 | # we return the output tokens from the `n` last blocks 230 | output = [] 231 | for i, blk in enumerate(self.blocks): 232 | x = blk(x) 233 | if len(self.blocks) - i <= n: 234 | output.append(self.norm(x)) 235 | return output 236 | 237 | 238 | def vit_tiny(patch_size=16, **kwargs): 239 | model = VisionTransformer( 240 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 241 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 242 | return model 243 | 244 | 245 | def vit_small(patch_size=16, **kwargs): 246 | model = VisionTransformer( 247 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 248 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 249 | return model 250 | 251 | 252 | def vit_base(patch_size=16, **kwargs): 253 | model = VisionTransformer( 254 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 255 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 256 | return model 257 | 258 | 259 | class DINOHead(nn.Module): 260 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 261 | super().__init__() 262 | nlayers = max(nlayers, 1) 263 | if nlayers == 1: 264 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 265 | else: 266 | layers = [nn.Linear(in_dim, hidden_dim)] 267 | if use_bn: 268 | layers.append(nn.BatchNorm1d(hidden_dim)) 269 | layers.append(nn.GELU()) 270 | for _ in range(nlayers - 2): 271 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 272 | if use_bn: 273 | layers.append(nn.BatchNorm1d(hidden_dim)) 274 | layers.append(nn.GELU()) 275 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 276 | self.mlp = nn.Sequential(*layers) 277 | self.apply(self._init_weights) 278 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 279 | self.last_layer.weight_g.data.fill_(1) 280 | if norm_last_layer: 281 | self.last_layer.weight_g.requires_grad = False 282 | 283 | def _init_weights(self, m): 284 | if isinstance(m, nn.Linear): 285 | trunc_normal_(m.weight, std=.02) 286 | if isinstance(m, nn.Linear) and m.bias is not None: 287 | nn.init.constant_(m.bias, 0) 288 | 289 | def forward(self, x): 290 | x = self.mlp(x) 291 | x = nn.functional.normalize(x, dim=-1, p=2) 292 | x = self.last_layer(x) 293 | return x 294 | -------------------------------------------------------------------------------- /utils/classification_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import LogisticRegression 3 | from tqdm import tqdm 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.nn.functional import sigmoid 8 | from torch.optim.lr_scheduler import CosineAnnealingLR 9 | import torch.distributed as dist 10 | from random import choices 11 | import sys 12 | 13 | def is_dist_avail_and_initialized(): 14 | if not dist.is_available(): 15 | return False 16 | if not dist.is_initialized(): 17 | return False 18 | return True 19 | 20 | def get_world_size(): 21 | if not is_dist_avail_and_initialized(): 22 | return 1 23 | return dist.get_world_size() 24 | 25 | 26 | # ============================== Classification ======================================= 27 | class MLLR: 28 | def __init__(self, max_iter=100, num_classes=None): 29 | self.index = None 30 | self.y = None 31 | self.num_classes = num_classes 32 | self.max_iter = max_iter 33 | 34 | def fit(self, X, y): 35 | if type(self.num_classes) == type(None): 36 | self.num_classes = np.array(y).shape[1] 37 | 38 | self.classifiers = [ 39 | LogisticRegression(max_iter=self.max_iter) for c in range(self.num_classes) 40 | ] 41 | 42 | pbar = tqdm(enumerate(self.classifiers)) 43 | for ind, c in pbar: 44 | pbar.set_description(f"training LinearRegressor for class {ind}") 45 | if y[:, ind].mean() in [0, 1]: 46 | print(f"No two classes for ind {ind}!") 47 | else: 48 | c.fit(X, y[:, ind]) 49 | 50 | def predict(self, X): 51 | predictions = np.zeros((X.shape[0], self.num_classes)) 52 | for ind, c in enumerate(self.classifiers): 53 | try: 54 | predictions[:, ind] = c.predict(X) 55 | except: 56 | pass 57 | return predictions 58 | 59 | 60 | def threshold_output(prediction, threshold=0.5, use_sigmoid=False): 61 | if use_sigmoid: 62 | prediction = sigmoid(prediction) 63 | return prediction > threshold 64 | 65 | 66 | class Multilabel_classifier(nn.Module): 67 | def __init__( 68 | self, 69 | num_features, 70 | num_classes, 71 | hidden_units=500, 72 | with_sigmoid=True, 73 | num_blocks=3, 74 | ): 75 | super().__init__() 76 | num_blocks = num_blocks 77 | layers = [] 78 | layers.extend( 79 | [ 80 | nn.Linear(num_features, hidden_units), 81 | nn.BatchNorm1d(hidden_units), 82 | nn.ReLU(), 83 | # nn.Dropout(), 84 | ] 85 | ) 86 | for i in range(num_blocks): 87 | layers.extend( 88 | [ 89 | nn.Linear(hidden_units, hidden_units), 90 | nn.BatchNorm1d(hidden_units), 91 | nn.ReLU(), 92 | nn.Dropout() if i == (num_blocks - 1) else nn.Identity(), 93 | ] 94 | ) 95 | layers.append(nn.Linear(hidden_units, num_classes)) 96 | self.layers = torch.nn.Sequential(*layers) 97 | if with_sigmoid: 98 | self.sigmoid = nn.Sigmoid() 99 | else: 100 | self.sigmoid = nn.Identity() 101 | 102 | def forward(self, x): 103 | x = self.layers(x) 104 | x = self.sigmoid(x) 105 | return x 106 | 107 | 108 | class ResBlock(nn.Module): 109 | def __init__(self, n_units=1024, norm=torch.nn.Identity, skip=True): 110 | super().__init__() 111 | self.fc1 = nn.Linear(n_units, n_units) 112 | self.norm1 = norm(n_units) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.fc2 = nn.Linear(n_units, n_units) 115 | self.norm2 = norm(n_units) 116 | self.skip = skip 117 | 118 | def forward(self, x): 119 | identity = x 120 | out = self.fc1(x) 121 | out = self.norm1(out) 122 | out = self.relu(out) 123 | out = self.fc2(out) 124 | out = self.norm2(out) 125 | if self.skip: 126 | out += identity 127 | out = self.relu(out) 128 | return out 129 | 130 | 131 | class residual_add_clf(nn.Module): 132 | def __init__( 133 | self, 134 | n_features, 135 | n_classes, 136 | n_layers=2, 137 | n_units=1024, 138 | norm_layer=None, 139 | skip=True, 140 | ): 141 | super().__init__() 142 | if norm_layer == "layer": 143 | self.norm = torch.nn.LayerNorm 144 | self.norm_layer = self.norm(n_units) 145 | elif norm_layer == "instance": 146 | self.norm = torch.nn.InstanceNorm1d 147 | self.norm_layer = self.norm(n_units) 148 | elif norm_layer == "batch_norm": 149 | self.norm = torch.nn.BatchNorm1d 150 | self.norm_layer = self.norm(1) 151 | else: 152 | self.norm = torch.nn.Identity 153 | self.norm_layer = self.norm(n_units) 154 | self.first_layer = nn.Linear(n_features, n_units) 155 | self.layers = torch.nn.Sequential( 156 | *[ 157 | ResBlock(n_units, self.norm, skip) 158 | for i in range(int(max(0, (n_layers - 2)) / 2)) 159 | ] 160 | ) 161 | self.final_layer = nn.Linear(n_units, n_classes) 162 | 163 | def forward(self, x): 164 | x = x.reshape(x.shape[0], 1, x.shape[-1]) 165 | f = self.norm_layer(self.first_layer(x)) 166 | x = torch.nn.functional.relu(f) 167 | if len(self.layers) > 0: 168 | x = self.layers(x) 169 | x = self.final_layer(x) 170 | return x 171 | 172 | class simple_clf(nn.Module): 173 | def __init__(self, n_features, n_classes, p=0.5): 174 | super().__init__() 175 | self.p = p 176 | self.clf = nn.Sequential( 177 | nn.Linear(n_features, 512), 178 | nn.ReLU(inplace=True), 179 | nn.Dropout(p=p), 180 | nn.Linear(512, 256), 181 | nn.ReLU(inplace=True), 182 | nn.Dropout(p=p), 183 | nn.Linear(256, n_classes), 184 | ) 185 | 186 | def forward(self, x): 187 | return self.clf(x) 188 | 189 | 190 | import torch 191 | import torch.nn as nn 192 | import torch.nn.functional as F 193 | 194 | 195 | class FocalLoss(nn.Module): 196 | def __init__(self, gamma=2): 197 | super().__init__() 198 | self.gamma = gamma 199 | 200 | def forward(self, logit, target, epoch=0): 201 | target = target.float() 202 | max_val = (-logit).clamp(min=0) 203 | loss = ( 204 | logit 205 | - logit * target 206 | + max_val 207 | + ((-max_val).exp() + (-logit - max_val).exp()).log() 208 | ) 209 | invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0)) 210 | loss = (invprobs * self.gamma).exp() * loss 211 | if len(loss.size()) == 2: 212 | loss = loss.sum(dim=1) 213 | return loss.mean() 214 | 215 | 216 | def balance_data(X_train, y_train): 217 | k = len(y_train) 218 | freq_per_class = y_train.mean(axis=0) 219 | balance_freqs_per_class = 1 / (freq_per_class) 220 | balance_freq_per_sample = y_train * balance_freqs_per_class 221 | balance_freq_per_sample = ( 222 | balance_freq_per_sample.max(axis=1) / balance_freqs_per_class.max() 223 | ) 224 | indices = choices(np.arange(k), weights=balance_freq_per_sample, k=k) 225 | return (X_train[indices], y_train[indices]) 226 | 227 | 228 | def get_scheduler(optim, total_steps, len_dl, args): 229 | lr = ( 230 | float(args.lr) * (int(args.batch_size_per_gpu) * get_world_size()) / 256.0 231 | ) # linear scaling rule 232 | if args.schedule == "Cosine": 233 | scheduler = CosineAnnealingLR(optim, int(total_steps)) 234 | wd_scheduler, lr_scheduler = None, None 235 | else: 236 | print(f"Scheduler {args.schedule} not implemented.") 237 | sys.exit() 238 | print(f"Using {args.schedule} learning rate schedule") 239 | 240 | return scheduler, lr_scheduler, wd_scheduler 241 | 242 | 243 | def get_optimizer(parameters, args): 244 | lr = ( 245 | float(args.lr) * (int(args.batch_size_per_gpu) * get_world_size()) / 256.0 246 | ) # linear scaling rule 247 | if args.optimizer == "AdamW": 248 | optim = torch.optim.AdamW( 249 | parameters, 250 | lr=float(lr), 251 | betas=(0.9, 0.999), 252 | eps=1e-08, 253 | weight_decay=float(args.wd), 254 | amsgrad=False, 255 | ) 256 | elif args.optimizer == "SGD": 257 | optim = torch.optim.SGD(parameters, lr=float(lr), weight_decay=float(args.wd)) 258 | 259 | else: 260 | print(f"Optimizer {args.optimizer} not implemented.") 261 | sys.exit() 262 | print(f"Using {args.optimizer} optimizer") 263 | return optim 264 | 265 | 266 | 267 | def network_predict(classifier, X_test): 268 | prediction = threshold_output(classifier(X_test)).int().cpu().detach().numpy() 269 | return prediction 270 | 271 | 272 | def network_save(classifier, config, task): 273 | torch.save(classifier.state_dict(), config["classification"][f"{task}_classifier"]) 274 | 275 | 276 | def network_load(config, task): 277 | return torch.load(config["classification"][f"{task}_classifier"]) 278 | 279 | 280 | def MLLR_train(num_classes, X_train, y_train, config=None): 281 | classifier = MLLR(max_iter=1000, num_classes=num_classes) 282 | classifier.fit(X_train, y_train) 283 | return classifier 284 | 285 | 286 | def MLLR_predict(classifier, X_test): 287 | prediction = classifier.predict(X_test) 288 | return prediction 289 | 290 | 291 | def MLLR_save(classifier, config, task): 292 | torch.save(classifier, config["classification"][f"{task}_classifier"]) 293 | 294 | 295 | def MLLR_load(config, task): 296 | return torch.load(config["classification"][f"{task}_classifier"]) 297 | 298 | 299 | def profile_features(X, y, inds_per_ID, IDs): 300 | new_X = [] 301 | new_y = [] 302 | for ID in IDs: 303 | new_X.append(X[inds_per_ID[ID]].mean(axis=0)) 304 | new_y.append(y[inds_per_ID[ID]][0]) 305 | return torch.Tensor(new_X), torch.Tensor(new_y) 306 | 307 | 308 | def get_dataset(config, dataset): 309 | features_list = config["classification"][dataset] 310 | # all_features, cell_types, protein_locations, IDs = torch.load(features_list[0]) 311 | all_features, protein_locations, cell_types, IDs = torch.load(features_list[0]) 312 | if len(features_list) > 1: 313 | for features_path in features_list[1:]: 314 | ( 315 | all_features_other, 316 | protein_locations_other, 317 | cell_types_other, 318 | IDs_other, 319 | ) = torch.load(features_path) 320 | all_features = torch.cat((all_features, all_features_other), axis=1) 321 | protein_locations.extend(protein_locations_other) 322 | cell_types.extend(cell_types_other) 323 | IDs.extend(IDs_other) 324 | return all_features, protein_locations, cell_types, IDs 325 | 326 | 327 | class FocalBCELoss(torch.nn.Module): 328 | y_int = True # y interpolation 329 | 330 | def __init__(self, gamma=2.0, weight=None, reduction="mean"): 331 | super().__init__() 332 | self.gamma = gamma 333 | self.weight = weight 334 | self.reduction = reduction 335 | 336 | def forward(self, inp, targ): 337 | "Applies focal loss based on https://arxiv.org/pdf/1708.02002.pdf" 338 | ce_loss = F.binary_cross_entropy_with_logits( 339 | inp, targ, weight=self.weight, reduction="none" 340 | ) 341 | p_t = torch.exp(-ce_loss) 342 | loss = (1 - p_t) ** self.gamma * ce_loss 343 | if self.reduction == "mean": 344 | loss = loss.mean() 345 | elif self.reduction == "sum": 346 | loss = loss.sum() 347 | return loss 348 | 349 | 350 | def write_to_tensorboard(writer, description_dict, t="epoch"): 351 | for k, v in description_dict.items(): 352 | if k == "epoch": 353 | continue 354 | if k == "step": 355 | continue 356 | writer.add_scalar(k, v, description_dict[t]) 357 | 358 | 359 | def get_classifier(args, embed_dim): 360 | args.num_classes = int(args.num_classes) 361 | if args.classifier_type == "simple_clf": 362 | classifier = simple_clf(embed_dim, args.num_classes, p=float(args.dropout)) 363 | elif args.classifier_type == "residual_add_clf": 364 | classifier = residual_add_clf( 365 | embed_dim, 366 | args.num_classes, 367 | n_layers=int(args.n_layers), 368 | n_units=int(args.n_units), 369 | norm_layer=args.norm_layer if "norm_layer" in args else None, 370 | skip=args.skip, 371 | ) 372 | return classifier 373 | 374 | -------------------------------------------------------------------------------- /data_utils/file_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision 3 | import torch 4 | import os 5 | import numpy as np 6 | import os.path 7 | import pandas as pd 8 | from skimage import io 9 | from pathlib import Path 10 | from functools import partial 11 | from sklearn.preprocessing import StandardScaler 12 | from data_utils import cellpainting_dataset 13 | 14 | 15 | t = torchvision.transforms.ToTensor() 16 | 17 | 18 | def tensor_loader(path, training=True): 19 | return torch.load(path) 20 | 21 | 22 | def default_loader(path, training=True): 23 | return t(io.imread(path)) 24 | 25 | 26 | def one_channel_loader(path, training=True): 27 | img = io.imread(path) 28 | if training: 29 | ch = np.random.randint(0, 4) 30 | return t(img[:, :, ch]) 31 | else: 32 | return [t(img[:, :, i]) for i in range(img.shape[-1])] 33 | 34 | 35 | def two_channel_loader(path, training=True): 36 | img = io.imread(path) 37 | if training: 38 | out = img[:, :, 0:2] 39 | ch = np.random.randint(0, 4) 40 | out[:, :, 0] = t(img[:, :, ch]) 41 | else: 42 | out = [t(img[:, :, [i, 1]]) for i in range(4)] 43 | return out 44 | 45 | 46 | def protein_channel_loader(path, training=True): 47 | img = io.imread(path) 48 | return t(img[:, :, 1]) 49 | 50 | 51 | def single_channel_loader( 52 | path, 53 | training=True, 54 | channel=0, 55 | triple=True, 56 | ): 57 | img = io.imread(path) 58 | # img = io.imread(path).transpose(1, 2, 0) 59 | img = img[:, :, [channel]] 60 | if triple: 61 | img = np.repeat(img, 3, axis=2) 62 | return t(img.astype(np.uint8)) 63 | 64 | 65 | # def single_channel_loader( 66 | # path, 67 | # training=True, 68 | # channel=0, 69 | # ): 70 | # img = io.imread(path) 71 | # # img = io.imread(path).transpose(1, 2, 0) 72 | # img = img[:, :, [channel]] 73 | # img = np.repeat(img, 3, axis=2) 74 | # return img.astype(np.uint8) 75 | 76 | 77 | def norm_loader(path, training=True): 78 | img = io.imread(path) 79 | img = img.astype(float) 80 | img -= np.min(img, axis=(0, 1)) 81 | img *= 255 / np.max(img, axis=(0, 1)) 82 | return img.astype(np.uint8) 83 | 84 | 85 | def pandas_reader_binary_labels(flist, target_labels=None, sample_single_cells=False): 86 | """ 87 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 88 | """ 89 | if isinstance(flist, pd.DataFrame): 90 | files = flist 91 | else: 92 | files = pd.read_csv(flist)[["file", "ID", "cell_type"] + target_labels] 93 | target_matrix = files[target_labels].values.astype(int) 94 | file_names = files["file"].values 95 | IDs = files["ID"].values 96 | cell_lines = files["cell_type"].values 97 | if sample_single_cells: 98 | ID_groups = files.groupby("ID").groups 99 | IDs = sorted(ID_groups.keys()) 100 | cell_lines = [cell_lines[ID_groups[ID]] for ID in sorted(ID_groups.keys())] 101 | file_names = [file_names[ID_groups[ID]] for ID in sorted(ID_groups.keys())] 102 | target_matrix = [ 103 | target_matrix[ID_groups[ID]] for ID in sorted(ID_groups.keys()) 104 | ] 105 | imlist = [] 106 | for impath, ID, cell_line, target in zip( 107 | file_names, IDs, cell_lines, target_matrix 108 | ): 109 | imlist.append((impath, target, cell_line, ID)) 110 | return imlist 111 | 112 | 113 | def pandas_reader(flist, ids=None): 114 | """ 115 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 116 | """ 117 | if isinstance(flist, pd.DataFrame): 118 | files = flist 119 | else: 120 | files = pd.read_csv(flist)[["file", "protein_location", "cell_type", "ID"]] 121 | if type(ids) is not type(None): 122 | files = files[files.ID.isin(ids)] 123 | files = np.array(files.to_records(index=False)) 124 | imlist = [] 125 | for impath, protein_location, cell_type, ID in files: 126 | imlist.append((impath, protein_location, cell_type, ID)) 127 | return imlist 128 | 129 | 130 | def pandas_reader_no_labels(flist, target_labels=None): 131 | """ 132 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 133 | """ 134 | files = pd.read_csv(flist)[["file", "ID"]] 135 | print(len(files)) 136 | files = np.array(files.to_records(index=False)) 137 | imlist = [] 138 | for impath, imlabel in files: 139 | imlist.append((impath, imlabel)) 140 | return imlist 141 | 142 | 143 | def default_flist_reader(flist): 144 | """ 145 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 146 | """ 147 | imlist = [] 148 | with open(flist, "r") as rf: 149 | for line in rf.readlines(): 150 | impath, imlabel = line.strip().split(",") 151 | imlist.append((impath, imlabel)) 152 | return imlist 153 | 154 | 155 | def pandas_reader_only_file(flist, ids=None, sample_single_cells=False): 156 | """ 157 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 158 | """ 159 | if isinstance(flist, pd.DataFrame): 160 | files = flist 161 | else: 162 | files = pd.read_csv(flist)[["file", "ID"]] 163 | if type(ids) is not type(None): 164 | files = files[files.ID.isin(ids)] 165 | if sample_single_cells: 166 | ID_groups = files.groupby("ID").groups 167 | IDs = sorted(ID_groups.keys()) 168 | file_names = [ 169 | files.iloc[ID_groups[ID]]["file"].values for ID in sorted(ID_groups.keys()) 170 | ] 171 | imlist = [] 172 | for impath, ID in zip(file_names, IDs): 173 | imlist.append((impath, ID)) 174 | else: 175 | files = np.array(files.to_records(index=False)) 176 | imlist = [] 177 | for impath, ID in files: 178 | imlist.append((impath, ID)) 179 | return imlist 180 | 181 | 182 | class ImageFileList(data.Dataset): 183 | def __init__( 184 | self, 185 | flist, 186 | root, 187 | transform=None, 188 | balance=True, 189 | flist_reader=pandas_reader_binary_labels, 190 | loader=default_loader, 191 | training=True, 192 | with_labels=True, 193 | target_labels=None, 194 | single_cells=False, 195 | sample_single_cells=False, 196 | ): 197 | self.balance = balance 198 | self.imdf = pd.read_csv(flist) 199 | self.idx = [] 200 | if type(target_labels) is not type(None): 201 | self.target_labels = sorted(target_labels) 202 | self.with_target_labels = True 203 | self.imlist = flist_reader(flist, self.target_labels) 204 | else: 205 | self.imlist = flist_reader(flist) 206 | self.with_target_labels = False 207 | self.target_labels = None 208 | if balance: 209 | self.parse_labels() 210 | self.transform = transform 211 | self.loader = loader 212 | self.training = training 213 | self.with_labels = with_labels 214 | self.root = root 215 | self.single_cells = single_cells 216 | self.sample_single_cells = sample_single_cells 217 | 218 | def parse_labels(self): 219 | if type(self.target_labels) is not type(None): 220 | self.unique = self.target_labels 221 | else: 222 | labels = [x for x in self.imdf.protein_location.apply(eval)] 223 | self.unique = set() 224 | result = [[self.unique.add(x) for x in y] for y in labels] 225 | self.unique = list(self.unique) 226 | self.unique.sort() 227 | for u in self.unique: 228 | self.imdf[u] = False 229 | for k in range(len(labels)): 230 | for c in labels[k]: 231 | self.imdf.loc[k, c] = True 232 | stats = pd.DataFrame( 233 | data=[{"class": u, "freq": np.sum(self.imdf[u])} for u in self.unique] 234 | ) 235 | self.stats = stats.sort_values(by="freq") 236 | N = int(self.stats.freq.mean()) 237 | print("Sampling", N, "images per class for", len(self.unique), "classes") 238 | self.N = N * len(self.unique) 239 | 240 | def __getitem__(self, index): 241 | # Mapping index to a virtual table of classes 242 | if self.balance: 243 | class_id = self.unique[index % len(self.unique)] 244 | sample_idx = self.imdf[self.imdf[class_id]].sample(n=1).index[0] 245 | else: 246 | sample_idx = index 247 | 248 | # Identify the sample 249 | if type(self.target_labels) is not type(None): 250 | impath, protein, cell, ID = self.imlist[sample_idx] 251 | elif self.with_labels: 252 | impath, protein, cell, ID = self.imlist[sample_idx] 253 | else: 254 | if self.sample_single_cells: 255 | files, ID = self.imlist[sample_idx] 256 | # Randomly choose a cell from the whole image 257 | impath = np.random.choice(files, 1)[0] 258 | else: 259 | impath, ID = self.imlist[sample_idx] 260 | img = self.loader(self.root + impath, self.training) 261 | 262 | # Transform the image 263 | if self.transform is not None: 264 | if isinstance(img, list): 265 | img = [self.transform(i) for i in img] 266 | else: 267 | img = self.transform(img) 268 | 269 | # Return the item 270 | if self.training: 271 | return img, ID 272 | elif type(self.target_labels) is not type(None): 273 | return img, protein.astype(int), cell, ID 274 | elif self.with_labels: 275 | return img, protein, cell, ID 276 | elif self.single_cells: 277 | return img, ID, Path(impath).stem 278 | else: 279 | return img, ID 280 | 281 | def __len__(self): 282 | return len(self.imlist) 283 | 284 | 285 | class AutoBalancedPrecomputedFeatures(data.Dataset): 286 | def __init__(self, source, balance, target_column, scaler=None, **kwargs): 287 | features, proteins, cells, IDs = torch.load(source) 288 | 289 | if isinstance(features, np.ndarray): 290 | self.features = torch.Tensor(features) 291 | elif isinstance(features, torch.Tensor): 292 | self.features = features.detach().cpu() 293 | self.IDs = np.array(IDs) 294 | self.proteins = proteins 295 | self.cells = cells 296 | if target_column == "proteins": 297 | if isinstance(proteins, np.ndarray): 298 | self.target = torch.Tensor(proteins) 299 | elif isinstance(proteins, torch.Tensor): 300 | self.target = proteins.detach().cpu() 301 | elif target_column == "cells": 302 | if isinstance(cells, np.ndarray): 303 | self.target = torch.Tensor(self.cells) 304 | elif isinstance(cells, torch.Tensor): 305 | self.target = self.cells.detach().cpu() 306 | # the following line removes all IDs from the origianl kaggle 307 | # competition (Since they had no cell type) 308 | indices = np.where(pd.DataFrame(IDs)[0].str.contains("-") == False)[0] 309 | self.features = self.features[indices, :] 310 | self.proteins = self.proteins[indices, :] 311 | self.cells = self.cells[indices, :] 312 | self.IDs = self.IDs[indices] 313 | self.target = self.target[indices, :] 314 | self.idx = [] 315 | self.df = pd.DataFrame(range(len(self.features)), columns=["ind"]) 316 | if balance: 317 | self.parse_labels() 318 | self.balance = balance 319 | 320 | def scale_features(self, scaler): 321 | self.scaler = scaler 322 | if self.scaler == "find_statistics": 323 | print("scaling training data!") 324 | self.scaler = StandardScaler().fit(self.features) 325 | self.features = torch.Tensor(self.scaler.transform(self.features.numpy())) 326 | elif self.scaler is not None: 327 | print("scaling validation / testing data!") 328 | self.features = torch.Tensor(self.scaler.transform(self.features.numpy())) 329 | 330 | def parse_labels(self): 331 | stats = pd.DataFrame( 332 | data=[ 333 | {"class": u, "freq": self.target[:, u].sum().item()} 334 | for u in range(self.target.shape[1]) 335 | ] 336 | ) 337 | self.stats = stats.sort_values(by="freq") 338 | N = int(self.stats.freq.mean()) 339 | print("Sampling", N, "samples per class for", self.target.shape[1], "classes") 340 | self.N = N * self.target.shape[1] 341 | 342 | def __getitem__(self, index): 343 | # Mapping index to a virtual table of classes 344 | # class_id = list(range(self.target.shape[1]))[index % self.target.shape[1]] 345 | class_id = np.random.choice(list(range(self.target.shape[1]))) 346 | while self.target[:, class_id].sum() == 0: 347 | class_id = np.random.choice(list(range(self.target.shape[1]))) 348 | if self.balance: 349 | sample_idx = np.random.choice(np.where(self.target[:, class_id])[0], 1) 350 | else: 351 | sample_idx = index 352 | # sample_idx = self.df[self.df[str(class_id)]].sample(n=1).index[0] 353 | return ( 354 | self.features[sample_idx], 355 | self.target[sample_idx], 356 | ) 357 | 358 | def __len__(self): 359 | return len(self.df) 360 | 361 | 362 | class AutoBalancedFileList(ImageFileList): 363 | def __init__( 364 | self, 365 | flist, 366 | root, 367 | transform=None, 368 | flist_reader=None, 369 | loader=default_loader, 370 | training=True, 371 | with_labels=True, 372 | target_labels=None, 373 | ): 374 | self.target_labels = sorted(target_labels) 375 | self.imdf = pd.read_csv(flist) 376 | self.parse_labels() 377 | self.transform = transform 378 | self.loader = loader 379 | self.training = training 380 | self.with_labels = with_labels 381 | self.root = root 382 | self.idx = [] 383 | 384 | def parse_labels(self): 385 | if type(self.target_labels) is not type(None): 386 | self.unique = self.target_labels 387 | else: 388 | labels = [x for x in self.imdf.protein_location.apply(eval)] 389 | self.unique = set() 390 | result = [[self.unique.add(x) for x in y] for y in labels] 391 | self.unique = list(self.unique) 392 | self.unique.sort() 393 | for u in self.unique: 394 | self.imdf[u] = False 395 | for k in range(len(labels)): 396 | for c in labels[k]: 397 | self.imdf.loc[k, c] = True 398 | stats = pd.DataFrame( 399 | data=[{"class": u, "freq": np.sum(self.imdf[u])} for u in self.unique] 400 | ) 401 | self.stats = stats.sort_values(by="freq") 402 | N = int(self.stats.freq.mean()) 403 | print("Sampling", N, "images per class for", len(self.unique), "classes") 404 | self.N = N * len(self.unique) 405 | 406 | def __getitem__(self, index): 407 | # Mapping index to a virtual table of classes 408 | class_id = self.unique[index % len(self.unique)] 409 | sample_idx = self.imdf[self.imdf[class_id]].sample(n=1).index[0] 410 | 411 | # Identify the sample 412 | if type(self.target_labels) is not type(None): 413 | impath, ID = self.imdf.iloc[sample_idx][["file", "ID"]] 414 | protein = self.imdf.iloc[sample_idx][self.target_labels].values.astype(int) 415 | elif self.with_labels: 416 | impath, protein, cell, ID = self.imdf.iloc[sample_idx][ 417 | ["file", "protein_location", "cell_type", "ID"] 418 | ] 419 | else: 420 | impath, ID = self.imdf.iloc[sample_idx][["file", "ID"]] 421 | img = self.loader(self.root + impath, self.training) 422 | 423 | # Transform the image 424 | if self.transform is not None: 425 | if isinstance(img, list): 426 | img = [self.transform(i) for i in img] 427 | else: 428 | img = self.transform(img) 429 | 430 | # Return the item 431 | if self.training: 432 | return img, protein.astype(int) 433 | elif type(self.target_labels) is not type(None): 434 | return img, protein.astype(int), cell, ID 435 | elif self.with_labels: 436 | return img, protein, cell, ID 437 | else: 438 | return img, ID 439 | 440 | def __len__(self): 441 | return self.N 442 | 443 | 444 | data_loaders = { 445 | "HPA": ImageFileList, 446 | "HPABalanced": AutoBalancedFileList, 447 | "CellPainting": cellpainting_dataset.SingleCellDataset, 448 | } 449 | 450 | image_modes = { 451 | "normalized_3_channels": default_loader, 452 | "normalized_4_channels": default_loader, 453 | "unnormalized_4_channels": norm_loader, 454 | "single_channel_r": partial(single_channel_loader, channel=0, triple=True), 455 | "single_channel_g": partial(single_channel_loader, channel=1, triple=True), 456 | "single_channel_b": partial(single_channel_loader, channel=2, triple=True), 457 | "single_channel_y": partial(single_channel_loader, channel=3, triple=True), 458 | "single_channel_r_no_triple": partial( 459 | single_channel_loader, channel=0, triple=False 460 | ), 461 | "single_channel_g_no_triple": partial( 462 | single_channel_loader, channel=1, triple=False 463 | ), 464 | "single_channel_b_no_triple": partial( 465 | single_channel_loader, channel=2, triple=False 466 | ), 467 | "single_channel_y_no_triple": partial( 468 | single_channel_loader, channel=3, triple=False 469 | ), 470 | "one_channel": one_channel_loader, 471 | "two_channels": two_channel_loader, 472 | "protein_channel": protein_channel_loader, 473 | } 474 | 475 | readers = {True: pandas_reader_binary_labels, False: pandas_reader_no_labels} 476 | # readers = {True: pandas_reader, False: pandas_reader_no_labels} 477 | 478 | 479 | def scKaggle_loader(fn, reader=io.imread): 480 | image = reader(fn) 481 | return image 482 | 483 | 484 | def scKaggle_df_reader(csvpath): 485 | """ 486 | Used for training MLP for scKaggle competition. 487 | Will hijack ID in ImageFileList as labels... 488 | """ 489 | df = pd.read_csv(csvpath) 490 | targs = torch.tensor(df.iloc[:, 9:].values.astype(np.float32)) 491 | lbls = df["Path"].values 492 | imlist = [[impath, targ] for impath, targ in zip(lbls, targs)] 493 | return imlist 494 | -------------------------------------------------------------------------------- /archs/cnns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from torchvision.models.utils import load_state_dict_from_url 5 | from typing import Type, Any, Callable, Union, List, Optional 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2','resnet50_4_channels'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 17 | 'resnet50_4_channels': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion: int = 1 40 | 41 | def __init__( 42 | self, 43 | inplanes: int, 44 | planes: int, 45 | stride: int = 1, 46 | downsample: Optional[nn.Module] = None, 47 | groups: int = 1, 48 | base_width: int = 64, 49 | dilation: int = 1, 50 | norm_layer: Optional[Callable[..., nn.Module]] = None 51 | ) -> None: 52 | super(BasicBlock, self).__init__() 53 | if norm_layer is None: 54 | norm_layer = nn.BatchNorm2d 55 | if groups != 1 or base_width != 64: 56 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 57 | if dilation > 1: 58 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 59 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 60 | self.conv1 = conv3x3(inplanes, planes, stride) 61 | self.bn1 = norm_layer(planes) 62 | self.relu = nn.ReLU(inplace=True) 63 | self.conv2 = conv3x3(planes, planes) 64 | self.bn2 = norm_layer(planes) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | 78 | if self.downsample is not None: 79 | identity = self.downsample(x) 80 | 81 | out += identity 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class Bottleneck(nn.Module): 88 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 89 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 90 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 91 | # This variant is also known as ResNet V1.5 and improves accuracy according to 92 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 93 | 94 | expansion: int = 4 95 | 96 | def __init__( 97 | self, 98 | inplanes: int, 99 | planes: int, 100 | stride: int = 1, 101 | downsample: Optional[nn.Module] = None, 102 | groups: int = 1, 103 | base_width: int = 64, 104 | dilation: int = 1, 105 | norm_layer: Optional[Callable[..., nn.Module]] = None 106 | ) -> None: 107 | super(Bottleneck, self).__init__() 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | width = int(planes * (base_width / 64.)) * groups 111 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 112 | self.conv1 = conv1x1(inplanes, width) 113 | self.bn1 = norm_layer(width) 114 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 115 | self.bn2 = norm_layer(width) 116 | self.conv3 = conv1x1(width, planes * self.expansion) 117 | self.bn3 = norm_layer(planes * self.expansion) 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | 122 | def forward(self, x: Tensor) -> Tensor: 123 | identity = x 124 | 125 | out = self.conv1(x) 126 | out = self.bn1(out) 127 | out = self.relu(out) 128 | 129 | out = self.conv2(out) 130 | out = self.bn2(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv3(out) 134 | out = self.bn3(out) 135 | 136 | if self.downsample is not None: 137 | identity = self.downsample(x) 138 | 139 | out += identity 140 | out = self.relu(out) 141 | 142 | return out 143 | 144 | 145 | class ResNet(nn.Module): 146 | def __init__( 147 | self, 148 | block: Type[Union[BasicBlock, Bottleneck]], 149 | layers: List[int], 150 | num_classes: int = 1000, 151 | zero_init_residual: bool = False, 152 | groups: int = 1, 153 | width_per_group: int = 64, 154 | replace_stride_with_dilation: Optional[List[bool]] = None, 155 | norm_layer: Optional[Callable[..., nn.Module]] = None 156 | ) -> None: 157 | super(ResNet, self).__init__() 158 | if norm_layer is None: 159 | norm_layer = nn.BatchNorm2d 160 | self._norm_layer = norm_layer 161 | 162 | self.inplanes = 64 163 | self.dilation = 1 164 | if replace_stride_with_dilation is None: 165 | # each element in the tuple indicates if we should replace 166 | # the 2x2 stride with a dilated convolution instead 167 | replace_stride_with_dilation = [False, False, False] 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError("replace_stride_with_dilation should be None " 170 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 171 | self.groups = groups 172 | self.base_width = width_per_group 173 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 174 | bias=False) 175 | self.bn1 = norm_layer(self.inplanes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 178 | self.layer1 = self._make_layer(block, 64, layers[0]) 179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 180 | dilate=replace_stride_with_dilation[0]) 181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 182 | dilate=replace_stride_with_dilation[1]) 183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 184 | dilate=replace_stride_with_dilation[2]) 185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 186 | self.fc = nn.Linear(512 * block.expansion, num_classes) 187 | 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 204 | 205 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 206 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 207 | norm_layer = self._norm_layer 208 | downsample = None 209 | previous_dilation = self.dilation 210 | if dilate: 211 | self.dilation *= stride 212 | stride = 1 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | conv1x1(self.inplanes, planes * block.expansion, stride), 216 | norm_layer(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 221 | self.base_width, previous_dilation, norm_layer)) 222 | self.inplanes = planes * block.expansion 223 | for _ in range(1, blocks): 224 | layers.append(block(self.inplanes, planes, groups=self.groups, 225 | base_width=self.base_width, dilation=self.dilation, 226 | norm_layer=norm_layer)) 227 | 228 | return nn.Sequential(*layers) 229 | 230 | def _forward_impl(self, x: Tensor) -> Tensor: 231 | # See note [TorchScript super()] 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.avgpool(x) 243 | x = torch.flatten(x, 1) 244 | x = self.fc(x) 245 | 246 | return x 247 | 248 | def forward(self, x: Tensor) -> Tensor: 249 | return self._forward_impl(x) 250 | 251 | class ResNet_4_channels(nn.Module): 252 | def __init__( 253 | self, 254 | block: Type[Union[BasicBlock, Bottleneck]], 255 | layers: List[int], 256 | num_classes: int = 1000, 257 | zero_init_residual: bool = False, 258 | groups: int = 1, 259 | width_per_group: int = 64, 260 | replace_stride_with_dilation: Optional[List[bool]] = None, 261 | norm_layer: Optional[Callable[..., nn.Module]] = None 262 | ) -> None: 263 | super(ResNet, self).__init__() 264 | if norm_layer is None: 265 | norm_layer = nn.BatchNorm2d 266 | self._norm_layer = norm_layer 267 | 268 | self.inplanes = 64 269 | self.dilation = 1 270 | if replace_stride_with_dilation is None: 271 | # each element in the tuple indicates if we should replace 272 | # the 2x2 stride with a dilated convolution instead 273 | replace_stride_with_dilation = [False, False, False] 274 | if len(replace_stride_with_dilation) != 3: 275 | raise ValueError("replace_stride_with_dilation should be None " 276 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 277 | self.groups = groups 278 | self.base_width = width_per_group 279 | self.conv1 = nn.Conv2d(4, self.inplanes, kernel_size=7, stride=2, padding=3, 280 | bias=False) 281 | self.bn1 = norm_layer(self.inplanes) 282 | self.relu = nn.ReLU(inplace=True) 283 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 284 | self.layer1 = self._make_layer(block, 64, layers[0]) 285 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 286 | dilate=replace_stride_with_dilation[0]) 287 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 288 | dilate=replace_stride_with_dilation[1]) 289 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 290 | dilate=replace_stride_with_dilation[2]) 291 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 292 | self.fc = nn.Linear(512 * block.expansion, num_classes) 293 | 294 | for m in self.modules(): 295 | if isinstance(m, nn.Conv2d): 296 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 297 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 298 | nn.init.constant_(m.weight, 1) 299 | nn.init.constant_(m.bias, 0) 300 | 301 | # Zero-initialize the last BN in each residual branch, 302 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 303 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 304 | if zero_init_residual: 305 | for m in self.modules(): 306 | if isinstance(m, Bottleneck): 307 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 308 | elif isinstance(m, BasicBlock): 309 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 310 | 311 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 312 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 313 | norm_layer = self._norm_layer 314 | downsample = None 315 | previous_dilation = self.dilation 316 | if dilate: 317 | self.dilation *= stride 318 | stride = 1 319 | if stride != 1 or self.inplanes != planes * block.expansion: 320 | downsample = nn.Sequential( 321 | conv1x1(self.inplanes, planes * block.expansion, stride), 322 | norm_layer(planes * block.expansion), 323 | ) 324 | 325 | layers = [] 326 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 327 | self.base_width, previous_dilation, norm_layer)) 328 | self.inplanes = planes * block.expansion 329 | for _ in range(1, blocks): 330 | layers.append(block(self.inplanes, planes, groups=self.groups, 331 | base_width=self.base_width, dilation=self.dilation, 332 | norm_layer=norm_layer)) 333 | 334 | return nn.Sequential(*layers) 335 | 336 | def _forward_impl(self, x: Tensor) -> Tensor: 337 | # See note [TorchScript super()] 338 | x = self.conv1(x) 339 | x = self.bn1(x) 340 | x = self.relu(x) 341 | x = self.maxpool(x) 342 | 343 | x = self.layer1(x) 344 | x = self.layer2(x) 345 | x = self.layer3(x) 346 | x = self.layer4(x) 347 | 348 | x = self.avgpool(x) 349 | x = torch.flatten(x, 1) 350 | x = self.fc(x) 351 | 352 | return x 353 | 354 | def forward(self, x: Tensor) -> Tensor: 355 | return self._forward_impl(x) 356 | 357 | 358 | 359 | def _resnet_4_channels( 360 | arch: str, 361 | block: Type[Union[BasicBlock, Bottleneck]], 362 | layers: List[int], 363 | pretrained: bool, 364 | progress: bool, 365 | **kwargs: Any 366 | ) -> ResNet: 367 | model = ResNet_4_channels(block, layers, **kwargs) 368 | if pretrained: 369 | state_dict = load_state_dict_from_url(model_urls[arch], 370 | progress=progress) 371 | model.load_state_dict(state_dict) 372 | return model 373 | 374 | def _resnet( 375 | arch: str, 376 | block: Type[Union[BasicBlock, Bottleneck]], 377 | layers: List[int], 378 | pretrained: bool, 379 | progress: bool, 380 | **kwargs: Any 381 | ) -> ResNet: 382 | model = ResNet(block, layers, **kwargs) 383 | if pretrained: 384 | state_dict = load_state_dict_from_url(model_urls[arch], 385 | progress=progress) 386 | model.load_state_dict(state_dict) 387 | return model 388 | 389 | 390 | 391 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 392 | r"""ResNet-18 model from 393 | `"Deep Residual Learning for Image Recognition" `_. 394 | 395 | Args: 396 | pretrained (bool): If True, returns a model pre-trained on ImageNet 397 | progress (bool): If True, displays a progress bar of the download to stderr 398 | """ 399 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 400 | **kwargs) 401 | 402 | 403 | 404 | 405 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 406 | r"""ResNet-34 model from 407 | `"Deep Residual Learning for Image Recognition" `_. 408 | 409 | Args: 410 | pretrained (bool): If True, returns a model pre-trained on ImageNet 411 | progress (bool): If True, displays a progress bar of the download to stderr 412 | """ 413 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 414 | **kwargs) 415 | 416 | 417 | 418 | 419 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 420 | r"""ResNet-50 model from 421 | `"Deep Residual Learning for Image Recognition" `_. 422 | 423 | Args: 424 | pretrained (bool): If True, returns a model pre-trained on ImageNet 425 | progress (bool): If True, displays a progress bar of the download to stderr 426 | """ 427 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 428 | **kwargs) 429 | 430 | 431 | def resnet50_4_channels(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet_4_channels: 432 | r"""ResNet-50 model from 433 | `"Deep Residual Learning for Image Recognition" `_. 434 | 435 | Args: 436 | pretrained (bool): If True, returns a model pre-trained on ImageNet 437 | progress (bool): If True, displays a progress bar of the download to stderr 438 | """ 439 | return _resnet_4_channels('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 440 | **kwargs) 441 | 442 | 443 | 444 | 445 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 446 | r"""ResNet-101 model from 447 | `"Deep Residual Learning for Image Recognition" `_. 448 | 449 | Args: 450 | pretrained (bool): If True, returns a model pre-trained on ImageNet 451 | progress (bool): If True, displays a progress bar of the download to stderr 452 | """ 453 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 454 | **kwargs) 455 | 456 | 457 | 458 | 459 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 460 | r"""ResNet-152 model from 461 | `"Deep Residual Learning for Image Recognition" `_. 462 | 463 | Args: 464 | pretrained (bool): If True, returns a model pre-trained on ImageNet 465 | progress (bool): If True, displays a progress bar of the download to stderr 466 | """ 467 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 468 | **kwargs) 469 | 470 | 471 | 472 | 473 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 474 | r"""ResNeXt-50 32x4d model from 475 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 476 | 477 | Args: 478 | pretrained (bool): If True, returns a model pre-trained on ImageNet 479 | progress (bool): If True, displays a progress bar of the download to stderr 480 | """ 481 | kwargs['groups'] = 32 482 | kwargs['width_per_group'] = 4 483 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 484 | pretrained, progress, **kwargs) 485 | 486 | 487 | 488 | 489 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 490 | r"""ResNeXt-101 32x8d model from 491 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 492 | 493 | Args: 494 | pretrained (bool): If True, returns a model pre-trained on ImageNet 495 | progress (bool): If True, displays a progress bar of the download to stderr 496 | """ 497 | kwargs['groups'] = 32 498 | kwargs['width_per_group'] = 8 499 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 500 | pretrained, progress, **kwargs) 501 | 502 | 503 | 504 | 505 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 506 | r"""Wide ResNet-50-2 model from 507 | `"Wide Residual Networks" `_. 508 | 509 | The model is the same as ResNet except for the bottleneck number of channels 510 | which is twice larger in every block. The number of channels in outer 1x1 511 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 512 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 513 | 514 | Args: 515 | pretrained (bool): If True, returns a model pre-trained on ImageNet 516 | progress (bool): If True, displays a progress bar of the download to stderr 517 | """ 518 | kwargs['width_per_group'] = 64 * 2 519 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 520 | pretrained, progress, **kwargs) 521 | 522 | 523 | 524 | 525 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 526 | r"""Wide ResNet-101-2 model from 527 | `"Wide Residual Networks" `_. 528 | 529 | The model is the same as ResNet except for the bottleneck number of channels 530 | which is twice larger in every block. The number of channels in outer 1x1 531 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 532 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 533 | 534 | Args: 535 | pretrained (bool): If True, returns a model pre-trained on ImageNet 536 | progress (bool): If True, displays a progress bar of the download to stderr 537 | """ 538 | kwargs['width_per_group'] = 64 * 2 539 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 540 | pretrained, progress, **kwargs) 541 | 542 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Misc functions. 16 | Mostly copy-paste from torchvision references or other public repos like DETR: 17 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 18 | """ 19 | import numbers 20 | from typing import Tuple, List, Optional 21 | from sklearn.metrics import f1_score 22 | import argparse 23 | import math 24 | from torch import Tensor 25 | import warnings 26 | from collections.abc import Sequence 27 | from PIL import Image 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F 31 | import torchvision 32 | from torchvision import transforms 33 | from skimage.filters import threshold_otsu 34 | from PIL import Image, ImageFilter, ImageOps 35 | from tqdm import tqdm 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.functional as F 39 | from torch.autograd import Variable 40 | from torchvision.ops import sigmoid_focal_loss 41 | from torch import nn, optim 42 | from torch.optim.lr_scheduler import CosineAnnealingLR 43 | import yaml 44 | from matplotlib import cm 45 | from sklearn.metrics.pairwise import cosine_similarity, cosine_distances 46 | from scipy.cluster.hierarchy import dendrogram, linkage, fcluster 47 | from scipy.spatial.distance import pdist, cdist 48 | from seaborn import clustermap 49 | from scipy.spatial.distance import squareform 50 | from sklearn.metrics import average_precision_score 51 | import pandas as pd 52 | from skimage import io 53 | import matplotlib.ticker as mtick 54 | import kornia.geometry.transform.imgwarp as K 55 | import os 56 | import sys 57 | import time 58 | import datetime 59 | import random 60 | import subprocess 61 | from collections import defaultdict, deque 62 | import torch.distributed as dist 63 | 64 | 65 | def load_pretrained_weights( 66 | model, pretrained_weights, checkpoint_key, model_name, patch_size 67 | ): 68 | if os.path.isfile(pretrained_weights): 69 | state_dict = torch.load(pretrained_weights, map_location="cpu") 70 | if checkpoint_key is not None and checkpoint_key in state_dict: 71 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 72 | state_dict = state_dict[checkpoint_key] 73 | # remove `module.` prefix 74 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 75 | # remove `backbone.` prefix induced by multicrop wrapper 76 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 77 | msg = model.load_state_dict(state_dict, strict=False) 78 | print( 79 | "Pretrained weights found at {} and loaded with msg: {}".format( 80 | pretrained_weights, msg 81 | ) 82 | ) 83 | else: 84 | print( 85 | "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." 86 | ) 87 | url = None 88 | if model_name == "vit_small" and patch_size == 16: 89 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 90 | elif model_name == "vit_small" and patch_size == 8: 91 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" 92 | elif model_name == "vit_base" and patch_size == 16: 93 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 94 | elif model_name == "vit_base" and patch_size == 8: 95 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 96 | if url is not None: 97 | print( 98 | "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." 99 | ) 100 | state_dict = torch.hub.load_state_dict_from_url( 101 | url="https://dl.fbaipublicfiles.com/dino/" + url 102 | ) 103 | model.load_state_dict(state_dict, strict=True) 104 | else: 105 | print( 106 | "There is no reference weights available for this model => We use random weights." 107 | ) 108 | 109 | 110 | def clip_gradients(model, clip): 111 | norms = [] 112 | for name, p in model.named_parameters(): 113 | if p.grad is not None: 114 | param_norm = p.grad.data.norm(2) 115 | norms.append(param_norm.item()) 116 | clip_coef = clip / (param_norm + 1e-6) 117 | if clip_coef < 1: 118 | p.grad.data.mul_(clip_coef) 119 | return norms 120 | 121 | 122 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 123 | if epoch >= freeze_last_layer: 124 | return 125 | for n, p in model.named_parameters(): 126 | if "last_layer" in n: 127 | p.grad = None 128 | 129 | 130 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 131 | """ 132 | Re-start from checkpoint 133 | """ 134 | if not os.path.isfile(ckp_path): 135 | return 136 | print("Found checkpoint at {}".format(ckp_path)) 137 | 138 | # open checkpoint file 139 | checkpoint = torch.load(ckp_path, map_location="cpu") 140 | 141 | # key is what to look for in the checkpoint file 142 | # value is the object to load 143 | # example: {'state_dict': model} 144 | for key, value in kwargs.items(): 145 | if key in checkpoint and value is not None: 146 | try: 147 | msg = value.load_state_dict(checkpoint[key], strict=False) 148 | print( 149 | "=> loaded '{}' from checkpoint '{}' with msg {}".format( 150 | key, ckp_path, msg 151 | ) 152 | ) 153 | except TypeError: 154 | try: 155 | msg = value.load_state_dict(checkpoint[key]) 156 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path)) 157 | except ValueError: 158 | print( 159 | "=> failed to load '{}' from checkpoint: '{}'".format( 160 | key, ckp_path 161 | ) 162 | ) 163 | else: 164 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) 165 | 166 | # re load variable important for the run 167 | if run_variables is not None: 168 | for var_name in run_variables: 169 | if var_name in checkpoint: 170 | run_variables[var_name] = checkpoint[var_name] 171 | 172 | 173 | def cosine_scheduler( 174 | base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0 175 | ): 176 | warmup_schedule = np.array([]) 177 | warmup_iters = warmup_epochs * niter_per_ep 178 | if warmup_epochs > 0: 179 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 180 | 181 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 182 | schedule = final_value + 0.5 * (base_value - final_value) * ( 183 | 1 + np.cos(np.pi * iters / len(iters)) 184 | ) 185 | 186 | schedule = np.concatenate((warmup_schedule, schedule)) 187 | assert len(schedule) == epochs * niter_per_ep 188 | return schedule 189 | 190 | 191 | def bool_flag(s): 192 | """ 193 | Parse boolean arguments from the command line. 194 | """ 195 | FALSY_STRINGS = {"off", "false", "0"} 196 | TRUTHY_STRINGS = {"on", "true", "1"} 197 | if s.lower() in FALSY_STRINGS: 198 | return False 199 | elif s.lower() in TRUTHY_STRINGS: 200 | return True 201 | else: 202 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 203 | 204 | 205 | def fix_random_seeds(seed=31): 206 | """ 207 | Fix random seeds. 208 | """ 209 | torch.manual_seed(seed) 210 | torch.cuda.manual_seed_all(seed) 211 | np.random.seed(seed) 212 | 213 | 214 | class SmoothedValue(object): 215 | """Track a series of values and provide access to smoothed values over a 216 | window or the global series average. 217 | """ 218 | 219 | def __init__(self, window_size=20, fmt=None): 220 | if fmt is None: 221 | fmt = "{median:.6f} ({global_avg:.6f})" 222 | self.deque = deque(maxlen=window_size) 223 | self.total = 0.0 224 | self.count = 0 225 | self.fmt = fmt 226 | 227 | def update(self, value, n=1): 228 | self.deque.append(value) 229 | self.count += n 230 | self.total += value * n 231 | 232 | def synchronize_between_processes(self): 233 | """ 234 | Warning: does not synchronize the deque! 235 | """ 236 | if not is_dist_avail_and_initialized(): 237 | return 238 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 239 | dist.barrier() 240 | dist.all_reduce(t) 241 | t = t.tolist() 242 | self.count = int(t[0]) 243 | self.total = t[1] 244 | 245 | @property 246 | def median(self): 247 | d = torch.tensor(list(self.deque)) 248 | return d.median().item() 249 | 250 | @property 251 | def avg(self): 252 | d = torch.tensor(list(self.deque), dtype=torch.float32) 253 | return d.mean().item() 254 | 255 | @property 256 | def global_avg(self): 257 | return self.total / self.count 258 | 259 | @property 260 | def max(self): 261 | return max(self.deque) 262 | 263 | @property 264 | def value(self): 265 | return self.deque[-1] 266 | 267 | def __str__(self): 268 | return self.fmt.format( 269 | median=self.median, 270 | avg=self.avg, 271 | global_avg=self.global_avg, 272 | max=self.max, 273 | value=self.value, 274 | ) 275 | 276 | 277 | def reduce_dict(input_dict, average=True): 278 | """ 279 | Args: 280 | input_dict (dict): all the values will be reduced 281 | average (bool): whether to do average or sum 282 | Reduce the values in the dictionary from all processes so that all processes 283 | have the averaged results. Returns a dict with the same fields as 284 | input_dict, after reduction. 285 | """ 286 | world_size = get_world_size() 287 | if world_size < 2: 288 | return input_dict 289 | with torch.no_grad(): 290 | names = [] 291 | values = [] 292 | # sort the keys so that they are consistent across processes 293 | for k in sorted(input_dict.keys()): 294 | names.append(k) 295 | values.append(input_dict[k]) 296 | values = torch.stack(values, dim=0) 297 | dist.all_reduce(values) 298 | if average: 299 | values /= world_size 300 | reduced_dict = {k: v for k, v in zip(names, values)} 301 | return reduced_dict 302 | 303 | 304 | class MetricLogger(object): 305 | def __init__(self, delimiter="\t"): 306 | self.meters = defaultdict(SmoothedValue) 307 | self.delimiter = delimiter 308 | 309 | def update(self, **kwargs): 310 | for k, v in kwargs.items(): 311 | if isinstance(v, torch.Tensor): 312 | v = v.item() 313 | assert isinstance(v, (float, int)) 314 | self.meters[k].update(v) 315 | 316 | def __getattr__(self, attr): 317 | if attr in self.meters: 318 | return self.meters[attr] 319 | if attr in self.__dict__: 320 | return self.__dict__[attr] 321 | raise AttributeError( 322 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 323 | ) 324 | 325 | def __str__(self): 326 | loss_str = [] 327 | for name, meter in self.meters.items(): 328 | loss_str.append("{}: {}".format(name, str(meter))) 329 | return self.delimiter.join(loss_str) 330 | 331 | def synchronize_between_processes(self): 332 | for meter in self.meters.values(): 333 | meter.synchronize_between_processes() 334 | 335 | def add_meter(self, name, meter): 336 | self.meters[name] = meter 337 | 338 | def log_every(self, iterable, print_freq, header=None): 339 | i = 0 340 | if not header: 341 | header = "" 342 | start_time = time.time() 343 | end = time.time() 344 | iter_time = SmoothedValue(fmt="{avg:.6f}") 345 | data_time = SmoothedValue(fmt="{avg:.6f}") 346 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 347 | if torch.cuda.is_available(): 348 | log_msg = self.delimiter.join( 349 | [ 350 | header, 351 | "[{0" + space_fmt + "}/{1}]", 352 | "eta: {eta}", 353 | "{meters}", 354 | "time: {time}", 355 | "data: {data}", 356 | "max mem: {memory:.0f}", 357 | ] 358 | ) 359 | else: 360 | log_msg = self.delimiter.join( 361 | [ 362 | header, 363 | "[{0" + space_fmt + "}/{1}]", 364 | "eta: {eta}", 365 | "{meters}", 366 | "time: {time}", 367 | "data: {data}", 368 | ] 369 | ) 370 | MB = 1024.0 * 1024.0 371 | for obj in iterable: 372 | data_time.update(time.time() - end) 373 | yield obj 374 | iter_time.update(time.time() - end) 375 | if i % print_freq == 0 or i == len(iterable) - 1: 376 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 377 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 378 | if torch.cuda.is_available(): 379 | print( 380 | log_msg.format( 381 | i, 382 | len(iterable), 383 | eta=eta_string, 384 | meters=str(self), 385 | time=str(iter_time), 386 | data=str(data_time), 387 | memory=torch.cuda.max_memory_allocated() / MB, 388 | ) 389 | ) 390 | else: 391 | print( 392 | log_msg.format( 393 | i, 394 | len(iterable), 395 | eta=eta_string, 396 | meters=str(self), 397 | time=str(iter_time), 398 | data=str(data_time), 399 | ) 400 | ) 401 | i += 1 402 | end = time.time() 403 | total_time = time.time() - start_time 404 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 405 | print( 406 | "{} Total time: {} ({:.6f} s / it)".format( 407 | header, total_time_str, total_time / len(iterable) 408 | ) 409 | ) 410 | 411 | 412 | def get_sha(): 413 | cwd = os.path.dirname(os.path.abspath(__file__)) 414 | 415 | def _run(command): 416 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 417 | 418 | sha = "N/A" 419 | diff = "clean" 420 | branch = "N/A" 421 | try: 422 | sha = _run(["git", "rev-parse", "HEAD"]) 423 | subprocess.check_output(["git", "diff"], cwd=cwd) 424 | diff = _run(["git", "diff-index", "HEAD"]) 425 | diff = "has uncommited changes" if diff else "clean" 426 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 427 | except Exception: 428 | pass 429 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 430 | return message 431 | 432 | 433 | def is_dist_avail_and_initialized(): 434 | if not dist.is_available(): 435 | return False 436 | if not dist.is_initialized(): 437 | return False 438 | return True 439 | 440 | 441 | def get_world_size(): 442 | if not is_dist_avail_and_initialized(): 443 | return 1 444 | return dist.get_world_size() 445 | 446 | 447 | def get_rank(): 448 | if not is_dist_avail_and_initialized(): 449 | return 0 450 | return dist.get_rank() 451 | 452 | 453 | def is_main_process(): 454 | return get_rank() == 0 455 | 456 | 457 | def save_on_master(*args, **kwargs): 458 | if is_main_process(): 459 | torch.save(*args, **kwargs) 460 | 461 | 462 | def setup_for_distributed(is_master): 463 | """ 464 | This function disables printing when not in master process 465 | """ 466 | import builtins as __builtin__ 467 | 468 | builtin_print = __builtin__.print 469 | 470 | def print(*args, **kwargs): 471 | force = kwargs.pop("force", False) 472 | if is_master or force: 473 | builtin_print(*args, **kwargs) 474 | 475 | __builtin__.print = print 476 | 477 | 478 | def init_distributed_mode(args): 479 | # launched with torch.distributed.launch 480 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 481 | args.rank = int(os.environ["RANK"]) 482 | args.world_size = int(os.environ["WORLD_SIZE"]) 483 | args.gpu = int(os.environ["LOCAL_RANK"]) 484 | # launched with submitit on a slurm cluster 485 | elif "SLURM_PROCID" in os.environ: 486 | args.rank = int(os.environ["SLURM_PROCID"]) 487 | args.gpu = args.rank % torch.cuda.device_count() 488 | args.world_size = int(os.environ["SLURM_NNODES"]) * int( 489 | os.environ["SLURM_TASKS_PER_NODE"][0] 490 | ) 491 | # launched naively with `python main_dino.py` 492 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 493 | elif torch.cuda.is_available(): 494 | print("Will run the code on one GPU.") 495 | args.rank, args.gpu, args.world_size = 0, 0, 1 496 | os.environ["MASTER_ADDR"] = "127.0.0.1" 497 | os.environ["MASTER_PORT"] = "29500" 498 | else: 499 | print("Does not support training without GPU.") 500 | sys.exit(1) 501 | 502 | dist.init_process_group( 503 | backend="nccl", 504 | init_method=args.dist_url, 505 | world_size=args.world_size, 506 | rank=args.rank, 507 | ) 508 | 509 | torch.cuda.set_device(args.gpu) 510 | print( 511 | "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True 512 | ) 513 | dist.barrier() 514 | setup_for_distributed(args.rank == 0) 515 | 516 | 517 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 518 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 519 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 520 | def norm_cdf(x): 521 | # Computes standard normal cumulative distribution function 522 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 523 | 524 | if (mean < a - 2 * std) or (mean > b + 2 * std): 525 | warnings.warn( 526 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 527 | "The distribution of values may be incorrect.", 528 | stacklevel=2, 529 | ) 530 | 531 | with torch.no_grad(): 532 | # Values are generated by using a truncated uniform distribution and 533 | # then using the inverse CDF for the normal distribution. 534 | # Get upper and lower cdf values 535 | l = norm_cdf((a - mean) / std) 536 | u = norm_cdf((b - mean) / std) 537 | 538 | # Uniformly fill tensor with values from [l, u], then translate to 539 | # [2l-1, 2u-1]. 540 | tensor.uniform_(2 * l - 1, 2 * u - 1) 541 | 542 | # Use inverse cdf transform for normal distribution to get truncated 543 | # standard normal 544 | tensor.erfinv_() 545 | 546 | # Transform to proper mean, std 547 | tensor.mul_(std * math.sqrt(2.0)) 548 | tensor.add_(mean) 549 | 550 | # Clamp to ensure it's in the proper range 551 | tensor.clamp_(min=a, max=b) 552 | return tensor 553 | 554 | 555 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 556 | # type: (Tensor, float, float, float, float) -> Tensor 557 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 558 | 559 | 560 | class LARS(torch.optim.Optimizer): 561 | """ 562 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 563 | """ 564 | 565 | def __init__( 566 | self, 567 | params, 568 | lr=0, 569 | weight_decay=0, 570 | momentum=0.9, 571 | eta=0.001, 572 | weight_decay_filter=None, 573 | lars_adaptation_filter=None, 574 | ): 575 | defaults = dict( 576 | lr=lr, 577 | weight_decay=weight_decay, 578 | momentum=momentum, 579 | eta=eta, 580 | weight_decay_filter=weight_decay_filter, 581 | lars_adaptation_filter=lars_adaptation_filter, 582 | ) 583 | super().__init__(params, defaults) 584 | 585 | @torch.no_grad() 586 | def step(self): 587 | for g in self.param_groups: 588 | for p in g["params"]: 589 | dp = p.grad 590 | 591 | if dp is None: 592 | continue 593 | 594 | if p.ndim != 1: 595 | dp = dp.add(p, alpha=g["weight_decay"]) 596 | 597 | if p.ndim != 1: 598 | param_norm = torch.norm(p) 599 | update_norm = torch.norm(dp) 600 | one = torch.ones_like(param_norm) 601 | q = torch.where( 602 | param_norm > 0.0, 603 | torch.where( 604 | update_norm > 0, (g["eta"] * param_norm / update_norm), one 605 | ), 606 | one, 607 | ) 608 | dp = dp.mul(q) 609 | 610 | param_state = self.state[p] 611 | if "mu" not in param_state: 612 | param_state["mu"] = torch.zeros_like(p) 613 | mu = param_state["mu"] 614 | mu.mul_(g["momentum"]).add_(dp) 615 | 616 | p.add_(mu, alpha=-g["lr"]) 617 | 618 | 619 | class MultiCropWrapper(nn.Module): 620 | """ 621 | Perform forward pass separately on each resolution input. 622 | The inputs corresponding to a single resolution are clubbed and single 623 | forward is run on the same resolution inputs. Hence we do several 624 | forward passes = number of different resolutions used. We then 625 | concatenate all the output features and run the head forward on these 626 | concatenated features. 627 | """ 628 | 629 | def __init__(self, backbone, head): 630 | super(MultiCropWrapper, self).__init__() 631 | # disable layers dedicated to ImageNet labels classification 632 | ## compatibility with xresnet ## 633 | try: 634 | backbone[-1], backbone.head = nn.Identity(), nn.Identity() 635 | print("Caught missing fc of xresnet in MultiCropWrapper") 636 | except: 637 | backbone.fc, backbone.head = ( 638 | nn.Identity(), 639 | nn.Identity(), 640 | ) # original single line 641 | 642 | self.backbone = backbone 643 | self.head = head 644 | 645 | def forward(self, x): 646 | # convert to list 647 | if not isinstance(x, list): 648 | x = [x] 649 | idx_crops = torch.cumsum( 650 | torch.unique_consecutive( 651 | torch.tensor([inp.shape[-1] for inp in x]), 652 | return_counts=True, 653 | )[1], 654 | 0, 655 | ) 656 | start_idx = 0 657 | for end_idx in idx_crops: 658 | _out = self.backbone(torch.cat(x[start_idx:end_idx])) 659 | if start_idx == 0: 660 | output = _out 661 | else: 662 | output = torch.cat((output, _out)) 663 | start_idx = end_idx 664 | # Run the head forward on the concatenated features. 665 | return self.head(output) 666 | 667 | 668 | def get_params_groups(*models): 669 | regularized = [] 670 | not_regularized = [] 671 | for model in models: 672 | for name, param in model.named_parameters(): 673 | if not param.requires_grad: 674 | continue 675 | # we do not regularize biases nor Norm parameters 676 | if name.endswith(".bias") or len(param.shape) == 1: 677 | not_regularized.append(param) 678 | else: 679 | regularized.append(param) 680 | return [{"params": regularized}, {"params": not_regularized, "weight_decay": 0.0}] 681 | 682 | 683 | def has_batchnorms(model): 684 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 685 | for name, module in model.named_modules(): 686 | if isinstance(module, bn_types): 687 | return True 688 | return False 689 | -------------------------------------------------------------------------------- /main_dino.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import os 16 | import sys 17 | import datetime 18 | import time 19 | import math 20 | import json 21 | from pathlib import Path 22 | from PIL import Image 23 | import yaml 24 | 25 | import numpy as np 26 | import torch 27 | import torch.nn as nn 28 | import torch.distributed as dist 29 | import torch.backends.cudnn as cudnn 30 | import torch.nn.functional as F 31 | from torchvision import datasets, transforms 32 | from torchvision import models as torchvision_models 33 | 34 | from utils import utils 35 | from utils.slurm import trigger_job_requeue, init_signal_handler 36 | from data_utils import file_dataset 37 | from functools import partial 38 | from archs import xresnet as cell_models 39 | from archs import vision_transformer as vits 40 | from archs.vision_transformer import DINOHead 41 | from utils.yaml_tfms import tfms_from_config 42 | 43 | 44 | torchvision_archs = sorted( 45 | name 46 | for name in torchvision_models.__dict__ 47 | if name.islower() 48 | and not name.startswith("__") 49 | and callable(torchvision_models.__dict__[name]) 50 | ) 51 | 52 | 53 | cell_archs = sorted( 54 | name 55 | for name in cell_models.__dict__ 56 | if name.islower() 57 | and not name.startswith("__") 58 | and callable(cell_models.__dict__[name]) 59 | ) 60 | 61 | 62 | def get_args_parser(): 63 | parser = argparse.ArgumentParser("DINO", add_help=False) 64 | 65 | # Model parameters 66 | parser.add_argument( 67 | "--arch", 68 | default="vit_small", 69 | type=str, 70 | choices=["vit_tiny", "vit_small", "vit_base", "deit_tiny", "deit_small"] 71 | + torchvision_archs 72 | + cell_archs, 73 | help="""Name of architecture to train. For quick experiments with ViTs, 74 | we recommend using vit_tiny or vit_small.""", 75 | ) 76 | parser.add_argument( 77 | "--patch_size", 78 | default=16, 79 | type=int, 80 | help="""Size in pixels 81 | of input square patches - default 16 (for 16x16 patches). Using smaller 82 | values leads to better performance but requires more memory. Applies only 83 | for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling 84 | mixed precision training (--use_fp16 false) to avoid unstabilities.""", 85 | ) 86 | parser.add_argument( 87 | "--out_dim", 88 | default=65536, 89 | type=int, 90 | help="""Dimensionality of 91 | the DINO head output. For complex and large datasets large values (like 65k) work well.""", 92 | ) 93 | parser.add_argument( 94 | "--norm_last_layer", 95 | default=True, 96 | type=utils.bool_flag, 97 | help="""Whether or not to weight normalize the last layer of the DINO head. 98 | Not normalizing leads to better performance but can make the training unstable. 99 | In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""", 100 | ) 101 | parser.add_argument( 102 | "--momentum_teacher", 103 | default=0.996, 104 | type=float, 105 | help="""Base EMA 106 | parameter for teacher update. The value is increased to 1 during training with cosine schedule. 107 | We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""", 108 | ) 109 | parser.add_argument( 110 | "--use_bn_in_head", 111 | default=False, 112 | type=utils.bool_flag, 113 | help="Whether to use batch normalizations in projection head (Default: False)", 114 | ) 115 | parser.add_argument( 116 | "--num_channels", 117 | default=3, 118 | type=int, 119 | help="""Number of channels for the Vision transformer""", 120 | ) 121 | 122 | # Temperature teacher parameters 123 | parser.add_argument( 124 | "--warmup_teacher_temp", 125 | default=0.04, 126 | type=float, 127 | help="""Initial value for the teacher temperature: 0.04 works well in most cases. 128 | Try decreasing it if the training loss does not decrease.""", 129 | ) 130 | parser.add_argument( 131 | "--teacher_temp", 132 | default=0.04, 133 | type=float, 134 | help="""Final value (after linear warmup) 135 | of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend 136 | starting with the default value of 0.04 and increase this slightly if needed.""", 137 | ) 138 | parser.add_argument( 139 | "--student_temp", 140 | default=0.1, 141 | type=float, 142 | ) 143 | parser.add_argument("--sample_single_cells", action="store_true") 144 | parser.add_argument( 145 | "--center_momentum", 146 | default=0.9, 147 | type=float, 148 | ) 149 | parser.add_argument( 150 | "--warmup_teacher_temp_epochs", 151 | default=0, 152 | type=int, 153 | help="Number of warmup epochs for the teacher temperature (Default: 30).", 154 | ) 155 | 156 | # Training/Optimization parameters 157 | parser.add_argument( 158 | "--use_fp16", 159 | type=utils.bool_flag, 160 | default=True, 161 | help="""Whether or not 162 | to use half precision for training. Improves training time and memory requirements, 163 | but can provoke instability and slight decay of performance. We recommend disabling 164 | mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""", 165 | ) 166 | parser.add_argument( 167 | "--weight_decay", 168 | type=float, 169 | default=0.04, 170 | help="""Initial value of the 171 | weight decay. With ViT, a smaller value at the beginning of training works well.""", 172 | ) 173 | parser.add_argument( 174 | "--weight_decay_end", 175 | type=float, 176 | default=0.4, 177 | help="""Final value of the 178 | weight decay. We use a cosine schedule for WD and using a larger decay by 179 | the end of training improves performance for ViTs.""", 180 | ) 181 | parser.add_argument( 182 | "--clip_grad", 183 | type=float, 184 | default=3.0, 185 | help="""Maximal parameter 186 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can 187 | help optimization for larger ViT architectures. 0 for disabling.""", 188 | ) 189 | parser.add_argument( 190 | "--batch_size_per_gpu", 191 | default=64, 192 | type=int, 193 | help="Per-GPU batch-size : number of distinct images loaded on one GPU.", 194 | ) 195 | parser.add_argument( 196 | "--epochs", default=100, type=int, help="Number of epochs of training." 197 | ) 198 | parser.add_argument( 199 | "--freeze_last_layer", 200 | default=1, 201 | type=int, 202 | help="""Number of epochs 203 | during which we keep the output layer fixed. Typically doing so during 204 | the first epoch helps training. Try increasing this value if the loss does not decrease.""", 205 | ) 206 | parser.add_argument( 207 | "--lr", 208 | default=0.0005, 209 | type=float, 210 | help="""Learning rate at the end of 211 | linear warmup (highest LR used during training). The learning rate is linearly scaled 212 | with the batch size, and specified here for a reference batch size of 256.""", 213 | ) 214 | parser.add_argument( 215 | "--warmup_epochs", 216 | default=10, 217 | type=int, 218 | help="Number of epochs for the linear learning-rate warm up.", 219 | ) 220 | parser.add_argument( 221 | "--min_lr", 222 | type=float, 223 | default=1e-6, 224 | help="""Target LR at the 225 | end of optimization. We use a cosine LR schedule with linear warmup.""", 226 | ) 227 | parser.add_argument( 228 | "--optimizer", 229 | default="adamw", 230 | type=str, 231 | choices=["adamw", "sgd", "lars"], 232 | help="""Type of optimizer. We recommend using adamw with ViTs.""", 233 | ) 234 | 235 | # Multi-crop parameters 236 | parser.add_argument( 237 | "--global_crops_scale", 238 | type=float, 239 | nargs="+", 240 | default=(0.4, 1.0), 241 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 242 | Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we 243 | recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""", 244 | ) 245 | parser.add_argument( 246 | "--local_crops_number", 247 | type=int, 248 | default=8, 249 | help="""Number of small 250 | local views to generate. Set this parameter to 0 to disable multi-crop training. 251 | When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """, 252 | ) 253 | parser.add_argument( 254 | "--local_crops_scale", 255 | type=float, 256 | nargs="+", 257 | default=(0.05, 0.4), 258 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 259 | Used for small local view cropping of multi-crop.""", 260 | ) 261 | 262 | # Misc 263 | parser.add_argument( 264 | "--data_path", 265 | default="/path/to/imagenet/train/", 266 | type=str, 267 | help="Please specify path to the ImageNet training data.", 268 | ) 269 | parser.add_argument( 270 | "--output_dir", default=".", type=str, help="Path to save logs and checkpoints." 271 | ) 272 | parser.add_argument( 273 | "--loader", 274 | default="folder", 275 | type=str, 276 | choices=["folder", "png_file_list", "tif_file_list"], 277 | help="Type of data loader", 278 | ) 279 | parser.add_argument( 280 | "--saveckp_freq", default=20, type=int, help="Save checkpoint every x epochs." 281 | ) 282 | parser.add_argument("--seed", default=0, type=int, help="Random seed.") 283 | parser.add_argument( 284 | "--num_workers", 285 | default=10, 286 | type=int, 287 | help="Number of data loading workers per GPU.", 288 | ) 289 | parser.add_argument( 290 | "--dist_url", 291 | default="env://", 292 | type=str, 293 | help="""url used to set up 294 | distributed training; see https://pytorch.org/docs/stable/distributed.html""", 295 | ) 296 | parser.add_argument( 297 | "--local_rank", 298 | default=0, 299 | type=int, 300 | help="Please ignore and do not set this argument.", 301 | ) 302 | parser.add_argument( 303 | "--root_dir_path", 304 | default="/home/ubuntu/data/CellNet_data/Hirano3D_v2.0/data/", 305 | type=str, 306 | help="root_dir_path for cells_dataloader.", 307 | ) 308 | parser.add_argument( 309 | "--RGBmode", default=False, type=utils.bool_flag, help="""enforce 3-channels""" 310 | ) 311 | parser.add_argument( 312 | "--scale-factor", 313 | default=1.0, 314 | type=float, 315 | help="""Set factor by which to resize source images.""", 316 | ) 317 | parser.add_argument("--config", default=".", type=str) 318 | return parser 319 | 320 | 321 | def train_dino(args, config): 322 | utils.init_distributed_mode(args) 323 | utils.fix_random_seeds(args.seed) 324 | init_signal_handler() 325 | print("git:\n {}\n".format(utils.get_sha())) 326 | print( 327 | "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())) 328 | ) 329 | cudnn.benchmark = True 330 | 331 | chosen_loader = file_dataset.image_modes[config["model"]["image_mode"]] 332 | FileList = file_dataset.data_loaders[config["model"]["datatype"]] 333 | 334 | # ============ preparing data ... ============ 335 | transform = DataAugmentationDINO(config=config) 336 | dataset = FileList( 337 | args.data_path, 338 | config["model"]["root"], 339 | transform=transform, 340 | loader=chosen_loader, 341 | flist_reader=partial( 342 | file_dataset.pandas_reader_only_file, 343 | sample_single_cells=args.sample_single_cells, 344 | ), 345 | with_labels=False, 346 | balance=False, 347 | sample_single_cells=args.sample_single_cells, 348 | ) 349 | sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) 350 | data_loader = torch.utils.data.DataLoader( 351 | dataset, 352 | sampler=sampler, 353 | batch_size=args.batch_size_per_gpu, 354 | num_workers=args.num_workers, 355 | pin_memory=True, 356 | drop_last=True, 357 | ) 358 | print(f"Data loaded: there are {len(dataset)} images.") 359 | 360 | # ============ building student and teacher networks ... ============ 361 | # we changed the name DeiT-S for ViT-S to avoid confusions 362 | args.arch = args.arch.replace("deit", "vit") 363 | # if the network is a vision transformer (i.e. vit_tiny, vit_small, vit_base) 364 | if args.arch in vits.__dict__.keys(): 365 | student = vits.__dict__[args.arch]( 366 | img_size=[config["embedding"]["image_size"]], 367 | patch_size=args.patch_size, 368 | drop_path_rate=0.1, # stochastic depth 369 | in_chans=args.num_channels, 370 | ) 371 | teacher = vits.__dict__[args.arch]( 372 | img_size=[config["embedding"]["image_size"]], 373 | patch_size=args.patch_size, 374 | in_chans=args.num_channels, 375 | ) 376 | embed_dim = student.embed_dim 377 | # otherwise, we check if the architecture is in torchvision models 378 | elif args.arch in torchvision_models.__dict__.keys(): 379 | student = torchvision_models.__dict__[args.arch]() 380 | teacher = torchvision_models.__dict__[args.arch]() 381 | embed_dim = student.fc.weight.shape[1] 382 | elif args.arch in cell_models.__dict__.keys(): 383 | student = partial(cell_models.__dict__[args.arch], c_in=args.num_channels)( 384 | False 385 | ) 386 | teacher = partial(cell_models.__dict__[args.arch], c_in=args.num_channels)( 387 | False 388 | ) 389 | embed_dim = student[-1].weight.shape[1] 390 | else: 391 | print(f"Unknow architecture: {args.arch}") 392 | 393 | # multi-crop wrapper handles forward with inputs of different resolutions 394 | student = utils.MultiCropWrapper( 395 | student, 396 | DINOHead( 397 | embed_dim, 398 | args.out_dim, 399 | use_bn=args.use_bn_in_head, 400 | norm_last_layer=args.norm_last_layer, 401 | ), 402 | ) 403 | teacher = utils.MultiCropWrapper( 404 | teacher, 405 | DINOHead(embed_dim, args.out_dim, args.use_bn_in_head), 406 | ) 407 | 408 | # move networks to gpu 409 | student, teacher = student.cuda(), teacher.cuda() 410 | # synchronize batch norms (if any) 411 | if utils.has_batchnorms(student): 412 | student = nn.SyncBatchNorm.convert_sync_batchnorm(student) 413 | teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) 414 | # we need DDP wrapper to have synchro batch norms working... 415 | teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu]) 416 | teacher_without_ddp = teacher.module 417 | else: 418 | # teacher_without_ddp and teacher are the same thing 419 | teacher_without_ddp = teacher 420 | student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu]) 421 | # teacher and student start with the same weights 422 | teacher_without_ddp.load_state_dict(student.module.state_dict()) 423 | # there is no backpropagation through the teacher, so no need for gradients 424 | for p in teacher.parameters(): 425 | p.requires_grad = False 426 | print(f"Student and Teacher are built: they are both {args.arch} network.") 427 | 428 | # ============ preparing loss ... ============ 429 | dino_loss = DINOLoss( 430 | args.out_dim, 431 | args.local_crops_number 432 | + 2, # total number of crops = 2 global crops + local_crops_number 433 | args.warmup_teacher_temp, 434 | args.teacher_temp, 435 | args.warmup_teacher_temp_epochs, 436 | args.epochs, 437 | args.student_temp, 438 | args.center_momentum, 439 | ).cuda() 440 | 441 | # ============ preparing optimizer ... ============ 442 | params_groups = utils.get_params_groups(student) 443 | if args.optimizer == "adamw": 444 | optimizer = torch.optim.AdamW(params_groups) # to use with ViTs 445 | elif args.optimizer == "sgd": 446 | optimizer = torch.optim.SGD( 447 | params_groups, lr=0, momentum=0.9 448 | ) # lr is set by scheduler 449 | elif args.optimizer == "lars": 450 | optimizer = utils.LARS(params_groups) # to use with convnet and large batches 451 | # for mixed precision training 452 | fp16_scaler = None 453 | if args.use_fp16: 454 | fp16_scaler = torch.cuda.amp.GradScaler() 455 | 456 | # ============ init schedulers ... ============ 457 | lr_schedule = utils.cosine_scheduler( 458 | args.lr 459 | * (args.batch_size_per_gpu * utils.get_world_size()) 460 | / 256.0, # linear scaling rule 461 | args.min_lr, 462 | args.epochs, 463 | len(data_loader), 464 | warmup_epochs=args.warmup_epochs, 465 | ) 466 | wd_schedule = utils.cosine_scheduler( 467 | args.weight_decay, 468 | args.weight_decay_end, 469 | args.epochs, 470 | len(data_loader), 471 | ) 472 | # momentum parameter is increased to 1. during training with a cosine schedule 473 | momentum_schedule = utils.cosine_scheduler( 474 | args.momentum_teacher, 1, args.epochs, len(data_loader) 475 | ) 476 | print("Loss, optimizer and schedulers ready.") 477 | 478 | # ============ optionally resume training ... ============ 479 | to_restore = {"epoch": 0} 480 | utils.restart_from_checkpoint( 481 | os.path.join(args.output_dir, "checkpoint.pth"), 482 | run_variables=to_restore, 483 | student=student, 484 | teacher=teacher, 485 | optimizer=optimizer, 486 | fp16_scaler=fp16_scaler, 487 | dino_loss=dino_loss, 488 | ) 489 | start_epoch = to_restore["epoch"] 490 | 491 | start_time = time.time() 492 | print("Starting DINO training !") 493 | for epoch in range(start_epoch, args.epochs): 494 | data_loader.sampler.set_epoch(epoch) 495 | 496 | # ============ training one epoch of DINO ... ============ 497 | train_stats = train_one_epoch( 498 | student, 499 | teacher, 500 | teacher_without_ddp, 501 | dino_loss, 502 | data_loader, 503 | optimizer, 504 | lr_schedule, 505 | wd_schedule, 506 | momentum_schedule, 507 | epoch, 508 | fp16_scaler, 509 | args, 510 | ) 511 | 512 | # ============ writing logs ... ============ 513 | save_dict = { 514 | "student": student.state_dict(), 515 | "teacher": teacher.state_dict(), 516 | "optimizer": optimizer.state_dict(), 517 | "epoch": epoch + 1, 518 | "args": args, 519 | "dino_loss": dino_loss.state_dict(), 520 | } 521 | if fp16_scaler is not None: 522 | save_dict["fp16_scaler"] = fp16_scaler.state_dict() 523 | utils.save_on_master(save_dict, os.path.join(args.output_dir, "checkpoint.pth")) 524 | if args.saveckp_freq and epoch % args.saveckp_freq == 0: 525 | utils.save_on_master( 526 | save_dict, os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth") 527 | ) 528 | log_stats = { 529 | **{f"train_{k}": v for k, v in train_stats.items()}, 530 | "epoch": epoch, 531 | } 532 | if utils.is_main_process(): 533 | with (Path(args.output_dir) / "log.txt").open("a") as f: 534 | f.write(json.dumps(log_stats) + "\n") 535 | total_time = time.time() - start_time 536 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 537 | print("Training time {}".format(total_time_str)) 538 | 539 | 540 | def train_one_epoch( 541 | student, 542 | teacher, 543 | teacher_without_ddp, 544 | dino_loss, 545 | data_loader, 546 | optimizer, 547 | lr_schedule, 548 | wd_schedule, 549 | momentum_schedule, 550 | epoch, 551 | fp16_scaler, 552 | args, 553 | ): 554 | metric_logger = utils.MetricLogger(delimiter=" ") 555 | header = "Epoch: [{}/{}]".format(epoch, args.epochs) 556 | for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): 557 | # update weight decay and learning rate according to their schedule 558 | it = len(data_loader) * epoch + it # global training iteration 559 | for i, param_group in enumerate(optimizer.param_groups): 560 | param_group["lr"] = lr_schedule[it] 561 | if i == 0: # only the first group is regularized 562 | param_group["weight_decay"] = wd_schedule[it] 563 | 564 | # move images to gpu 565 | images = [im.cuda(non_blocking=True) for im in images] 566 | # teacher and student forward passes + compute dino loss 567 | with torch.cuda.amp.autocast(fp16_scaler is not None): 568 | teacher_output = teacher( 569 | images[:2] 570 | ) # only the 2 global views pass through the teacher 571 | student_output = student(images) 572 | loss = dino_loss(student_output, teacher_output, epoch) 573 | 574 | if not math.isfinite(loss.item()): 575 | print("Loss is {}, stopping training".format(loss.item()), force=True) 576 | sys.exit(1) 577 | 578 | # student update 579 | optimizer.zero_grad() 580 | param_norms = None 581 | if fp16_scaler is None: 582 | loss.backward() 583 | if args.clip_grad: 584 | param_norms = utils.clip_gradients(student, args.clip_grad) 585 | utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer) 586 | optimizer.step() 587 | else: 588 | fp16_scaler.scale(loss).backward() 589 | if args.clip_grad: 590 | fp16_scaler.unscale_( 591 | optimizer 592 | ) # unscale the gradients of optimizer's assigned params in-place 593 | param_norms = utils.clip_gradients(student, args.clip_grad) 594 | utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer) 595 | fp16_scaler.step(optimizer) 596 | fp16_scaler.update() 597 | 598 | # EMA update for the teacher 599 | with torch.no_grad(): 600 | m = momentum_schedule[it] # momentum parameter 601 | for param_q, param_k in zip( 602 | student.module.parameters(), teacher_without_ddp.parameters() 603 | ): 604 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 605 | 606 | # logging 607 | torch.cuda.synchronize() 608 | metric_logger.update(loss=loss.item()) 609 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 610 | metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"]) 611 | 612 | if utils.get_rank() == 0 and os.environ["SIGNAL_RECEIVED"] == "True": 613 | trigger_job_requeue() 614 | 615 | # gather the stats from all processes 616 | metric_logger.synchronize_between_processes() 617 | print("Averaged stats:", metric_logger) 618 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 619 | 620 | 621 | class DINOLoss(nn.Module): 622 | def __init__( 623 | self, 624 | out_dim, 625 | ncrops, 626 | warmup_teacher_temp, 627 | teacher_temp, 628 | warmup_teacher_temp_epochs, 629 | nepochs, 630 | student_temp=0.1, 631 | center_momentum=0.9, 632 | ): 633 | super().__init__() 634 | self.student_temp = student_temp 635 | self.center_momentum = center_momentum 636 | self.ncrops = ncrops 637 | self.register_buffer("center", torch.zeros(1, out_dim)) 638 | # we apply a warm up for the teacher temperature because 639 | # a too high temperature makes the training instable at the beginning 640 | self.teacher_temp_schedule = np.concatenate( 641 | ( 642 | np.linspace( 643 | warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs 644 | ), 645 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp, 646 | ) 647 | ) 648 | 649 | def forward(self, student_output, teacher_output, epoch): 650 | """ 651 | Cross-entropy between softmax outputs of the teacher and student networks. 652 | """ 653 | student_out = student_output / self.student_temp 654 | student_out = student_out.chunk(self.ncrops) 655 | 656 | # teacher centering and sharpening 657 | temp = self.teacher_temp_schedule[epoch] 658 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 659 | teacher_out = teacher_out.detach().chunk(2) 660 | 661 | total_loss = 0 662 | n_loss_terms = 0 663 | for iq, q in enumerate(teacher_out): 664 | for v in range(len(student_out)): 665 | if v == iq: 666 | # we skip cases where student and teacher operate on the same view 667 | continue 668 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 669 | total_loss += loss.mean() 670 | n_loss_terms += 1 671 | total_loss /= n_loss_terms 672 | self.update_center(teacher_output) 673 | return total_loss 674 | 675 | @torch.no_grad() 676 | def update_center(self, teacher_output): 677 | """ 678 | Update center used for teacher output. 679 | """ 680 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 681 | dist.all_reduce(batch_center) 682 | batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) 683 | 684 | # ema update 685 | self.center = self.center * self.center_momentum + batch_center * ( 686 | 1 - self.center_momentum 687 | ) 688 | 689 | 690 | class DataAugmentationDINO(object): 691 | def __init__(self, config): 692 | self.config = config 693 | ( 694 | self.global_transfo1, 695 | self.global_transfo2, 696 | self.local_transfo, 697 | _, 698 | ) = tfms_from_config(self.config) 699 | 700 | def __call__(self, image): 701 | crops = [] 702 | crops.append(self.global_transfo1(image)) 703 | crops.append(self.global_transfo2(image)) 704 | for _ in range(self.config["local_crops_number"]): 705 | crops.append(self.local_transfo(image)) 706 | return crops 707 | 708 | 709 | if __name__ == "__main__": 710 | parser = argparse.ArgumentParser("DINO", parents=[get_args_parser()]) 711 | args = parser.parse_args() 712 | config = yaml.safe_load(open(args.config, "r")) 713 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 714 | train_dino(args, config) 715 | --------------------------------------------------------------------------------