├── object-localization ├── __init__.py ├── visualizations.py ├── networks.py ├── object_discovery.py ├── datasets.py └── main.py ├── semantic-segmentation ├── __init__.py ├── config │ ├── eval.yaml │ ├── base.yaml │ └── train.yaml ├── model │ ├── __init__.py │ └── model.py ├── eval_utils.py ├── dataset │ ├── __init__.py │ └── voc.py ├── README.md ├── eval.py ├── util.py └── train.py ├── requirements.txt ├── object-segmentation ├── config │ └── eval.yaml ├── dataset.py ├── metrics.py ├── main.py └── util.py ├── .gitignore ├── extract ├── extract_utils.py └── extract.py └── README.md /object-localization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /semantic-segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | fire 3 | opencv-python-headless 4 | pillow 5 | scikit-image 6 | scipy 7 | torch 8 | torchvision 9 | tqdm 10 | pymatting 11 | -------------------------------------------------------------------------------- /semantic-segmentation/config/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base 4 | - _self_ 5 | 6 | name: 'eval' 7 | job_type: 'eval' 8 | 9 | segments_dir: "" 10 | -------------------------------------------------------------------------------- /semantic-segmentation/model/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | from .model import get_deeplab_resnet, get_deeplab_vit 7 | 8 | 9 | def get_model(name: str, num_classes: int): 10 | if 'resnet' in name: 11 | model = get_deeplab_resnet(num_classes=(num_classes + 1)) # add 1 for bg 12 | elif 'vit' in name: 13 | model = get_deeplab_vit(backbone_name=name, num_classes=(num_classes + 1)) # add 1 for bg 14 | else: 15 | raise NotImplementedError() 16 | return model 17 | 18 | -------------------------------------------------------------------------------- /semantic-segmentation/config/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | run: 4 | dir: ./outputs/${name}/${now:%Y-%m-%d--%H-%M-%S} 5 | 6 | name: "debug" 7 | seed: 1 8 | job_type: 'train' 9 | fp16: False 10 | cpu: False 11 | wandb: False 12 | wandb_kwargs: 13 | project: deep-spectral-segmentation 14 | 15 | data: 16 | num_classes: 20 17 | dataset: pascal 18 | train_kwargs: 19 | root: ${oc.env:HOME}/machine-learning-datasets/semantic-segmentation/PASCAL_VOC/VOC2012 20 | year: "2012" 21 | image_set: train 22 | download: False 23 | val_kwargs: 24 | root: ${oc.env:HOME}/machine-learning-datasets/semantic-segmentation/PASCAL_VOC/VOC2012 25 | year: "2012" 26 | image_set: "val" 27 | download: False 28 | loader: 29 | batch_size: 144 30 | num_workers: 8 31 | pin_memory: False 32 | transform: 33 | resize_size: 256 34 | crop_size: 224 35 | img_mean: [0.485, 0.456, 0.406] 36 | img_std: [0.229, 0.224, 0.225] 37 | 38 | segments_dir: "" 39 | 40 | logging: 41 | print_freq: 50 -------------------------------------------------------------------------------- /semantic-segmentation/config/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - base 4 | - _self_ 5 | 6 | job_type: 'train' 7 | eval_every: 1 # eval every this many epochs 8 | checkpoint_every: 10 # checkpoint every this many epochs 9 | 10 | unfrozen_backbone_layers: 1 # -1 to train all, 0 to freeze entirely, > 0 to specify 11 | model: 12 | name: resnet50 13 | num_classes: ${data.num_classes} 14 | 15 | # Please change these 16 | segments_dir: "" 17 | matching: "" 18 | 19 | checkpoint: 20 | resume: null 21 | resume_training: True 22 | resume_optimizer_only: False 23 | 24 | # Exponential moving average of model parameters 25 | ema: 26 | use_ema: False 27 | decay: 0.999 28 | update_every: 10 29 | 30 | # Training steps/epochs 31 | max_train_steps: 5000 32 | max_train_epochs: null 33 | 34 | # Optimization 35 | lr: 0.005 36 | gradient_accumulation_steps: 1 37 | optimizer: 38 | scale_learning_rate_with_batch_size: False 39 | clip_grad_norm: null 40 | 41 | # Timm optimizer 42 | kind: 'timm' 43 | kwargs: 44 | opt: 'adamw' 45 | weight_decay: 1e-8 46 | 47 | # Learning rate scheduling 48 | scheduler: 49 | 50 | # Transformers scheduler 51 | kind: 'transformers' 52 | stepwise: True 53 | kwargs: 54 | name: linear 55 | num_warmup_steps: 0 56 | num_training_steps: ${max_train_steps} 57 | -------------------------------------------------------------------------------- /object-segmentation/config/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | 5 | hydra: 6 | run: 7 | dir: ./outputs/${name}/${now:%Y-%m-%d_%H-%M-%S} 8 | 9 | # General 10 | name: "debug" 11 | seed: 1 12 | job_type: 'eval' 13 | fp16: False 14 | cpu: False 15 | wandb: False 16 | wandb_kwargs: 17 | project: deep-spectral-segmentation 18 | 19 | # Data 20 | data_root: ${env:GANSEG_DATA_SEG_ROOT} # <- REPLACE THIS WITH YOUR DIRECTORY 21 | data: 22 | - name: 'CUB' 23 | images_dir: "${data_root}/CUB_200_2011/test_images" 24 | labels_dir: "${data_root}/CUB_200_2011/test_segmentations" 25 | crop: True 26 | image_size: null 27 | - name: 'DUT_OMRON' 28 | images_dir: "${data_root}/DUT_OMRON/DUT-OMRON-image" 29 | labels_dir: "${data_root}/DUT_OMRON/pixelwiseGT-new-PNG" 30 | crop: False 31 | image_size: null 32 | - name: 'DUTS' 33 | images_dir: "${data_root}/DUTS/DUTS-TE/DUTS-TE-Image" 34 | labels_dir: "${data_root}/DUTS/DUTS-TE/DUTS-TE-Mask" 35 | crop: False 36 | image_size: null 37 | - name: 'ECSSD' 38 | images_dir: "${data_root}/ECSSD/images" 39 | labels_dir: "${data_root}/ECSSD/ground_truth_mask" 40 | crop: False 41 | image_size: null 42 | 43 | dataloader: 44 | batch_size: 128 45 | num_workers: 16 46 | 47 | # Predictions 48 | predictions: 49 | root: "/path/to/object-segmentation-data" 50 | run: "run_name" 51 | downsample: 16 # null 52 | 53 | # The paths to the predictions 54 | CUB: ${predictions.root}/CUB_200_2011/${predictions.run} 55 | DUT_OMRON: ${predictions.root}/DUT_OMRON/${predictions.run} 56 | DUTS: ${predictions.root}/DUTS/${predictions.run} 57 | ECSSD: ${predictions.root}/ECSSD/${predictions.run} 58 | -------------------------------------------------------------------------------- /semantic-segmentation/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from joblib import Parallel 3 | from joblib.parallel import delayed 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | 7 | def hungarian_match(flat_preds, flat_targets, preds_k, targets_k, metric='acc', n_jobs=16): 8 | assert (preds_k == targets_k) # one to one 9 | num_k = preds_k 10 | 11 | # perform hungarian matching 12 | print('Using iou as metric') 13 | results = Parallel(n_jobs=n_jobs, backend='multiprocessing')(delayed(get_iou)( 14 | flat_preds, flat_targets, c1, c2) for c2 in range(num_k) for c1 in range(num_k)) 15 | results = np.array(results) 16 | results = results.reshape((num_k, num_k)).T 17 | match = linear_sum_assignment(flat_targets.shape[0] - results) 18 | match = np.array(list(zip(*match))) 19 | res = [] 20 | for out_c, gt_c in match: 21 | res.append((out_c, gt_c)) 22 | 23 | return res 24 | 25 | 26 | def majority_vote(flat_preds, flat_targets, preds_k, targets_k, n_jobs=16): 27 | iou_mat = Parallel(n_jobs=n_jobs, backend='multiprocessing')(delayed(get_iou)( 28 | flat_preds, flat_targets, c1, c2) for c2 in range(targets_k) for c1 in range(preds_k)) 29 | iou_mat = np.array(iou_mat) 30 | results = iou_mat.reshape((targets_k, preds_k)).T 31 | results = np.argmax(results, axis=1) 32 | match = np.array(list(zip(range(preds_k), results))) 33 | return match 34 | 35 | 36 | def get_iou(flat_preds, flat_targets, c1, c2): 37 | tp = 0 38 | fn = 0 39 | fp = 0 40 | tmp_all_gt = (flat_preds == c1) 41 | tmp_pred = (flat_targets == c2) 42 | tp += np.sum(tmp_all_gt & tmp_pred) 43 | fp += np.sum(~tmp_all_gt & tmp_pred) 44 | fn += np.sum(tmp_all_gt & ~tmp_pred) 45 | jac = float(tp) / max(float(tp + fp + fn), 1e-8) 46 | return jac 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | .vscode 3 | wandb 4 | outputs 5 | tmp* 6 | slurm-logs 7 | preprocess/extract-features/features_* 8 | preprocess/extract-features/eigensegments_* 9 | preprocess/extract-features/mattes_* 10 | preprocess/extract-features/masks_* 11 | preprocess/extract-features/multilabel_masks_* 12 | preprocess/extract-features/semantic_segmentations_* 13 | preprocess/extract-features/crf_semantic_segmentations_* 14 | old 15 | dino 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | .github 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # Lightning /research 46 | test_tube_exp/ 47 | tests/tests_tt_dir/ 48 | tests/save_dir 49 | default/ 50 | data/ 51 | test_tube_logs/ 52 | test_tube_data/ 53 | datasets/ 54 | model_weights/ 55 | tests/save_dir 56 | tests/tests_tt_dir/ 57 | processed/ 58 | raw/ 59 | 60 | # PyInstaller 61 | # Usually these files are written by a python script from a template 62 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 63 | *.manifest 64 | *.spec 65 | 66 | # Installer logs 67 | pip-log.txt 68 | pip-delete-this-directory.txt 69 | 70 | # Unit test / coverage reports 71 | htmlcov/ 72 | .tox/ 73 | .coverage 74 | .coverage.* 75 | .cache 76 | nosetests.xml 77 | coverage.xml 78 | *.cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | 91 | # Flask stuff: 92 | instance/ 93 | .webassets-cache 94 | 95 | # Scrapy stuff: 96 | .scrapy 97 | 98 | # Sphinx documentation 99 | docs/_build/ 100 | 101 | # PyBuilder 102 | target/ 103 | 104 | # Jupyter Notebook 105 | .ipynb_checkpoints 106 | 107 | # pyenv 108 | .python-version 109 | 110 | # celery beat schedule file 111 | celerybeat-schedule 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | 138 | # IDEs 139 | .idea 140 | .vscode 141 | 142 | # seed project 143 | lightning_logs/ 144 | MNIST 145 | .DS_Store 146 | -------------------------------------------------------------------------------- /semantic-segmentation/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import albumentations.pytorch as AP 3 | import cv2 4 | from torch.utils.data._utils.collate import default_collate 5 | 6 | from .voc import VOCSegmentationWithPseudolabels 7 | 8 | 9 | def get_transforms(resize_size, crop_size, img_mean, img_std): 10 | 11 | # Multiple training transforms for contrastive learning 12 | train_joint_transform = A.Compose([ 13 | A.SmallestMaxSize(resize_size, interpolation=cv2.INTER_CUBIC), 14 | A.RandomCrop(crop_size, crop_size), 15 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 16 | train_geometric_transform = A.ReplayCompose([ 17 | A.RandomResizedCrop(crop_size, crop_size, interpolation=cv2.INTER_CUBIC), 18 | A.HorizontalFlip(), 19 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 20 | train_separate_transform = A.Compose([ 21 | A.ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8), 22 | A.ToGray(p=0.2), A.GaussianBlur(p=0.1), # A.Solarize(p=0.1) 23 | A.Normalize(mean=img_mean, std=img_std), AP.ToTensorV2(), 24 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 25 | 26 | # Validation transform -- no resizing! 27 | val_transform = A.Compose([ 28 | # A.Resize(resize_size, resize_size, interpolation=cv2.INTER_CUBIC), A.CenterCrop(crop_size, crop_size), 29 | A.Normalize(mean=img_mean, std=img_std), AP.ToTensorV2() 30 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'}) 31 | 32 | train_transforms_tuple = (train_joint_transform, train_geometric_transform, train_separate_transform) 33 | return train_transforms_tuple, val_transform 34 | 35 | 36 | def collate_fn(batch): 37 | everything_but_metadata = [t[:-1] for t in batch] 38 | metadata = [t[-1] for t in batch] 39 | return (*default_collate(everything_but_metadata), metadata) 40 | 41 | 42 | def get_datasets(cfg): 43 | 44 | # Get transforms 45 | train_transforms_tuple, val_transform = get_transforms(**cfg.data.transform) 46 | 47 | # Get the label map 48 | if cfg.matching: 49 | matching = dict(eval(str(cfg.matching))) 50 | print(f'Using matching: {matching}') 51 | else: 52 | matching = None 53 | 54 | # Training dataset 55 | dataset_train = VOCSegmentationWithPseudolabels( 56 | **cfg.data.train_kwargs, 57 | segments_dir=cfg.segments_dir, 58 | transforms_tuple=train_transforms_tuple, 59 | label_map=matching 60 | ) 61 | 62 | # Validation dataset 63 | dataset_val = VOCSegmentationWithPseudolabels( 64 | **cfg.data.val_kwargs, 65 | segments_dir=cfg.segments_dir, 66 | transform=val_transform, 67 | label_map=matching 68 | ) 69 | 70 | return dataset_train, dataset_val, collate_fn 71 | -------------------------------------------------------------------------------- /object-segmentation/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms as T 6 | 7 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 8 | 9 | 10 | def get_paths_from_folders(images_dir): 11 | """Returns list of files in folders of input""" 12 | paths = [] 13 | for folder in Path(images_dir).iterdir(): 14 | for p in folder.iterdir(): 15 | paths.append(p) 16 | return paths 17 | 18 | 19 | def central_crop(x): 20 | dims = x.size 21 | crop = T.CenterCrop(min(dims[0], dims[1])) 22 | return crop(x) 23 | 24 | 25 | class SegmentationDataset(Dataset): 26 | 27 | def __init__( 28 | self, 29 | images_dir: str, 30 | labels_dir: str, 31 | image_size: Optional[int] = None, 32 | resize_image=True, 33 | resize_mask=None, 34 | crop=True, 35 | mean=[0.5, 0.5, 0.5], 36 | std=[0.5, 0.5, 0.5], 37 | name: Optional[str] = None, 38 | ): 39 | self.name = name 40 | self.crop = crop 41 | 42 | # Find out if dataset is organized into folders or not 43 | has_folders = not any(str(next(Path(images_dir).iterdir())).endswith(ext) for ext in IMG_EXTENSIONS) 44 | 45 | # Get and sort list of paths 46 | if has_folders: 47 | image_paths = get_paths_from_folders(images_dir) 48 | label_paths = get_paths_from_folders(labels_dir) 49 | else: 50 | image_paths = Path(images_dir).iterdir() 51 | label_paths = Path(labels_dir).iterdir() 52 | self.image_paths = list(sorted(image_paths)) 53 | self.label_paths = list(sorted(label_paths)) 54 | assert len(self.image_paths) == len(self.label_paths) 55 | 56 | # Transformation 57 | resize_image = (image_size is not None) and resize_image 58 | resize_mask = resize_image if resize_mask is None else resize_mask 59 | image_transform = [T.ToTensor(), T.Normalize(mean=mean, std=std)] 60 | mask_transform = [T.ToTensor()] 61 | if resize_image: 62 | image_transform.insert(0, T.Resize(image_size)) 63 | if resize_mask: 64 | mask_transform.insert(0, T.Resize(image_size)) 65 | if crop: 66 | image_transform.insert(0, central_crop) 67 | mask_transform.insert(0, central_crop) 68 | self.image_transform = T.Compose(image_transform) 69 | self.mask_transform = T.Compose(mask_transform) 70 | 71 | def __len__(self): 72 | return len(self.image_paths) 73 | 74 | def __getitem__(self, idx): 75 | 76 | # Load 77 | image = Image.open(self.image_paths[idx]) 78 | mask = Image.open(self.label_paths[idx]) 79 | metadata = {'image_file': str(self.image_paths[idx])} 80 | 81 | # Transform 82 | image = image.convert('RGB') 83 | mask = mask.convert('RGB') 84 | image = self.image_transform(image) 85 | mask = self.mask_transform(mask) 86 | mask = (mask > 0.5)[0].long() # TODO: this could be improved 87 | return image, mask, metadata 88 | -------------------------------------------------------------------------------- /semantic-segmentation/README.md: -------------------------------------------------------------------------------- 1 | ## Semantic Segmentation 2 | 3 | We begin by extracting features and eigenvectors from our images. For instructions on this process, follow the steps in "Extraction" in the main `README`. 4 | 5 | Next, we obtain coarse (i.e. patch-level) semantic segmentations. This process involves (1) extracting segments from the eigenvectors, (2) taking a bounding box around them, (3) extracting features for these boxes, (4) clustering these features, (5) obtaining coarse semantic segmentations. 6 | 7 | For example, you can run the following in the `extract` directory. 8 | 9 | ```bash 10 | # Example parameters for the semantic segmentation experiments 11 | DATASET="VOC2012" 12 | MODEL="dino_vits16" 13 | MATRIX="laplacian" 14 | DOWNSAMPLE=16 15 | N_SEG=15 16 | N_ERODE=2 17 | N_DILATE=5 18 | 19 | # Extract segments 20 | python extract.py extract_multi_region_segmentations \ 21 | --non_adaptive_num_segments ${N_SEG} \ 22 | --features_dir "./data/${DATASET}/features/${MODEL}" \ 23 | --eigs_dir "./data/${DATASET}/eigs/${MATRIX}" \ 24 | --output_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" 25 | 26 | # Extract bounding boxes 27 | python extract.py extract_bboxes \ 28 | --features_dir "./data/${DATASET}/features/${MODEL}" \ 29 | --segmentations_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" \ 30 | --num_erode ${N_ERODE} \ 31 | --num_dilate ${N_DILATE} \ 32 | --downsample_factor ${DOWNSAMPLE} \ 33 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bboxes.pth" 34 | 35 | # Extract bounding box features 36 | python extract.py extract_bbox_features \ 37 | --model_name ${MODEL} \ 38 | --images_root "./data/${DATASET}/images" \ 39 | --bbox_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bboxes.pth" \ 40 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_features.pth" 41 | 42 | # Extract clusters 43 | python extract.py extract_bbox_clusters \ 44 | --bbox_features_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_features.pth" \ 45 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_clusters.pth" 46 | 47 | # Create semantic segmentations 48 | python extract.py extract_semantic_segmentations \ 49 | --segmentations_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" \ 50 | --bbox_clusters_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_clusters.pth" \ 51 | --output_dir "./data/${DATASET}/semantic_segmentations/patches/${MATRIX}/segmaps" 52 | ``` 53 | 54 | At this point, you can evaluate the segmentations using `eval.py` in this directory. For example: 55 | ```bash 56 | python eval.py segments_dir="/output_dir/from/above" 57 | ``` 58 | 59 | Optionally, you can also perform self-training using `train.py`. You can specify the correct matching using `matching="\"[(0, 0), ... (19, 6), (20, 7)]\""`. This matching may be obtained by first evaluating using `python eval.py`. For example: 60 | ```bash 61 | python train.py lr=2e-4 data.loader.batch_size=96 segments_dir="/path/to/segmaps" matching="\"[(0, 0), ... (19, 6), (20, 7)]\"" 62 | ``` 63 | 64 | Please note that the unsupervised semantic segmentation results have very high variance; some runs are much better than others. This variance is primarily due to the random seeds of the K-means clustering steps above, and it is secondarily due to randomness in the self-training stage. Also please note that this code has been heavily re-factored for its public release. Although we try to ensure that there are no bugs, it is nevertheless possible that there is a bug we have overlooked. 65 | -------------------------------------------------------------------------------- /object-localization/visualizations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vis utilities. Code adapted from LOST: https://github.com/valeoai/LOST 3 | """ 4 | import cv2 5 | import torch 6 | import skimage.io 7 | import numpy as np 8 | import torch.nn as nn 9 | from PIL import Image 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, plot_seed=False): 14 | """ 15 | Visualization of the predicted box and the corresponding seed patch. 16 | """ 17 | w_featmap, h_featmap = dims 18 | 19 | # Plot the box 20 | cv2.rectangle( 21 | image, 22 | (int(pred[0]), int(pred[1])), 23 | (int(pred[2]), int(pred[3])), 24 | (255, 0, 0), 3, 25 | ) 26 | 27 | # Plot the seed 28 | if plot_seed: 29 | s_ = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap)) 30 | size_ = np.asarray(scales) / 2 31 | cv2.rectangle( 32 | image, 33 | (int(s_[1] * scales[1] - (size_[1] / 2)), int(s_[0] * scales[0] - (size_[0] / 2))), 34 | (int(s_[1] * scales[1] + (size_[1] / 2)), int(s_[0] * scales[0] + (size_[0] / 2))), 35 | (0, 255, 0), -1, 36 | ) 37 | 38 | pltname = f"{vis_folder}/LOST_{im_name}.png" 39 | Image.fromarray(image).save(pltname) 40 | print(f"Predictions saved at {pltname}.") 41 | 42 | def visualize_fms(A, seed, scores, dims, scales, output_folder, im_name): 43 | """ 44 | Visualization of the maps presented in Figure 2 of the paper. 45 | """ 46 | w_featmap, h_featmap = dims 47 | 48 | # Binarized similarity 49 | binA = A.copy() 50 | binA[binA < 0] = 0 51 | binA[binA > 0] = 1 52 | 53 | # Get binarized correlation for this pixel and make it appear in gray 54 | im_corr = np.zeros((3, len(scores))) 55 | where = binA[seed, :] > 0 56 | im_corr[:, where] = np.array([128 / 255, 133 / 255, 133 / 255]).reshape((3, 1)) 57 | # Show selected pixel in green 58 | im_corr[:, seed] = [204 / 255, 37 / 255, 41 / 255] 59 | # Reshape and rescale 60 | im_corr = im_corr.reshape((3, w_featmap, h_featmap)) 61 | im_corr = ( 62 | nn.functional.interpolate( 63 | torch.from_numpy(im_corr).unsqueeze(0), 64 | scale_factor=scales, 65 | mode="nearest", 66 | )[0].cpu().numpy() 67 | ) 68 | 69 | # Save correlations 70 | skimage.io.imsave( 71 | fname=f"{output_folder}/corr_{im_name}.png", 72 | arr=im_corr.transpose((1, 2, 0)), 73 | ) 74 | print(f"Image saved at {output_folder}/corr_{im_name}.png .") 75 | 76 | # Save inverse degree 77 | im_deg = ( 78 | nn.functional.interpolate( 79 | torch.from_numpy(1 / binA.sum(-1)).reshape(1, 1, w_featmap, h_featmap), 80 | scale_factor=scales, 81 | mode="nearest", 82 | )[0][0].cpu().numpy() 83 | ) 84 | plt.imsave(fname=f"{output_folder}/deg_{im_name}.png", arr=im_deg) 85 | print(f"Image saved at {output_folder}/deg_{im_name}.png .") 86 | 87 | def visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims, vis_folder, im_name): 88 | """ 89 | Visualization of the seed expansion presented in Figure 3 of the paper. 90 | """ 91 | w_featmap, h_featmap = dims 92 | 93 | # Before expansion 94 | cv2.rectangle( 95 | image, 96 | (int(pred_seed[0]), int(pred_seed[1])), 97 | (int(pred_seed[2]), int(pred_seed[3])), 98 | (204, 204, 0), # Yellow 99 | 3, 100 | ) 101 | 102 | # After expansion 103 | cv2.rectangle( 104 | image, 105 | (int(pred[0]), int(pred[1])), 106 | (int(pred[2]), int(pred[3])), 107 | (204, 0, 204), # Magenta 108 | 3, 109 | ) 110 | 111 | # Position of the seed 112 | center = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap)) 113 | start_1 = center[0] * scales[0] 114 | end_1 = center[0] * scales[0] + scales[0] 115 | start_2 = center[1] * scales[1] 116 | end_2 = center[1] * scales[1] + scales[1] 117 | image[start_1:end_1, start_2:end_2, 0] = 204 118 | image[start_1:end_1, start_2:end_2, 1] = 37 119 | image[start_1:end_1, start_2:end_2, 2] = 41 120 | 121 | pltname = f"{vis_folder}/LOST_seed_expansion_{im_name}.png" 122 | Image.fromarray(image).save(pltname) 123 | print(f"Image saved at {pltname}.") 124 | -------------------------------------------------------------------------------- /object-localization/networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loads model. Depends on DINO repo. 3 | Code adapted from LOST: https://github.com/valeoai/LOST 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.models.resnet import resnet50 8 | from torchvision.models.vgg import vgg16 9 | 10 | import dino.vision_transformer as vits 11 | 12 | 13 | def get_model(arch, patch_size, resnet_dilate, device): 14 | if "resnet" in arch: 15 | if resnet_dilate == 1: 16 | replace_stride_with_dilation = [False, False, False] 17 | elif resnet_dilate == 2: 18 | replace_stride_with_dilation = [False, False, True] 19 | elif resnet_dilate == 4: 20 | replace_stride_with_dilation = [False, True, True] 21 | 22 | if "imagenet" in arch: 23 | model = resnet50( 24 | pretrained=True, 25 | replace_stride_with_dilation=replace_stride_with_dilation, 26 | ) 27 | else: 28 | model = resnet50( 29 | pretrained=False, 30 | replace_stride_with_dilation=replace_stride_with_dilation, 31 | ) 32 | elif "vgg16" in arch: 33 | if "imagenet" in arch: 34 | model = vgg16(pretrained=True) 35 | else: 36 | model = vgg16(pretrained=False) 37 | else: 38 | model = vits.__dict__[arch](patch_size=patch_size, num_classes=0) 39 | 40 | for p in model.parameters(): 41 | p.requires_grad = False 42 | 43 | # Initialize model with pretraining 44 | if "imagenet" not in arch: 45 | url = None 46 | if arch == "vit_small" and patch_size == 16: 47 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 48 | elif arch == "vit_small" and patch_size == 8: 49 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper 50 | elif arch == "vit_base" and patch_size == 16: 51 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 52 | elif arch == "vit_base" and patch_size == 8: 53 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 54 | elif arch == "resnet50": 55 | url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth" 56 | if url is not None: 57 | print( 58 | "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." 59 | ) 60 | state_dict = torch.hub.load_state_dict_from_url( 61 | url="https://dl.fbaipublicfiles.com/dino/" + url 62 | ) 63 | strict_loading = False if "resnet" in arch else True 64 | msg = model.load_state_dict(state_dict, strict=strict_loading) 65 | print( 66 | "Pretrained weights found at {} and loaded with msg: {}".format( 67 | url, msg 68 | ) 69 | ) 70 | else: 71 | print( 72 | "There is no reference weights available for this model => We use random weights." 73 | ) 74 | 75 | # If ResNet or VGG16 loose the last fully connected layer 76 | if "resnet" in arch: 77 | model = ResNet50Bottom(model) 78 | elif "vgg16" in arch: 79 | model = vgg16Bottom(model) 80 | 81 | model.eval() 82 | model.to(device) 83 | return model 84 | 85 | 86 | class ResNet50Bottom(nn.Module): 87 | # https://forums.fast.ai/t/pytorch-best-way-to-get-at-intermediate-layers-in-vgg-and-resnet/5707/2 88 | def __init__(self, original_model): 89 | super(ResNet50Bottom, self).__init__() 90 | # Remove avgpool and fc layers 91 | self.features = nn.Sequential(*list(original_model.children())[:-2]) 92 | 93 | def forward(self, x): 94 | x = self.features(x) 95 | return x 96 | 97 | 98 | class vgg16Bottom(nn.Module): 99 | # https://forums.fast.ai/t/pytorch-best-way-to-get-at-intermediate-layers-in-vgg-and-resnet/5707/2 100 | def __init__(self, original_model): 101 | super(vgg16Bottom, self).__init__() 102 | # Remove avgpool and the classifier 103 | self.features = nn.Sequential(*list(original_model.children())[:-2]) 104 | # Remove the last maxPool2d 105 | self.features = nn.Sequential(*list(self.features[0][:-1])) 106 | 107 | def forward(self, x): 108 | x = self.features(x) 109 | return x 110 | -------------------------------------------------------------------------------- /semantic-segmentation/model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torchvision.models._utils import IntermediateLayerGetter 5 | from torchvision.models.segmentation.deeplabv3 import (ASPP, DeepLabHead, DeepLabV3) 6 | 7 | 8 | def get_deeplab_resnet(num_classes: int, name: str = 'deeplabv3plus', output_stride: int = 8): 9 | 10 | if output_stride == 8: 11 | replace_stride_with_dilation = [False, True, True] 12 | aspp_dilate = [12, 24, 36] 13 | elif output_stride == 16: 14 | replace_stride_with_dilation = [False, False, True] 15 | aspp_dilate = [6, 12, 18] 16 | else: 17 | raise NotImplementedError() 18 | 19 | backbone = torch.hub.load( 20 | 'facebookresearch/dino:main', 21 | 'dino_resnet50', 22 | replace_stride_with_dilation=replace_stride_with_dilation 23 | ) 24 | 25 | inplanes = 2048 26 | low_level_planes = 256 27 | 28 | if name == 'deeplabv3plus': 29 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} 30 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 31 | DeepLab = DeepLabV3Plus 32 | elif name == 'deeplabv3': 33 | return_layers = {'layer4': 'out'} 34 | DeepLab = DeepLabV3 35 | classifier = DeepLabHead(inplanes, num_classes, aspp_dilate) 36 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 37 | 38 | model = DeepLab(backbone, classifier) 39 | return model 40 | 41 | 42 | def get_deeplab_vit(num_classes: int, backbone_name: str = 'vits16', name: str = 'deeplabv3plus'): 43 | 44 | # Backbone 45 | backbone = torch.hub.load('facebookresearch/dino:main', f'dino_{backbone_name}') 46 | 47 | # Classifier 48 | aspp_dilate = [12, 24, 36] 49 | inplanes = low_level_planes = backbone.embed_dim 50 | if name == 'deeplabv3plus': 51 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 52 | DeepLab = DeepLabV3Plus 53 | elif name == 'deeplabv3': 54 | DeepLab = DeepLabV3 55 | classifier = DeepLabHead(inplanes, num_classes, aspp_dilate) 56 | 57 | # Wrap 58 | backbone = VisionTransformerWrapper(backbone) 59 | model = DeepLab(backbone, classifier) 60 | return model 61 | 62 | 63 | class VisionTransformerWrapper(nn.Module): 64 | def __init__(self, backbone): 65 | super().__init__() 66 | self.backbone = backbone 67 | 68 | def forward(self, x): 69 | # Forward 70 | output = self.backbone.get_intermediate_layers(x, n=5) 71 | # Reshaping 72 | assert (len(output) == 5), f'{output.shape=}' 73 | H_patch = x.shape[-2] // self.backbone.patch_embed.patch_size 74 | W_patch = x.shape[-1] // self.backbone.patch_embed.patch_size 75 | out_ll = output[0][:, 1:, :].transpose(-2, -1).unflatten(-1, (H_patch, W_patch)) 76 | out = output[-1][:, 1:, :].transpose(-2, -1).unflatten(-1, (H_patch, W_patch)) 77 | return {'low_level': out_ll, 'out': out} 78 | 79 | 80 | class DeepLabHeadV3Plus(nn.Module): 81 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): 82 | super(DeepLabHeadV3Plus, self).__init__() 83 | self.project = nn.Sequential( 84 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 85 | nn.BatchNorm2d(48), 86 | nn.ReLU(inplace=True), 87 | ) 88 | 89 | self.aspp = ASPP(in_channels, aspp_dilate) 90 | 91 | self.classifier = nn.Sequential( 92 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 93 | nn.BatchNorm2d(256), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(256, num_classes, 1) 96 | ) 97 | self._init_weight() 98 | 99 | def forward(self, feature): 100 | low_level_feature = self.project(feature['low_level']) 101 | output_feature = self.aspp(feature['out']) 102 | output_feature = F.interpolate( 103 | output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) 104 | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1)) 105 | 106 | def _init_weight(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight) 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | 115 | class DeepLabV3Plus(nn.Module): 116 | def __init__(self, backbone, classifier): 117 | super().__init__() 118 | self.backbone = backbone 119 | self.classifier = classifier 120 | 121 | def forward(self, x): 122 | input_shape = x.shape[-2:] 123 | features = self.backbone(x) 124 | x = self.classifier(features) 125 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 126 | return x 127 | -------------------------------------------------------------------------------- /object-segmentation/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import numpy as np 4 | 5 | 6 | @torch.no_grad() 7 | def compute_metrics(preds, targets, metrics=['f_max', 'acc', 'iou'], threshold=0.5, swap_dims=False, preds_are_soft=False): 8 | 9 | # Move to CPU 10 | preds = preds.detach() # .cpu() 11 | targets = targets.detach() # .cpu() 12 | assert len(targets.shape) == 3 13 | if preds_are_soft: 14 | assert len(preds.shape) == 4 15 | soft_preds = torch.softmax(preds, dim=1)[:, (0 if swap_dims else 1)] # convert to probabilities 16 | hard_preds = soft_preds > threshold 17 | else: 18 | assert 'f_max' not in metrics, 'must have soft preds for f_max' 19 | assert (len(preds.shape) == 3) 20 | assert (preds.dtype == torch.bool) or (preds.dtype == torch.uint8) or (preds.dtype == torch.long) 21 | assert (preds.max() <= 1) and (preds.min() >= 0) 22 | soft_preds = [None] * len(preds) 23 | hard_preds = preds.bool() 24 | 25 | # Compute 26 | results = defaultdict(list) 27 | for soft_pred, hard_pred, target in zip(soft_preds, hard_preds, targets): 28 | if 'f_max' in metrics: 29 | precision, recall = compute_prs(soft_pred, target, prob_bins=255) 30 | results['f_max_precision'].append(precision) 31 | results['f_max_recall'].append(recall) 32 | if 'f_beta' in metrics: 33 | precision, recall = precision_recall(target, hard_preds) 34 | results['f_beta_precision'].append([precision]) 35 | results['f_beta_recall'].append([recall]) 36 | if 'acc' in metrics: 37 | acc = compute_accuracy(hard_pred, target) 38 | results['acc'].append(acc) 39 | if 'iou' in metrics: 40 | iou = compute_iou(hard_pred, target) 41 | results['iou'].append(iou) 42 | return dict(results) 43 | 44 | 45 | @torch.no_grad() 46 | def aggregate_metrics(totals): 47 | results = defaultdict(list) 48 | if 'acc' in totals: 49 | results['acc'] = mean(totals['acc']) 50 | if 'iou' in totals: 51 | results['iou'] = mean(totals['iou']) 52 | if 'loss' in totals: 53 | results['loss'] = mean(totals['loss']) 54 | if 'f_max_precision' in totals and 'f_max_recall' in totals: 55 | precisions = torch.tensor(totals['f_max_precision']) 56 | recalls = torch.tensor(totals['f_max_recall']) 57 | results['f_max'] = F_max(precisions, recalls) 58 | if 'f_beta_precision' in totals and 'f_beta_recall' in totals: 59 | precisions = torch.tensor(totals['f_beta_precision']) 60 | recalls = torch.tensor(totals['f_beta_recall']) 61 | results['f_beta'] = F_max(precisions, recalls) 62 | return results 63 | 64 | 65 | def compute_accuracy(pred, target): 66 | pred, target = pred.to(torch.bool), target.to(torch.bool) 67 | return torch.mean((pred == target).to(torch.float)).item() 68 | 69 | 70 | def compute_iou(pred, target): 71 | pred, target = pred.to(torch.bool), target.to(torch.bool) 72 | intersection = torch.sum(pred * (pred == target), dim=[-1, -2]).squeeze() 73 | union = torch.sum(pred + target, dim=[-1, -2]).squeeze() 74 | iou = (intersection.to(torch.float) / union).mean() 75 | iou = iou.item() if (iou == iou) else 0 # deal with nans, i.e. torch.nan_to_num(iou, nan=0.0) 76 | return iou 77 | 78 | 79 | def compute_prs(pred, target, prob_bins=255): 80 | p = [] 81 | r = [] 82 | for split in np.arange(0.0, 1.0, 1.0 / prob_bins): 83 | if split == 0.0: 84 | continue 85 | pr = precision_recall(target, pred > split) 86 | p.append(pr[0]) 87 | r.append(pr[1]) 88 | return p, r 89 | 90 | 91 | def precision_recall(mask_gt, mask): 92 | mask_gt, mask = mask_gt.to(torch.bool), mask.to(torch.bool) 93 | true_positive = torch.sum(mask_gt * (mask_gt == mask), dim=[-1, -2]).squeeze() 94 | mask_area = torch.sum(mask, dim=[-1, -2]).to(torch.float) 95 | mask_gt_area = torch.sum(mask_gt, dim=[-1, -2]).to(torch.float) 96 | precision = true_positive / mask_area 97 | precision[mask_area == 0.0] = 1.0 98 | recall = true_positive / mask_gt_area 99 | recall[mask_gt_area == 0.0] = 1.0 100 | return precision.item(), recall.item() 101 | 102 | 103 | def F_scores(p, r, betta_sq=0.3): 104 | f_scores = ((1 + betta_sq) * p * r) / (betta_sq * p + r) 105 | f_scores[f_scores != f_scores] = 0.0 # handle nans 106 | return f_scores 107 | 108 | 109 | def F_max(precisions, recalls, betta_sq=0.3): 110 | f_scores = F_scores(precisions, recalls, betta_sq) 111 | f_scores = f_scores.mean(dim=0) 112 | # print('f_scores.shape: ', f_scores.shape) 113 | # print('torch.argmax(f_scores): ', torch.argmax(f_scores)) 114 | return f_scores.max().item() 115 | 116 | 117 | def mean(x): 118 | return sum(x) / len(x) 119 | 120 | 121 | def list_of_dicts_to_dict_of_lists(LD): 122 | return {k: [dic[k] for dic in LD] for k in LD[0]} 123 | 124 | 125 | def list_of_dict_of_lists_to_dict_of_lists(LD): 126 | return {k: [v for dic in LD for v in dic[k]] for k in LD[0]} 127 | 128 | 129 | def dict_of_lists_to_list_of_dicts(DL): 130 | return [dict(zip(DL, t)) for t in zip(*DL.values())] 131 | -------------------------------------------------------------------------------- /object-segmentation/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import sys 4 | import math 5 | import datetime 6 | import torch 7 | import hydra 8 | import wandb 9 | import cv2 10 | import numpy as np 11 | from contextlib import nullcontext 12 | from collections import namedtuple 13 | from pathlib import Path 14 | from typing import Callable, Iterable, Optional 15 | from PIL import Image 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from torchvision.transforms import functional as TF 19 | from accelerate import Accelerator 20 | from omegaconf import OmegaConf, DictConfig 21 | from tqdm import tqdm 22 | 23 | import metrics 24 | import util as utils 25 | from dataset import SegmentationDataset, central_crop 26 | 27 | 28 | @hydra.main(config_path='config', config_name='eval') 29 | def main(cfg: DictConfig): 30 | 31 | # Accelerator 32 | accelerator = Accelerator(fp16=cfg.fp16, cpu=cfg.cpu) 33 | 34 | # Logging 35 | utils.setup_distributed_print(accelerator.is_local_main_process) 36 | if cfg.wandb and accelerator.is_local_main_process: 37 | wandb.init(name=cfg.name, job_type=cfg.job_type, config=OmegaConf.to_container(cfg), save_code=True, **cfg.wandb_kwargs) 38 | cfg = DictConfig(wandb.config.as_dict()) # get the config back from wandb for hyperparameter sweeps 39 | 40 | # Configuration 41 | print(OmegaConf.to_yaml(cfg)) 42 | print(f'Current working directory: {os.getcwd()}') 43 | 44 | # Set random seed 45 | utils.set_seed(cfg.seed) 46 | 47 | # Datasets 48 | # NOTE: The batch size must be 1 for test because the masks are different sizes, 49 | # and evaluation should be done using the mask at the original resolution. 50 | test_dataloaders = [] 51 | for data_cfg in cfg.data: 52 | test_dataset = SegmentationDataset(**data_cfg) 53 | test_dataloader = DataLoader(test_dataset, **{**cfg.dataloader, 'batch_size': 1}) 54 | test_dataloaders.append(test_dataloader) 55 | 56 | # Evaluation 57 | if cfg.job_type == 'eval': 58 | for dataloader in test_dataloaders: 59 | evaluate_predictions(cfg=cfg, accelerator=accelerator, dataloader_val=dataloader) 60 | else: 61 | raise NotImplementedError() 62 | 63 | 64 | @torch.no_grad() 65 | def evaluate_predictions( 66 | *, 67 | cfg: DictConfig, 68 | dataloader_val: Iterable, 69 | accelerator: Accelerator, 70 | **_unused_kwargs): 71 | 72 | # Evaluate 73 | name = dataloader_val.dataset.name 74 | all_results = [] 75 | pbar = tqdm(dataloader_val, desc=f'Evaluating {name}') 76 | for i, (images, targets, metadatas) in enumerate(pbar): 77 | 78 | # Convert 79 | targets = targets.squeeze(0) 80 | 81 | # Load predictions 82 | id = Path(metadatas['image_file'][0]).stem 83 | predictions_file = os.path.join(cfg.predictions[name], f'{id}.png') 84 | preds = np.array(Image.open(predictions_file).convert('L')) # (H_patch, W_patch) 85 | assert set(np.unique(preds).tolist()) in [{0, 255}, {0, 1}, {0}], set(np.unique(preds).tolist()) 86 | preds[preds == 255] = 1 87 | 88 | # Resize if segmentation is patchwise 89 | if cfg.predictions.downsample is not None: 90 | H, W = targets.shape 91 | H_pred, W_pred = preds.shape 92 | H_pad, W_pad = H_pred * cfg.predictions.downsample, W_pred * cfg.predictions.downsample 93 | H_max, W_max = max(H_pad, H), max(W_pad, W) 94 | preds = cv2.resize(preds, dsize=(W_max, H_max), interpolation=cv2.INTER_NEAREST) 95 | preds[:H_pad, :W_pad] = cv2.resize(preds, dsize=(W_pad, H_pad), interpolation=cv2.INTER_NEAREST) 96 | 97 | # Convert, optional center crop, and unsqueeze 98 | preds = torch.from_numpy(preds) 99 | if dataloader_val.dataset.crop: 100 | preds = TF.center_crop(preds, output_size=min(preds.shape)) 101 | preds = torch.unsqueeze(preds, dim=0) 102 | targets = torch.unsqueeze(targets, dim=0) 103 | 104 | # Compute metrics 105 | results = metrics.compute_metrics(preds=preds, targets=targets, metrics=['acc', 'iou']) 106 | all_results.append(results) 107 | 108 | # Aggregate results 109 | all_results = metrics.list_of_dict_of_lists_to_dict_of_lists(all_results) 110 | results = metrics.aggregate_metrics(all_results) 111 | for metric_name, value in results.items(): 112 | print(f'[{name}] {metric_name}: {value}') 113 | 114 | 115 | @torch.no_grad() 116 | def visualize( 117 | *, 118 | cfg: DictConfig, 119 | model: torch.nn.Module, 120 | dataloader_vis: Iterable, 121 | accelerator: Accelerator, 122 | identifier: str = '', 123 | num_batches: Optional[int] = None, 124 | **_unused_kwargs): 125 | 126 | # Eval mode 127 | model.eval() 128 | metric_logger = utils.MetricLogger(delimiter=" ") 129 | progress_bar = metric_logger.log_every(dataloader_vis, cfg.logging.print_freq, "Vis") 130 | 131 | # Visualize 132 | for batch_idx, (inputs, target) in enumerate(progress_bar): 133 | if num_batches is not None and batch_idx >= num_batches: 134 | break 135 | 136 | # Inverse normalization 137 | Inv = utils.NormalizeInverse(mean=cfg.data.transform.img_mean, std=cfg.data.transform.img_std) 138 | image = Inv(inputs).clamp(0, 1) 139 | vis_dict = dict(image=image) 140 | 141 | # Save images 142 | wandb_log_dict = {} 143 | for name, images in vis_dict.items(): 144 | for i, image in enumerate(images): 145 | pil_image = utils.tensor_to_pil(image) 146 | filename = f'vis/{name}/p-{accelerator.process_index}-b-{batch_idx}-img-{i}-{name}-{identifier}.png' 147 | Path(filename).parent.mkdir(exist_ok=True, parents=True) 148 | pil_image.save(filename) 149 | if i < 2: # log to Weights and Biases 150 | wandb_filename = f'vis/{name}/p-{accelerator.process_index}-b-{batch_idx}-img-{i}-{name}' 151 | wandb_log_dict[wandb_filename] = [wandb.Image(pil_image)] 152 | if cfg.wandb and accelerator.is_local_main_process: 153 | wandb.log(wandb_log_dict, commit=False) 154 | print(f'Saved visualizations to {Path("vis").absolute()}') 155 | 156 | 157 | if __name__ == '__main__': 158 | main() -------------------------------------------------------------------------------- /semantic-segmentation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Iterable, Optional 4 | 5 | import hydra 6 | import numpy as np 7 | import torch 8 | import wandb 9 | from accelerate import Accelerator 10 | from matplotlib.cm import get_cmap 11 | from omegaconf import DictConfig, OmegaConf 12 | from PIL import Image 13 | from skimage.color import label2rgb 14 | from tqdm import tqdm, trange 15 | 16 | import eval_utils 17 | import util as utils 18 | from dataset.voc import VOCSegmentationWithPseudolabels 19 | 20 | 21 | @hydra.main(config_path='config', config_name='eval') 22 | def main(cfg: DictConfig): 23 | 24 | # Accelerator 25 | accelerator = Accelerator(fp16=cfg.fp16, cpu=cfg.cpu) 26 | 27 | # Logging 28 | utils.setup_distributed_print(accelerator.is_local_main_process) 29 | if cfg.wandb and accelerator.is_local_main_process: 30 | wandb.init(name=cfg.name, job_type=cfg.job_type, config=OmegaConf.to_container(cfg), save_code=True, **cfg.wandb_kwargs) 31 | cfg = DictConfig(wandb.config.as_dict()) # get the config back from wandb for hyperparameter sweeps 32 | 33 | # Configuration 34 | print(OmegaConf.to_yaml(cfg)) 35 | print(f'Current working directory: {os.getcwd()}') 36 | 37 | # Set random seed 38 | utils.set_seed(cfg.seed) 39 | 40 | # Create dataset with segments/pseudolabels 41 | dataset_val = VOCSegmentationWithPseudolabels( 42 | **cfg.data.val_kwargs, 43 | segments_dir=cfg.segments_dir, 44 | transform=None, # no transform to evaluate at original resolution 45 | ) 46 | 47 | # Evaluate 48 | eval_stats, match = evaluate(cfg=cfg, dataset_val=dataset_val, n_clusters=cfg.get('n_clusters', None)) 49 | print(eval_stats) 50 | if cfg.wandb and accelerator.is_local_main_process: 51 | wandb.summary['mIoU'] = eval_stats['mIoU'] 52 | 53 | # Visualize 54 | visualize(cfg=cfg, dataset_val=dataset_val) 55 | 56 | 57 | def visualize( 58 | *, 59 | cfg: DictConfig, 60 | dataset_val: Iterable, 61 | vis_dir: str = './vis'): 62 | 63 | # Visualize 64 | num_vis = 40 65 | vis_dir = Path(vis_dir) 66 | colors = get_cmap('tab20', cfg.data.num_classes + 1).colors[:,:3] 67 | pbar = tqdm(dataset_val, total=num_vis, desc='Saving visualizations: ') 68 | for i, (image, target, mask, metadata) in enumerate(pbar): 69 | if i >= num_vis: break 70 | image = np.array(image) 71 | target = np.array(target) 72 | target[target == 255] = 0 # set the "unknown" regions to background for visualization 73 | # Overlay mask on image 74 | image_pred_overlay = label2rgb(label=mask, image=image, colors=colors[np.unique(target)[1:]], bg_label=0, alpha=0.45) 75 | image_target_overlay = label2rgb(label=target, image=image, colors=colors[np.unique(target)[1:]], bg_label=0, alpha=0.45) 76 | # Save 77 | image_id = metadata["id"] 78 | path_pred = vis_dir / 'pred' / f'{image_id}-pred.png' 79 | path_target = vis_dir / 'target' / f'{image_id}-target.png' 80 | path_pred.parent.mkdir(exist_ok=True, parents=True) 81 | path_target.parent.mkdir(exist_ok=True, parents=True) 82 | Image.fromarray((image_pred_overlay * 255).astype(np.uint8)).save(str(path_pred)) 83 | Image.fromarray((image_target_overlay * 255).astype(np.uint8)).save(str(path_target)) 84 | print(f'Saved visualizations to {vis_dir.absolute()}') 85 | 86 | 87 | def evaluate( 88 | *, 89 | cfg: DictConfig, 90 | dataset_val: Iterable, 91 | n_clusters: Optional[int] = None): 92 | 93 | # Add background class 94 | n_classes = cfg.data.num_classes + 1 95 | if n_clusters is None: 96 | n_clusters = n_classes 97 | 98 | # Iterate 99 | tp = [0] * n_classes 100 | fp = [0] * n_classes 101 | fn = [0] * n_classes 102 | 103 | # Load all pixel embeddings 104 | all_preds = np.zeros((len(dataset_val) * 500 * 500), dtype=np.float32) 105 | all_gt = np.zeros((len(dataset_val) * 500 * 500), dtype=np.float32) 106 | offset_ = 0 107 | 108 | # Add all pixels to our arrays 109 | _alread_warned = 0 110 | for i in trange(len(dataset_val), desc='Concatenating all predictions'): 111 | image, target, mask, metadata = dataset_val[i] 112 | # Check where ground-truth is valid and append valid pixels to the array 113 | valid = (target != 255) 114 | n_valid = np.sum(valid) 115 | all_gt[offset_:offset_+n_valid] = target[valid] 116 | # Append the predicted targets in the array 117 | all_preds[offset_:offset_+n_valid, ] = mask[valid] 118 | all_gt[offset_:offset_+n_valid, ] = target[valid] 119 | # Update offset_ 120 | offset_ += n_valid 121 | 122 | # Truncate to the actual number of pixels 123 | all_preds = all_preds[:offset_, ] 124 | all_gt = all_gt[:offset_, ] 125 | 126 | # Do hungarian matching 127 | num_elems = offset_ 128 | if n_clusters == n_classes: 129 | print('Using hungarian algorithm for matching') 130 | match = eval_utils.hungarian_match(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes, metric='iou') 131 | else: 132 | print('Using majority voting for matching') 133 | match = eval_utils.majority_vote(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes) 134 | print(f'Optimal matching: {match}') 135 | 136 | # Remap predictions 137 | reordered_preds = np.zeros(num_elems, dtype=all_preds.dtype) 138 | for pred_i, target_i in match: 139 | reordered_preds[all_preds == int(pred_i)] = int(target_i) 140 | 141 | # TP, FP, and FN evaluation 142 | for i_part in range(0, n_classes): 143 | tmp_all_gt = (all_gt == i_part) 144 | tmp_pred = (reordered_preds == i_part) 145 | tp[i_part] += np.sum(tmp_all_gt & tmp_pred) 146 | fp[i_part] += np.sum(~tmp_all_gt & tmp_pred) 147 | fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred) 148 | 149 | # Calculate Jaccard index 150 | jac = [0] * n_classes 151 | for i_part in range(0, n_classes): 152 | jac[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8) 153 | 154 | # Print results 155 | eval_result = dict() 156 | eval_result['jaccards_all_categs'] = jac 157 | eval_result['mIoU'] = np.mean(jac) 158 | print('Evaluation of semantic segmentation ') 159 | print('mIoU is %.2f' % (100*eval_result['mIoU'])) 160 | return eval_result, match 161 | 162 | 163 | if __name__ == '__main__': 164 | torch.set_grad_enabled(False) 165 | main() 166 | -------------------------------------------------------------------------------- /semantic-segmentation/dataset/voc.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import Any, Callable, Dict, List, Optional, Tuple 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from torchvision.datasets.voc import (DATASET_YEAR_DICT, VisionDataset, os, verify_str_arg) 10 | 11 | 12 | def _resize_pseudolabel(pseudolabel, img): 13 | if ( 14 | (pseudolabel.shape[0] == img.shape[0] // 16) or 15 | (pseudolabel.shape[0] == img.shape[0] // 8) or 16 | (pseudolabel.shape[0] == 2 * (img.shape[0] // 16)) 17 | ): 18 | return cv2.resize(pseudolabel, dsize=img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 19 | return pseudolabel 20 | 21 | 22 | class VOCSegmentationWithPseudolabelsBase(VisionDataset): 23 | 24 | _SPLITS_DIR = "Segmentation" 25 | _TARGET_DIR = "SegmentationClass" 26 | _TARGET_FILE_EXT = ".png" 27 | 28 | def __init__( 29 | self, 30 | root: str, 31 | year: str = "2012", 32 | image_set: str = "train", 33 | download: bool = False, 34 | transform: Optional[Callable] = None, 35 | target_transform: Optional[Callable] = None, 36 | transforms: Optional[Callable] = None, 37 | ): 38 | super().__init__(root, transforms, transform, target_transform) 39 | if year == "2007-test": 40 | if image_set == "test": 41 | warnings.warn( 42 | "Acessing the test image set of the year 2007 with year='2007-test' is deprecated. " 43 | "Please use the combination year='2007' and image_set='test' instead." 44 | ) 45 | year = "2007" 46 | else: 47 | raise ValueError( 48 | "In the test image set of the year 2007 only image_set='test' is allowed. " 49 | "For all other image sets use year='2007' instead." 50 | ) 51 | self.year = year 52 | 53 | valid_image_sets = ["train", "trainval", "val"] 54 | if year == "2007": 55 | valid_image_sets.append("test") 56 | self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets) 57 | 58 | key = "2007-test" if year == "2007" and image_set == "test" else year 59 | dataset_year_dict = DATASET_YEAR_DICT[key] 60 | 61 | self.url = dataset_year_dict["url"] 62 | self.filename = dataset_year_dict["filename"] 63 | self.md5 = dataset_year_dict["md5"] 64 | 65 | base_dir = dataset_year_dict["base_dir"] 66 | voc_root = os.path.join(self.root, base_dir) 67 | 68 | if download: 69 | from torchvision.datasets.voc import download_and_extract_archive 70 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) 71 | 72 | if not os.path.isdir(voc_root): 73 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 74 | 75 | splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR) 76 | split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt") 77 | 78 | if self.image_set == 'train': # everything except val 79 | image_dir = os.path.join(voc_root, "JPEGImages") 80 | with open(os.path.join(splits_dir, "val.txt"), "r") as f: 81 | val_file_stems = set([stem.strip() for stem in f.readlines()]) 82 | all_image_paths = [p for p in Path(image_dir).iterdir()] 83 | train_image_paths = [str(p) for p in all_image_paths if p.stem not in val_file_stems] 84 | self.images = sorted(train_image_paths) 85 | # For the targets, we will just replicate the same target however many times 86 | target_dir = os.path.join(voc_root, self._TARGET_DIR) 87 | self.targets = [str(next(Path(target_dir).iterdir()))] * len(self.images) 88 | 89 | else: 90 | 91 | with open(os.path.join(split_f), "r") as f: 92 | file_names = [x.strip() for x in f.readlines()] 93 | 94 | image_dir = os.path.join(voc_root, "JPEGImages") 95 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] 96 | 97 | target_dir = os.path.join(voc_root, self._TARGET_DIR) 98 | self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names] 99 | 100 | assert len(self.images) == len(self.targets), ( len(self.images), len(self.targets)) 101 | 102 | @property 103 | def masks(self) -> List[str]: 104 | return self.targets 105 | 106 | def _prepare_label_map(self, label_map): 107 | if label_map is not None: 108 | self.label_map_fn = np.vectorize(label_map.__getitem__) 109 | else: 110 | self.label_map_fn = None 111 | 112 | def _prepare_segments_dir(self, segments_dir): 113 | self.segments_dir = segments_dir 114 | # Get segment and image files, which are assumed to be in correspondence 115 | all_segment_files = sorted(map(str, Path(segments_dir).iterdir())) 116 | all_img_files = sorted(Path(self.images[0]).parent.iterdir()) 117 | assert len(all_img_files) == len(all_segment_files), (len(all_img_files), len(all_segment_files)) 118 | # Create mapping because I named the segment files badly (sequentially instead of by image id) 119 | all_img_stems = [p.stem for p in all_img_files] 120 | valid_img_stems = set([Path(p).stem for p in self.images]) # in our split (e.g. 'val') 121 | segment_files = [] 122 | for i in range(len(all_img_stems)): 123 | if all_img_stems[i] in valid_img_stems: 124 | segment_files.append(all_segment_files[i]) 125 | self.segments = segment_files 126 | assert len(self.segments) == len(self.images), f'{len(self.segments)=} and {len(self.images)=}' 127 | print('Loaded segments and images') 128 | print(f'First image filepath: {self.images[0]}') 129 | print(f'First segmap filepath: {self.segments[0]}') 130 | print(f'Last image filepath: {self.images[-1]}') 131 | print(f'Last segmap filepath: {self.segments[-1]}') 132 | 133 | def _load(self, index: int): 134 | # Load image 135 | img = np.array(Image.open(self.images[index]).convert("RGB")) 136 | target = np.array(Image.open(self.masks[index])) 137 | metadata = {'id': Path(self.images[index]).stem, 'path': self.images[index], 'shape': tuple(img.shape[:2])} 138 | # New: load segmap and accompanying metedata 139 | pseudolabel = np.array(Image.open(self.segments[index])) 140 | pseudolabel = _resize_pseudolabel(pseudolabel, img) # HACK HACK HACK 141 | if self.label_map_fn is not None: 142 | pseudolabel = self.label_map_fn(pseudolabel) 143 | return (img, target, pseudolabel, metadata) 144 | 145 | def __len__(self) -> int: 146 | return len(self.images) 147 | 148 | 149 | class VOCSegmentationWithPseudolabels(VOCSegmentationWithPseudolabelsBase): 150 | 151 | def __init__(self, *args, segments_dir, transform = None, label_map = None, **kwargs): 152 | super().__init__(*args, **kwargs) 153 | self._prepare_segments_dir(segments_dir) 154 | self.transform = transform 155 | self._prepare_label_map(label_map) 156 | 157 | def __getitem__(self, index: int): 158 | img, target, pseudolabel, metadata = self._load(index) 159 | if self.transform is not None: 160 | # Transform 161 | data = self.transform(image=img, mask1=target, mask2=pseudolabel) 162 | # Unpack 163 | img, target, pseudolabel = data['image'], data['mask1'], data['mask2'] 164 | if torch.is_tensor(target): 165 | target = target.long() 166 | if torch.is_tensor(pseudolabel): 167 | pseudolabel = pseudolabel.long() 168 | return img, target, pseudolabel, metadata 169 | -------------------------------------------------------------------------------- /extract/extract_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from multiprocessing import Pool 4 | from pathlib import Path 5 | from typing import Any, Callable, Iterable, Optional, Tuple, Union 6 | 7 | import cv2 8 | import numpy as np 9 | import scipy.sparse 10 | import torch 11 | from skimage.morphology import binary_dilation, binary_erosion 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms 14 | from tqdm import tqdm 15 | 16 | 17 | class ImagesDataset(Dataset): 18 | """A very simple dataset for loading images.""" 19 | 20 | def __init__(self, filenames: str, images_root: Optional[str] = None, transform: Optional[Callable] = None, 21 | prepare_filenames: bool = True) -> None: 22 | self.root = None if images_root is None else Path(images_root) 23 | self.filenames = sorted(list(set(filenames))) if prepare_filenames else filenames 24 | self.transform = transform 25 | 26 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 27 | path = self.filenames[index] 28 | full_path = Path(path) if self.root is None else self.root / path 29 | assert full_path.is_file(), f'Not a file: {full_path}' 30 | image = cv2.imread(str(full_path)) 31 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 32 | if self.transform is not None: 33 | image = self.transform(image) 34 | return image, path, index 35 | 36 | def __len__(self) -> int: 37 | return len(self.filenames) 38 | 39 | 40 | def get_model(name: str): 41 | if 'dino' in name: 42 | model = torch.hub.load('facebookresearch/dino:main', name) 43 | model.fc = torch.nn.Identity() 44 | val_transform = get_transform(name) 45 | patch_size = model.patch_embed.patch_size 46 | num_heads = model.blocks[0].attn.num_heads 47 | else: 48 | raise ValueError(f'Cannot get model: {name}') 49 | model = model.eval() 50 | return model, val_transform, patch_size, num_heads 51 | 52 | 53 | def get_transform(name: str): 54 | if any(x in name for x in ('dino', 'mocov3', 'convnext', )): 55 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 56 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 57 | else: 58 | raise NotImplementedError() 59 | return transform 60 | 61 | 62 | def get_inverse_transform(name: str): 63 | if 'dino' in name: 64 | inv_normalize = transforms.Normalize( 65 | [-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], 66 | [1 / 0.229, 1 / 0.224, 1 / 0.225]) 67 | transform = transforms.Compose([transforms.ToTensor(), inv_normalize]) 68 | else: 69 | raise NotImplementedError() 70 | return transform 71 | 72 | 73 | def get_image_sizes(data_dict: dict, downsample_factor: Optional[int] = None): 74 | P = data_dict['patch_size'] if downsample_factor is None else downsample_factor 75 | B, C, H, W = data_dict['shape'] 76 | assert B == 1, 'assumption violated :(' 77 | H_patch, W_patch = H // P, W // P 78 | H_pad, W_pad = H_patch * P, W_patch * P 79 | return (B, C, H, W, P, H_patch, W_patch, H_pad, W_pad) 80 | 81 | 82 | def _get_files(p: str): 83 | if Path(p).is_dir(): 84 | return sorted(Path(p).iterdir()) 85 | elif Path(p).is_file(): 86 | return Path(p).read_text().splitlines() 87 | else: 88 | raise ValueError(p) 89 | 90 | 91 | def get_paired_input_files(path1: str, path2: str): 92 | files1 = _get_files(path1) 93 | files2 = _get_files(path2) 94 | assert len(files1) == len(files2) 95 | return list(enumerate(zip(files1, files2))) 96 | 97 | 98 | def make_output_dir(output_dir, check_if_empty=True): 99 | output_dir = Path(output_dir) 100 | output_dir.mkdir(exist_ok=True, parents=True) 101 | if check_if_empty and (len(list(output_dir.iterdir())) > 0): 102 | print(f'Output dir: {str(output_dir)}') 103 | if input(f'Output dir already contains files. Continue? (y/n) >> ') != 'y': 104 | sys.exit() # skip because already generated 105 | 106 | 107 | def get_largest_cc(mask: np.array): 108 | from skimage.measure import label as measure_label 109 | labels = measure_label(mask) # get connected components 110 | largest_cc_index = np.argmax(np.bincount(labels.flat)[1:]) + 1 111 | largest_cc_mask = (labels == largest_cc_index) 112 | return largest_cc_mask 113 | 114 | 115 | def erode_or_dilate_mask(x: Union[torch.Tensor, np.ndarray], r: int = 0, erode=True): 116 | fn = binary_erosion if erode else binary_dilation 117 | for _ in range(r): 118 | x_new = fn(x) 119 | if x_new.sum() > 0: # do not erode the entire mask away 120 | x = x_new 121 | return x 122 | 123 | 124 | def get_border_fraction(segmap: np.array): 125 | num_border_pixels = 2 * (segmap.shape[0] + segmap.shape[1]) 126 | counts_map = {idx: 0 for idx in np.unique(segmap)} 127 | np.zeros(len(np.unique(segmap))) 128 | for border in [segmap[:, 0], segmap[:, -1], segmap[0, :], segmap[-1, :]]: 129 | unique, counts = np.unique(border, return_counts=True) 130 | for idx, count in zip(unique.tolist(), counts.tolist()): 131 | counts_map[idx] += count 132 | # normlized_counts_map = {idx: count / num_border_pixels for idx, count in counts_map.items()} 133 | indices = np.array(list(counts_map.keys())) 134 | normlized_counts = np.array(list(counts_map.values())) / num_border_pixels 135 | return indices, normlized_counts 136 | 137 | 138 | def parallel_process(inputs: Iterable, fn: Callable, multiprocessing: int = 0): 139 | start = time.time() 140 | if multiprocessing: 141 | print('Starting multiprocessing') 142 | with Pool(multiprocessing) as pool: 143 | for _ in tqdm(pool.imap(fn, inputs), total=len(inputs)): 144 | pass 145 | else: 146 | for inp in tqdm(inputs): 147 | fn(inp) 148 | print(f'Finished in {time.time() - start:.1f}s') 149 | 150 | 151 | def knn_affinity(image, n_neighbors=[20, 10], distance_weights=[2.0, 0.1]): 152 | """Computes a KNN-based affinity matrix. Note that this function requires pymatting""" 153 | try: 154 | from pymatting.util.kdtree import knn 155 | except: 156 | raise ImportError( 157 | 'Please install pymatting to compute KNN affinity matrices:\n' 158 | 'pip3 install pymatting' 159 | ) 160 | 161 | h, w = image.shape[:2] 162 | r, g, b = image.reshape(-1, 3).T 163 | n = w * h 164 | 165 | x = np.tile(np.linspace(0, 1, w), h) 166 | y = np.repeat(np.linspace(0, 1, h), w) 167 | 168 | i, j = [], [] 169 | 170 | for k, distance_weight in zip(n_neighbors, distance_weights): 171 | f = np.stack( 172 | [r, g, b, distance_weight * x, distance_weight * y], 173 | axis=1, 174 | out=np.zeros((n, 5), dtype=np.float32), 175 | ) 176 | 177 | distances, neighbors = knn(f, f, k=k) 178 | 179 | i.append(np.repeat(np.arange(n), k)) 180 | j.append(neighbors.flatten()) 181 | 182 | ij = np.concatenate(i + j) 183 | ji = np.concatenate(j + i) 184 | coo_data = np.ones(2 * sum(n_neighbors) * n) 185 | 186 | # This is our affinity matrix 187 | W = scipy.sparse.csr_matrix((coo_data, (ij, ji)), (n, n)) 188 | return W 189 | 190 | 191 | def rw_affinity(image, sigma=0.033, radius=1): 192 | """Computes a random walk-based affinity matrix. Note that this function requires pymatting""" 193 | try: 194 | from pymatting.laplacian.rw_laplacian import _rw_laplacian 195 | except: 196 | raise ImportError( 197 | 'Please install pymatting to compute RW affinity matrices:\n' 198 | 'pip3 install pymatting' 199 | ) 200 | h, w = image.shape[:2] 201 | n = h * w 202 | values, i_inds, j_inds = _rw_laplacian(image, sigma, radius) 203 | W = scipy.sparse.csr_matrix((values, (i_inds, j_inds)), shape=(n, n)) 204 | return W 205 | 206 | 207 | def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12): 208 | """Gets the diagonal sum of a sparse matrix""" 209 | try: 210 | from pymatting.util.util import row_sum 211 | except: 212 | raise ImportError( 213 | 'Please install pymatting to compute the diagonal sums:\n' 214 | 'pip3 install pymatting' 215 | ) 216 | 217 | D = row_sum(W) 218 | D[D < threshold] = 1.0 # Prevent division by zero. 219 | D = scipy.sparse.diags(D) 220 | return D 221 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## Deep Spectral Methods for Unsupervised Localization and Segmentation (CVPR 2022 - Oral) 4 | 5 | [![Project](http://img.shields.io/badge/Project%20Page-3d3d8f.svg)](https://lukemelas.github.io/deep-spectral-segmentation/) 6 | [![Demo](http://img.shields.io/badge/Demo-9acbff.svg)](https://huggingface.co/spaces/lukemelas/deep-spectral-segmentation) 7 | [![Conference](http://img.shields.io/badge/CVPR-2022-4b44ce.svg)](#) 8 | [![Paper](http://img.shields.io/badge/Paper-arxiv.1001.2234-B31B1B.svg)](#) 9 | 10 |
11 | 12 | ### Description 13 | This code accompanies the paper [Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization](https://lukemelas.github.io/deep-spectral-segmentation/). 14 | 15 | ### Abstract 16 | 17 | Unsupervised localization and segmentation are long-standing computer vision challenges that involve decomposing an image into semantically-meaningful segments without any labeled data. These tasks are particularly interesting in an unsupervised setting due to the difficulty and cost of obtaining dense image annotations, but existing unsupervised approaches struggle with complex scenes containing multiple objects. Differently from existing methods, which are purely based on deep learning, we take inspiration from traditional spectral segmentation methods by reframing image decomposition as a graph partitioning problem. Specifically, we examine the eigenvectors of the Laplacian of a feature affinity matrix from self-supervised networks. We find that these eigenvectors already decompose an image into meaningful segments, and can be readily used to localize objects in a scene. Furthermore, by clustering the features associated with these segments across a dataset, we can obtain well-delineated, nameable regions, i.e. semantic segmentations. Experiments on complex datasets (Pascal VOC, MS-COCO) demonstrate that our simple spectral method outperforms the state-of-the-art in unsupervised localization and segmentation by a significant margin. Furthermore, our method can be readily used for a variety of complex image editing tasks, such as background removal and compositing. 18 | 19 | ### Demo 20 | Please check out our interactive demo on [Huggingface Spaces](https://huggingface.co/spaces/lukemelas/deep-spectral-segmentation)! The demo enables you to upload an image and outputs the eigenvectors extracted by our method. It does not perform the downstream tasks in our paper (e.g. semantic segmentation), but it should give you some intuition for how you might use utilize our method for your own research/use-case. 21 | 22 | ### Examples 23 | 24 | ![Examples](https://lukemelas.github.io/deep-spectral-segmentation/images/example.png) 25 | 26 | ### How to run 27 | 28 | #### Dependencies 29 | The minimal set of dependencies is listed in `requirements.txt`. 30 | 31 | #### Data Preparation 32 | 33 | The data preparation process simply consists of collecting your images into a single folder. Here, we describe the process for [Pascal VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012//). Pascal VOC 2007 and MS-COCO are similar. 34 | 35 | Download the images into a single folder. Then create a text file where each line contains the name of an image file. For example, here is our initial data layout: 36 | ``` 37 | data 38 | └── VOC2012 39 |    ├── images 40 | │ └── {image_id}.jpg 41 | └── lists 42 | └── images.txt 43 | ``` 44 | 45 | #### Extraction 46 | 47 | We first extract features from images and stores these into files. We then extract eigenvectors from these features. Once we have the eigenvectors, we can perform downstream tasks such as object segmentation and object localization. 48 | 49 | The primary script for this extraction process is `extract.py` in the `extract/` directory. All functions in `extract.py` have helpful docstrings with example usage. 50 | 51 | ##### Step 1: Feature Extraction 52 | 53 | First, we extract features from our images and save them to `.pth` files. 54 | 55 | With regard to models, our repository currently only supports DINO, but other models are easy to add (see the `get_model` function in `extract_utils.py`). The DINO model is downloaded automatically using `torch.hub`. 56 | 57 | Here is an example using `dino_vits16`: 58 | 59 | ```bash 60 | python extract.py extract_features \ 61 | --images_list "./data/VOC2012/lists/images.txt" \ 62 | --images_root "./data/VOC2012/images" \ 63 | --output_dir "./data/VOC2012/features/dino_vits16" \ 64 | --model_name dino_vits16 \ 65 | --batch_size 1 66 | ``` 67 | 68 | ##### Step 2: Eigenvector Computation 69 | 70 | Second, we extract eigenvectors from our features and save them to `.pth` files. 71 | 72 | Here, we extract the top `K=5` eigenvectors of the Laplacian matrix of our features: 73 | 74 | ```bash 75 | python extract.py extract_eigs \ 76 | --images_root "./data/VOC2012/images" \ 77 | --features_dir "./data/VOC2012/features/dino_vits16" \ 78 | --which_matrix "laplacian" \ 79 | --output_dir "./data/VOC2012/eigs/laplacian" \ 80 | --K 5 81 | ``` 82 | 83 | The final data structure after extracting eigenvectors looks like: 84 | ``` 85 | data 86 | ├── VOC2012 87 | │   ├── eigs 88 | │ │ └── {outpur_dir_name} 89 | │ │ └── {image_id}.pth 90 | │   ├── features 91 | │ │ └── {model_name} 92 | │ │ └── {image_id}.pth 93 | │   ├── images 94 | │ │ └── {image_id}.jpg 95 | │ └── lists 96 | │ └── images.txt 97 | └── VOC2007 98 | └── ... 99 | ``` 100 | 101 | At this point, you are ready to use the eigenvectors for downstream tasks such as object localization, object segmentation, and semantic segmentation. 102 | 103 | #### Object Localization 104 | 105 | First, clone the `dino` repo inside this project root (or symlink it). 106 | ```bash 107 | git clone https://github.com/facebookresearch/dino 108 | ``` 109 | 110 | Run the steps above to save your eigenvectors inside a directory, which we will now call `${EIGS_DIR}`. You can then move to the `object-localization` directory and evaluate object localization with: 111 | ```bash 112 | python main.py \ 113 | --eigenseg \ 114 | --precomputed_eigs_dir ${EIGS_DIR} \ 115 | --dataset VOC12 \ 116 | --name "example_eigs" 117 | ``` 118 | 119 | #### Object Segmentation 120 | 121 | To perform object segmentation (i.e. single-region segmentations), you first extract features and eigenvectors (as described above). You then extract coarse (i.e. patch-level) single-region segmentations from the eigenvectors, and then turn these into high-resolution segmentations using a CRF. 122 | 123 | Below, we will give example commands for the CUB bird dataset (`CUB_200_2011`). To download this dataset, as well as the three other object segmentation datasets used in our paper, you can follow the instructions in [unsupervised-image-segmentation](https://github.com/lukemelas/unsupervised-image-segmentation). Then make sure to specify the `data_root` parameter in the `config/eval.yaml`. 124 | 125 | For example: 126 | ```bash 127 | 128 | # Example dataset 129 | DATASET=CUB_200_2011 130 | 131 | # Features 132 | python extract.py extract_features \ 133 | --images_list "./data/object-segmentation/${DATASET}/lists/images.txt" \ 134 | --images_root "./data/object-segmentation/${DATASET}/images" \ 135 | --output_dir "./data/object-segmentation/${DATASET}/features/dino_vits16" \ 136 | --model_name dino_vits16 \ 137 | --batch_size 1 138 | 139 | # Eigenvectors 140 | python extract.py extract_eigs \ 141 | --images_root "./data/object-segmentation/${DATASET}/images" \ 142 | --features_dir "./data/object-segmentation/${DATASET}/features/dino_vits16/" \ 143 | --which_matrix "laplacian" \ 144 | --output_dir "./data/object-segmentation/${DATASET}/eigs/laplacian_dino_vits16" \ 145 | --K 2 \ 146 | 147 | 148 | # Extract single-region segmentatiosn 149 | python extract.py extract_single_region_segmentations \ 150 | --features_dir "./data/object-segmentation/${DATASET}/features/dino_vits16" \ 151 | --eigs_dir "./data/object-segmentation/${DATASET}/eigs/laplacian_dino_vits16" \ 152 | --output_dir "./data/object-segmentation/${DATASET}/single_region_segmentation/patches/laplacian_dino_vits16" 153 | 154 | # With CRF 155 | # Optionally, you can also use `--multiprocessing 64` to speed up computation by running on 64 processes 156 | python extract.py extract_crf_segmentations \ 157 | --images_list "./data/object-segmentation/${DATASET}/lists/images.txt" \ 158 | --images_root "./data/object-segmentation/${DATASET}/images" \ 159 | --segmentations_dir "./data/object-segmentation/${DATASET}/single_region_segmentation/patches/laplacian_dino_vits16" \ 160 | --output_dir "./data/object-segmentation/${DATASET}/single_region_segmentation/crf/laplacian_dino_vits16" \ 161 | --downsample_factor 16 \ 162 | --num_classes 2 163 | ``` 164 | 165 | After this extraction process, you should have a file with full-resolution segmentations. Then to evaluate on object segmentation, you can move into the `object-segmentation` directory and run `python main.py`. For example: 166 | 167 | ```bash 168 | python main.py predictions.root="./data/object-segmentation" predictions.run="single_region_segmentation/crf/laplacian_dino_vits16" 169 | ``` 170 | 171 | By default, this assumes that all four object segmentations are available. To run on a custom dataset or only a subset of these datasets, simply edit `configs/eval.yaml`. 172 | 173 | Also, if you want to visualize your segmentations, you should be able to use `streamlit run extract.py vis_segmentations` (after installing streamlit). 174 | 175 | #### Semantic Segmentation 176 | 177 | For semantic segmentation, we provide full instructions in the `semantic-segmentation` subfolder. 178 | 179 | #### Acknowledgements 180 | 181 | L. M. K. acknowledges the generous support of the Rhodes Trust. C. R. is supported by Innovate UK (project 71653) on behalf of UK Research and Innovation (UKRI) and by the European Research Council (ERC) IDIU-638009. I. L. and A. V. are supported by the VisualAI EPSRC programme grant (EP/T028572/1). 182 | 183 | We would like to acknowledge LOST ([paper](https://arxiv.org/abs/2109.14279) and [code](https://github.com/valeoai/LOST)), whose code we adapt for our object localization experiments. If you are interested in object localization, we suggest checking out their work! 184 | 185 | #### Citation 186 | ``` 187 | @inproceedings{ 188 | melaskyriazi2022deep, 189 | title={Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization} 190 | author={Luke Melas-Kyriazi and Christian Rupprecht and Iro Laina and Andrea Vedaldi} 191 | year={2022} 192 | booktitle={CVPR} 193 | } 194 | ``` 195 | -------------------------------------------------------------------------------- /object-localization/object_discovery.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main functions for object discovery. 3 | Code adapted from LOST: https://github.com/valeoai/LOST 4 | """ 5 | from collections import namedtuple 6 | from typing import Optional, Tuple 7 | import torch 8 | import torch.nn.functional as F 9 | import scipy 10 | import scipy.ndimage 11 | 12 | import numpy as np 13 | from datasets import bbox_iou 14 | 15 | 16 | def get_eigenvectors_from_features(feats, which_matrix: str = 'affinity_torch', K=2): 17 | from scipy.sparse.linalg import eigsh 18 | 19 | # Eigenvectors of affinity matrix 20 | if which_matrix == 'affinity_torch': 21 | A = feats @ feats.T 22 | eigenvalues, eigenvectors = torch.eig(A, eigenvectors=True) 23 | 24 | # Eigenvectors of affinity matrix with scipy 25 | elif which_matrix == 'affinity': 26 | A = (feats @ feats.T).cpu().numpy() 27 | eigenvalues, eigenvectors = eigsh(A, which='LM', k=K) 28 | eigenvectors = torch.flip(torch.from_numpy(eigenvectors), dims=(-1,)) 29 | 30 | # Eigenvectors of laplacian matrix 31 | elif which_matrix == 'laplacian': 32 | A = (feats @ feats.T).cpu().numpy() 33 | _W_semantic = (A * (A > 0)) 34 | _W_semantic = _W_semantic / _W_semantic.max() 35 | diag = _W_semantic @ np.ones(_W_semantic.shape[0]) 36 | diag[diag < 1e-12] = 1.0 37 | D = np.diag(diag) # row sum 38 | try: 39 | eigenvalues, eigenvectors = eigsh(D - _W_semantic, k=K, sigma=0, which='LM', M=D) 40 | except: 41 | eigenvalues, eigenvectors = eigsh(D - _W_semantic, k=K, which='SM', M=D) 42 | eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float() 43 | 44 | # Eigenvectors of matting laplacian matrix 45 | elif which_matrix == 'matting_laplacian': 46 | 47 | raise NotImplementedError() 48 | 49 | # # Get sizes 50 | # B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict) 51 | # H_pad_lr, W_pad_lr = H_pad // image_downsample_factor, W_pad // image_downsample_factor 52 | 53 | # # Load image 54 | # image_file = str(Path(images_root) / f'{image_id}.jpg') 55 | # image_lr = Image.open(image_file).resize((W_pad_lr, H_pad_lr), Image.BILINEAR) 56 | # image_lr = np.array(image_lr) / 255. 57 | 58 | # # Get color affinities 59 | # W_lr = utils.knn_affinity(image_lr / 255) 60 | 61 | # # Get semantic affinities 62 | # k_feats_lr = F.interpolate( 63 | # k_feats.T.reshape(1, -1, H_patch, W_patch), 64 | # size=(H_pad_lr, W_pad_lr), mode='bilinear', align_corners=False 65 | # ).reshape(-1, H_pad_lr * W_pad_lr).T 66 | # A_sm_lr = k_feats_lr @ k_feats_lr.T 67 | # W_sm_lr = (A_sm_lr * (A_sm_lr > 0)).cpu().numpy() 68 | # W_sm_lr = W_sm_lr / W_sm_lr.max() 69 | 70 | # # Combine 71 | # W_color = np.array(W_lr.todense().astype(np.float32)) 72 | # W_comb = W_sm_lr + W_color * image_color_lambda # combination 73 | # D_comb = utils.get_diagonal(W_comb) 74 | 75 | # # Extract eigenvectors 76 | # try: 77 | # eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM', M=D_comb) 78 | # except: 79 | # eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM', M=D_comb) 80 | # eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float() 81 | 82 | return eigenvectors 83 | 84 | 85 | def get_bbox_from_patch_mask(patch_mask, init_image_size, img_np: Optional[np.array] = None): 86 | 87 | # Sizing 88 | H, W = init_image_size[1:] 89 | T = patch_mask.numel() 90 | if (H // 8) * (W // 8) == T: 91 | P, H_lr, W_lr = (8, H // 8, W // 8) 92 | elif (H // 16) * (W // 16) == T: 93 | P, H_lr, W_lr = (16, H // 16, W // 16) 94 | elif 4 * (H // 16) * (W // 16) == T: 95 | P, H_lr, W_lr = (8, 2 * (H // 16), 2 * (W // 16)) 96 | elif 16 * (H // 32) * (W // 32) == T: 97 | P, H_lr, W_lr = (8, 4 * (H // 32), 4 * (W // 32)) 98 | else: 99 | raise ValueError(f'{init_image_size=}, {patch_mask.shape=}') 100 | 101 | # Create patch mask 102 | patch_mask = patch_mask.reshape(H_lr, W_lr).cpu().numpy() 103 | 104 | # Possibly reverse mask 105 | # print(np.mean(patch_mask).item()) 106 | if 0.5 < np.mean(patch_mask).item() < 1.0: 107 | patch_mask = (1 - patch_mask).astype(np.uint8) 108 | elif np.sum(patch_mask).item() == 0: # nothing detected at all, so cover the entire image 109 | patch_mask = (1 - patch_mask).astype(np.uint8) 110 | 111 | # Get the box corresponding to the largest connected component of the first eigenvector 112 | xmin, ymin, xmax, ymax = get_largest_cc_box(patch_mask) 113 | # pred = [xmin, ymin, xmax, ymax] 114 | 115 | # Rescale to image size 116 | r_xmin, r_xmax = P * xmin, P * xmax 117 | r_ymin, r_ymax = P * ymin, P * ymax 118 | 119 | # Prediction bounding box 120 | pred = [r_xmin, r_ymin, r_xmax, r_ymax] 121 | 122 | # Check not out of image size (used when padding) 123 | pred[2] = min(pred[2], W) 124 | pred[3] = min(pred[3], H) 125 | 126 | return np.asarray(pred) 127 | 128 | 129 | def lost(feats, dims, scales, init_image_size, k_patches=100): 130 | """ 131 | Implementation of LOST method. 132 | Inputs 133 | feats: the pixel/patche features of an image 134 | dims: dimension of the map from which the features are used 135 | scales: from image to map scale 136 | init_image_size: size of the image 137 | k_patches: number of k patches retrieved that are compared to the seed at seed expansion 138 | Outputs 139 | pred: box predictions 140 | A: binary affinity matrix 141 | scores: lowest degree scores for all patches 142 | seed: selected patch corresponding to an object 143 | """ 144 | # Compute the similarity 145 | A = (feats @ feats.transpose(1, 2)).squeeze() 146 | 147 | # Compute the inverse degree centrality measure per patch 148 | sorted_patches, scores = patch_scoring(A) 149 | 150 | # Select the initial seed 151 | seed = sorted_patches[0] 152 | 153 | # Seed expansion 154 | potentials = sorted_patches[:k_patches] 155 | similars = potentials[A[seed, potentials] > 0.0] 156 | M = torch.sum(A[similars, :], dim=0) 157 | 158 | # Box extraction 159 | pred, _ = detect_box( 160 | M, seed, dims, scales=scales, initial_im_size=init_image_size[1:] 161 | ) 162 | 163 | return np.asarray(pred), A, M, scores, seed 164 | 165 | 166 | def patch_scoring(M, threshold=0.): 167 | """ 168 | Patch scoring based on the inverse degree. 169 | """ 170 | # Cloning important 171 | A = M.clone() 172 | 173 | # Zero diagonal 174 | A.fill_diagonal_(0) 175 | 176 | # Make sure symmetric and non nul 177 | A[A < 0] = 0 178 | C = A + A.t() # NOTE: this was not used. should this be used? 179 | 180 | # Sort pixels by inverse degree 181 | cent = -torch.sum(A > threshold, dim=1).type(torch.float32) 182 | sel = torch.argsort(cent, descending=True) 183 | 184 | return sel, cent 185 | 186 | 187 | def detect_box(A, seed, dims, initial_im_size=None, scales=None): 188 | """ 189 | Extract a box corresponding to the seed patch. Among connected components extract from the affinity matrix, select the one corresponding to the seed patch. 190 | """ 191 | w_featmap, h_featmap = dims 192 | 193 | correl = A.reshape(w_featmap, h_featmap).float() 194 | 195 | # Compute connected components 196 | labeled_array, num_features = scipy.ndimage.label(correl.cpu().numpy() > 0.0) 197 | 198 | # Find connected component corresponding to the initial seed 199 | cc = labeled_array[np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap))] 200 | 201 | # Should not happen with LOST 202 | if cc == 0: 203 | raise ValueError("The seed is in the background component.") 204 | 205 | # Find box 206 | mask = np.where(labeled_array == cc) 207 | 208 | # Add +1 because excluded max 209 | ymin, ymax = min(mask[0]), max(mask[0]) + 1 210 | xmin, xmax = min(mask[1]), max(mask[1]) + 1 211 | 212 | # Rescale to image size 213 | r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax 214 | r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax 215 | 216 | pred = [r_xmin, r_ymin, r_xmax, r_ymax] 217 | 218 | # Check not out of image size (used when padding) 219 | if initial_im_size: 220 | pred[2] = min(pred[2], initial_im_size[1]) 221 | pred[3] = min(pred[3], initial_im_size[0]) 222 | 223 | # Coordinate predictions for the feature space 224 | # Axis different then in image space 225 | pred_feats = [ymin, xmin, ymax, xmax] 226 | 227 | return pred, pred_feats 228 | 229 | 230 | def dino_seg(attn, dims, patch_size, head=0): 231 | """ 232 | Extraction of boxes based on the DINO segmentation method proposed in https://github.com/facebookresearch/dino. 233 | """ 234 | w_featmap, h_featmap = dims 235 | nh = attn.shape[1] 236 | official_th = 0.6 237 | 238 | # We keep only the output patch attention 239 | # Get the attentions corresponding to [CLS] token 240 | attentions = attn[0, :, 0, 1:].reshape(nh, -1) 241 | 242 | # we keep only a certain percentage of the mass 243 | val, idx = torch.sort(attentions) 244 | val /= torch.sum(val, dim=1, keepdim=True) 245 | cumval = torch.cumsum(val, dim=1) 246 | th_attn = cumval > (1 - official_th) 247 | idx2 = torch.argsort(idx) 248 | for h in range(nh): 249 | th_attn[h] = th_attn[h][idx2[h]] 250 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 251 | 252 | # Connected components 253 | labeled_array, num_features = scipy.ndimage.label(th_attn[head].cpu().numpy()) 254 | 255 | # Find the biggest component 256 | size_components = [np.sum(labeled_array == c) for c in range(np.max(labeled_array))] 257 | 258 | if len(size_components) > 1: 259 | # Select the biggest component avoiding component 0 corresponding to background 260 | biggest_component = np.argmax(size_components[1:]) + 1 261 | else: 262 | # Cases of a single component 263 | biggest_component = 0 264 | 265 | # Mask corresponding to connected component 266 | mask = np.where(labeled_array == biggest_component) 267 | 268 | # Add +1 because excluded max 269 | ymin, ymax = min(mask[0]), max(mask[0]) + 1 270 | xmin, xmax = min(mask[1]), max(mask[1]) + 1 271 | 272 | # Rescale to image 273 | r_xmin, r_xmax = xmin * patch_size, xmax * patch_size 274 | r_ymin, r_ymax = ymin * patch_size, ymax * patch_size 275 | pred = [r_xmin, r_ymin, r_xmax, r_ymax] 276 | 277 | return pred 278 | 279 | 280 | def get_largest_cc_box(mask: np.array): 281 | from skimage.measure import label as measure_label 282 | labels = measure_label(mask) # get connected components 283 | largest_cc_index = np.argmax(np.bincount(labels.flat)[1:]) + 1 284 | mask = np.where(labels == largest_cc_index) 285 | ymin, ymax = min(mask[0]), max(mask[0]) + 1 286 | xmin, xmax = min(mask[1]), max(mask[1]) + 1 287 | return [xmin, ymin, xmax, ymax] 288 | -------------------------------------------------------------------------------- /object-segmentation/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers, mostly from torchvision 3 | """ 4 | import time 5 | import datetime 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torchvision 11 | from dataclasses import dataclass 12 | from collections import defaultdict, deque 13 | from typing import Callable, Optional 14 | from PIL import Image 15 | from accelerate import Accelerator 16 | from omegaconf import DictConfig 17 | 18 | 19 | @dataclass 20 | class TrainState: 21 | epoch: int = 0 22 | step: int = 0 23 | best_val: Optional[float] = None 24 | 25 | 26 | def get_optimizer(cfg: DictConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer: 27 | # Determine the learning rate 28 | if cfg.optimizer.scale_learning_rate_with_batch_size: 29 | lr = accelerator.state.num_processes * cfg.data.loader.batch_size * cfg.optimizer.base_lr 30 | print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format( 31 | ws=accelerator.state.num_processes, bs=cfg.data.loader.batch_size, blr=cfg.lr, lr=lr)) 32 | else: # scale base learning rate by batch size 33 | lr = cfg.lr 34 | print('lr = {lr} (absolute learning rate)'.format(lr=lr)) 35 | # Construct optimizer 36 | if cfg.optimizer.kind == 'torch': 37 | parameters = [p for p in model.parameters() if p.requires_grad] 38 | optimizer = getattr(torch.optim, cfg.optimizer.cls)(parameters, lr=lr, **cfg.optimizer.kwargs) 39 | elif cfg.optimizer.kind == 'timm': 40 | from timm.optim import create_optimizer_v2 41 | optimizer = create_optimizer_v2(model, lr=lr, **cfg.optimizer.kwargs) 42 | elif cfg.optimizer.kind == 'transformers': 43 | import transformers 44 | parameters = [p for p in model.parameters() if p.requires_grad] 45 | optimizer = getattr(transformers, cfg.optimizer.name)(parameters, lr=lr, **cfg.optimizer.kwargs) 46 | else: 47 | raise NotImplementedError(f'invalid optimizer config: {cfg.optimizer}') 48 | return optimizer 49 | 50 | 51 | def get_scheduler(cfg: DictConfig, optimizer: torch.optim.Optimizer) -> Callable: 52 | if cfg.scheduler.kind == 'torch': 53 | Sch = getattr(torch.optim.lr_scheduler, cfg.scheduler.cls) 54 | scheduler = Sch(optimizer=optimizer, **cfg.scheduler.kwargs) 55 | if cfg.scheduler.warmup: 56 | from warmup_scheduler import GradualWarmupScheduler 57 | scheduler = GradualWarmupScheduler( # wrap scheduler with warmup 58 | optimizer, multiplier=1, total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler) 59 | elif cfg.scheduler.kind == 'timm': 60 | from timm.scheduler import create_scheduler 61 | scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs) 62 | elif cfg.scheduler.kind == 'transformers': 63 | from transformers import get_scheduler 64 | scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs) 65 | else: 66 | raise NotImplementedError(f'invalid scheduler config: {cfg.scheduler}') 67 | return scheduler 68 | 69 | 70 | @torch.no_grad() 71 | def accuracy(output, target, topk=(1,)): 72 | """Computes the accuracy over the k top predictions for the specified values of k""" 73 | # reshape 74 | target = target.reshape(-1) 75 | output = output.reshape(target.size(0), -1) 76 | 77 | maxk = max(topk) 78 | batch_size = target.size(0) 79 | 80 | _, pred = output.topk(maxk, 1, True, True) 81 | pred = pred.t() 82 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 83 | 84 | res = [] 85 | for k in topk: 86 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 87 | res.append(correct_k.mul_(100.0 / batch_size)) 88 | return res 89 | 90 | 91 | class SmoothedValue(object): 92 | """Track a series of values and provide access to smoothed values over a 93 | window or the global series average. 94 | """ 95 | 96 | def __init__(self, window_size=20, fmt=None): 97 | if fmt is None: 98 | fmt = "{median:.4f} ({global_avg:.4f})" 99 | self.deque = deque(maxlen=window_size) 100 | self.total = 0.0 101 | self.count = 0 102 | self.fmt = fmt 103 | 104 | def update(self, value, n=1): 105 | self.deque.append(value) 106 | self.count += n 107 | self.total += value * n 108 | 109 | def synchronize_between_processes(self, device='cuda'): 110 | """ 111 | Warning: does not synchronize the deque! 112 | """ 113 | if not using_distributed(): 114 | return 115 | print(f"device={device}") 116 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device) 117 | dist.barrier() 118 | dist.all_reduce(t) 119 | t = t.tolist() 120 | self.count = int(t[0]) 121 | self.total = t[1] 122 | 123 | @property 124 | def median(self): 125 | d = torch.tensor(list(self.deque)) 126 | return d.median().item() 127 | 128 | @property 129 | def avg(self): 130 | d = torch.tensor(list(self.deque), dtype=torch.float32) 131 | return d.mean().item() 132 | 133 | @property 134 | def global_avg(self): 135 | return self.total / self.count 136 | 137 | @property 138 | def max(self): 139 | return max(self.deque) 140 | 141 | @property 142 | def value(self): 143 | return self.deque[-1] 144 | 145 | def __str__(self): 146 | return self.fmt.format( 147 | median=self.median, 148 | avg=self.avg, 149 | global_avg=self.global_avg, 150 | max=self.max, 151 | value=self.value) 152 | 153 | 154 | class MetricLogger(object): 155 | def __init__(self, delimiter="\t"): 156 | self.meters = defaultdict(SmoothedValue) 157 | self.delimiter = delimiter 158 | 159 | def update(self, **kwargs): 160 | n = kwargs.pop('n', 1) 161 | for k, v in kwargs.items(): 162 | if isinstance(v, torch.Tensor): 163 | v = v.item() 164 | assert isinstance(v, (float, int)) 165 | self.meters[k].update(v, n=n) 166 | 167 | def __getattr__(self, attr): 168 | if attr in self.meters: 169 | return self.meters[attr] 170 | if attr in self.__dict__: 171 | return self.__dict__[attr] 172 | raise AttributeError("'{}' object has no attribute '{}'".format( 173 | type(self).__name__, attr)) 174 | 175 | def __str__(self): 176 | loss_str = [] 177 | for name, meter in self.meters.items(): 178 | loss_str.append( 179 | "{}: {}".format(name, str(meter)) 180 | ) 181 | return self.delimiter.join(loss_str) 182 | 183 | def synchronize_between_processes(self, device='cuda'): 184 | for meter in self.meters.values(): 185 | meter.synchronize_between_processes(device=device) 186 | 187 | def add_meter(self, name, meter): 188 | self.meters[name] = meter 189 | 190 | def log_every(self, iterable, print_freq, header=None): 191 | i = 0 192 | if not header: 193 | header = '' 194 | start_time = time.time() 195 | end = time.time() 196 | iter_time = SmoothedValue(fmt='{avg:.4f}') 197 | data_time = SmoothedValue(fmt='{avg:.4f}') 198 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 199 | log_msg = [ 200 | header, 201 | '[{0' + space_fmt + '}/{1}]', 202 | 'eta: {eta}', 203 | '{meters}', 204 | 'time: {time}', 205 | 'data: {data}' 206 | ] 207 | if torch.cuda.is_available(): 208 | log_msg.append('max mem: {memory:.0f}') 209 | log_msg = self.delimiter.join(log_msg) 210 | MB = 1024.0 * 1024.0 211 | for obj in iterable: 212 | data_time.update(time.time() - end) 213 | yield obj 214 | iter_time.update(time.time() - end) 215 | if i % print_freq == 0 or i == len(iterable) - 1: 216 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 217 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 218 | if torch.cuda.is_available(): 219 | print(log_msg.format( 220 | i, len(iterable), eta=eta_string, 221 | meters=str(self), 222 | time=str(iter_time), data=str(data_time), 223 | memory=torch.cuda.max_memory_allocated() / MB)) 224 | else: 225 | print(log_msg.format( 226 | i, len(iterable), eta=eta_string, 227 | meters=str(self), 228 | time=str(iter_time), data=str(data_time))) 229 | i += 1 230 | end = time.time() 231 | total_time = time.time() - start_time 232 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 233 | print('{} Total time: {} ({:.4f} s / it)'.format( 234 | header, total_time_str, total_time / len(iterable))) 235 | 236 | 237 | class NormalizeInverse(torchvision.transforms.Normalize): 238 | """ 239 | Undoes the normalization and returns the reconstructed images in the input domain. 240 | """ 241 | 242 | def __init__(self, mean, std): 243 | mean = torch.as_tensor(mean) 244 | std = torch.as_tensor(std) 245 | std_inv = 1 / (std + 1e-7) 246 | mean_inv = -mean * std_inv 247 | super().__init__(mean=mean_inv, std=std_inv) 248 | 249 | def __call__(self, tensor): 250 | return super().__call__(tensor.clone()) 251 | 252 | 253 | def set_requires_grad(module, requires_grad=True): 254 | for p in module.parameters(): 255 | p.requires_grad = requires_grad 256 | 257 | 258 | def resume_from_checkpoint(cfg, model, optimizer=None, scheduler=None, model_ema=None): 259 | 260 | # Resume model state dict 261 | checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu') 262 | if 'model' in checkpoint: 263 | state_dict, key = checkpoint['model'], 'model' 264 | else: 265 | state_dict, key = checkpoint, 'N/A' 266 | if any(k.startswith('module.') for k in state_dict.keys()): 267 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 268 | print('Removed "module." from checkpoint state dict') 269 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 270 | print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}') 271 | if len(missing_keys): 272 | print(f' - Missing_keys: {missing_keys}') 273 | if len(unexpected_keys): 274 | print(f' - Unexpected_keys: {unexpected_keys}') 275 | # Resume model ema 276 | if cfg.ema.use_ema: 277 | if checkpoint['model_ema']: 278 | model_ema.load_state_dict(checkpoint['model_ema']) 279 | print('Loaded model ema from checkpoint') 280 | else: 281 | model_ema.load_state_dict(model.parameters()) 282 | print('No model ema in checkpoint; loaded current parameters into model') 283 | else: 284 | if 'model_ema' in checkpoint: 285 | print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)') 286 | else: 287 | print('Not using model ema, and no model_ema found in checkpoint.') 288 | 289 | # Resume optimization state 290 | if cfg.checkpoint.resume_training and 'train' in cfg.job_type: 291 | if 'steps' in checkpoint: 292 | checkpoint['step'] = checkpoint['steps'] 293 | assert {'optimizer', 'scheduler', 'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys())) 294 | optimizer.load_state_dict(checkpoint['optimizer']) 295 | scheduler.load_state_dict(checkpoint['scheduler']) 296 | epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val'] 297 | train_state = TrainState(epoch=epoch, step=step, best_val=best_val) 298 | print(f'Loaded optimizer/scheduler at epoch {epoch} from checkpoint') 299 | elif cfg.checkpoint.resume_optimizer_only: 300 | assert 'optimizer' in set(checkpoint.keys()) 301 | optimizer.load_state_dict(checkpoint['optimizer']) 302 | print(f'Loaded optimizer from checkpoint, but did not load scheduler/epoch') 303 | else: 304 | train_state = TrainState() 305 | print('Did not resume training (i.e. optimizer/scheduler/epoch)') 306 | 307 | return train_state 308 | 309 | 310 | def setup_distributed_print(is_master): 311 | """ 312 | This function disables printing when not in master process 313 | """ 314 | import builtins as __builtin__ 315 | builtin_print = __builtin__.print 316 | 317 | def print(*args, **kwargs): 318 | force = kwargs.pop('force', False) 319 | if is_master or force: 320 | builtin_print(*args, **kwargs) 321 | 322 | __builtin__.print = print 323 | 324 | 325 | def using_distributed(): 326 | return dist.is_available() and dist.is_initialized() 327 | 328 | 329 | def get_rank(): 330 | return dist.get_rank() if using_distributed() else 0 331 | 332 | 333 | def set_seed(seed): 334 | rank = get_rank() 335 | seed = seed + rank 336 | torch.manual_seed(seed) 337 | torch.cuda.manual_seed(seed) 338 | np.random.seed(seed) 339 | random.seed(seed) 340 | torch.backends.cudnn.enabled = True 341 | torch.backends.cudnn.benchmark = True 342 | if using_distributed(): 343 | print(f'Seeding node {rank} with seed {seed}', force=True) 344 | else: 345 | print(f'Seeding node {rank} with seed {seed}') 346 | 347 | 348 | def tensor_to_pil(image: torch.Tensor): 349 | assert len(image.shape) and image.shape[0] == 3, f"{image.shape=}" 350 | image = (image.float() * 0.5 + 0.5).clamp(0, 1).detach().cpu().requires_grad_(False) 351 | ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 352 | return Image.fromarray(ndarr) 353 | 354 | 355 | def albumentations_to_torch(transform): 356 | def _transform(img, target): 357 | augmented = transform(image=img, mask=target) 358 | return augmented['image'], augmented['mask'] 359 | return _transform 360 | -------------------------------------------------------------------------------- /semantic-segmentation/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers, mostly from torchvision 3 | """ 4 | import time 5 | import datetime 6 | import random 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torchvision 11 | from dataclasses import dataclass 12 | from collections import defaultdict, deque 13 | from typing import Callable, Optional 14 | from PIL import Image 15 | from accelerate import Accelerator 16 | from omegaconf import DictConfig 17 | 18 | 19 | @dataclass 20 | class TrainState: 21 | epoch: int = 0 22 | step: int = 0 23 | best_val: Optional[float] = None 24 | 25 | 26 | def get_optimizer(cfg: DictConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer: 27 | # Determine the learning rate 28 | if cfg.optimizer.scale_learning_rate_with_batch_size: 29 | lr = accelerator.state.num_processes * cfg.data.loader.batch_size * cfg.optimizer.base_lr 30 | print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format( 31 | ws=accelerator.state.num_processes, bs=cfg.data.loader.batch_size, blr=cfg.lr, lr=lr)) 32 | else: # scale base learning rate by batch size 33 | lr = cfg.lr 34 | print('lr = {lr} (absolute learning rate)'.format(lr=lr)) 35 | # Construct optimizer 36 | if cfg.optimizer.kind == 'torch': 37 | parameters = [p for p in model.parameters() if p.requires_grad] 38 | optimizer = getattr(torch.optim, cfg.optimizer.cls)(parameters, lr=lr, **cfg.optimizer.kwargs) 39 | elif cfg.optimizer.kind == 'timm': 40 | from timm.optim import create_optimizer_v2 41 | optimizer = create_optimizer_v2(model, lr=lr, **cfg.optimizer.kwargs) 42 | elif cfg.optimizer.kind == 'transformers': 43 | import transformers 44 | parameters = [p for p in model.parameters() if p.requires_grad] 45 | optimizer = getattr(transformers, cfg.optimizer.name)(parameters, lr=lr, **cfg.optimizer.kwargs) 46 | else: 47 | raise NotImplementedError(f'invalid optimizer config: {cfg.optimizer}') 48 | return optimizer 49 | 50 | 51 | def get_scheduler(cfg: DictConfig, optimizer: torch.optim.Optimizer) -> Callable: 52 | if cfg.scheduler.kind == 'torch': 53 | Sch = getattr(torch.optim.lr_scheduler, cfg.scheduler.cls) 54 | scheduler = Sch(optimizer=optimizer, **cfg.scheduler.kwargs) 55 | if cfg.scheduler.warmup: 56 | from warmup_scheduler import GradualWarmupScheduler 57 | scheduler = GradualWarmupScheduler( # wrap scheduler with warmup 58 | optimizer, multiplier=1, total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler) 59 | elif cfg.scheduler.kind == 'timm': 60 | from timm.scheduler import create_scheduler 61 | scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs) 62 | elif cfg.scheduler.kind == 'transformers': 63 | from transformers import get_scheduler 64 | scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs) 65 | else: 66 | raise NotImplementedError(f'invalid scheduler config: {cfg.scheduler}') 67 | return scheduler 68 | 69 | 70 | @torch.no_grad() 71 | def accuracy(output, target, topk=(1,)): 72 | """Computes the accuracy over the k top predictions for the specified values of k""" 73 | # reshape 74 | target = target.reshape(-1) 75 | output = output.reshape(target.size(0), -1) 76 | 77 | maxk = max(topk) 78 | batch_size = target.size(0) 79 | 80 | _, pred = output.topk(maxk, 1, True, True) 81 | pred = pred.t() 82 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 83 | 84 | res = [] 85 | for k in topk: 86 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 87 | res.append(correct_k.mul_(100.0 / batch_size)) 88 | return res 89 | 90 | 91 | class SmoothedValue(object): 92 | """Track a series of values and provide access to smoothed values over a 93 | window or the global series average. 94 | """ 95 | 96 | def __init__(self, window_size=20, fmt=None): 97 | if fmt is None: 98 | fmt = "{median:.4f} ({global_avg:.4f})" 99 | self.deque = deque(maxlen=window_size) 100 | self.total = 0.0 101 | self.count = 0 102 | self.fmt = fmt 103 | 104 | def update(self, value, n=1): 105 | self.deque.append(value) 106 | self.count += n 107 | self.total += value * n 108 | 109 | def synchronize_between_processes(self, device='cuda'): 110 | """ 111 | Warning: does not synchronize the deque! 112 | """ 113 | if not using_distributed(): 114 | return 115 | print(f"device={device}") 116 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device) 117 | dist.barrier() 118 | dist.all_reduce(t) 119 | t = t.tolist() 120 | self.count = int(t[0]) 121 | self.total = t[1] 122 | 123 | @property 124 | def median(self): 125 | d = torch.tensor(list(self.deque)) 126 | return d.median().item() 127 | 128 | @property 129 | def avg(self): 130 | d = torch.tensor(list(self.deque), dtype=torch.float32) 131 | return d.mean().item() 132 | 133 | @property 134 | def global_avg(self): 135 | return self.total / self.count 136 | 137 | @property 138 | def max(self): 139 | return max(self.deque) 140 | 141 | @property 142 | def value(self): 143 | return self.deque[-1] 144 | 145 | def __str__(self): 146 | return self.fmt.format( 147 | median=self.median, 148 | avg=self.avg, 149 | global_avg=self.global_avg, 150 | max=self.max, 151 | value=self.value) 152 | 153 | 154 | class MetricLogger(object): 155 | def __init__(self, delimiter="\t"): 156 | self.meters = defaultdict(SmoothedValue) 157 | self.delimiter = delimiter 158 | 159 | def update(self, **kwargs): 160 | n = kwargs.pop('n', 1) 161 | for k, v in kwargs.items(): 162 | if isinstance(v, torch.Tensor): 163 | v = v.item() 164 | assert isinstance(v, (float, int)) 165 | self.meters[k].update(v, n=n) 166 | 167 | def __getattr__(self, attr): 168 | if attr in self.meters: 169 | return self.meters[attr] 170 | if attr in self.__dict__: 171 | return self.__dict__[attr] 172 | raise AttributeError("'{}' object has no attribute '{}'".format( 173 | type(self).__name__, attr)) 174 | 175 | def __str__(self): 176 | loss_str = [] 177 | for name, meter in self.meters.items(): 178 | loss_str.append( 179 | "{}: {}".format(name, str(meter)) 180 | ) 181 | return self.delimiter.join(loss_str) 182 | 183 | def synchronize_between_processes(self, device='cuda'): 184 | for meter in self.meters.values(): 185 | meter.synchronize_between_processes(device=device) 186 | 187 | def add_meter(self, name, meter): 188 | self.meters[name] = meter 189 | 190 | def log_every(self, iterable, print_freq, header=None): 191 | i = 0 192 | if not header: 193 | header = '' 194 | start_time = time.time() 195 | end = time.time() 196 | iter_time = SmoothedValue(fmt='{avg:.4f}') 197 | data_time = SmoothedValue(fmt='{avg:.4f}') 198 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 199 | log_msg = [ 200 | header, 201 | '[{0' + space_fmt + '}/{1}]', 202 | 'eta: {eta}', 203 | '{meters}', 204 | 'time: {time}', 205 | 'data: {data}' 206 | ] 207 | if torch.cuda.is_available(): 208 | log_msg.append('max mem: {memory:.0f}') 209 | log_msg = self.delimiter.join(log_msg) 210 | MB = 1024.0 * 1024.0 211 | for obj in iterable: 212 | data_time.update(time.time() - end) 213 | yield obj 214 | iter_time.update(time.time() - end) 215 | if i % print_freq == 0 or i == len(iterable) - 1: 216 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 217 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 218 | if torch.cuda.is_available(): 219 | print(log_msg.format( 220 | i, len(iterable), eta=eta_string, 221 | meters=str(self), 222 | time=str(iter_time), data=str(data_time), 223 | memory=torch.cuda.max_memory_allocated() / MB)) 224 | else: 225 | print(log_msg.format( 226 | i, len(iterable), eta=eta_string, 227 | meters=str(self), 228 | time=str(iter_time), data=str(data_time))) 229 | i += 1 230 | end = time.time() 231 | total_time = time.time() - start_time 232 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 233 | print('{} Total time: {} ({:.4f} s / it)'.format( 234 | header, total_time_str, total_time / len(iterable))) 235 | 236 | 237 | class NormalizeInverse(torchvision.transforms.Normalize): 238 | """ 239 | Undoes the normalization and returns the reconstructed images in the input domain. 240 | """ 241 | 242 | def __init__(self, mean, std): 243 | mean = torch.as_tensor(mean) 244 | std = torch.as_tensor(std) 245 | std_inv = 1 / (std + 1e-7) 246 | mean_inv = -mean * std_inv 247 | super().__init__(mean=mean_inv, std=std_inv) 248 | 249 | def __call__(self, tensor): 250 | return super().__call__(tensor.clone()) 251 | 252 | 253 | def set_requires_grad(module, requires_grad=True): 254 | for p in module.parameters(): 255 | p.requires_grad = requires_grad 256 | 257 | 258 | def resume_from_checkpoint(cfg, model, optimizer=None, scheduler=None, model_ema=None): 259 | 260 | # Resume model state dict 261 | checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu') 262 | if 'model' in checkpoint: 263 | state_dict, key = checkpoint['model'], 'model' 264 | else: 265 | state_dict, key = checkpoint, 'N/A' 266 | if any(k.startswith('module.') for k in state_dict.keys()): 267 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 268 | print('Removed "module." from checkpoint state dict') 269 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 270 | print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}') 271 | if len(missing_keys): 272 | print(f' - Missing_keys: {missing_keys}') 273 | if len(unexpected_keys): 274 | print(f' - Unexpected_keys: {unexpected_keys}') 275 | # Resume model ema 276 | if cfg.ema.use_ema: 277 | if checkpoint['model_ema']: 278 | model_ema.load_state_dict(checkpoint['model_ema']) 279 | print('Loaded model ema from checkpoint') 280 | else: 281 | model_ema.load_state_dict(model.parameters()) 282 | print('No model ema in checkpoint; loaded current parameters into model') 283 | else: 284 | if 'model_ema' in checkpoint: 285 | print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)') 286 | else: 287 | print('Not using model ema, and no model_ema found in checkpoint.') 288 | 289 | # Resume optimization state 290 | if cfg.checkpoint.resume_training and 'train' in cfg.job_type: 291 | if 'steps' in checkpoint: 292 | checkpoint['step'] = checkpoint['steps'] 293 | assert {'optimizer', 'scheduler', 'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys())) 294 | optimizer.load_state_dict(checkpoint['optimizer']) 295 | scheduler.load_state_dict(checkpoint['scheduler']) 296 | epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val'] 297 | train_state = TrainState(epoch=epoch, step=step, best_val=best_val) 298 | print(f'Loaded optimizer/scheduler at epoch {epoch} from checkpoint') 299 | elif cfg.checkpoint.resume_optimizer_only: 300 | assert 'optimizer' in set(checkpoint.keys()) 301 | optimizer.load_state_dict(checkpoint['optimizer']) 302 | print(f'Loaded optimizer from checkpoint, but did not load scheduler/epoch') 303 | else: 304 | train_state = TrainState() 305 | print('Did not resume training (i.e. optimizer/scheduler/epoch)') 306 | 307 | return train_state 308 | 309 | 310 | def setup_distributed_print(is_master): 311 | """ 312 | This function disables printing when not in master process 313 | """ 314 | import builtins as __builtin__ 315 | builtin_print = __builtin__.print 316 | 317 | def print(*args, **kwargs): 318 | force = kwargs.pop('force', False) 319 | if is_master or force: 320 | builtin_print(*args, **kwargs) 321 | 322 | __builtin__.print = print 323 | 324 | 325 | def using_distributed(): 326 | return dist.is_available() and dist.is_initialized() 327 | 328 | 329 | def get_rank(): 330 | return dist.get_rank() if using_distributed() else 0 331 | 332 | 333 | def set_seed(seed): 334 | rank = get_rank() 335 | seed = seed + rank 336 | torch.manual_seed(seed) 337 | torch.cuda.manual_seed(seed) 338 | np.random.seed(seed) 339 | random.seed(seed) 340 | torch.backends.cudnn.enabled = True 341 | torch.backends.cudnn.benchmark = True 342 | if using_distributed(): 343 | print(f'Seeding node {rank} with seed {seed}', force=True) 344 | else: 345 | print(f'Seeding node {rank} with seed {seed}') 346 | 347 | 348 | def tensor_to_pil(image: torch.Tensor): 349 | assert len(image.shape) and image.shape[0] == 3, f"{image.shape=}" 350 | image = (image.float() * 0.5 + 0.5).clamp(0, 1).detach().cpu().requires_grad_(False) 351 | ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 352 | return Image.fromarray(ndarr) 353 | 354 | 355 | def albumentations_to_torch(transform): 356 | def _transform(img, target): 357 | augmented = transform(image=img, mask=target) 358 | return augmented['image'], augmented['mask'] 359 | return _transform 360 | -------------------------------------------------------------------------------- /object-localization/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datasets file. Code adapted from LOST: https://github.com/valeoai/LOST 3 | """ 4 | import os 5 | import torch 6 | import json 7 | import torchvision 8 | import numpy as np 9 | import skimage.io 10 | 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from torchvision import transforms as pth_transforms 14 | from traitlets.traitlets import default 15 | 16 | 17 | class ImageDataset: 18 | def __init__(self, image_path, transform): 19 | 20 | self.image_path = image_path 21 | self.transform = transform 22 | self.name = image_path.split("/")[-1] 23 | 24 | # Read the image 25 | with open(image_path, "rb") as f: 26 | img = Image.open(f) 27 | img = img.convert("RGB") 28 | 29 | # Build a dataloader 30 | img = self.transform(img) 31 | self.dataloader = [[img, image_path]] 32 | 33 | def get_image_name(self, *args, **kwargs): 34 | return self.image_path.split("/")[-1].split(".")[0] 35 | 36 | def load_image(self, *args, **kwargs): 37 | return skimage.io.imread(self.image_path) 38 | 39 | 40 | class Dataset: 41 | def __init__(self, dataset_name, dataset_set, remove_hards, transform): 42 | """ 43 | Build the dataloader 44 | """ 45 | 46 | self.dataset_name = dataset_name 47 | self.set = dataset_set 48 | self.transform = transform 49 | 50 | if dataset_name == "VOC07": 51 | self.root_path = "datasets/VOC2007" 52 | self.year = "2007" 53 | elif dataset_name == "VOC12": 54 | self.root_path = "datasets/VOC2012" 55 | self.year = "2012" 56 | elif dataset_name == "COCO20k": 57 | self.year = "2014" 58 | self.root_path = f"datasets/COCO/images/{dataset_set}{self.year}" 59 | self.sel20k = 'datasets/coco_20k_filenames.txt' 60 | # JSON file constructed based on COCO train2014 gt 61 | self.all_annfile = "datasets/COCO/annotations/instances_train2014.json" 62 | self.annfile = "datasets/instances_train2014_sel20k.json" 63 | if not os.path.exists(self.annfile): 64 | select_coco_20k(self.sel20k, self.all_annfile) 65 | else: 66 | raise ValueError("Unknown dataset.") 67 | 68 | if not os.path.exists(self.root_path): 69 | print(self.root_path) 70 | raise ValueError("Please follow the README to setup the datasets.") 71 | 72 | self.name = f"{self.dataset_name}_{self.set}" 73 | 74 | # Build the dataloader 75 | if "VOC" in dataset_name: 76 | self.dataloader = torchvision.datasets.VOCDetection( 77 | self.root_path, 78 | year=self.year, 79 | image_set=self.set, 80 | transform=self.transform, 81 | download=False, 82 | ) 83 | elif "COCO20k" == dataset_name: 84 | self.dataloader = torchvision.datasets.CocoDetection( 85 | self.root_path, annFile=self.annfile, transform=self.transform 86 | ) 87 | else: 88 | raise ValueError("Unknown dataset.") 89 | 90 | # Set hards images that are not included 91 | self.remove_hards = remove_hards 92 | self.hards = [] 93 | if remove_hards: 94 | self.name += f"-nohards" 95 | self.hards = self.get_hards() 96 | print(f"Nb images discarded {len(self.hards)}") 97 | 98 | def load_image(self, im_name): 99 | """ 100 | Load the image corresponding to the im_name 101 | """ 102 | if "VOC" in self.dataset_name: 103 | image = skimage.io.imread(f"/datasets_local/VOC{self.year}/JPEGImages/{im_name}") 104 | elif "COCO" in self.dataset_name: 105 | im_path = self.path_20k[self.sel_20k.index(im_name)] 106 | image = skimage.io.imread(f"/datasets_local/COCO/images/{im_path}") 107 | else: 108 | raise ValueError("Unkown dataset.") 109 | return image 110 | 111 | def get_image_name(self, inp): 112 | """ 113 | Return the image name 114 | """ 115 | if "VOC" in self.dataset_name: 116 | im_name = inp["annotation"]["filename"] 117 | elif "COCO" in self.dataset_name: 118 | im_name = str(inp[0]["image_id"]) 119 | 120 | return im_name 121 | 122 | def extract_gt(self, targets, im_name): 123 | if "VOC" in self.dataset_name: 124 | return extract_gt_VOC(targets, remove_hards=self.remove_hards) 125 | elif "COCO" in self.dataset_name: 126 | return extract_gt_COCO(targets, remove_iscrowd=True) 127 | else: 128 | raise ValueError("Unknown dataset") 129 | 130 | def extract_classes(self): 131 | if "VOC" in self.dataset_name: 132 | cls_path = f"classes_{self.set}_{self.year}.txt" 133 | elif "COCO" in self.dataset_name: 134 | cls_path = f"classes_{self.dataset}_{self.set}_{self.year}.txt" 135 | 136 | # Load if exists 137 | if os.path.exists(cls_path): 138 | all_classes = [] 139 | with open(cls_path, "r") as f: 140 | for line in f: 141 | all_classes.append(line.strip()) 142 | else: 143 | print("Extract all classes from the dataset") 144 | if "VOC" in self.dataset_name: 145 | all_classes = self.extract_classes_VOC() 146 | elif "COCO" in self.dataset_name: 147 | all_classes = self.extract_classes_COCO() 148 | 149 | with open(cls_path, "w") as f: 150 | for s in all_classes: 151 | f.write(str(s) + "\n") 152 | 153 | return all_classes 154 | 155 | def extract_classes_VOC(self): 156 | all_classes = [] 157 | for im_id, inp in enumerate(tqdm(self.dataloader)): 158 | objects = inp[1]["annotation"]["object"] 159 | 160 | for o in range(len(objects)): 161 | if objects[o]["name"] not in all_classes: 162 | all_classes.append(objects[o]["name"]) 163 | 164 | return all_classes 165 | 166 | def extract_classes_COCO(self): 167 | all_classes = [] 168 | for im_id, inp in enumerate(tqdm(self.dataloader)): 169 | objects = inp[1] 170 | 171 | for o in range(len(objects)): 172 | if objects[o]["category_id"] not in all_classes: 173 | all_classes.append(objects[o]["category_id"]) 174 | 175 | return all_classes 176 | 177 | def get_hards(self): 178 | hard_path = "datasets/hard_%s_%s_%s.txt" % (self.dataset_name, self.set, self.year) 179 | if os.path.exists(hard_path): 180 | hards = [] 181 | with open(hard_path, "r") as f: 182 | for line in f: 183 | hards.append(int(line.strip())) 184 | else: 185 | print("Discover hard images that should be discarded") 186 | 187 | if "VOC" in self.dataset_name: 188 | # set the hards 189 | hards = discard_hard_voc(self.dataloader) 190 | 191 | with open(hard_path, "w") as f: 192 | for s in hards: 193 | f.write(str(s) + "\n") 194 | 195 | return hards 196 | 197 | 198 | def discard_hard_voc(dataloader): 199 | hards = [] 200 | for im_id, inp in enumerate(tqdm(dataloader)): 201 | objects = inp[1]["annotation"]["object"] 202 | nb_obj = len(objects) 203 | 204 | hard = np.zeros(nb_obj) 205 | for i, o in enumerate(range(nb_obj)): 206 | hard[i] = ( 207 | 1 208 | if (objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1") 209 | else 0 210 | ) 211 | 212 | # all images with only truncated or difficult objects 213 | if np.sum(hard) == nb_obj: 214 | hards.append(im_id) 215 | return hards 216 | 217 | 218 | def extract_gt_COCO(targets, remove_iscrowd=True): 219 | objects = targets 220 | nb_obj = len(objects) 221 | 222 | gt_bbxs = [] 223 | gt_clss = [] 224 | for o in range(nb_obj): 225 | # Remove iscrowd boxes 226 | if remove_iscrowd and objects[o]["iscrowd"] == 1: 227 | continue 228 | gt_cls = objects[o]["category_id"] 229 | gt_clss.append(gt_cls) 230 | bbx = objects[o]["bbox"] 231 | x1y1x2y2 = [bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]] 232 | x1y1x2y2 = [int(round(x)) for x in x1y1x2y2] 233 | gt_bbxs.append(x1y1x2y2) 234 | 235 | return np.asarray(gt_bbxs), gt_clss 236 | 237 | 238 | def extract_gt_VOC(targets, remove_hards=False): 239 | objects = targets["annotation"]["object"] 240 | nb_obj = len(objects) 241 | 242 | gt_bbxs = [] 243 | gt_clss = [] 244 | for o in range(nb_obj): 245 | if remove_hards and ( 246 | objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1" 247 | ): 248 | continue 249 | gt_cls = objects[o]["name"] 250 | gt_clss.append(gt_cls) 251 | obj = objects[o]["bndbox"] 252 | x1y1x2y2 = [ 253 | int(obj["xmin"]), 254 | int(obj["ymin"]), 255 | int(obj["xmax"]), 256 | int(obj["ymax"]), 257 | ] 258 | # Original annotations are integers in the range [1, W or H] 259 | # Assuming they mean 1-based pixel indices (inclusive), 260 | # a box with annotation (xmin=1, xmax=W) covers the whole image. 261 | # In coordinate space this is represented by (xmin=0, xmax=W) 262 | x1y1x2y2[0] -= 1 263 | x1y1x2y2[1] -= 1 264 | gt_bbxs.append(x1y1x2y2) 265 | 266 | return np.asarray(gt_bbxs), gt_clss 267 | 268 | 269 | def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): 270 | # https://github.com/ultralytics/yolov5/blob/develop/utils/general.py 271 | # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 272 | box2 = box2.T 273 | 274 | # Get the coordinates of bounding boxes 275 | if x1y1x2y2: # x1, y1, x2, y2 = box1 276 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] 277 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] 278 | else: # transform from xywh to xyxy 279 | b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 280 | b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 281 | b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 282 | b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 283 | 284 | # Intersection area 285 | inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * ( 286 | torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1) 287 | ).clamp(0) 288 | 289 | # Union Area 290 | w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps 291 | w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps 292 | union = w1 * h1 + w2 * h2 - inter + eps 293 | 294 | iou = inter / union 295 | if GIoU or DIoU or CIoU: 296 | cw = torch.max(b1_x2, b2_x2) - torch.min( 297 | b1_x1, b2_x1 298 | ) # convex (smallest enclosing box) width 299 | ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height 300 | if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 301 | c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared 302 | rho2 = ( 303 | (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + 304 | (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2 305 | ) / 4 # center distance squared 306 | if DIoU: 307 | return iou - rho2 / c2 # DIoU 308 | elif ( 309 | CIoU 310 | ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 311 | v = (4 / math.pi ** 2) * torch.pow( 312 | torch.atan(w2 / h2) - torch.atan(w1 / h1), 2 313 | ) 314 | with torch.no_grad(): 315 | alpha = v / (v - iou + (1 + eps)) 316 | return iou - (rho2 / c2 + v * alpha) # CIoU 317 | else: # GIoU https://arxiv.org/pdf/1902.09630.pdf 318 | c_area = cw * ch + eps # convex area 319 | return iou - (c_area - union) / c_area # GIoU 320 | else: 321 | return iou # IoU 322 | 323 | 324 | def select_coco_20k(sel_file, all_annotations_file): 325 | print('Building COCO 20k dataset.') 326 | 327 | # load all annotations 328 | with open(all_annotations_file, "r") as f: 329 | train2014 = json.load(f) 330 | 331 | # load selected images 332 | with open(sel_file, "r") as f: 333 | sel_20k = f.readlines() 334 | sel_20k = [s.replace("\n", "") for s in sel_20k] 335 | im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in sel_20k] 336 | 337 | # # OLD 338 | # new_anno = [] 339 | # new_images = [] 340 | # for i in tqdm(im20k): 341 | # new_anno.extend( 342 | # [a for a in train2014["annotations"] if a["image_id"] == int(i)] 343 | # ) 344 | # new_images.extend([a for a in train2014["images"] if a["id"] == int(i)]) 345 | 346 | # NEW 347 | from collections import defaultdict 348 | id_to_ann = defaultdict(list) # image_id --> [annotations] 349 | id_to_img = defaultdict(list) # image_id --> [images] 350 | for a in tqdm(train2014["annotations"]): 351 | id_to_ann[a["image_id"]].append(a) 352 | for im in tqdm(train2014["images"]): 353 | id_to_img[im["id"]].append(a) 354 | new_anno = [id_to_ann[int(id)] for id in im20k] 355 | new_images = [id_to_img[int(id)] for id in im20k] 356 | print(len(im20k)) 357 | print(len(new_anno)) 358 | print(len(new_images)) 359 | 360 | train2014_20k = {} 361 | train2014_20k["images"] = new_images 362 | train2014_20k["annotations"] = new_anno 363 | train2014_20k["categories"] = train2014["categories"] 364 | 365 | with open("datasets/instances_train2014_sel20k.json", "w") as outfile: 366 | json.dump(train2014_20k, outfile) 367 | 368 | print('Done.') 369 | -------------------------------------------------------------------------------- /semantic-segmentation/train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import os 4 | import sys 5 | from contextlib import nullcontext 6 | from pathlib import Path 7 | from typing import Callable, Iterable, Optional 8 | 9 | import hydra 10 | import numpy as np 11 | import torch 12 | import wandb 13 | from accelerate import Accelerator 14 | from omegaconf import DictConfig, OmegaConf 15 | from PIL import Image 16 | from torch.nn import functional as F 17 | from torch.utils.data import DataLoader 18 | from tqdm import tqdm 19 | 20 | import util as utils 21 | from dataset import get_datasets 22 | from model import get_model 23 | 24 | 25 | @hydra.main(config_path='config', config_name='train') 26 | def main(cfg: DictConfig): 27 | 28 | # Accelerator 29 | accelerator = Accelerator(fp16=cfg.fp16, cpu=cfg.cpu) 30 | 31 | # Logging 32 | utils.setup_distributed_print(accelerator.is_local_main_process) 33 | if cfg.wandb and accelerator.is_local_main_process: 34 | wandb.init(name=cfg.name, job_type=cfg.job_type, config=OmegaConf.to_container(cfg), save_code=True, **cfg.wandb_kwargs) 35 | cfg = DictConfig(wandb.config.as_dict()) # get the config back from wandb for hyperparameter sweeps 36 | 37 | # Configuration 38 | print(OmegaConf.to_yaml(cfg)) 39 | print(f'Current working directory: {os.getcwd()}') 40 | 41 | # Set random seed 42 | utils.set_seed(cfg.seed) 43 | 44 | # Create model 45 | model = get_model(**cfg.model) 46 | 47 | # Freeze layers, if desired 48 | if cfg.unfrozen_backbone_layers >= 0: 49 | num_unfrozen = None if (cfg.unfrozen_backbone_layers == 0) else (-cfg.unfrozen_backbone_layers) 50 | for module in list(model.backbone.children())[:num_unfrozen]: 51 | for p in module.parameters(): 52 | p.requires_grad_(False) 53 | 54 | print(f'Parameters (total): {sum(p.numel() for p in model.parameters()):_d}') 55 | print(f'Parameters (train): {sum(p.numel() for p in model.parameters() if p.requires_grad):_d}') 56 | print(f'Backbone parameters (total): {sum(p.numel() for p in model.backbone.parameters()):_d}') 57 | print(f'Backbone parameters (train): {sum(p.numel() for p in model.backbone.parameters() if p.requires_grad):_d}') 58 | 59 | # Optimizer and scheduler 60 | optimizer = utils.get_optimizer(cfg, model, accelerator) 61 | scheduler = utils.get_scheduler(cfg, optimizer) 62 | 63 | # Resume from checkpoint and create the initial training state 64 | if cfg.checkpoint.resume: 65 | train_state: utils.TrainState = utils.resume_from_checkpoint(cfg, model, optimizer, scheduler, model_ema=None) 66 | else: 67 | train_state = utils.TrainState() # start training from scratch 68 | 69 | # Data 70 | dataset_train, dataset_val, collate_fn = get_datasets(cfg) 71 | dataloader_train = DataLoader(dataset_train, shuffle=True, drop_last=True, 72 | collate_fn=collate_fn, **cfg.data.loader) 73 | dataloader_val = DataLoader(dataset_val, shuffle=False, drop_last=False, 74 | collate_fn=collate_fn, **{**cfg.data.loader, 'batch_size': 1}) 75 | total_batch_size = cfg.data.loader.batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps 76 | 77 | # SyncBatchNorm 78 | if accelerator.num_processes > 1: 79 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 80 | 81 | # Setup 82 | model, optimizer, dataloader_train = accelerator.prepare(model, optimizer, dataloader_train) 83 | 84 | # Exponential moving average of model parameters 85 | if cfg.ema.use_ema: 86 | from torch_ema import ExponentialMovingAverage 87 | model_ema = ExponentialMovingAverage((p for p in model.parameters() if p.requires_grad), decay=cfg.ema.decay) 88 | print('Initialized model EMA') 89 | else: 90 | model_ema = None 91 | print('Not using model EMA') 92 | 93 | # Shared training, evaluation, and visualization args 94 | kwargs = dict( 95 | cfg=cfg, 96 | model=model, 97 | dataloader_train=dataloader_train, 98 | dataloader_val=dataloader_val, 99 | optimizer=optimizer, 100 | scheduler=scheduler, 101 | accelerator=accelerator, 102 | model_ema=model_ema, 103 | train_state=train_state) 104 | 105 | # Evaluation 106 | if cfg.job_type == 'generate': 107 | test_stats = generate(**kwargs) 108 | return 0 109 | 110 | # Evaluation 111 | if cfg.job_type == 'eval': 112 | test_stats = evaluate(**kwargs) 113 | return test_stats['val_loss'] 114 | 115 | # Info 116 | print(f'***** Starting training at {datetime.datetime.now()} *****') 117 | print(f' Dataset train size: {len(dataset_train):_}') 118 | print(f' Dataset val size: {len(dataset_val):_}') 119 | print(f' Dataloader train size: {len(dataloader_train):_}') 120 | print(f' Dataloader val size: {len(dataloader_val):_}') 121 | print(f' Batch size per device = {cfg.data.loader.batch_size}') 122 | print(f' Total train batch size (w. parallel, dist & accum) = {total_batch_size}') 123 | print(f' Gradient Accumulation steps = {cfg.gradient_accumulation_steps}') 124 | print(f' Max optimization steps = {cfg.max_train_steps}') 125 | print(f' Max optimization epochs = {cfg.max_train_epochs}') 126 | print(f' Training state = {train_state}') 127 | 128 | # Evaluate masks before training 129 | if cfg.get('eval_masks_before_training', True): 130 | print('Evaluating masks before training...') 131 | if accelerator.is_main_process: 132 | evaluate(**kwargs, evaluate_dataset_pseudolabels=True) # <-- to evaluate the self-training masks 133 | torch.cuda.synchronize() 134 | 135 | # Training loop 136 | while True: 137 | 138 | # Single epoch of training 139 | train_state = train_one_epoch(**kwargs) 140 | 141 | # Save checkpoint on only 1 process 142 | if accelerator.is_local_main_process: 143 | checkpoint_dict = { 144 | 'model': accelerator.unwrap_model(model).state_dict(), 145 | 'optimizer': optimizer.state_dict(), 146 | 'scheduler': scheduler.state_dict(), 147 | 'epoch': train_state.epoch, 148 | 'step': train_state.step, 149 | 'best_val': train_state.best_val, 150 | 'model_ema': model_ema.state_dict() if model_ema else {}, 151 | 'cfg': cfg 152 | } 153 | print(f'Saved checkpoint to {str(Path(".").resolve())}') 154 | accelerator.save(checkpoint_dict, 'checkpoint-latest.pth') 155 | if (train_state.epoch > 0) and (train_state.epoch % cfg.checkpoint_every == 0): 156 | accelerator.save(checkpoint_dict, f'checkpoint-{train_state.epoch:04d}.pth') 157 | 158 | # Evaluate 159 | if train_state.epoch % cfg.get('eval_every', 1) == 0: 160 | test_stats = evaluate(**kwargs) 161 | if accelerator.is_local_main_process: 162 | if (train_state.best_val is None) or (test_stats['mIoU'] > train_state.best_val): 163 | train_state.best_val = test_stats['mIoU'] 164 | torch.save(checkpoint_dict, 'checkpoint-best.pth') 165 | if cfg.wandb: 166 | wandb.log(test_stats) 167 | wandb.run.summary["best_mIoU"] = train_state.best_val 168 | 169 | # End training 170 | if ((cfg.max_train_steps is not None and train_state.step >= cfg.max_train_steps) or 171 | (cfg.max_train_epochs is not None and train_state.epoch >= cfg.max_train_epochs)): 172 | print(f'Ending training at: {datetime.datetime.now()}') 173 | print(f'Final train state: {train_state}') 174 | sys.exit() 175 | 176 | 177 | def train_one_epoch( 178 | *, 179 | cfg: DictConfig, 180 | model: torch.nn.Module, 181 | dataloader_train: Iterable, 182 | optimizer: torch.optim.Optimizer, 183 | accelerator: Accelerator, 184 | scheduler: Callable, 185 | train_state: utils.TrainState, 186 | model_ema: Optional[object] = None, 187 | **_unused_kwargs 188 | ): 189 | 190 | # Train mode 191 | model.train() 192 | log_header = f'Epoch: [{train_state.epoch}]' 193 | metric_logger = utils.MetricLogger(delimiter=" ") 194 | metric_logger.add_meter('step', utils.SmoothedValue(window_size=1, fmt='{value:.0f}')) 195 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 196 | progress_bar = metric_logger.log_every(dataloader_train, cfg.logging.print_freq, header=log_header) 197 | 198 | # Train 199 | for i, (images, _, pseudolabels, _) in enumerate(progress_bar): 200 | if i >= cfg.get('limit_train_batches', math.inf): 201 | break 202 | 203 | # Forward 204 | output = model(images) # (B, C, H, W) 205 | 206 | # Cross-entropy loss 207 | loss = F.cross_entropy(output, pseudolabels) 208 | 209 | # Measure accuracy 210 | acc1, acc5 = utils.accuracy(output, pseudolabels, topk=(1, 5)) 211 | 212 | # Exit if loss is NaN 213 | loss_value = loss.item() 214 | if not math.isfinite(loss_value): 215 | print("Loss is {}, stopping training".format(loss_value)) 216 | sys.exit(1) 217 | 218 | # Loss scaling and backward 219 | accelerator.backward(loss) 220 | 221 | # Gradient accumulation, optimizer step, scheduler step 222 | if i % cfg.gradient_accumulation_steps == 0: 223 | optimizer.step() 224 | optimizer.zero_grad() 225 | torch.cuda.synchronize() 226 | if cfg.scheduler.stepwise: 227 | scheduler.step() 228 | train_state.step += 1 229 | 230 | # Model EMA 231 | if model_ema is not None and (train_state.epoch % cfg.ema.update_every) == 0: 232 | model_ema.update((p for p in model.parameters() if p.requires_grad)) 233 | 234 | # Logging 235 | log_dict = dict( 236 | lr=optimizer.param_groups[0]["lr"], step=train_state.step, 237 | train_loss=loss_value, sup_loss=sup_loss, con_loss=con_loss, 238 | train_top1=acc1[0], train_top5=acc5[0], 239 | ) 240 | metric_logger.update(**log_dict) 241 | if cfg.wandb and accelerator.is_local_main_process: 242 | wandb.log(log_dict) 243 | 244 | # Scheduler 245 | if not cfg.scheduler.stepwise: 246 | scheduler.step() 247 | 248 | # Epoch complete 249 | train_state.epoch += 1 250 | 251 | # Gather stats from all processes 252 | metric_logger.synchronize_between_processes(device=accelerator.device) 253 | print("Averaged stats:", metric_logger) 254 | return train_state 255 | 256 | 257 | @torch.no_grad() 258 | def evaluate( 259 | *, 260 | cfg: DictConfig, 261 | model: torch.nn.Module, 262 | dataloader_val: Iterable, 263 | accelerator: Accelerator, 264 | model_ema: Optional[object] = None, 265 | evaluate_dataset_pseudolabels: bool = False, 266 | **_unused_kwargs 267 | ): 268 | 269 | # To avoid CUDA errors on my machine 270 | torch.backends.cudnn.benchmark = False 271 | 272 | # Eval mode 273 | model.eval() 274 | torch.cuda.synchronize() 275 | eval_context = model_ema.average_parameters if cfg.ema.use_ema else nullcontext 276 | 277 | # Add background class 278 | n_classes = cfg.data.num_classes + 1 279 | 280 | # Iterate 281 | tp = [0] * n_classes 282 | fp = [0] * n_classes 283 | fn = [0] * n_classes 284 | 285 | # Check 286 | assert dataloader_val.batch_size == 1, 'Please use batch_size=1 for val to compute mIoU' 287 | 288 | # Load all pixel embeddings 289 | all_preds = np.zeros((len(dataloader_val) * 500 * 500), dtype=np.float32) 290 | all_gt = np.zeros((len(dataloader_val) * 500 * 500), dtype=np.float32) 291 | offset_ = 0 292 | 293 | # Add all pixels to our arrays 294 | with eval_context(): 295 | for (inputs, targets, mask, _) in tqdm(dataloader_val, desc='Concatenating all predictions'): 296 | 297 | # Predict 298 | if evaluate_dataset_pseudolabels: 299 | mask = mask 300 | else: 301 | logits = model(inputs.to(accelerator.device).contiguous()).squeeze(0) # (C, H, W) 302 | mask = torch.argmax(logits, dim=0) # (H, W) 303 | 304 | # Convert 305 | target = targets.numpy().squeeze() 306 | mask = mask.cpu().numpy().squeeze() 307 | 308 | # Check where ground-truth is valid and append valid pixels to the array 309 | valid = (target != 255) 310 | n_valid = np.sum(valid) 311 | all_gt[offset_:offset_+n_valid] = target[valid] 312 | 313 | # Possibly reshape embedding to match gt. 314 | if mask.shape != target.shape: 315 | raise ValueError(f'{mask.shape=} != {target.shape=}') 316 | 317 | # Append the predicted targets in the array 318 | all_preds[offset_:offset_+n_valid, ] = mask[valid] 319 | all_gt[offset_:offset_+n_valid, ] = target[valid] 320 | 321 | # Update offset_ 322 | offset_ += n_valid 323 | 324 | # Truncate to the actual number of pixels 325 | all_preds = all_preds[:offset_, ] 326 | all_gt = all_gt[:offset_, ] 327 | 328 | # TP, FP, and FN evaluation 329 | for i_part in range(0, n_classes): 330 | tmp_all_gt = (all_gt == i_part) 331 | tmp_pred = (all_preds == i_part) 332 | tp[i_part] += np.sum(tmp_all_gt & tmp_pred) 333 | fp[i_part] += np.sum(~tmp_all_gt & tmp_pred) 334 | fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred) 335 | 336 | # Calculate Jaccard index 337 | jac = [0] * n_classes 338 | for i_part in range(0, n_classes): 339 | jac[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8) 340 | 341 | # Print results 342 | eval_result = dict() 343 | eval_result['jaccards_all_categs'] = jac 344 | eval_result['mIoU'] = np.mean(jac) 345 | print('Evaluation of semantic segmentation ') 346 | print(f'Full eval result: {eval_result}') 347 | print('mIoU is %.2f' % (100*eval_result['mIoU'])) 348 | return eval_result 349 | 350 | 351 | @torch.no_grad() 352 | def generate( 353 | *, 354 | cfg: DictConfig, 355 | model: torch.nn.Module, 356 | dataloader_val: Iterable, 357 | accelerator: Accelerator, 358 | model_ema: Optional[object] = None, 359 | **_unused_kwargs 360 | ): 361 | 362 | # To avoid CUDA errors on my machine 363 | torch.backends.cudnn.benchmark = False 364 | 365 | # Eval mode 366 | model.eval() 367 | torch.cuda.synchronize() 368 | eval_context = model_ema.average_parameters if cfg.ema.use_ema else nullcontext 369 | 370 | # Create paths 371 | preds_dir = Path('preds') 372 | gt_dir = Path('gt') 373 | preds_dir.mkdir(exist_ok=True, parents=True) 374 | gt_dir.mkdir(exist_ok=True, parents=True) 375 | 376 | # Generate and save 377 | with eval_context(): 378 | for (inputs, targets, _, metadata) in tqdm(dataloader_val, desc='Concatenating all predictions'): 379 | # Predict 380 | logits = model(inputs.to(accelerator.device).contiguous()).squeeze(0) # (C, H, W) 381 | # Convert 382 | preds = torch.argmax(logits, dim=0).cpu().numpy().astype(np.uint8) 383 | gt = targets.squeeze().numpy().astype(np.uint8) 384 | # Save 385 | Image.fromarray(preds).convert('L').save(preds_dir / f"{metadata[0]['id']}.png") 386 | Image.fromarray(gt).convert('L').save(gt_dir / f"{metadata[0]['id']}.png") 387 | 388 | print(f'Saved to {Path(".").absolute()}') 389 | 390 | 391 | if __name__ == '__main__': 392 | main() 393 | -------------------------------------------------------------------------------- /object-localization/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main experiment file. Code adapted from LOST: https://github.com/valeoai/LOST 3 | """ 4 | import os 5 | import sys 6 | import argparse 7 | import random 8 | import pickle 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from pprint import pprint 14 | from typing import Union 15 | from pathlib import Path 16 | from torchvision import transforms 17 | from tqdm import tqdm 18 | from PIL import Image 19 | 20 | from networks import get_model 21 | from datasets import ImageDataset, Dataset, bbox_iou 22 | from visualizations import visualize_fms, visualize_predictions, visualize_seed_expansion 23 | from object_discovery import lost, detect_box, dino_seg, get_eigenvectors_from_features, get_largest_cc_box, get_bbox_from_patch_mask 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser("Visualize Self-Attention maps") 28 | parser.add_argument( 29 | "--arch", 30 | default="vit_small", 31 | type=str, 32 | choices=[ 33 | "vit_tiny", 34 | "vit_small", 35 | "vit_base", 36 | "resnet50", 37 | "vgg16_imagenet", 38 | "resnet50_imagenet", 39 | ], 40 | help="Model architecture.", 41 | ) 42 | parser.add_argument( 43 | "--patch_size", default=16, type=int, help="Patch resolution of the model." 44 | ) 45 | 46 | # Use a dataset 47 | parser.add_argument( 48 | "--dataset", 49 | default="VOC07", 50 | type=str, 51 | choices=[None, "VOC07", "VOC12", "COCO20k"], 52 | help="Dataset name.", 53 | ) 54 | parser.add_argument( 55 | "--set", 56 | default="trainval", 57 | type=str, 58 | choices=["val", "train", "trainval", "test"], 59 | help="Path of the image to load.", 60 | ) 61 | # Or use a single image 62 | parser.add_argument( 63 | "--image_path", 64 | type=str, 65 | default=None, 66 | help="If want to apply only on one image, give file path.", 67 | ) 68 | 69 | # Folder used to output visualizations and 70 | parser.add_argument( 71 | "--output_dir", type=str, default="outputs", help="Output directory to store predictions and visualizations." 72 | ) 73 | 74 | # Evaluation setup 75 | parser.add_argument("--no_hard", action="store_true", help="Only used in the case of the VOC_all setup (see the paper).") 76 | parser.add_argument("--no_evaluation", action="store_true", help="Compute the evaluation.") 77 | parser.add_argument("--save_predictions", default=True, type=bool, help="Save predicted bouding boxes.") 78 | 79 | # Visualization 80 | parser.add_argument( 81 | "--visualize", 82 | type=str, 83 | choices=["fms", "seed_expansion", "pred", None], 84 | default=None, 85 | help="Select the different type of visualizations.", 86 | ) 87 | 88 | # For ResNet dilation 89 | parser.add_argument("--resnet_dilate", type=int, default=2, help="Dilation level of the resnet model.") 90 | 91 | # LOST parameters 92 | parser.add_argument( 93 | "--which_features", 94 | type=str, 95 | default="k", 96 | choices=["k", "q", "v"], 97 | help="Which features to use", 98 | ) 99 | parser.add_argument( 100 | "--k_patches", 101 | type=int, 102 | default=100, 103 | help="Number of patches with the lowest degree considered." 104 | ) 105 | 106 | # Misc 107 | parser.add_argument("--name", type=str, default=None, help='Experiment name') 108 | parser.add_argument("--skip_if_exists", action='store_true', help='If results dir already exists , exit') 109 | 110 | # Use dino-seg proposed method 111 | parser.add_argument("--dinoseg", action="store_true", help="Apply DINO-seg baseline.") 112 | parser.add_argument("--dinoseg_head", type=int, default=4) 113 | 114 | # Use eigenvalue method 115 | parser.add_argument("--eigenseg", action='store_true', help='Apply eigenvalue method') 116 | parser.add_argument("--precomputed_eigs_dir", default=None, type=str, 117 | help='Apply eigenvalue method with precomputed bboxes') 118 | parser.add_argument("--precomputed_eigs_downsample", default=16, type=str) 119 | parser.add_argument("--which_matrix", choices=['infer', 'affinity', 'laplacian'], 120 | default='infer', help='Which matrix to use for eigenvector calculation') 121 | 122 | # Parse 123 | args = parser.parse_args() 124 | 125 | # Modify 126 | if args.image_path is not None: 127 | args.save_predictions = False 128 | args.no_evaluation = True 129 | args.dataset = None 130 | 131 | return args 132 | 133 | 134 | @torch.no_grad() 135 | def main(): 136 | 137 | # Args 138 | args = parse_args() 139 | 140 | # ------------------------------------------------------------------------------------------------------- 141 | # Dataset 142 | 143 | # Transform 144 | transform = transforms.Compose([ 145 | transforms.ToTensor(), 146 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 147 | ]) 148 | 149 | # If an image_path is given, apply the method only to the image 150 | if args.image_path is not None: 151 | dataset = ImageDataset(args.image_path, transform) 152 | else: 153 | dataset = Dataset(args.dataset, args.set, args.no_hard, transform) 154 | 155 | # Naming 156 | if args.name is not None: 157 | exp_name = args.name 158 | elif args.dinoseg: 159 | # Experiment with the baseline DINO-seg 160 | if "vit" not in args.arch: 161 | raise ValueError("DINO-seg can only be applied to tranformer networks.") 162 | exp_name = f"{args.arch}-{args.patch_size}_dinoseg-head{args.dinoseg_head}" 163 | else: 164 | # Experiment with LOST 165 | exp_name = f"LOST-{args.arch}" 166 | if "resnet" in args.arch: 167 | exp_name += f"dilate{args.resnet_dilate}" 168 | elif "vit" in args.arch: 169 | exp_name += f"{args.patch_size}_{args.which_features}" 170 | 171 | # ------------------------------------------------------------------------------------------------------- 172 | # Directories 173 | if args.image_path is None: 174 | args.output_dir = os.path.join(args.output_dir, dataset.name) 175 | 176 | # Skip if already exists 177 | exp_dir = Path(args.output_dir) / exp_name 178 | if args.skip_if_exists and exp_dir.is_dir() and len(list(exp_dir.iterdir())) > 0: 179 | print(f'Directory already exists and is not empty: {str(exp_dir)}') 180 | print(f'Exiting...') 181 | sys.exit() 182 | os.makedirs(args.output_dir, exist_ok=True) 183 | 184 | # ------------------------------------------------------------------------------------------------------- 185 | # Model 186 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 187 | model = get_model(args.arch, args.patch_size, args.resnet_dilate, device) 188 | 189 | print(f"Running LOST on the dataset {dataset.name} (exp: {exp_name})") 190 | print(f"Args:") 191 | print(pprint(args.__dict__)) 192 | 193 | # Visualization 194 | if args.visualize: 195 | vis_folder = f"{args.output_dir}/visualizations/{exp_name}" 196 | os.makedirs(vis_folder, exist_ok=True) 197 | 198 | # ------------------------------------------------------------------------------------------------------- 199 | # Loop over images 200 | preds_dict = {} 201 | gt_dict = {} 202 | cnt = 0 203 | corloc = np.zeros(len(dataset.dataloader)) 204 | 205 | pbar = tqdm(dataset.dataloader) 206 | for im_id, inp in enumerate(pbar): 207 | 208 | # ------------ IMAGE PROCESSING ------------------------------------------- 209 | img = inp[0] 210 | init_image_size = img.shape 211 | 212 | # Get the name of the image 213 | im_name = dataset.get_image_name(inp[1]) 214 | 215 | # Pass in case of no gt boxes in the image 216 | if im_name is None: 217 | continue 218 | 219 | # Padding the image with zeros to fit multiple of patch-size 220 | if args.eigenseg: 221 | size_im = ( 222 | img.shape[0], 223 | int(np.floor(img.shape[1] / args.patch_size) * args.patch_size), 224 | int(np.floor(img.shape[2] / args.patch_size) * args.patch_size), 225 | ) 226 | img = paded = img[:, :size_im[1], :size_im[2]] 227 | else: 228 | size_im = ( 229 | img.shape[0], 230 | int(np.ceil(img.shape[1] / args.patch_size) * args.patch_size), 231 | int(np.ceil(img.shape[2] / args.patch_size) * args.patch_size), 232 | ) 233 | paded = torch.zeros(size_im) 234 | paded[:, : img.shape[1], : img.shape[2]] = img 235 | img = paded 236 | 237 | # Size for transformers 238 | w_featmap = img.shape[-2] // args.patch_size 239 | h_featmap = img.shape[-1] // args.patch_size 240 | 241 | # ------------ GROUND-TRUTH ------------------------------------------- 242 | if not args.no_evaluation: 243 | gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name) 244 | 245 | if gt_bbxs is not None: 246 | # Discard images with no gt annotations 247 | # Happens only in the case of VOC07 and VOC12 248 | if gt_bbxs.shape[0] == 0 and args.no_hard: 249 | continue 250 | 251 | # ------------ EXTRACT FEATURES ------------------------------------------- 252 | 253 | # Load precomputer bounding boxes 254 | if args.eigenseg and args.precomputed_eigs_dir is not None: 255 | 256 | # Load 257 | if 'COCO' in dataset.name: 258 | fname = f"COCO_train2014_{int(im_name):012d}.pth" 259 | elif 'VOC' in dataset.name: 260 | fname = im_name.replace('.jpg', '.pth') 261 | precomputed_eigs_file = os.path.join(args.precomputed_eigs_dir, fname) 262 | precomputed_eigs = torch.load(precomputed_eigs_file, map_location='cpu') 263 | eigenvectors = precomputed_eigs['eigenvectors'] # tensor of shape (K, H_lr * W_lr) 264 | 265 | # Get eigenvectors 266 | if args.which_matrix == 'infer': # get the type 267 | which_matrix = Path(args.precomputed_eigs_dir).name.split('_')[0] 268 | else: 269 | which_matrix = args.which_matrix 270 | segment_index = {'matting': 1, 'laplacian': 1, 'affinity': 0}[which_matrix] 271 | patch_mask = (eigenvectors[segment_index] > 0) 272 | pred = get_bbox_from_patch_mask(patch_mask, init_image_size) 273 | 274 | # Extract features from self-supervised model 275 | else: 276 | 277 | # Move to GPU 278 | img = img.cuda(non_blocking=True) 279 | 280 | # ------------ FORWARD PASS ------------------------------------------- 281 | if "vit" in args.arch: 282 | # Store the outputs of qkv layer from the last attention layer 283 | feat_out = {} 284 | def hook_fn_forward_qkv(module, input, output): 285 | feat_out["qkv"] = output 286 | model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) 287 | 288 | # Forward pass in the model 289 | attentions = model.get_last_selfattention(img[None, :, :, :]) 290 | 291 | # Scaling factor 292 | scales = [args.patch_size, args.patch_size] 293 | 294 | # Dimensions 295 | nb_im = attentions.shape[0] # Batch size 296 | nh = attentions.shape[1] # Number of heads 297 | nb_tokens = attentions.shape[2] # Number of tokens 298 | 299 | # Baseline: compute DINO segmentation technique proposed in the DINO paper 300 | # and select the biggest component 301 | if args.dinoseg: 302 | pred = dino_seg(attentions, (w_featmap, h_featmap), args.patch_size, head=args.dinoseg_head) 303 | pred = np.asarray(pred) 304 | else: 305 | # Extract the qkv features of the last attention layer 306 | qkv = ( 307 | feat_out["qkv"] 308 | .reshape(nb_im, nb_tokens, 3, nh, -1 // nh) 309 | .permute(2, 0, 3, 1, 4) 310 | ) 311 | q, k, v = qkv[0], qkv[1], qkv[2] 312 | k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1) 313 | q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1) 314 | v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1) 315 | 316 | # Modality selection 317 | if args.which_features == "k": 318 | feats = k[:, 1:, :] 319 | elif args.which_features == "q": 320 | feats = q[:, 1:, :] 321 | elif args.which_features == "v": 322 | feats = v[:, 1:, :] 323 | elif "resnet" in args.arch: 324 | x = model.forward(img[None, :, :, :]) 325 | d, w_featmap, h_featmap = x.shape[1:] 326 | feats = x.reshape((1, d, -1)).transpose(2, 1) 327 | # Apply layernorm 328 | layernorm = nn.LayerNorm(feats.size()[1:]).to(device) 329 | feats = layernorm(feats) 330 | # Scaling factor 331 | scales = [ 332 | float(img.shape[1]) / x.shape[2], 333 | float(img.shape[2]) / x.shape[3], 334 | ] 335 | elif "vgg16" in args.arch: 336 | x = model.forward(img[None, :, :, :]) 337 | d, w_featmap, h_featmap = x.shape[1:] 338 | feats = x.reshape((1, d, -1)).transpose(2, 1) 339 | # Apply layernorm 340 | layernorm = nn.LayerNorm(feats.size()[1:]).to(device) 341 | feats = layernorm(feats) 342 | # Scaling factor 343 | scales = [ 344 | float(img.shape[1]) / x.shape[2], 345 | float(img.shape[2]) / x.shape[3], 346 | ] 347 | else: 348 | raise ValueError("Unknown model.") 349 | 350 | # Sizes 351 | dims_wh = [w_featmap, h_featmap] 352 | 353 | # ------------ Apply LOST ------------------------------------------- 354 | if not args.dinoseg: 355 | if args.eigenseg: 356 | 357 | # Get eigenvectors 358 | eigenvectors = get_eigenvectors_from_features(feats, args.which_matrix) 359 | 360 | # Get bounding box 361 | assert ('affinity' in args.which_matrix) ^ ('laplacian' in args.which_matrix) 362 | eig_index = 0 if 'affinity' in args.which_matrix else 1 363 | patch_mask = (eigenvectors[:, eig_index] > 0) 364 | pred = get_bbox_from_patch_mask(patch_mask, init_image_size) 365 | 366 | else: 367 | 368 | pred, A, M, scores, seed = lost(feats, dims_wh, scales, init_image_size, k_patches=args.k_patches) 369 | 370 | # ------------ Visualizations ------------------------------------------- 371 | if args.visualize == "fms": 372 | visualize_fms(A.clone().cpu().numpy(), seed, scores, dims_wh, scales, vis_folder, im_name) 373 | 374 | elif args.visualize == "seed_expansion": 375 | image = dataset.load_image(im_name) 376 | 377 | # Before expansion 378 | pred_seed, _ = detect_box(A[seed, :], seed, dims_wh, scales=scales, initial_im_size=init_image_size[1:]) 379 | visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims_wh, vis_folder, im_name) 380 | 381 | elif args.visualize == "pred": 382 | image = dataset.load_image(im_name) 383 | visualize_predictions(image, pred, seed, scales, dims_wh, vis_folder, im_name) 384 | 385 | # Save the prediction 386 | preds_dict[im_name] = pred 387 | gt_dict[im_name] = gt_bbxs 388 | 389 | # Evaluation 390 | if args.no_evaluation: 391 | continue 392 | 393 | # Compare prediction to GT boxes 394 | ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(gt_bbxs)) 395 | 396 | if torch.any(ious >= 0.5): 397 | corloc[im_id] = 1 398 | 399 | cnt += 1 400 | if cnt % 10 == 0: 401 | pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt} ({int(np.sum(corloc))/cnt * 100:.1f}%)") 402 | 403 | # Save predicted bounding boxes 404 | if args.save_predictions: 405 | folder = f"{args.output_dir}/{exp_name}" 406 | os.makedirs(folder, exist_ok=True) 407 | with open(os.path.join(folder, "preds.pkl"), "wb") as f: 408 | pickle.dump(preds_dict, f) 409 | with open(os.path.join(folder, "gt.pkl"), "wb") as f: 410 | pickle.dump(gt_dict, f) 411 | print(f"Predictions saved to {folder}") 412 | 413 | # Evaluate 414 | if not args.no_evaluation: 415 | print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})") 416 | result_file = os.path.join(folder, 'results.txt') 417 | with open(result_file, 'w') as f: 418 | f.write('corloc,%.1f,,\n'%(100*np.sum(corloc)/cnt)) 419 | print('File saved at %s'%result_file) 420 | 421 | 422 | if __name__ == "__main__": 423 | main() -------------------------------------------------------------------------------- /extract/extract.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from pathlib import Path 3 | from typing import Optional, Tuple 4 | 5 | import cv2 6 | import fire 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from accelerate import Accelerator 11 | from PIL import Image 12 | from scipy.sparse.linalg import eigsh 13 | from sklearn.cluster import KMeans, MiniBatchKMeans 14 | from sklearn.decomposition import PCA 15 | from torchvision.utils import draw_bounding_boxes 16 | from tqdm import tqdm 17 | 18 | import extract_utils as utils 19 | 20 | 21 | def extract_features( 22 | images_list: str, 23 | images_root: Optional[str], 24 | model_name: str, 25 | batch_size: int, 26 | output_dir: str, 27 | which_block: int = -1, 28 | ): 29 | """ 30 | Extract features from a list of images. 31 | 32 | Example: 33 | python extract.py extract_features \ 34 | --images_list "./data/VOC2012/lists/images.txt" \ 35 | --images_root "./data/VOC2012/images" \ 36 | --output_dir "./data/VOC2012/features/dino_vits16" \ 37 | --model_name dino_vits16 \ 38 | --batch_size 1 39 | """ 40 | 41 | # Output 42 | utils.make_output_dir(output_dir) 43 | 44 | # Models 45 | model_name = model_name.lower() 46 | model, val_transform, patch_size, num_heads = utils.get_model(model_name) 47 | 48 | # Add hook 49 | if 'dino' in model_name or 'mocov3' in model_name: 50 | feat_out = {} 51 | def hook_fn_forward_qkv(module, input, output): 52 | feat_out["qkv"] = output 53 | model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv) 54 | else: 55 | raise ValueError(model_name) 56 | 57 | # Dataset 58 | filenames = Path(images_list).read_text().splitlines() 59 | dataset = utils.ImagesDataset(filenames=filenames, images_root=images_root, transform=val_transform) 60 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8) 61 | print(f'Dataset size: {len(dataset)=}') 62 | print(f'Dataloader size: {len(dataloader)=}') 63 | 64 | # Prepare 65 | accelerator = Accelerator(fp16=True, cpu=False) 66 | # model, dataloader = accelerator.prepare(model, dataloader) 67 | model = model.to(accelerator.device) 68 | 69 | # Process 70 | pbar = tqdm(dataloader, desc='Processing') 71 | for i, (images, files, indices) in enumerate(pbar): 72 | output_dict = {} 73 | 74 | # Check if file already exists 75 | id = Path(files[0]).stem 76 | output_file = Path(output_dir) / f'{id}.pth' 77 | if output_file.is_file(): 78 | pbar.write(f'Skipping existing file {str(output_file)}') 79 | continue 80 | 81 | # Reshape image 82 | P = patch_size 83 | B, C, H, W = images.shape 84 | H_patch, W_patch = H // P, W // P 85 | H_pad, W_pad = H_patch * P, W_patch * P 86 | T = H_patch * W_patch + 1 # number of tokens, add 1 for [CLS] 87 | # images = F.interpolate(images, size=(H_pad, W_pad), mode='bilinear') # resize image 88 | images = images[:, :, :H_pad, :W_pad] 89 | images = images.to(accelerator.device) 90 | 91 | # Forward and collect features into output dict 92 | if 'dino' in model_name or 'mocov3' in model_name: 93 | # accelerator.unwrap_model(model).get_intermediate_layers(images)[0].squeeze(0) 94 | model.get_intermediate_layers(images)[0].squeeze(0) 95 | # output_dict['out'] = out 96 | output_qkv = feat_out["qkv"].reshape(B, T, 3, num_heads, -1 // num_heads).permute(2, 0, 3, 1, 4) 97 | # output_dict['q'] = output_qkv[0].transpose(1, 2).reshape(B, T, -1)[:, 1:, :] 98 | output_dict['k'] = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :] 99 | # output_dict['v'] = output_qkv[2].transpose(1, 2).reshape(B, T, -1)[:, 1:, :] 100 | else: 101 | raise ValueError(model_name) 102 | 103 | # Metadata 104 | output_dict['indices'] = indices[0] 105 | output_dict['file'] = files[0] 106 | output_dict['id'] = id 107 | output_dict['model_name'] = model_name 108 | output_dict['patch_size'] = patch_size 109 | output_dict['shape'] = (B, C, H, W) 110 | output_dict = {k: (v.detach().cpu() if torch.is_tensor(v) else v) for k, v in output_dict.items()} 111 | 112 | # Save 113 | accelerator.save(output_dict, str(output_file)) 114 | accelerator.wait_for_everyone() 115 | 116 | print(f'Saved features to {output_dir}') 117 | 118 | 119 | def _extract_eig( 120 | inp: Tuple[int, str], 121 | K: int, 122 | images_root: str, 123 | output_dir: str, 124 | which_matrix: str = 'laplacian', 125 | which_features: str = 'k', 126 | normalize: bool = True, 127 | lapnorm: bool = True, 128 | which_color_matrix: str = 'knn', 129 | threshold_at_zero: bool = True, 130 | image_downsample_factor: Optional[int] = None, 131 | image_color_lambda: float = 10, 132 | ): 133 | index, features_file = inp 134 | 135 | # Load 136 | data_dict = torch.load(features_file, map_location='cpu') 137 | image_id = data_dict['file'][:-4] 138 | 139 | # Load 140 | output_file = str(Path(output_dir) / f'{image_id}.pth') 141 | if Path(output_file).is_file(): 142 | print(f'Skipping existing file {str(output_file)}') 143 | return # skip because already generated 144 | 145 | # Load affinity matrix 146 | feats = data_dict[which_features].squeeze().cuda() 147 | if normalize: 148 | feats = F.normalize(feats, p=2, dim=-1) 149 | 150 | # Eigenvectors of affinity matrix 151 | if which_matrix == 'affinity_torch': 152 | W = feats @ feats.T 153 | if threshold_at_zero: 154 | W = (W * (W > 0)) 155 | eigenvalues, eigenvectors = torch.eig(W, eigenvectors=True) 156 | eigenvalues = eigenvalues.cpu() 157 | eigenvectors = eigenvectors.cpu() 158 | 159 | # Eigenvectors of affinity matrix with scipy 160 | elif which_matrix == 'affinity_svd': 161 | USV = torch.linalg.svd(feats, full_matrices=False) 162 | eigenvectors = USV[0][:, :K].T.to('cpu', non_blocking=True) 163 | eigenvalues = USV[1][:K].to('cpu', non_blocking=True) 164 | 165 | # Eigenvectors of affinity matrix with scipy 166 | elif which_matrix == 'affinity': 167 | W = (feats @ feats.T) 168 | if threshold_at_zero: 169 | W = (W * (W > 0)) 170 | W = W.cpu().numpy() 171 | eigenvalues, eigenvectors = eigsh(W, which='LM', k=K) 172 | eigenvectors = torch.flip(torch.from_numpy(eigenvectors), dims=(-1,)).T 173 | 174 | # Eigenvectors of matting laplacian matrix 175 | elif which_matrix in ['matting_laplacian', 'laplacian']: 176 | 177 | # Get sizes 178 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict) 179 | if image_downsample_factor is None: 180 | image_downsample_factor = P 181 | H_pad_lr, W_pad_lr = H_pad // image_downsample_factor, W_pad // image_downsample_factor 182 | 183 | # Upscale features to match the resolution 184 | if (H_patch, W_patch) != (H_pad_lr, W_pad_lr): 185 | feats = F.interpolate( 186 | feats.T.reshape(1, -1, H_patch, W_patch), 187 | size=(H_pad_lr, W_pad_lr), mode='bilinear', align_corners=False 188 | ).reshape(-1, H_pad_lr * W_pad_lr).T 189 | 190 | ### Feature affinities 191 | W_feat = (feats @ feats.T) 192 | if threshold_at_zero: 193 | W_feat = (W_feat * (W_feat > 0)) 194 | W_feat = W_feat / W_feat.max() # NOTE: If features are normalized, this naturally does nothing 195 | W_feat = W_feat.cpu().numpy() 196 | 197 | ### Color affinities 198 | # If we are fusing with color affinites, then load the image and compute 199 | if image_color_lambda > 0: 200 | 201 | # Load image 202 | image_file = str(Path(images_root) / f'{image_id}.jpg') 203 | image_lr = Image.open(image_file).resize((W_pad_lr, H_pad_lr), Image.BILINEAR) 204 | image_lr = np.array(image_lr) / 255. 205 | 206 | # Color affinities (of type scipy.sparse.csr_matrix) 207 | if which_color_matrix == 'knn': 208 | W_lr = utils.knn_affinity(image_lr) 209 | elif which_color_matrix == 'rw': 210 | W_lr = utils.rw_affinity(image_lr) 211 | 212 | # Convert to dense numpy array 213 | W_color = np.array(W_lr.todense().astype(np.float32)) 214 | 215 | else: 216 | 217 | # No color affinity 218 | W_color = 0 219 | 220 | # Combine 221 | W_comb = W_feat + W_color * image_color_lambda # combination 222 | D_comb = np.array(utils.get_diagonal(W_comb).todense()) # is dense or sparse faster? not sure, should check 223 | 224 | # Extract eigenvectors 225 | if lapnorm: 226 | try: 227 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM', M=D_comb) 228 | except: 229 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM', M=D_comb) 230 | else: 231 | try: 232 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM') 233 | except: 234 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM') 235 | eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float() 236 | 237 | # Sign ambiguity 238 | for k in range(eigenvectors.shape[0]): 239 | if 0.5 < torch.mean((eigenvectors[k] > 0).float()).item() < 1.0: # reverse segment 240 | eigenvectors[k] = 0 - eigenvectors[k] 241 | 242 | # Save dict 243 | output_dict = {'eigenvalues': eigenvalues, 'eigenvectors': eigenvectors} 244 | torch.save(output_dict, output_file) 245 | 246 | 247 | def extract_eigs( 248 | images_root: str, 249 | features_dir: str, 250 | output_dir: str, 251 | which_matrix: str = 'laplacian', 252 | which_color_matrix: str = 'knn', 253 | which_features: str = 'k', 254 | normalize: bool = True, 255 | threshold_at_zero: bool = True, 256 | lapnorm: bool = True, 257 | K: int = 20, 258 | image_downsample_factor: Optional[int] = None, 259 | image_color_lambda: float = 0.0, 260 | multiprocessing: int = 0 261 | ): 262 | """ 263 | Extracts eigenvalues from features. 264 | 265 | Example: 266 | python extract.py extract_eigs \ 267 | --images_root "./data/VOC2012/images" \ 268 | --features_dir "./data/VOC2012/features/dino_vits16" \ 269 | --which_matrix "laplacian" \ 270 | --output_dir "./data/VOC2012/eigs/laplacian" \ 271 | --K 5 272 | """ 273 | utils.make_output_dir(output_dir) 274 | kwargs = dict(K=K, which_matrix=which_matrix, which_features=which_features, which_color_matrix=which_color_matrix, 275 | normalize=normalize, threshold_at_zero=threshold_at_zero, images_root=images_root, output_dir=output_dir, 276 | image_downsample_factor=image_downsample_factor, image_color_lambda=image_color_lambda, lapnorm=lapnorm) 277 | print(kwargs) 278 | fn = partial(_extract_eig, **kwargs) 279 | inputs = list(enumerate(sorted(Path(features_dir).iterdir()))) 280 | utils.parallel_process(inputs, fn, multiprocessing) 281 | 282 | 283 | def _extract_multi_region_segmentations( 284 | inp: Tuple[int, Tuple[str, str]], 285 | adaptive: bool, 286 | non_adaptive_num_segments: int, 287 | infer_bg_index: bool, 288 | kmeans_baseline: bool, 289 | output_dir: str, 290 | num_eigenvectors: int, 291 | ): 292 | index, (feature_path, eigs_path) = inp 293 | 294 | # Load 295 | data_dict = torch.load(feature_path, map_location='cpu') 296 | data_dict.update(torch.load(eigs_path, map_location='cpu')) 297 | 298 | # Output file 299 | id = Path(data_dict['id']) 300 | output_file = str(Path(output_dir) / f'{id}.png') 301 | if Path(output_file).is_file(): 302 | print(f'Skipping existing file {str(output_file)}') 303 | return # skip because already generated 304 | 305 | # Sizes 306 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict) 307 | 308 | # If adaptive, we use the gaps between eigenvalues to determine the number of 309 | # segments per image. If not, we use non_adaptive_num_segments to get a fixed 310 | # number of segments per image. 311 | if adaptive: 312 | indices_by_gap = np.argsort(np.diff(data_dict['eigenvalues'].numpy()))[::-1] 313 | index_largest_gap = indices_by_gap[indices_by_gap != 0][0] # remove zero and take the biggest 314 | n_clusters = index_largest_gap + 1 315 | # print(f'Number of clusters: {n_clusters}') 316 | else: 317 | n_clusters = non_adaptive_num_segments 318 | 319 | # K-Means 320 | kmeans = KMeans(n_clusters=n_clusters) 321 | 322 | # Compute segments using eigenvector or baseline K-means 323 | if kmeans_baseline: 324 | feats = data_dict['k'].squeeze().numpy() 325 | clusters = kmeans.fit_predict(feats) 326 | else: 327 | eigenvectors = data_dict['eigenvectors'][1:1+num_eigenvectors].numpy() # take non-constant eigenvectors 328 | # import pdb; pdb.set_trace() 329 | clusters = kmeans.fit_predict(eigenvectors.T) 330 | 331 | # Reshape 332 | if clusters.size == H_patch * W_patch: # TODO: better solution might be to pass in patch index 333 | segmap = clusters.reshape(H_patch, W_patch) 334 | elif clusters.size == H_patch * W_patch * 4: 335 | segmap = clusters.reshape(H_patch * 2, W_patch * 2) 336 | else: 337 | raise ValueError() 338 | 339 | # TODO: Improve this step in the pipeline. 340 | # Background detection: we assume that the segment with the most border pixels is the 341 | # background region. We will always make this region equal 0. 342 | if infer_bg_index: 343 | indices, normlized_counts = utils.get_border_fraction(segmap) 344 | bg_index = indices[np.argmax(normlized_counts)].item() 345 | bg_region = (segmap == bg_index) 346 | zero_region = (segmap == 0) 347 | segmap[bg_region] = 0 348 | segmap[zero_region] = bg_index 349 | 350 | # Save dict 351 | Image.fromarray(segmap).convert('L').save(output_file) 352 | 353 | 354 | def extract_multi_region_segmentations( 355 | features_dir: str, 356 | eigs_dir: str, 357 | output_dir: str, 358 | adaptive: bool = False, 359 | non_adaptive_num_segments: int = 4, 360 | infer_bg_index: bool = True, 361 | kmeans_baseline: bool = False, 362 | num_eigenvectors: int = 1_000_000, 363 | multiprocessing: int = 0 364 | ): 365 | """ 366 | Example: 367 | python extract.py extract_multi_region_segmentations \ 368 | --features_dir "./data/VOC2012/features/dino_vits16" \ 369 | --eigs_dir "./data/VOC2012/eigs/laplacian" \ 370 | --output_dir "./data/VOC2012/multi_region_segmentation/fixed" \ 371 | """ 372 | utils.make_output_dir(output_dir) 373 | fn = partial(_extract_multi_region_segmentations, adaptive=adaptive, infer_bg_index=infer_bg_index, 374 | non_adaptive_num_segments=non_adaptive_num_segments, num_eigenvectors=num_eigenvectors, 375 | kmeans_baseline=kmeans_baseline, output_dir=output_dir) 376 | inputs = utils.get_paired_input_files(features_dir, eigs_dir) 377 | utils.parallel_process(inputs, fn, multiprocessing) 378 | 379 | 380 | def _extract_single_region_segmentations( 381 | inp: Tuple[int, Tuple[str, str]], 382 | threshold: float, 383 | output_dir: str, 384 | ): 385 | index, (feature_path, eigs_path) = inp 386 | 387 | # Load 388 | data_dict = torch.load(feature_path, map_location='cpu') 389 | data_dict.update(torch.load(eigs_path, map_location='cpu')) 390 | 391 | # Output file 392 | id = Path(data_dict['id']) 393 | output_file = str(Path(output_dir) / f'{id}.png') 394 | if Path(output_file).is_file(): 395 | print(f'Skipping existing file {str(output_file)}') 396 | return # skip because already generated 397 | 398 | # Sizes 399 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict) 400 | 401 | # Eigenvector 402 | eigenvector = data_dict['eigenvectors'][1].numpy() # take smallest non-zero eigenvector 403 | segmap = (eigenvector > threshold).reshape(H_patch, W_patch) 404 | 405 | # Save dict 406 | Image.fromarray(segmap).convert('L').save(output_file) 407 | 408 | 409 | def extract_single_region_segmentations( 410 | features_dir: str, 411 | eigs_dir: str, 412 | output_dir: str, 413 | threshold: float = 0.0, 414 | multiprocessing: int = 0 415 | ): 416 | """ 417 | Example: 418 | python extract.py extract_single_region_segmentations \ 419 | --features_dir "./data/VOC2012/features/dino_vits16" \ 420 | --eigs_dir "./data/VOC2012/eigs/laplacian" \ 421 | --output_dir "./data/VOC2012/single_region_segmentation/patches" \ 422 | """ 423 | utils.make_output_dir(output_dir) 424 | fn = partial(_extract_single_region_segmentations, threshold=threshold, output_dir=output_dir) 425 | inputs = utils.get_paired_input_files(features_dir, eigs_dir) 426 | utils.parallel_process(inputs, fn, multiprocessing) 427 | 428 | 429 | def _extract_bbox( 430 | inp: Tuple[str, str], 431 | num_erode: int, 432 | num_dilate: int, 433 | skip_bg_index: bool, 434 | downsample_factor: Optional[int] = None 435 | ): 436 | index, (feature_path, segmentation_path) = inp 437 | 438 | # Load 439 | data_dict = torch.load(feature_path, map_location='cpu') 440 | segmap = np.array(Image.open(str(segmentation_path))) 441 | image_id = data_dict['id'] 442 | 443 | # Sizes 444 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict, downsample_factor) 445 | 446 | # Get bounding boxes 447 | outputs = {'bboxes': [], 'bboxes_original_resolution': [], 'segment_indices': [], 'id': image_id, 448 | 'format': "(xmin, ymin, xmax, ymax)"} 449 | for segment_index in sorted(np.unique(segmap).tolist()): 450 | if (not skip_bg_index) or (segment_index > 0): # skip 0, because 0 is the background 451 | 452 | # Erode and dilate mask 453 | binary_mask = (segmap == segment_index) 454 | binary_mask = utils.erode_or_dilate_mask(binary_mask, r=num_erode, erode=True) 455 | binary_mask = utils.erode_or_dilate_mask(binary_mask, r=num_dilate, erode=False) 456 | 457 | # Find box 458 | mask = np.where(binary_mask == 1) 459 | ymin, ymax = min(mask[0]), max(mask[0]) + 1 # add +1 because excluded max 460 | xmin, xmax = min(mask[1]), max(mask[1]) + 1 # add +1 because excluded max 461 | bbox = [xmin, ymin, xmax, ymax] 462 | bbox_resized = [x * P for x in bbox] # rescale to image size 463 | bbox_features = [ymin, xmin, ymax, xmax] # feature space coordinates are different 464 | 465 | # Append 466 | outputs['segment_indices'].append(segment_index) 467 | outputs['bboxes'].append(bbox) 468 | outputs['bboxes_original_resolution'].append(bbox_resized) 469 | 470 | return outputs 471 | 472 | 473 | def extract_bboxes( 474 | features_dir: str, 475 | segmentations_dir: str, 476 | output_file: str, 477 | num_erode: int = 2, 478 | num_dilate: int = 3, 479 | skip_bg_index: bool = True, 480 | downsample_factor: Optional[int] = None, 481 | ): 482 | """ 483 | Note: There is no need for multiprocessing here, as it is more convenient to save 484 | the entire output as a single JSON file. Example: 485 | python extract.py extract_bboxes \ 486 | --features_dir "./data/VOC2012/features/dino_vits16" \ 487 | --segmentations_dir "./data/VOC2012/multi_region_segmentation/fixed" \ 488 | --num_erode 2 --num_dilate 5 \ 489 | --output_file "./data/VOC2012/multi_region_bboxes/fixed/bboxes_e2_d5.pth" \ 490 | """ 491 | utils.make_output_dir(str(Path(output_file).parent), check_if_empty=False) 492 | fn = partial(_extract_bbox, num_erode=num_erode, num_dilate=num_dilate, skip_bg_index=skip_bg_index, 493 | downsample_factor=downsample_factor) 494 | inputs = utils.get_paired_input_files(features_dir, segmentations_dir) 495 | all_outputs = [fn(inp) for inp in tqdm(inputs, desc='Extracting bounding boxes')] 496 | torch.save(all_outputs, output_file) 497 | print('Done') 498 | 499 | 500 | def extract_bbox_features( 501 | images_root: str, 502 | bbox_file: str, 503 | model_name: str, 504 | output_file: str, 505 | ): 506 | """ 507 | Example: 508 | python extract.py extract_bbox_features \ 509 | --model_name dino_vits16 \ 510 | --images_root "./data/VOC2012/images" \ 511 | --bbox_file "./data/VOC2012/multi_region_bboxes/fixed/bboxes_e2_d5.pth" \ 512 | --output_file "./data/VOC2012/features/dino_vits16" \ 513 | --output_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_features_e2_d5.pth" \ 514 | """ 515 | 516 | # Load bounding boxes 517 | bbox_list = torch.load(bbox_file) 518 | total_num_boxes = sum(len(d['bboxes']) for d in bbox_list) 519 | print(f'Loaded bounding box list. There are {total_num_boxes} total bounding boxes.') 520 | 521 | # Models 522 | model_name_lower = model_name.lower() 523 | model, val_transform, patch_size, num_heads = utils.get_model(model_name_lower) 524 | model.eval().to('cuda') 525 | 526 | # Loop over boxes 527 | for bbox_dict in tqdm(bbox_list): 528 | # Get image info 529 | image_id = bbox_dict['id'] 530 | bboxes = bbox_dict['bboxes_original_resolution'] 531 | # Load image as tensor 532 | image_filename = str(Path(images_root) / f'{image_id}.jpg') 533 | image = val_transform(Image.open(image_filename).convert('RGB')) # (3, H, W) 534 | image = image.unsqueeze(0).to('cuda') # (1, 3, H, W) 535 | features_crops = [] 536 | for (xmin, ymin, xmax, ymax) in bboxes: 537 | image_crop = image[:, :, ymin:ymax, xmin:xmax] 538 | features_crop = model(image_crop).squeeze().cpu() 539 | features_crops.append(features_crop) 540 | bbox_dict['features'] = torch.stack(features_crops, dim=0) 541 | 542 | # Save 543 | torch.save(bbox_list, output_file) 544 | print(f'Saved features to {output_file}') 545 | 546 | 547 | def extract_bbox_clusters( 548 | bbox_features_file: str, 549 | output_file: str, 550 | num_clusters: int = 20, 551 | seed: int = 0, 552 | pca_dim: Optional[int] = 0, 553 | ): 554 | """ 555 | Example: 556 | python extract.py extract_bbox_clusters \ 557 | --bbox_features_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_features_e2_d5.pth" \ 558 | --pca_dim 32 --num_clusters 21 --seed 0 \ 559 | --output_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_clusters_e2_d5_pca_32.pth" \ 560 | """ 561 | 562 | # Load bounding boxes 563 | bbox_list = torch.load(bbox_features_file) 564 | total_num_boxes = sum(len(d['bboxes']) for d in bbox_list) 565 | print(f'Loaded bounding box list. There are {total_num_boxes} total bounding boxes with features.') 566 | 567 | # Loop over boxes and stack features with PyTorch, because Numpy is too slow 568 | print(f'Stacking and normalizing features') 569 | all_features = torch.cat([bbox_dict['features'] for bbox_dict in bbox_list], dim=0) # (numBbox, D) 570 | all_features = all_features / torch.norm(all_features, dim=-1, keepdim=True) # (numBbox, D)f 571 | all_features = all_features.numpy() 572 | 573 | # Cluster: PCA 574 | if pca_dim: 575 | pca = PCA(pca_dim) 576 | print(f'Computing PCA with dimension {pca_dim}') 577 | all_features = pca.fit_transform(all_features) 578 | 579 | # Cluster: K-Means 580 | print(f'Computing K-Means clustering with {num_clusters} clusters') 581 | kmeans = MiniBatchKMeans(n_clusters=num_clusters, batch_size=4096, max_iter=5000, random_state=seed) 582 | clusters = kmeans.fit_predict(all_features) 583 | 584 | # Print 585 | _indices, _counts = np.unique(clusters, return_counts=True) 586 | print(f'Cluster indices: {_indices.tolist()}') 587 | print(f'Cluster counts: {_counts.tolist()}') 588 | 589 | # Loop over boxes and add clusters 590 | idx = 0 591 | for bbox_dict in bbox_list: 592 | num_bboxes = len(bbox_dict['bboxes']) 593 | del bbox_dict['features'] # bbox_dict['features'] = bbox_dict['features'].squeeze() 594 | bbox_dict['clusters'] = clusters[idx: idx + num_bboxes] 595 | idx = idx + num_bboxes 596 | 597 | # Save 598 | torch.save(bbox_list, output_file) 599 | print(f'Saved features to {output_file}') 600 | 601 | 602 | def extract_semantic_segmentations( 603 | segmentations_dir: str, 604 | bbox_clusters_file: str, 605 | output_dir: str, 606 | ): 607 | """ 608 | Example: 609 | python extract.py extract_semantic_segmentations \ 610 | --segmentations_dir "./data/VOC2012/multi_region_segmentation/fixed" \ 611 | --bbox_clusters_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_clusters_e2_d5_pca_32.pth" \ 612 | --output_dir "./data/VOC2012/semantic_segmentations/patches/fixed/segmaps_e2_d5_pca_32" \ 613 | """ 614 | 615 | # Load bounding boxes 616 | bbox_list = torch.load(bbox_clusters_file) 617 | total_num_boxes = sum(len(d['bboxes']) for d in bbox_list) 618 | print(f'Loaded bounding box list. There are {total_num_boxes} total bounding boxes with features and clusters.') 619 | 620 | # Output 621 | utils.make_output_dir(output_dir) 622 | 623 | # Loop over boxes 624 | for bbox_dict in tqdm(bbox_list): 625 | # Get image info 626 | image_id = bbox_dict['id'] 627 | # Load segmentation as tensor 628 | segmap_path = str(Path(segmentations_dir) / f'{image_id}.png') 629 | segmap = np.array(Image.open(segmap_path)) 630 | # Check if the segmap is a binary file with foreground pixels saved as 255 instead of 1 631 | # this will be the case for some of our baselines 632 | if set(np.unique(segmap).tolist()).issubset({0, 255}): 633 | segmap[segmap == 255] = 1 634 | # Semantic map 635 | if not len(bbox_dict['segment_indices']) == len(bbox_dict['clusters'].tolist()): 636 | import pdb 637 | pdb.set_trace() 638 | semantic_map = dict(zip(bbox_dict['segment_indices'], bbox_dict['clusters'].tolist())) 639 | assert 0 not in semantic_map, semantic_map 640 | semantic_map[0] = 0 # background region remains zero 641 | # Perform mapping 642 | semantic_segmap = np.vectorize(semantic_map.__getitem__)(segmap) 643 | # Save 644 | output_file = str(Path(output_dir) / f'{image_id}.png') 645 | Image.fromarray(semantic_segmap.astype(np.uint8)).convert('L').save(output_file) 646 | 647 | print(f'Saved features to {output_dir}') 648 | 649 | 650 | def _extract_crf_segmentations( 651 | inp: Tuple[int, Tuple[str, str]], 652 | images_root: str, 653 | num_classes: int, 654 | output_dir: str, 655 | crf_params: Tuple, 656 | downsample_factor: int = 16, 657 | ): 658 | index, (image_file, segmap_path) = inp 659 | 660 | # Output file 661 | id = Path(image_file).stem 662 | output_file = str(Path(output_dir) / f'{id}.png') 663 | if Path(output_file).is_file(): 664 | print(f'Skipping existing file {str(output_file)}') 665 | return # skip because already generated 666 | 667 | # Load image and segmap 668 | image_file = str(Path(images_root) / f'{id}.jpg') 669 | image = np.array(Image.open(image_file).convert('RGB')) # (H_patch, W_patch, 3) 670 | segmap = np.array(Image.open(segmap_path)) # (H_patch, W_patch) 671 | 672 | # Sizes 673 | P = downsample_factor 674 | H, W = image.shape[:2] 675 | H_patch, W_patch = H // P, W // P 676 | H_pad, W_pad = H_patch * P, W_patch * P 677 | 678 | # Resize and expand 679 | segmap_upscaled = cv2.resize(segmap, dsize=(W_pad, H_pad), interpolation=cv2.INTER_NEAREST) # (H_pad, W_pad) 680 | segmap_orig_res = cv2.resize(segmap, dsize=(W, H), interpolation=cv2.INTER_NEAREST) # (H, W) 681 | segmap_orig_res[:H_pad, :W_pad] = segmap_upscaled # replace with the correctly upscaled version, just in case they are different 682 | 683 | # Convert binary 684 | if set(np.unique(segmap_orig_res).tolist()) == {0, 255}: 685 | segmap_orig_res[segmap_orig_res == 255] = 1 686 | 687 | # CRF 688 | import denseCRF # make sure you've installed SimpleCRF 689 | unary_potentials = F.one_hot(torch.from_numpy(segmap_orig_res).long(), num_classes=num_classes) 690 | segmap_crf = denseCRF.densecrf(image, unary_potentials, crf_params) # (H_pad, W_pad) 691 | 692 | # Save 693 | Image.fromarray(segmap_crf).convert('L').save(output_file) 694 | 695 | 696 | def extract_crf_segmentations( 697 | images_list: str, 698 | images_root: str, 699 | segmentations_dir: str, 700 | output_dir: str, 701 | num_classes: int = 21, 702 | downsample_factor: int = 16, 703 | multiprocessing: int = 0, 704 | # CRF parameters 705 | w1 = 10, # weight of bilateral term # default: 10.0, 706 | alpha = 80, # spatial std # default: 80, 707 | beta = 13, # rgb std # default: 13, 708 | w2 = 3, # weight of spatial term # default: 3.0, 709 | gamma = 3, # spatial std # default: 3, 710 | it = 5.0, # iteration # default: 5.0, 711 | ): 712 | """ 713 | Applies a CRF to segmentations in order to sharpen them. 714 | 715 | Example: 716 | python extract.py extract_crf_segmentations \ 717 | --images_list "./data/VOC2012/lists/images.txt" \ 718 | --images_root "./data/VOC2012/images" \ 719 | --segmentations_dir "./data/VOC2012/semantic_segmentations/patches/fixed/segmaps_e2_d5_pca_32" \ 720 | --output_dir "./data/VOC2012/semantic_segmentations/crf/fixed/segmaps_e2_d5_pca_32" \ 721 | """ 722 | try: 723 | import denseCRF 724 | except: 725 | raise ImportError( 726 | 'Please install SimpleCRF to compute CRF segmentations:\n' 727 | 'pip3 install SimpleCRF' 728 | ) 729 | 730 | utils.make_output_dir(output_dir) 731 | fn = partial(_extract_crf_segmentations, images_root=images_root, num_classes=num_classes, output_dir=output_dir, 732 | crf_params=(w1, alpha, beta, w2, gamma, it), downsample_factor=downsample_factor) 733 | inputs = utils.get_paired_input_files(images_list, segmentations_dir) 734 | print(f'Found {len(inputs)} images and segmaps') 735 | utils.parallel_process(inputs, fn, multiprocessing) 736 | 737 | 738 | def vis_segmentations( 739 | images_list: str, 740 | images_root: str, 741 | segmentations_dir: str, 742 | bbox_file: Optional[str] = None, 743 | ): 744 | """ 745 | Example: 746 | streamlit run extract.py vis_segmentations -- \ 747 | --images_list "./data/VOC2012/lists/images.txt" \ 748 | --images_root "./data/VOC2012/images" \ 749 | --segmentations_dir "./data/VOC2012/multi_region_segmentation/fixed" 750 | or alternatively: 751 | --segmentations_dir "./data/VOC2012/semantic_segmentations/crf/fixed/segmaps_e2_d5_pca_32/" 752 | """ 753 | # Streamlit setup 754 | import streamlit as st 755 | from matplotlib.cm import get_cmap 756 | from skimage.color import label2rgb 757 | st.set_page_config(layout='wide') 758 | 759 | # Inputs 760 | image_paths = [] 761 | segmap_paths = [] 762 | images_root = Path(images_root) 763 | segmentations_dir = Path(segmentations_dir) 764 | for image_file in Path(images_list).read_text().splitlines(): 765 | segmap_file = f'{Path(image_file).stem}.png' 766 | image_paths.append(images_root / image_file) 767 | segmap_paths.append(segmentations_dir / segmap_file) 768 | print(f'Found {len(image_paths)} image and segmap paths') 769 | 770 | # Load optional bounding boxes 771 | if bbox_file is not None: 772 | bboxes_list = torch.load(bbox_file) 773 | 774 | # Colors 775 | colors = get_cmap('tab20', 21).colors[:, :3] 776 | 777 | # Which index 778 | which_index = st.number_input(label='Which index to view (0 for all)', value=0) 779 | 780 | # Load 781 | total = 0 782 | for i, (image_path, segmap_path) in enumerate(zip(image_paths, segmap_paths)): 783 | if total > 40: break 784 | image_id = image_path.stem 785 | 786 | # Streamlit 787 | cols = [] 788 | 789 | # Load 790 | image = np.array(Image.open(image_path).convert('RGB')) 791 | segmap = np.array(Image.open(segmap_path)) 792 | 793 | # Convert binary 794 | if set(np.unique(segmap).tolist()) == {0, 255}: 795 | segmap[segmap == 255] = 1 796 | 797 | # Resize 798 | segmap_fullres = cv2.resize(segmap, dsize=image.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) 799 | 800 | # Only view images with a specific class 801 | if which_index not in np.unique(segmap): 802 | continue 803 | total += 1 804 | 805 | # Streamlit 806 | cols.append({'image': image, 'caption': image_id}) 807 | 808 | # Load optional bounding boxes 809 | bboxes = None 810 | if bbox_file is not None: 811 | bboxes = torch.tensor(bboxes_list[i]['bboxes_original_resolution']) 812 | assert bboxes_list[i]['id'] == image_id, f"{bboxes_list[i]['id']=} but {image_id=}" 813 | image_torch = torch.from_numpy(image).permute(2, 0, 1) 814 | image_with_boxes_torch = draw_bounding_boxes(image_torch, bboxes) 815 | image_with_boxes = image_with_boxes_torch.permute(1, 2, 0).numpy() 816 | 817 | # Streamlit 818 | cols.append({'image': image_with_boxes}) 819 | 820 | # Color 821 | segmap_label_indices, segmap_label_counts = np.unique(segmap, return_counts=True) 822 | blank_segmap_overlay = label2rgb(label=segmap_fullres, image=np.full_like(image, 128), 823 | colors=colors[segmap_label_indices[segmap_label_indices != 0]], bg_label=0, alpha=1.0) 824 | image_segmap_overlay = label2rgb(label=segmap_fullres, image=image, 825 | colors=colors[segmap_label_indices[segmap_label_indices != 0]], bg_label=0, alpha=0.45) 826 | segmap_caption = dict(zip(segmap_label_indices.tolist(), (segmap_label_counts).tolist())) 827 | 828 | # Streamlit 829 | cols.append({'image': blank_segmap_overlay, 'caption': segmap_caption}) 830 | cols.append({'image': image_segmap_overlay, 'caption': segmap_caption}) 831 | 832 | # Display 833 | for d, col in zip(cols, st.columns(len(cols))): 834 | col.image(**d) 835 | 836 | 837 | if __name__ == '__main__': 838 | torch.set_grad_enabled(False) 839 | fire.Fire(dict( 840 | extract_features=extract_features, 841 | extract_eigs=extract_eigs, 842 | extract_multi_region_segmentations=extract_multi_region_segmentations, 843 | extract_bboxes=extract_bboxes, 844 | extract_bbox_features=extract_bbox_features, 845 | extract_bbox_clusters=extract_bbox_clusters, 846 | extract_semantic_segmentations=extract_semantic_segmentations, 847 | extract_crf_segmentations=extract_crf_segmentations, 848 | extract_single_region_segmentations=extract_single_region_segmentations, 849 | vis_segmentations=vis_segmentations, 850 | )) 851 | --------------------------------------------------------------------------------