├── object-localization
├── __init__.py
├── visualizations.py
├── networks.py
├── object_discovery.py
├── datasets.py
└── main.py
├── semantic-segmentation
├── __init__.py
├── config
│ ├── eval.yaml
│ ├── base.yaml
│ └── train.yaml
├── model
│ ├── __init__.py
│ └── model.py
├── eval_utils.py
├── dataset
│ ├── __init__.py
│ └── voc.py
├── README.md
├── eval.py
├── util.py
└── train.py
├── requirements.txt
├── object-segmentation
├── config
│ └── eval.yaml
├── dataset.py
├── metrics.py
├── main.py
└── util.py
├── .gitignore
├── extract
├── extract_utils.py
└── extract.py
└── README.md
/object-localization/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/semantic-segmentation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | fire
3 | opencv-python-headless
4 | pillow
5 | scikit-image
6 | scipy
7 | torch
8 | torchvision
9 | tqdm
10 | pymatting
11 |
--------------------------------------------------------------------------------
/semantic-segmentation/config/eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - base
4 | - _self_
5 |
6 | name: 'eval'
7 | job_type: 'eval'
8 |
9 | segments_dir: ""
10 |
--------------------------------------------------------------------------------
/semantic-segmentation/model/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 |
6 | from .model import get_deeplab_resnet, get_deeplab_vit
7 |
8 |
9 | def get_model(name: str, num_classes: int):
10 | if 'resnet' in name:
11 | model = get_deeplab_resnet(num_classes=(num_classes + 1)) # add 1 for bg
12 | elif 'vit' in name:
13 | model = get_deeplab_vit(backbone_name=name, num_classes=(num_classes + 1)) # add 1 for bg
14 | else:
15 | raise NotImplementedError()
16 | return model
17 |
18 |
--------------------------------------------------------------------------------
/semantic-segmentation/config/base.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | hydra:
3 | run:
4 | dir: ./outputs/${name}/${now:%Y-%m-%d--%H-%M-%S}
5 |
6 | name: "debug"
7 | seed: 1
8 | job_type: 'train'
9 | fp16: False
10 | cpu: False
11 | wandb: False
12 | wandb_kwargs:
13 | project: deep-spectral-segmentation
14 |
15 | data:
16 | num_classes: 20
17 | dataset: pascal
18 | train_kwargs:
19 | root: ${oc.env:HOME}/machine-learning-datasets/semantic-segmentation/PASCAL_VOC/VOC2012
20 | year: "2012"
21 | image_set: train
22 | download: False
23 | val_kwargs:
24 | root: ${oc.env:HOME}/machine-learning-datasets/semantic-segmentation/PASCAL_VOC/VOC2012
25 | year: "2012"
26 | image_set: "val"
27 | download: False
28 | loader:
29 | batch_size: 144
30 | num_workers: 8
31 | pin_memory: False
32 | transform:
33 | resize_size: 256
34 | crop_size: 224
35 | img_mean: [0.485, 0.456, 0.406]
36 | img_std: [0.229, 0.224, 0.225]
37 |
38 | segments_dir: ""
39 |
40 | logging:
41 | print_freq: 50
--------------------------------------------------------------------------------
/semantic-segmentation/config/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - base
4 | - _self_
5 |
6 | job_type: 'train'
7 | eval_every: 1 # eval every this many epochs
8 | checkpoint_every: 10 # checkpoint every this many epochs
9 |
10 | unfrozen_backbone_layers: 1 # -1 to train all, 0 to freeze entirely, > 0 to specify
11 | model:
12 | name: resnet50
13 | num_classes: ${data.num_classes}
14 |
15 | # Please change these
16 | segments_dir: ""
17 | matching: ""
18 |
19 | checkpoint:
20 | resume: null
21 | resume_training: True
22 | resume_optimizer_only: False
23 |
24 | # Exponential moving average of model parameters
25 | ema:
26 | use_ema: False
27 | decay: 0.999
28 | update_every: 10
29 |
30 | # Training steps/epochs
31 | max_train_steps: 5000
32 | max_train_epochs: null
33 |
34 | # Optimization
35 | lr: 0.005
36 | gradient_accumulation_steps: 1
37 | optimizer:
38 | scale_learning_rate_with_batch_size: False
39 | clip_grad_norm: null
40 |
41 | # Timm optimizer
42 | kind: 'timm'
43 | kwargs:
44 | opt: 'adamw'
45 | weight_decay: 1e-8
46 |
47 | # Learning rate scheduling
48 | scheduler:
49 |
50 | # Transformers scheduler
51 | kind: 'transformers'
52 | stepwise: True
53 | kwargs:
54 | name: linear
55 | num_warmup_steps: 0
56 | num_training_steps: ${max_train_steps}
57 |
--------------------------------------------------------------------------------
/object-segmentation/config/eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - _self_
4 |
5 | hydra:
6 | run:
7 | dir: ./outputs/${name}/${now:%Y-%m-%d_%H-%M-%S}
8 |
9 | # General
10 | name: "debug"
11 | seed: 1
12 | job_type: 'eval'
13 | fp16: False
14 | cpu: False
15 | wandb: False
16 | wandb_kwargs:
17 | project: deep-spectral-segmentation
18 |
19 | # Data
20 | data_root: ${env:GANSEG_DATA_SEG_ROOT} # <- REPLACE THIS WITH YOUR DIRECTORY
21 | data:
22 | - name: 'CUB'
23 | images_dir: "${data_root}/CUB_200_2011/test_images"
24 | labels_dir: "${data_root}/CUB_200_2011/test_segmentations"
25 | crop: True
26 | image_size: null
27 | - name: 'DUT_OMRON'
28 | images_dir: "${data_root}/DUT_OMRON/DUT-OMRON-image"
29 | labels_dir: "${data_root}/DUT_OMRON/pixelwiseGT-new-PNG"
30 | crop: False
31 | image_size: null
32 | - name: 'DUTS'
33 | images_dir: "${data_root}/DUTS/DUTS-TE/DUTS-TE-Image"
34 | labels_dir: "${data_root}/DUTS/DUTS-TE/DUTS-TE-Mask"
35 | crop: False
36 | image_size: null
37 | - name: 'ECSSD'
38 | images_dir: "${data_root}/ECSSD/images"
39 | labels_dir: "${data_root}/ECSSD/ground_truth_mask"
40 | crop: False
41 | image_size: null
42 |
43 | dataloader:
44 | batch_size: 128
45 | num_workers: 16
46 |
47 | # Predictions
48 | predictions:
49 | root: "/path/to/object-segmentation-data"
50 | run: "run_name"
51 | downsample: 16 # null
52 |
53 | # The paths to the predictions
54 | CUB: ${predictions.root}/CUB_200_2011/${predictions.run}
55 | DUT_OMRON: ${predictions.root}/DUT_OMRON/${predictions.run}
56 | DUTS: ${predictions.root}/DUTS/${predictions.run}
57 | ECSSD: ${predictions.root}/ECSSD/${predictions.run}
58 |
--------------------------------------------------------------------------------
/semantic-segmentation/eval_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from joblib import Parallel
3 | from joblib.parallel import delayed
4 | from scipy.optimize import linear_sum_assignment
5 |
6 |
7 | def hungarian_match(flat_preds, flat_targets, preds_k, targets_k, metric='acc', n_jobs=16):
8 | assert (preds_k == targets_k) # one to one
9 | num_k = preds_k
10 |
11 | # perform hungarian matching
12 | print('Using iou as metric')
13 | results = Parallel(n_jobs=n_jobs, backend='multiprocessing')(delayed(get_iou)(
14 | flat_preds, flat_targets, c1, c2) for c2 in range(num_k) for c1 in range(num_k))
15 | results = np.array(results)
16 | results = results.reshape((num_k, num_k)).T
17 | match = linear_sum_assignment(flat_targets.shape[0] - results)
18 | match = np.array(list(zip(*match)))
19 | res = []
20 | for out_c, gt_c in match:
21 | res.append((out_c, gt_c))
22 |
23 | return res
24 |
25 |
26 | def majority_vote(flat_preds, flat_targets, preds_k, targets_k, n_jobs=16):
27 | iou_mat = Parallel(n_jobs=n_jobs, backend='multiprocessing')(delayed(get_iou)(
28 | flat_preds, flat_targets, c1, c2) for c2 in range(targets_k) for c1 in range(preds_k))
29 | iou_mat = np.array(iou_mat)
30 | results = iou_mat.reshape((targets_k, preds_k)).T
31 | results = np.argmax(results, axis=1)
32 | match = np.array(list(zip(range(preds_k), results)))
33 | return match
34 |
35 |
36 | def get_iou(flat_preds, flat_targets, c1, c2):
37 | tp = 0
38 | fn = 0
39 | fp = 0
40 | tmp_all_gt = (flat_preds == c1)
41 | tmp_pred = (flat_targets == c2)
42 | tp += np.sum(tmp_all_gt & tmp_pred)
43 | fp += np.sum(~tmp_all_gt & tmp_pred)
44 | fn += np.sum(tmp_all_gt & ~tmp_pred)
45 | jac = float(tp) / max(float(tp + fp + fn), 1e-8)
46 | return jac
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom
2 | .vscode
3 | wandb
4 | outputs
5 | tmp*
6 | slurm-logs
7 | preprocess/extract-features/features_*
8 | preprocess/extract-features/eigensegments_*
9 | preprocess/extract-features/mattes_*
10 | preprocess/extract-features/masks_*
11 | preprocess/extract-features/multilabel_masks_*
12 | preprocess/extract-features/semantic_segmentations_*
13 | preprocess/extract-features/crf_semantic_segmentations_*
14 | old
15 | dino
16 |
17 | # Byte-compiled / optimized / DLL files
18 | __pycache__/
19 | *.py[cod]
20 | *$py.class
21 | .github
22 |
23 | # C extensions
24 | *.so
25 |
26 | # Distribution / packaging
27 | .Python
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 |
45 | # Lightning /research
46 | test_tube_exp/
47 | tests/tests_tt_dir/
48 | tests/save_dir
49 | default/
50 | data/
51 | test_tube_logs/
52 | test_tube_data/
53 | datasets/
54 | model_weights/
55 | tests/save_dir
56 | tests/tests_tt_dir/
57 | processed/
58 | raw/
59 |
60 | # PyInstaller
61 | # Usually these files are written by a python script from a template
62 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
63 | *.manifest
64 | *.spec
65 |
66 | # Installer logs
67 | pip-log.txt
68 | pip-delete-this-directory.txt
69 |
70 | # Unit test / coverage reports
71 | htmlcov/
72 | .tox/
73 | .coverage
74 | .coverage.*
75 | .cache
76 | nosetests.xml
77 | coverage.xml
78 | *.cover
79 | .hypothesis/
80 | .pytest_cache/
81 |
82 | # Translations
83 | *.mo
84 | *.pot
85 |
86 | # Django stuff:
87 | *.log
88 | local_settings.py
89 | db.sqlite3
90 |
91 | # Flask stuff:
92 | instance/
93 | .webassets-cache
94 |
95 | # Scrapy stuff:
96 | .scrapy
97 |
98 | # Sphinx documentation
99 | docs/_build/
100 |
101 | # PyBuilder
102 | target/
103 |
104 | # Jupyter Notebook
105 | .ipynb_checkpoints
106 |
107 | # pyenv
108 | .python-version
109 |
110 | # celery beat schedule file
111 | celerybeat-schedule
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 |
138 | # IDEs
139 | .idea
140 | .vscode
141 |
142 | # seed project
143 | lightning_logs/
144 | MNIST
145 | .DS_Store
146 |
--------------------------------------------------------------------------------
/semantic-segmentation/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | import albumentations as A
2 | import albumentations.pytorch as AP
3 | import cv2
4 | from torch.utils.data._utils.collate import default_collate
5 |
6 | from .voc import VOCSegmentationWithPseudolabels
7 |
8 |
9 | def get_transforms(resize_size, crop_size, img_mean, img_std):
10 |
11 | # Multiple training transforms for contrastive learning
12 | train_joint_transform = A.Compose([
13 | A.SmallestMaxSize(resize_size, interpolation=cv2.INTER_CUBIC),
14 | A.RandomCrop(crop_size, crop_size),
15 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'})
16 | train_geometric_transform = A.ReplayCompose([
17 | A.RandomResizedCrop(crop_size, crop_size, interpolation=cv2.INTER_CUBIC),
18 | A.HorizontalFlip(),
19 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'})
20 | train_separate_transform = A.Compose([
21 | A.ColorJitter(0.4, 0.4, 0.2, 0.1, p=0.8),
22 | A.ToGray(p=0.2), A.GaussianBlur(p=0.1), # A.Solarize(p=0.1)
23 | A.Normalize(mean=img_mean, std=img_std), AP.ToTensorV2(),
24 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'})
25 |
26 | # Validation transform -- no resizing!
27 | val_transform = A.Compose([
28 | # A.Resize(resize_size, resize_size, interpolation=cv2.INTER_CUBIC), A.CenterCrop(crop_size, crop_size),
29 | A.Normalize(mean=img_mean, std=img_std), AP.ToTensorV2()
30 | ], additional_targets={'mask1': 'mask', 'mask2': 'mask'})
31 |
32 | train_transforms_tuple = (train_joint_transform, train_geometric_transform, train_separate_transform)
33 | return train_transforms_tuple, val_transform
34 |
35 |
36 | def collate_fn(batch):
37 | everything_but_metadata = [t[:-1] for t in batch]
38 | metadata = [t[-1] for t in batch]
39 | return (*default_collate(everything_but_metadata), metadata)
40 |
41 |
42 | def get_datasets(cfg):
43 |
44 | # Get transforms
45 | train_transforms_tuple, val_transform = get_transforms(**cfg.data.transform)
46 |
47 | # Get the label map
48 | if cfg.matching:
49 | matching = dict(eval(str(cfg.matching)))
50 | print(f'Using matching: {matching}')
51 | else:
52 | matching = None
53 |
54 | # Training dataset
55 | dataset_train = VOCSegmentationWithPseudolabels(
56 | **cfg.data.train_kwargs,
57 | segments_dir=cfg.segments_dir,
58 | transforms_tuple=train_transforms_tuple,
59 | label_map=matching
60 | )
61 |
62 | # Validation dataset
63 | dataset_val = VOCSegmentationWithPseudolabels(
64 | **cfg.data.val_kwargs,
65 | segments_dir=cfg.segments_dir,
66 | transform=val_transform,
67 | label_map=matching
68 | )
69 |
70 | return dataset_train, dataset_val, collate_fn
71 |
--------------------------------------------------------------------------------
/object-segmentation/dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | from torchvision import transforms as T
6 |
7 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
8 |
9 |
10 | def get_paths_from_folders(images_dir):
11 | """Returns list of files in folders of input"""
12 | paths = []
13 | for folder in Path(images_dir).iterdir():
14 | for p in folder.iterdir():
15 | paths.append(p)
16 | return paths
17 |
18 |
19 | def central_crop(x):
20 | dims = x.size
21 | crop = T.CenterCrop(min(dims[0], dims[1]))
22 | return crop(x)
23 |
24 |
25 | class SegmentationDataset(Dataset):
26 |
27 | def __init__(
28 | self,
29 | images_dir: str,
30 | labels_dir: str,
31 | image_size: Optional[int] = None,
32 | resize_image=True,
33 | resize_mask=None,
34 | crop=True,
35 | mean=[0.5, 0.5, 0.5],
36 | std=[0.5, 0.5, 0.5],
37 | name: Optional[str] = None,
38 | ):
39 | self.name = name
40 | self.crop = crop
41 |
42 | # Find out if dataset is organized into folders or not
43 | has_folders = not any(str(next(Path(images_dir).iterdir())).endswith(ext) for ext in IMG_EXTENSIONS)
44 |
45 | # Get and sort list of paths
46 | if has_folders:
47 | image_paths = get_paths_from_folders(images_dir)
48 | label_paths = get_paths_from_folders(labels_dir)
49 | else:
50 | image_paths = Path(images_dir).iterdir()
51 | label_paths = Path(labels_dir).iterdir()
52 | self.image_paths = list(sorted(image_paths))
53 | self.label_paths = list(sorted(label_paths))
54 | assert len(self.image_paths) == len(self.label_paths)
55 |
56 | # Transformation
57 | resize_image = (image_size is not None) and resize_image
58 | resize_mask = resize_image if resize_mask is None else resize_mask
59 | image_transform = [T.ToTensor(), T.Normalize(mean=mean, std=std)]
60 | mask_transform = [T.ToTensor()]
61 | if resize_image:
62 | image_transform.insert(0, T.Resize(image_size))
63 | if resize_mask:
64 | mask_transform.insert(0, T.Resize(image_size))
65 | if crop:
66 | image_transform.insert(0, central_crop)
67 | mask_transform.insert(0, central_crop)
68 | self.image_transform = T.Compose(image_transform)
69 | self.mask_transform = T.Compose(mask_transform)
70 |
71 | def __len__(self):
72 | return len(self.image_paths)
73 |
74 | def __getitem__(self, idx):
75 |
76 | # Load
77 | image = Image.open(self.image_paths[idx])
78 | mask = Image.open(self.label_paths[idx])
79 | metadata = {'image_file': str(self.image_paths[idx])}
80 |
81 | # Transform
82 | image = image.convert('RGB')
83 | mask = mask.convert('RGB')
84 | image = self.image_transform(image)
85 | mask = self.mask_transform(mask)
86 | mask = (mask > 0.5)[0].long() # TODO: this could be improved
87 | return image, mask, metadata
88 |
--------------------------------------------------------------------------------
/semantic-segmentation/README.md:
--------------------------------------------------------------------------------
1 | ## Semantic Segmentation
2 |
3 | We begin by extracting features and eigenvectors from our images. For instructions on this process, follow the steps in "Extraction" in the main `README`.
4 |
5 | Next, we obtain coarse (i.e. patch-level) semantic segmentations. This process involves (1) extracting segments from the eigenvectors, (2) taking a bounding box around them, (3) extracting features for these boxes, (4) clustering these features, (5) obtaining coarse semantic segmentations.
6 |
7 | For example, you can run the following in the `extract` directory.
8 |
9 | ```bash
10 | # Example parameters for the semantic segmentation experiments
11 | DATASET="VOC2012"
12 | MODEL="dino_vits16"
13 | MATRIX="laplacian"
14 | DOWNSAMPLE=16
15 | N_SEG=15
16 | N_ERODE=2
17 | N_DILATE=5
18 |
19 | # Extract segments
20 | python extract.py extract_multi_region_segmentations \
21 | --non_adaptive_num_segments ${N_SEG} \
22 | --features_dir "./data/${DATASET}/features/${MODEL}" \
23 | --eigs_dir "./data/${DATASET}/eigs/${MATRIX}" \
24 | --output_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}"
25 |
26 | # Extract bounding boxes
27 | python extract.py extract_bboxes \
28 | --features_dir "./data/${DATASET}/features/${MODEL}" \
29 | --segmentations_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" \
30 | --num_erode ${N_ERODE} \
31 | --num_dilate ${N_DILATE} \
32 | --downsample_factor ${DOWNSAMPLE} \
33 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bboxes.pth"
34 |
35 | # Extract bounding box features
36 | python extract.py extract_bbox_features \
37 | --model_name ${MODEL} \
38 | --images_root "./data/${DATASET}/images" \
39 | --bbox_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bboxes.pth" \
40 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_features.pth"
41 |
42 | # Extract clusters
43 | python extract.py extract_bbox_clusters \
44 | --bbox_features_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_features.pth" \
45 | --output_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_clusters.pth"
46 |
47 | # Create semantic segmentations
48 | python extract.py extract_semantic_segmentations \
49 | --segmentations_dir "./data/${DATASET}/multi_region_segmentation/${MATRIX}" \
50 | --bbox_clusters_file "./data/${DATASET}/multi_region_bboxes/${MATRIX}/bbox_clusters.pth" \
51 | --output_dir "./data/${DATASET}/semantic_segmentations/patches/${MATRIX}/segmaps"
52 | ```
53 |
54 | At this point, you can evaluate the segmentations using `eval.py` in this directory. For example:
55 | ```bash
56 | python eval.py segments_dir="/output_dir/from/above"
57 | ```
58 |
59 | Optionally, you can also perform self-training using `train.py`. You can specify the correct matching using `matching="\"[(0, 0), ... (19, 6), (20, 7)]\""`. This matching may be obtained by first evaluating using `python eval.py`. For example:
60 | ```bash
61 | python train.py lr=2e-4 data.loader.batch_size=96 segments_dir="/path/to/segmaps" matching="\"[(0, 0), ... (19, 6), (20, 7)]\""
62 | ```
63 |
64 | Please note that the unsupervised semantic segmentation results have very high variance; some runs are much better than others. This variance is primarily due to the random seeds of the K-means clustering steps above, and it is secondarily due to randomness in the self-training stage. Also please note that this code has been heavily re-factored for its public release. Although we try to ensure that there are no bugs, it is nevertheless possible that there is a bug we have overlooked.
65 |
--------------------------------------------------------------------------------
/object-localization/visualizations.py:
--------------------------------------------------------------------------------
1 | """
2 | Vis utilities. Code adapted from LOST: https://github.com/valeoai/LOST
3 | """
4 | import cv2
5 | import torch
6 | import skimage.io
7 | import numpy as np
8 | import torch.nn as nn
9 | from PIL import Image
10 |
11 | import matplotlib.pyplot as plt
12 |
13 | def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, plot_seed=False):
14 | """
15 | Visualization of the predicted box and the corresponding seed patch.
16 | """
17 | w_featmap, h_featmap = dims
18 |
19 | # Plot the box
20 | cv2.rectangle(
21 | image,
22 | (int(pred[0]), int(pred[1])),
23 | (int(pred[2]), int(pred[3])),
24 | (255, 0, 0), 3,
25 | )
26 |
27 | # Plot the seed
28 | if plot_seed:
29 | s_ = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap))
30 | size_ = np.asarray(scales) / 2
31 | cv2.rectangle(
32 | image,
33 | (int(s_[1] * scales[1] - (size_[1] / 2)), int(s_[0] * scales[0] - (size_[0] / 2))),
34 | (int(s_[1] * scales[1] + (size_[1] / 2)), int(s_[0] * scales[0] + (size_[0] / 2))),
35 | (0, 255, 0), -1,
36 | )
37 |
38 | pltname = f"{vis_folder}/LOST_{im_name}.png"
39 | Image.fromarray(image).save(pltname)
40 | print(f"Predictions saved at {pltname}.")
41 |
42 | def visualize_fms(A, seed, scores, dims, scales, output_folder, im_name):
43 | """
44 | Visualization of the maps presented in Figure 2 of the paper.
45 | """
46 | w_featmap, h_featmap = dims
47 |
48 | # Binarized similarity
49 | binA = A.copy()
50 | binA[binA < 0] = 0
51 | binA[binA > 0] = 1
52 |
53 | # Get binarized correlation for this pixel and make it appear in gray
54 | im_corr = np.zeros((3, len(scores)))
55 | where = binA[seed, :] > 0
56 | im_corr[:, where] = np.array([128 / 255, 133 / 255, 133 / 255]).reshape((3, 1))
57 | # Show selected pixel in green
58 | im_corr[:, seed] = [204 / 255, 37 / 255, 41 / 255]
59 | # Reshape and rescale
60 | im_corr = im_corr.reshape((3, w_featmap, h_featmap))
61 | im_corr = (
62 | nn.functional.interpolate(
63 | torch.from_numpy(im_corr).unsqueeze(0),
64 | scale_factor=scales,
65 | mode="nearest",
66 | )[0].cpu().numpy()
67 | )
68 |
69 | # Save correlations
70 | skimage.io.imsave(
71 | fname=f"{output_folder}/corr_{im_name}.png",
72 | arr=im_corr.transpose((1, 2, 0)),
73 | )
74 | print(f"Image saved at {output_folder}/corr_{im_name}.png .")
75 |
76 | # Save inverse degree
77 | im_deg = (
78 | nn.functional.interpolate(
79 | torch.from_numpy(1 / binA.sum(-1)).reshape(1, 1, w_featmap, h_featmap),
80 | scale_factor=scales,
81 | mode="nearest",
82 | )[0][0].cpu().numpy()
83 | )
84 | plt.imsave(fname=f"{output_folder}/deg_{im_name}.png", arr=im_deg)
85 | print(f"Image saved at {output_folder}/deg_{im_name}.png .")
86 |
87 | def visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims, vis_folder, im_name):
88 | """
89 | Visualization of the seed expansion presented in Figure 3 of the paper.
90 | """
91 | w_featmap, h_featmap = dims
92 |
93 | # Before expansion
94 | cv2.rectangle(
95 | image,
96 | (int(pred_seed[0]), int(pred_seed[1])),
97 | (int(pred_seed[2]), int(pred_seed[3])),
98 | (204, 204, 0), # Yellow
99 | 3,
100 | )
101 |
102 | # After expansion
103 | cv2.rectangle(
104 | image,
105 | (int(pred[0]), int(pred[1])),
106 | (int(pred[2]), int(pred[3])),
107 | (204, 0, 204), # Magenta
108 | 3,
109 | )
110 |
111 | # Position of the seed
112 | center = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap))
113 | start_1 = center[0] * scales[0]
114 | end_1 = center[0] * scales[0] + scales[0]
115 | start_2 = center[1] * scales[1]
116 | end_2 = center[1] * scales[1] + scales[1]
117 | image[start_1:end_1, start_2:end_2, 0] = 204
118 | image[start_1:end_1, start_2:end_2, 1] = 37
119 | image[start_1:end_1, start_2:end_2, 2] = 41
120 |
121 | pltname = f"{vis_folder}/LOST_seed_expansion_{im_name}.png"
122 | Image.fromarray(image).save(pltname)
123 | print(f"Image saved at {pltname}.")
124 |
--------------------------------------------------------------------------------
/object-localization/networks.py:
--------------------------------------------------------------------------------
1 | """
2 | Loads model. Depends on DINO repo.
3 | Code adapted from LOST: https://github.com/valeoai/LOST
4 | """
5 | import torch
6 | import torch.nn as nn
7 | from torchvision.models.resnet import resnet50
8 | from torchvision.models.vgg import vgg16
9 |
10 | import dino.vision_transformer as vits
11 |
12 |
13 | def get_model(arch, patch_size, resnet_dilate, device):
14 | if "resnet" in arch:
15 | if resnet_dilate == 1:
16 | replace_stride_with_dilation = [False, False, False]
17 | elif resnet_dilate == 2:
18 | replace_stride_with_dilation = [False, False, True]
19 | elif resnet_dilate == 4:
20 | replace_stride_with_dilation = [False, True, True]
21 |
22 | if "imagenet" in arch:
23 | model = resnet50(
24 | pretrained=True,
25 | replace_stride_with_dilation=replace_stride_with_dilation,
26 | )
27 | else:
28 | model = resnet50(
29 | pretrained=False,
30 | replace_stride_with_dilation=replace_stride_with_dilation,
31 | )
32 | elif "vgg16" in arch:
33 | if "imagenet" in arch:
34 | model = vgg16(pretrained=True)
35 | else:
36 | model = vgg16(pretrained=False)
37 | else:
38 | model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
39 |
40 | for p in model.parameters():
41 | p.requires_grad = False
42 |
43 | # Initialize model with pretraining
44 | if "imagenet" not in arch:
45 | url = None
46 | if arch == "vit_small" and patch_size == 16:
47 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
48 | elif arch == "vit_small" and patch_size == 8:
49 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
50 | elif arch == "vit_base" and patch_size == 16:
51 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
52 | elif arch == "vit_base" and patch_size == 8:
53 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
54 | elif arch == "resnet50":
55 | url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
56 | if url is not None:
57 | print(
58 | "Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
59 | )
60 | state_dict = torch.hub.load_state_dict_from_url(
61 | url="https://dl.fbaipublicfiles.com/dino/" + url
62 | )
63 | strict_loading = False if "resnet" in arch else True
64 | msg = model.load_state_dict(state_dict, strict=strict_loading)
65 | print(
66 | "Pretrained weights found at {} and loaded with msg: {}".format(
67 | url, msg
68 | )
69 | )
70 | else:
71 | print(
72 | "There is no reference weights available for this model => We use random weights."
73 | )
74 |
75 | # If ResNet or VGG16 loose the last fully connected layer
76 | if "resnet" in arch:
77 | model = ResNet50Bottom(model)
78 | elif "vgg16" in arch:
79 | model = vgg16Bottom(model)
80 |
81 | model.eval()
82 | model.to(device)
83 | return model
84 |
85 |
86 | class ResNet50Bottom(nn.Module):
87 | # https://forums.fast.ai/t/pytorch-best-way-to-get-at-intermediate-layers-in-vgg-and-resnet/5707/2
88 | def __init__(self, original_model):
89 | super(ResNet50Bottom, self).__init__()
90 | # Remove avgpool and fc layers
91 | self.features = nn.Sequential(*list(original_model.children())[:-2])
92 |
93 | def forward(self, x):
94 | x = self.features(x)
95 | return x
96 |
97 |
98 | class vgg16Bottom(nn.Module):
99 | # https://forums.fast.ai/t/pytorch-best-way-to-get-at-intermediate-layers-in-vgg-and-resnet/5707/2
100 | def __init__(self, original_model):
101 | super(vgg16Bottom, self).__init__()
102 | # Remove avgpool and the classifier
103 | self.features = nn.Sequential(*list(original_model.children())[:-2])
104 | # Remove the last maxPool2d
105 | self.features = nn.Sequential(*list(self.features[0][:-1]))
106 |
107 | def forward(self, x):
108 | x = self.features(x)
109 | return x
110 |
--------------------------------------------------------------------------------
/semantic-segmentation/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from torchvision.models._utils import IntermediateLayerGetter
5 | from torchvision.models.segmentation.deeplabv3 import (ASPP, DeepLabHead, DeepLabV3)
6 |
7 |
8 | def get_deeplab_resnet(num_classes: int, name: str = 'deeplabv3plus', output_stride: int = 8):
9 |
10 | if output_stride == 8:
11 | replace_stride_with_dilation = [False, True, True]
12 | aspp_dilate = [12, 24, 36]
13 | elif output_stride == 16:
14 | replace_stride_with_dilation = [False, False, True]
15 | aspp_dilate = [6, 12, 18]
16 | else:
17 | raise NotImplementedError()
18 |
19 | backbone = torch.hub.load(
20 | 'facebookresearch/dino:main',
21 | 'dino_resnet50',
22 | replace_stride_with_dilation=replace_stride_with_dilation
23 | )
24 |
25 | inplanes = 2048
26 | low_level_planes = 256
27 |
28 | if name == 'deeplabv3plus':
29 | return_layers = {'layer4': 'out', 'layer1': 'low_level'}
30 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
31 | DeepLab = DeepLabV3Plus
32 | elif name == 'deeplabv3':
33 | return_layers = {'layer4': 'out'}
34 | DeepLab = DeepLabV3
35 | classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
36 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
37 |
38 | model = DeepLab(backbone, classifier)
39 | return model
40 |
41 |
42 | def get_deeplab_vit(num_classes: int, backbone_name: str = 'vits16', name: str = 'deeplabv3plus'):
43 |
44 | # Backbone
45 | backbone = torch.hub.load('facebookresearch/dino:main', f'dino_{backbone_name}')
46 |
47 | # Classifier
48 | aspp_dilate = [12, 24, 36]
49 | inplanes = low_level_planes = backbone.embed_dim
50 | if name == 'deeplabv3plus':
51 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
52 | DeepLab = DeepLabV3Plus
53 | elif name == 'deeplabv3':
54 | DeepLab = DeepLabV3
55 | classifier = DeepLabHead(inplanes, num_classes, aspp_dilate)
56 |
57 | # Wrap
58 | backbone = VisionTransformerWrapper(backbone)
59 | model = DeepLab(backbone, classifier)
60 | return model
61 |
62 |
63 | class VisionTransformerWrapper(nn.Module):
64 | def __init__(self, backbone):
65 | super().__init__()
66 | self.backbone = backbone
67 |
68 | def forward(self, x):
69 | # Forward
70 | output = self.backbone.get_intermediate_layers(x, n=5)
71 | # Reshaping
72 | assert (len(output) == 5), f'{output.shape=}'
73 | H_patch = x.shape[-2] // self.backbone.patch_embed.patch_size
74 | W_patch = x.shape[-1] // self.backbone.patch_embed.patch_size
75 | out_ll = output[0][:, 1:, :].transpose(-2, -1).unflatten(-1, (H_patch, W_patch))
76 | out = output[-1][:, 1:, :].transpose(-2, -1).unflatten(-1, (H_patch, W_patch))
77 | return {'low_level': out_ll, 'out': out}
78 |
79 |
80 | class DeepLabHeadV3Plus(nn.Module):
81 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
82 | super(DeepLabHeadV3Plus, self).__init__()
83 | self.project = nn.Sequential(
84 | nn.Conv2d(low_level_channels, 48, 1, bias=False),
85 | nn.BatchNorm2d(48),
86 | nn.ReLU(inplace=True),
87 | )
88 |
89 | self.aspp = ASPP(in_channels, aspp_dilate)
90 |
91 | self.classifier = nn.Sequential(
92 | nn.Conv2d(304, 256, 3, padding=1, bias=False),
93 | nn.BatchNorm2d(256),
94 | nn.ReLU(inplace=True),
95 | nn.Conv2d(256, num_classes, 1)
96 | )
97 | self._init_weight()
98 |
99 | def forward(self, feature):
100 | low_level_feature = self.project(feature['low_level'])
101 | output_feature = self.aspp(feature['out'])
102 | output_feature = F.interpolate(
103 | output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
104 | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
105 |
106 | def _init_weight(self):
107 | for m in self.modules():
108 | if isinstance(m, nn.Conv2d):
109 | nn.init.kaiming_normal_(m.weight)
110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111 | nn.init.constant_(m.weight, 1)
112 | nn.init.constant_(m.bias, 0)
113 |
114 |
115 | class DeepLabV3Plus(nn.Module):
116 | def __init__(self, backbone, classifier):
117 | super().__init__()
118 | self.backbone = backbone
119 | self.classifier = classifier
120 |
121 | def forward(self, x):
122 | input_shape = x.shape[-2:]
123 | features = self.backbone(x)
124 | x = self.classifier(features)
125 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
126 | return x
127 |
--------------------------------------------------------------------------------
/object-segmentation/metrics.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 | import numpy as np
4 |
5 |
6 | @torch.no_grad()
7 | def compute_metrics(preds, targets, metrics=['f_max', 'acc', 'iou'], threshold=0.5, swap_dims=False, preds_are_soft=False):
8 |
9 | # Move to CPU
10 | preds = preds.detach() # .cpu()
11 | targets = targets.detach() # .cpu()
12 | assert len(targets.shape) == 3
13 | if preds_are_soft:
14 | assert len(preds.shape) == 4
15 | soft_preds = torch.softmax(preds, dim=1)[:, (0 if swap_dims else 1)] # convert to probabilities
16 | hard_preds = soft_preds > threshold
17 | else:
18 | assert 'f_max' not in metrics, 'must have soft preds for f_max'
19 | assert (len(preds.shape) == 3)
20 | assert (preds.dtype == torch.bool) or (preds.dtype == torch.uint8) or (preds.dtype == torch.long)
21 | assert (preds.max() <= 1) and (preds.min() >= 0)
22 | soft_preds = [None] * len(preds)
23 | hard_preds = preds.bool()
24 |
25 | # Compute
26 | results = defaultdict(list)
27 | for soft_pred, hard_pred, target in zip(soft_preds, hard_preds, targets):
28 | if 'f_max' in metrics:
29 | precision, recall = compute_prs(soft_pred, target, prob_bins=255)
30 | results['f_max_precision'].append(precision)
31 | results['f_max_recall'].append(recall)
32 | if 'f_beta' in metrics:
33 | precision, recall = precision_recall(target, hard_preds)
34 | results['f_beta_precision'].append([precision])
35 | results['f_beta_recall'].append([recall])
36 | if 'acc' in metrics:
37 | acc = compute_accuracy(hard_pred, target)
38 | results['acc'].append(acc)
39 | if 'iou' in metrics:
40 | iou = compute_iou(hard_pred, target)
41 | results['iou'].append(iou)
42 | return dict(results)
43 |
44 |
45 | @torch.no_grad()
46 | def aggregate_metrics(totals):
47 | results = defaultdict(list)
48 | if 'acc' in totals:
49 | results['acc'] = mean(totals['acc'])
50 | if 'iou' in totals:
51 | results['iou'] = mean(totals['iou'])
52 | if 'loss' in totals:
53 | results['loss'] = mean(totals['loss'])
54 | if 'f_max_precision' in totals and 'f_max_recall' in totals:
55 | precisions = torch.tensor(totals['f_max_precision'])
56 | recalls = torch.tensor(totals['f_max_recall'])
57 | results['f_max'] = F_max(precisions, recalls)
58 | if 'f_beta_precision' in totals and 'f_beta_recall' in totals:
59 | precisions = torch.tensor(totals['f_beta_precision'])
60 | recalls = torch.tensor(totals['f_beta_recall'])
61 | results['f_beta'] = F_max(precisions, recalls)
62 | return results
63 |
64 |
65 | def compute_accuracy(pred, target):
66 | pred, target = pred.to(torch.bool), target.to(torch.bool)
67 | return torch.mean((pred == target).to(torch.float)).item()
68 |
69 |
70 | def compute_iou(pred, target):
71 | pred, target = pred.to(torch.bool), target.to(torch.bool)
72 | intersection = torch.sum(pred * (pred == target), dim=[-1, -2]).squeeze()
73 | union = torch.sum(pred + target, dim=[-1, -2]).squeeze()
74 | iou = (intersection.to(torch.float) / union).mean()
75 | iou = iou.item() if (iou == iou) else 0 # deal with nans, i.e. torch.nan_to_num(iou, nan=0.0)
76 | return iou
77 |
78 |
79 | def compute_prs(pred, target, prob_bins=255):
80 | p = []
81 | r = []
82 | for split in np.arange(0.0, 1.0, 1.0 / prob_bins):
83 | if split == 0.0:
84 | continue
85 | pr = precision_recall(target, pred > split)
86 | p.append(pr[0])
87 | r.append(pr[1])
88 | return p, r
89 |
90 |
91 | def precision_recall(mask_gt, mask):
92 | mask_gt, mask = mask_gt.to(torch.bool), mask.to(torch.bool)
93 | true_positive = torch.sum(mask_gt * (mask_gt == mask), dim=[-1, -2]).squeeze()
94 | mask_area = torch.sum(mask, dim=[-1, -2]).to(torch.float)
95 | mask_gt_area = torch.sum(mask_gt, dim=[-1, -2]).to(torch.float)
96 | precision = true_positive / mask_area
97 | precision[mask_area == 0.0] = 1.0
98 | recall = true_positive / mask_gt_area
99 | recall[mask_gt_area == 0.0] = 1.0
100 | return precision.item(), recall.item()
101 |
102 |
103 | def F_scores(p, r, betta_sq=0.3):
104 | f_scores = ((1 + betta_sq) * p * r) / (betta_sq * p + r)
105 | f_scores[f_scores != f_scores] = 0.0 # handle nans
106 | return f_scores
107 |
108 |
109 | def F_max(precisions, recalls, betta_sq=0.3):
110 | f_scores = F_scores(precisions, recalls, betta_sq)
111 | f_scores = f_scores.mean(dim=0)
112 | # print('f_scores.shape: ', f_scores.shape)
113 | # print('torch.argmax(f_scores): ', torch.argmax(f_scores))
114 | return f_scores.max().item()
115 |
116 |
117 | def mean(x):
118 | return sum(x) / len(x)
119 |
120 |
121 | def list_of_dicts_to_dict_of_lists(LD):
122 | return {k: [dic[k] for dic in LD] for k in LD[0]}
123 |
124 |
125 | def list_of_dict_of_lists_to_dict_of_lists(LD):
126 | return {k: [v for dic in LD for v in dic[k]] for k in LD[0]}
127 |
128 |
129 | def dict_of_lists_to_list_of_dicts(DL):
130 | return [dict(zip(DL, t)) for t in zip(*DL.values())]
131 |
--------------------------------------------------------------------------------
/object-segmentation/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 | import sys
4 | import math
5 | import datetime
6 | import torch
7 | import hydra
8 | import wandb
9 | import cv2
10 | import numpy as np
11 | from contextlib import nullcontext
12 | from collections import namedtuple
13 | from pathlib import Path
14 | from typing import Callable, Iterable, Optional
15 | from PIL import Image
16 | from torch.nn import functional as F
17 | from torch.utils.data import DataLoader
18 | from torchvision.transforms import functional as TF
19 | from accelerate import Accelerator
20 | from omegaconf import OmegaConf, DictConfig
21 | from tqdm import tqdm
22 |
23 | import metrics
24 | import util as utils
25 | from dataset import SegmentationDataset, central_crop
26 |
27 |
28 | @hydra.main(config_path='config', config_name='eval')
29 | def main(cfg: DictConfig):
30 |
31 | # Accelerator
32 | accelerator = Accelerator(fp16=cfg.fp16, cpu=cfg.cpu)
33 |
34 | # Logging
35 | utils.setup_distributed_print(accelerator.is_local_main_process)
36 | if cfg.wandb and accelerator.is_local_main_process:
37 | wandb.init(name=cfg.name, job_type=cfg.job_type, config=OmegaConf.to_container(cfg), save_code=True, **cfg.wandb_kwargs)
38 | cfg = DictConfig(wandb.config.as_dict()) # get the config back from wandb for hyperparameter sweeps
39 |
40 | # Configuration
41 | print(OmegaConf.to_yaml(cfg))
42 | print(f'Current working directory: {os.getcwd()}')
43 |
44 | # Set random seed
45 | utils.set_seed(cfg.seed)
46 |
47 | # Datasets
48 | # NOTE: The batch size must be 1 for test because the masks are different sizes,
49 | # and evaluation should be done using the mask at the original resolution.
50 | test_dataloaders = []
51 | for data_cfg in cfg.data:
52 | test_dataset = SegmentationDataset(**data_cfg)
53 | test_dataloader = DataLoader(test_dataset, **{**cfg.dataloader, 'batch_size': 1})
54 | test_dataloaders.append(test_dataloader)
55 |
56 | # Evaluation
57 | if cfg.job_type == 'eval':
58 | for dataloader in test_dataloaders:
59 | evaluate_predictions(cfg=cfg, accelerator=accelerator, dataloader_val=dataloader)
60 | else:
61 | raise NotImplementedError()
62 |
63 |
64 | @torch.no_grad()
65 | def evaluate_predictions(
66 | *,
67 | cfg: DictConfig,
68 | dataloader_val: Iterable,
69 | accelerator: Accelerator,
70 | **_unused_kwargs):
71 |
72 | # Evaluate
73 | name = dataloader_val.dataset.name
74 | all_results = []
75 | pbar = tqdm(dataloader_val, desc=f'Evaluating {name}')
76 | for i, (images, targets, metadatas) in enumerate(pbar):
77 |
78 | # Convert
79 | targets = targets.squeeze(0)
80 |
81 | # Load predictions
82 | id = Path(metadatas['image_file'][0]).stem
83 | predictions_file = os.path.join(cfg.predictions[name], f'{id}.png')
84 | preds = np.array(Image.open(predictions_file).convert('L')) # (H_patch, W_patch)
85 | assert set(np.unique(preds).tolist()) in [{0, 255}, {0, 1}, {0}], set(np.unique(preds).tolist())
86 | preds[preds == 255] = 1
87 |
88 | # Resize if segmentation is patchwise
89 | if cfg.predictions.downsample is not None:
90 | H, W = targets.shape
91 | H_pred, W_pred = preds.shape
92 | H_pad, W_pad = H_pred * cfg.predictions.downsample, W_pred * cfg.predictions.downsample
93 | H_max, W_max = max(H_pad, H), max(W_pad, W)
94 | preds = cv2.resize(preds, dsize=(W_max, H_max), interpolation=cv2.INTER_NEAREST)
95 | preds[:H_pad, :W_pad] = cv2.resize(preds, dsize=(W_pad, H_pad), interpolation=cv2.INTER_NEAREST)
96 |
97 | # Convert, optional center crop, and unsqueeze
98 | preds = torch.from_numpy(preds)
99 | if dataloader_val.dataset.crop:
100 | preds = TF.center_crop(preds, output_size=min(preds.shape))
101 | preds = torch.unsqueeze(preds, dim=0)
102 | targets = torch.unsqueeze(targets, dim=0)
103 |
104 | # Compute metrics
105 | results = metrics.compute_metrics(preds=preds, targets=targets, metrics=['acc', 'iou'])
106 | all_results.append(results)
107 |
108 | # Aggregate results
109 | all_results = metrics.list_of_dict_of_lists_to_dict_of_lists(all_results)
110 | results = metrics.aggregate_metrics(all_results)
111 | for metric_name, value in results.items():
112 | print(f'[{name}] {metric_name}: {value}')
113 |
114 |
115 | @torch.no_grad()
116 | def visualize(
117 | *,
118 | cfg: DictConfig,
119 | model: torch.nn.Module,
120 | dataloader_vis: Iterable,
121 | accelerator: Accelerator,
122 | identifier: str = '',
123 | num_batches: Optional[int] = None,
124 | **_unused_kwargs):
125 |
126 | # Eval mode
127 | model.eval()
128 | metric_logger = utils.MetricLogger(delimiter=" ")
129 | progress_bar = metric_logger.log_every(dataloader_vis, cfg.logging.print_freq, "Vis")
130 |
131 | # Visualize
132 | for batch_idx, (inputs, target) in enumerate(progress_bar):
133 | if num_batches is not None and batch_idx >= num_batches:
134 | break
135 |
136 | # Inverse normalization
137 | Inv = utils.NormalizeInverse(mean=cfg.data.transform.img_mean, std=cfg.data.transform.img_std)
138 | image = Inv(inputs).clamp(0, 1)
139 | vis_dict = dict(image=image)
140 |
141 | # Save images
142 | wandb_log_dict = {}
143 | for name, images in vis_dict.items():
144 | for i, image in enumerate(images):
145 | pil_image = utils.tensor_to_pil(image)
146 | filename = f'vis/{name}/p-{accelerator.process_index}-b-{batch_idx}-img-{i}-{name}-{identifier}.png'
147 | Path(filename).parent.mkdir(exist_ok=True, parents=True)
148 | pil_image.save(filename)
149 | if i < 2: # log to Weights and Biases
150 | wandb_filename = f'vis/{name}/p-{accelerator.process_index}-b-{batch_idx}-img-{i}-{name}'
151 | wandb_log_dict[wandb_filename] = [wandb.Image(pil_image)]
152 | if cfg.wandb and accelerator.is_local_main_process:
153 | wandb.log(wandb_log_dict, commit=False)
154 | print(f'Saved visualizations to {Path("vis").absolute()}')
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
--------------------------------------------------------------------------------
/semantic-segmentation/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Iterable, Optional
4 |
5 | import hydra
6 | import numpy as np
7 | import torch
8 | import wandb
9 | from accelerate import Accelerator
10 | from matplotlib.cm import get_cmap
11 | from omegaconf import DictConfig, OmegaConf
12 | from PIL import Image
13 | from skimage.color import label2rgb
14 | from tqdm import tqdm, trange
15 |
16 | import eval_utils
17 | import util as utils
18 | from dataset.voc import VOCSegmentationWithPseudolabels
19 |
20 |
21 | @hydra.main(config_path='config', config_name='eval')
22 | def main(cfg: DictConfig):
23 |
24 | # Accelerator
25 | accelerator = Accelerator(fp16=cfg.fp16, cpu=cfg.cpu)
26 |
27 | # Logging
28 | utils.setup_distributed_print(accelerator.is_local_main_process)
29 | if cfg.wandb and accelerator.is_local_main_process:
30 | wandb.init(name=cfg.name, job_type=cfg.job_type, config=OmegaConf.to_container(cfg), save_code=True, **cfg.wandb_kwargs)
31 | cfg = DictConfig(wandb.config.as_dict()) # get the config back from wandb for hyperparameter sweeps
32 |
33 | # Configuration
34 | print(OmegaConf.to_yaml(cfg))
35 | print(f'Current working directory: {os.getcwd()}')
36 |
37 | # Set random seed
38 | utils.set_seed(cfg.seed)
39 |
40 | # Create dataset with segments/pseudolabels
41 | dataset_val = VOCSegmentationWithPseudolabels(
42 | **cfg.data.val_kwargs,
43 | segments_dir=cfg.segments_dir,
44 | transform=None, # no transform to evaluate at original resolution
45 | )
46 |
47 | # Evaluate
48 | eval_stats, match = evaluate(cfg=cfg, dataset_val=dataset_val, n_clusters=cfg.get('n_clusters', None))
49 | print(eval_stats)
50 | if cfg.wandb and accelerator.is_local_main_process:
51 | wandb.summary['mIoU'] = eval_stats['mIoU']
52 |
53 | # Visualize
54 | visualize(cfg=cfg, dataset_val=dataset_val)
55 |
56 |
57 | def visualize(
58 | *,
59 | cfg: DictConfig,
60 | dataset_val: Iterable,
61 | vis_dir: str = './vis'):
62 |
63 | # Visualize
64 | num_vis = 40
65 | vis_dir = Path(vis_dir)
66 | colors = get_cmap('tab20', cfg.data.num_classes + 1).colors[:,:3]
67 | pbar = tqdm(dataset_val, total=num_vis, desc='Saving visualizations: ')
68 | for i, (image, target, mask, metadata) in enumerate(pbar):
69 | if i >= num_vis: break
70 | image = np.array(image)
71 | target = np.array(target)
72 | target[target == 255] = 0 # set the "unknown" regions to background for visualization
73 | # Overlay mask on image
74 | image_pred_overlay = label2rgb(label=mask, image=image, colors=colors[np.unique(target)[1:]], bg_label=0, alpha=0.45)
75 | image_target_overlay = label2rgb(label=target, image=image, colors=colors[np.unique(target)[1:]], bg_label=0, alpha=0.45)
76 | # Save
77 | image_id = metadata["id"]
78 | path_pred = vis_dir / 'pred' / f'{image_id}-pred.png'
79 | path_target = vis_dir / 'target' / f'{image_id}-target.png'
80 | path_pred.parent.mkdir(exist_ok=True, parents=True)
81 | path_target.parent.mkdir(exist_ok=True, parents=True)
82 | Image.fromarray((image_pred_overlay * 255).astype(np.uint8)).save(str(path_pred))
83 | Image.fromarray((image_target_overlay * 255).astype(np.uint8)).save(str(path_target))
84 | print(f'Saved visualizations to {vis_dir.absolute()}')
85 |
86 |
87 | def evaluate(
88 | *,
89 | cfg: DictConfig,
90 | dataset_val: Iterable,
91 | n_clusters: Optional[int] = None):
92 |
93 | # Add background class
94 | n_classes = cfg.data.num_classes + 1
95 | if n_clusters is None:
96 | n_clusters = n_classes
97 |
98 | # Iterate
99 | tp = [0] * n_classes
100 | fp = [0] * n_classes
101 | fn = [0] * n_classes
102 |
103 | # Load all pixel embeddings
104 | all_preds = np.zeros((len(dataset_val) * 500 * 500), dtype=np.float32)
105 | all_gt = np.zeros((len(dataset_val) * 500 * 500), dtype=np.float32)
106 | offset_ = 0
107 |
108 | # Add all pixels to our arrays
109 | _alread_warned = 0
110 | for i in trange(len(dataset_val), desc='Concatenating all predictions'):
111 | image, target, mask, metadata = dataset_val[i]
112 | # Check where ground-truth is valid and append valid pixels to the array
113 | valid = (target != 255)
114 | n_valid = np.sum(valid)
115 | all_gt[offset_:offset_+n_valid] = target[valid]
116 | # Append the predicted targets in the array
117 | all_preds[offset_:offset_+n_valid, ] = mask[valid]
118 | all_gt[offset_:offset_+n_valid, ] = target[valid]
119 | # Update offset_
120 | offset_ += n_valid
121 |
122 | # Truncate to the actual number of pixels
123 | all_preds = all_preds[:offset_, ]
124 | all_gt = all_gt[:offset_, ]
125 |
126 | # Do hungarian matching
127 | num_elems = offset_
128 | if n_clusters == n_classes:
129 | print('Using hungarian algorithm for matching')
130 | match = eval_utils.hungarian_match(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes, metric='iou')
131 | else:
132 | print('Using majority voting for matching')
133 | match = eval_utils.majority_vote(all_preds, all_gt, preds_k=n_clusters, targets_k=n_classes)
134 | print(f'Optimal matching: {match}')
135 |
136 | # Remap predictions
137 | reordered_preds = np.zeros(num_elems, dtype=all_preds.dtype)
138 | for pred_i, target_i in match:
139 | reordered_preds[all_preds == int(pred_i)] = int(target_i)
140 |
141 | # TP, FP, and FN evaluation
142 | for i_part in range(0, n_classes):
143 | tmp_all_gt = (all_gt == i_part)
144 | tmp_pred = (reordered_preds == i_part)
145 | tp[i_part] += np.sum(tmp_all_gt & tmp_pred)
146 | fp[i_part] += np.sum(~tmp_all_gt & tmp_pred)
147 | fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred)
148 |
149 | # Calculate Jaccard index
150 | jac = [0] * n_classes
151 | for i_part in range(0, n_classes):
152 | jac[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8)
153 |
154 | # Print results
155 | eval_result = dict()
156 | eval_result['jaccards_all_categs'] = jac
157 | eval_result['mIoU'] = np.mean(jac)
158 | print('Evaluation of semantic segmentation ')
159 | print('mIoU is %.2f' % (100*eval_result['mIoU']))
160 | return eval_result, match
161 |
162 |
163 | if __name__ == '__main__':
164 | torch.set_grad_enabled(False)
165 | main()
166 |
--------------------------------------------------------------------------------
/semantic-segmentation/dataset/voc.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 | from typing import Any, Callable, Dict, List, Optional, Tuple
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 | from torchvision.datasets.voc import (DATASET_YEAR_DICT, VisionDataset, os, verify_str_arg)
10 |
11 |
12 | def _resize_pseudolabel(pseudolabel, img):
13 | if (
14 | (pseudolabel.shape[0] == img.shape[0] // 16) or
15 | (pseudolabel.shape[0] == img.shape[0] // 8) or
16 | (pseudolabel.shape[0] == 2 * (img.shape[0] // 16))
17 | ):
18 | return cv2.resize(pseudolabel, dsize=img.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
19 | return pseudolabel
20 |
21 |
22 | class VOCSegmentationWithPseudolabelsBase(VisionDataset):
23 |
24 | _SPLITS_DIR = "Segmentation"
25 | _TARGET_DIR = "SegmentationClass"
26 | _TARGET_FILE_EXT = ".png"
27 |
28 | def __init__(
29 | self,
30 | root: str,
31 | year: str = "2012",
32 | image_set: str = "train",
33 | download: bool = False,
34 | transform: Optional[Callable] = None,
35 | target_transform: Optional[Callable] = None,
36 | transforms: Optional[Callable] = None,
37 | ):
38 | super().__init__(root, transforms, transform, target_transform)
39 | if year == "2007-test":
40 | if image_set == "test":
41 | warnings.warn(
42 | "Acessing the test image set of the year 2007 with year='2007-test' is deprecated. "
43 | "Please use the combination year='2007' and image_set='test' instead."
44 | )
45 | year = "2007"
46 | else:
47 | raise ValueError(
48 | "In the test image set of the year 2007 only image_set='test' is allowed. "
49 | "For all other image sets use year='2007' instead."
50 | )
51 | self.year = year
52 |
53 | valid_image_sets = ["train", "trainval", "val"]
54 | if year == "2007":
55 | valid_image_sets.append("test")
56 | self.image_set = verify_str_arg(image_set, "image_set", valid_image_sets)
57 |
58 | key = "2007-test" if year == "2007" and image_set == "test" else year
59 | dataset_year_dict = DATASET_YEAR_DICT[key]
60 |
61 | self.url = dataset_year_dict["url"]
62 | self.filename = dataset_year_dict["filename"]
63 | self.md5 = dataset_year_dict["md5"]
64 |
65 | base_dir = dataset_year_dict["base_dir"]
66 | voc_root = os.path.join(self.root, base_dir)
67 |
68 | if download:
69 | from torchvision.datasets.voc import download_and_extract_archive
70 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5)
71 |
72 | if not os.path.isdir(voc_root):
73 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
74 |
75 | splits_dir = os.path.join(voc_root, "ImageSets", self._SPLITS_DIR)
76 | split_f = os.path.join(splits_dir, image_set.rstrip("\n") + ".txt")
77 |
78 | if self.image_set == 'train': # everything except val
79 | image_dir = os.path.join(voc_root, "JPEGImages")
80 | with open(os.path.join(splits_dir, "val.txt"), "r") as f:
81 | val_file_stems = set([stem.strip() for stem in f.readlines()])
82 | all_image_paths = [p for p in Path(image_dir).iterdir()]
83 | train_image_paths = [str(p) for p in all_image_paths if p.stem not in val_file_stems]
84 | self.images = sorted(train_image_paths)
85 | # For the targets, we will just replicate the same target however many times
86 | target_dir = os.path.join(voc_root, self._TARGET_DIR)
87 | self.targets = [str(next(Path(target_dir).iterdir()))] * len(self.images)
88 |
89 | else:
90 |
91 | with open(os.path.join(split_f), "r") as f:
92 | file_names = [x.strip() for x in f.readlines()]
93 |
94 | image_dir = os.path.join(voc_root, "JPEGImages")
95 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
96 |
97 | target_dir = os.path.join(voc_root, self._TARGET_DIR)
98 | self.targets = [os.path.join(target_dir, x + self._TARGET_FILE_EXT) for x in file_names]
99 |
100 | assert len(self.images) == len(self.targets), ( len(self.images), len(self.targets))
101 |
102 | @property
103 | def masks(self) -> List[str]:
104 | return self.targets
105 |
106 | def _prepare_label_map(self, label_map):
107 | if label_map is not None:
108 | self.label_map_fn = np.vectorize(label_map.__getitem__)
109 | else:
110 | self.label_map_fn = None
111 |
112 | def _prepare_segments_dir(self, segments_dir):
113 | self.segments_dir = segments_dir
114 | # Get segment and image files, which are assumed to be in correspondence
115 | all_segment_files = sorted(map(str, Path(segments_dir).iterdir()))
116 | all_img_files = sorted(Path(self.images[0]).parent.iterdir())
117 | assert len(all_img_files) == len(all_segment_files), (len(all_img_files), len(all_segment_files))
118 | # Create mapping because I named the segment files badly (sequentially instead of by image id)
119 | all_img_stems = [p.stem for p in all_img_files]
120 | valid_img_stems = set([Path(p).stem for p in self.images]) # in our split (e.g. 'val')
121 | segment_files = []
122 | for i in range(len(all_img_stems)):
123 | if all_img_stems[i] in valid_img_stems:
124 | segment_files.append(all_segment_files[i])
125 | self.segments = segment_files
126 | assert len(self.segments) == len(self.images), f'{len(self.segments)=} and {len(self.images)=}'
127 | print('Loaded segments and images')
128 | print(f'First image filepath: {self.images[0]}')
129 | print(f'First segmap filepath: {self.segments[0]}')
130 | print(f'Last image filepath: {self.images[-1]}')
131 | print(f'Last segmap filepath: {self.segments[-1]}')
132 |
133 | def _load(self, index: int):
134 | # Load image
135 | img = np.array(Image.open(self.images[index]).convert("RGB"))
136 | target = np.array(Image.open(self.masks[index]))
137 | metadata = {'id': Path(self.images[index]).stem, 'path': self.images[index], 'shape': tuple(img.shape[:2])}
138 | # New: load segmap and accompanying metedata
139 | pseudolabel = np.array(Image.open(self.segments[index]))
140 | pseudolabel = _resize_pseudolabel(pseudolabel, img) # HACK HACK HACK
141 | if self.label_map_fn is not None:
142 | pseudolabel = self.label_map_fn(pseudolabel)
143 | return (img, target, pseudolabel, metadata)
144 |
145 | def __len__(self) -> int:
146 | return len(self.images)
147 |
148 |
149 | class VOCSegmentationWithPseudolabels(VOCSegmentationWithPseudolabelsBase):
150 |
151 | def __init__(self, *args, segments_dir, transform = None, label_map = None, **kwargs):
152 | super().__init__(*args, **kwargs)
153 | self._prepare_segments_dir(segments_dir)
154 | self.transform = transform
155 | self._prepare_label_map(label_map)
156 |
157 | def __getitem__(self, index: int):
158 | img, target, pseudolabel, metadata = self._load(index)
159 | if self.transform is not None:
160 | # Transform
161 | data = self.transform(image=img, mask1=target, mask2=pseudolabel)
162 | # Unpack
163 | img, target, pseudolabel = data['image'], data['mask1'], data['mask2']
164 | if torch.is_tensor(target):
165 | target = target.long()
166 | if torch.is_tensor(pseudolabel):
167 | pseudolabel = pseudolabel.long()
168 | return img, target, pseudolabel, metadata
169 |
--------------------------------------------------------------------------------
/extract/extract_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | from multiprocessing import Pool
4 | from pathlib import Path
5 | from typing import Any, Callable, Iterable, Optional, Tuple, Union
6 |
7 | import cv2
8 | import numpy as np
9 | import scipy.sparse
10 | import torch
11 | from skimage.morphology import binary_dilation, binary_erosion
12 | from torch.utils.data import Dataset
13 | from torchvision import transforms
14 | from tqdm import tqdm
15 |
16 |
17 | class ImagesDataset(Dataset):
18 | """A very simple dataset for loading images."""
19 |
20 | def __init__(self, filenames: str, images_root: Optional[str] = None, transform: Optional[Callable] = None,
21 | prepare_filenames: bool = True) -> None:
22 | self.root = None if images_root is None else Path(images_root)
23 | self.filenames = sorted(list(set(filenames))) if prepare_filenames else filenames
24 | self.transform = transform
25 |
26 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
27 | path = self.filenames[index]
28 | full_path = Path(path) if self.root is None else self.root / path
29 | assert full_path.is_file(), f'Not a file: {full_path}'
30 | image = cv2.imread(str(full_path))
31 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32 | if self.transform is not None:
33 | image = self.transform(image)
34 | return image, path, index
35 |
36 | def __len__(self) -> int:
37 | return len(self.filenames)
38 |
39 |
40 | def get_model(name: str):
41 | if 'dino' in name:
42 | model = torch.hub.load('facebookresearch/dino:main', name)
43 | model.fc = torch.nn.Identity()
44 | val_transform = get_transform(name)
45 | patch_size = model.patch_embed.patch_size
46 | num_heads = model.blocks[0].attn.num_heads
47 | else:
48 | raise ValueError(f'Cannot get model: {name}')
49 | model = model.eval()
50 | return model, val_transform, patch_size, num_heads
51 |
52 |
53 | def get_transform(name: str):
54 | if any(x in name for x in ('dino', 'mocov3', 'convnext', )):
55 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
56 | transform = transforms.Compose([transforms.ToTensor(), normalize])
57 | else:
58 | raise NotImplementedError()
59 | return transform
60 |
61 |
62 | def get_inverse_transform(name: str):
63 | if 'dino' in name:
64 | inv_normalize = transforms.Normalize(
65 | [-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
66 | [1 / 0.229, 1 / 0.224, 1 / 0.225])
67 | transform = transforms.Compose([transforms.ToTensor(), inv_normalize])
68 | else:
69 | raise NotImplementedError()
70 | return transform
71 |
72 |
73 | def get_image_sizes(data_dict: dict, downsample_factor: Optional[int] = None):
74 | P = data_dict['patch_size'] if downsample_factor is None else downsample_factor
75 | B, C, H, W = data_dict['shape']
76 | assert B == 1, 'assumption violated :('
77 | H_patch, W_patch = H // P, W // P
78 | H_pad, W_pad = H_patch * P, W_patch * P
79 | return (B, C, H, W, P, H_patch, W_patch, H_pad, W_pad)
80 |
81 |
82 | def _get_files(p: str):
83 | if Path(p).is_dir():
84 | return sorted(Path(p).iterdir())
85 | elif Path(p).is_file():
86 | return Path(p).read_text().splitlines()
87 | else:
88 | raise ValueError(p)
89 |
90 |
91 | def get_paired_input_files(path1: str, path2: str):
92 | files1 = _get_files(path1)
93 | files2 = _get_files(path2)
94 | assert len(files1) == len(files2)
95 | return list(enumerate(zip(files1, files2)))
96 |
97 |
98 | def make_output_dir(output_dir, check_if_empty=True):
99 | output_dir = Path(output_dir)
100 | output_dir.mkdir(exist_ok=True, parents=True)
101 | if check_if_empty and (len(list(output_dir.iterdir())) > 0):
102 | print(f'Output dir: {str(output_dir)}')
103 | if input(f'Output dir already contains files. Continue? (y/n) >> ') != 'y':
104 | sys.exit() # skip because already generated
105 |
106 |
107 | def get_largest_cc(mask: np.array):
108 | from skimage.measure import label as measure_label
109 | labels = measure_label(mask) # get connected components
110 | largest_cc_index = np.argmax(np.bincount(labels.flat)[1:]) + 1
111 | largest_cc_mask = (labels == largest_cc_index)
112 | return largest_cc_mask
113 |
114 |
115 | def erode_or_dilate_mask(x: Union[torch.Tensor, np.ndarray], r: int = 0, erode=True):
116 | fn = binary_erosion if erode else binary_dilation
117 | for _ in range(r):
118 | x_new = fn(x)
119 | if x_new.sum() > 0: # do not erode the entire mask away
120 | x = x_new
121 | return x
122 |
123 |
124 | def get_border_fraction(segmap: np.array):
125 | num_border_pixels = 2 * (segmap.shape[0] + segmap.shape[1])
126 | counts_map = {idx: 0 for idx in np.unique(segmap)}
127 | np.zeros(len(np.unique(segmap)))
128 | for border in [segmap[:, 0], segmap[:, -1], segmap[0, :], segmap[-1, :]]:
129 | unique, counts = np.unique(border, return_counts=True)
130 | for idx, count in zip(unique.tolist(), counts.tolist()):
131 | counts_map[idx] += count
132 | # normlized_counts_map = {idx: count / num_border_pixels for idx, count in counts_map.items()}
133 | indices = np.array(list(counts_map.keys()))
134 | normlized_counts = np.array(list(counts_map.values())) / num_border_pixels
135 | return indices, normlized_counts
136 |
137 |
138 | def parallel_process(inputs: Iterable, fn: Callable, multiprocessing: int = 0):
139 | start = time.time()
140 | if multiprocessing:
141 | print('Starting multiprocessing')
142 | with Pool(multiprocessing) as pool:
143 | for _ in tqdm(pool.imap(fn, inputs), total=len(inputs)):
144 | pass
145 | else:
146 | for inp in tqdm(inputs):
147 | fn(inp)
148 | print(f'Finished in {time.time() - start:.1f}s')
149 |
150 |
151 | def knn_affinity(image, n_neighbors=[20, 10], distance_weights=[2.0, 0.1]):
152 | """Computes a KNN-based affinity matrix. Note that this function requires pymatting"""
153 | try:
154 | from pymatting.util.kdtree import knn
155 | except:
156 | raise ImportError(
157 | 'Please install pymatting to compute KNN affinity matrices:\n'
158 | 'pip3 install pymatting'
159 | )
160 |
161 | h, w = image.shape[:2]
162 | r, g, b = image.reshape(-1, 3).T
163 | n = w * h
164 |
165 | x = np.tile(np.linspace(0, 1, w), h)
166 | y = np.repeat(np.linspace(0, 1, h), w)
167 |
168 | i, j = [], []
169 |
170 | for k, distance_weight in zip(n_neighbors, distance_weights):
171 | f = np.stack(
172 | [r, g, b, distance_weight * x, distance_weight * y],
173 | axis=1,
174 | out=np.zeros((n, 5), dtype=np.float32),
175 | )
176 |
177 | distances, neighbors = knn(f, f, k=k)
178 |
179 | i.append(np.repeat(np.arange(n), k))
180 | j.append(neighbors.flatten())
181 |
182 | ij = np.concatenate(i + j)
183 | ji = np.concatenate(j + i)
184 | coo_data = np.ones(2 * sum(n_neighbors) * n)
185 |
186 | # This is our affinity matrix
187 | W = scipy.sparse.csr_matrix((coo_data, (ij, ji)), (n, n))
188 | return W
189 |
190 |
191 | def rw_affinity(image, sigma=0.033, radius=1):
192 | """Computes a random walk-based affinity matrix. Note that this function requires pymatting"""
193 | try:
194 | from pymatting.laplacian.rw_laplacian import _rw_laplacian
195 | except:
196 | raise ImportError(
197 | 'Please install pymatting to compute RW affinity matrices:\n'
198 | 'pip3 install pymatting'
199 | )
200 | h, w = image.shape[:2]
201 | n = h * w
202 | values, i_inds, j_inds = _rw_laplacian(image, sigma, radius)
203 | W = scipy.sparse.csr_matrix((values, (i_inds, j_inds)), shape=(n, n))
204 | return W
205 |
206 |
207 | def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12):
208 | """Gets the diagonal sum of a sparse matrix"""
209 | try:
210 | from pymatting.util.util import row_sum
211 | except:
212 | raise ImportError(
213 | 'Please install pymatting to compute the diagonal sums:\n'
214 | 'pip3 install pymatting'
215 | )
216 |
217 | D = row_sum(W)
218 | D[D < threshold] = 1.0 # Prevent division by zero.
219 | D = scipy.sparse.diags(D)
220 | return D
221 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Deep Spectral Methods for Unsupervised Localization and Segmentation (CVPR 2022 - Oral)
4 |
5 | [](https://lukemelas.github.io/deep-spectral-segmentation/)
6 | [](https://huggingface.co/spaces/lukemelas/deep-spectral-segmentation)
7 | [](#)
8 | [](#)
9 |
10 |
11 |
12 | ### Description
13 | This code accompanies the paper [Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization](https://lukemelas.github.io/deep-spectral-segmentation/).
14 |
15 | ### Abstract
16 |
17 | Unsupervised localization and segmentation are long-standing computer vision challenges that involve decomposing an image into semantically-meaningful segments without any labeled data. These tasks are particularly interesting in an unsupervised setting due to the difficulty and cost of obtaining dense image annotations, but existing unsupervised approaches struggle with complex scenes containing multiple objects. Differently from existing methods, which are purely based on deep learning, we take inspiration from traditional spectral segmentation methods by reframing image decomposition as a graph partitioning problem. Specifically, we examine the eigenvectors of the Laplacian of a feature affinity matrix from self-supervised networks. We find that these eigenvectors already decompose an image into meaningful segments, and can be readily used to localize objects in a scene. Furthermore, by clustering the features associated with these segments across a dataset, we can obtain well-delineated, nameable regions, i.e. semantic segmentations. Experiments on complex datasets (Pascal VOC, MS-COCO) demonstrate that our simple spectral method outperforms the state-of-the-art in unsupervised localization and segmentation by a significant margin. Furthermore, our method can be readily used for a variety of complex image editing tasks, such as background removal and compositing.
18 |
19 | ### Demo
20 | Please check out our interactive demo on [Huggingface Spaces](https://huggingface.co/spaces/lukemelas/deep-spectral-segmentation)! The demo enables you to upload an image and outputs the eigenvectors extracted by our method. It does not perform the downstream tasks in our paper (e.g. semantic segmentation), but it should give you some intuition for how you might use utilize our method for your own research/use-case.
21 |
22 | ### Examples
23 |
24 | 
25 |
26 | ### How to run
27 |
28 | #### Dependencies
29 | The minimal set of dependencies is listed in `requirements.txt`.
30 |
31 | #### Data Preparation
32 |
33 | The data preparation process simply consists of collecting your images into a single folder. Here, we describe the process for [Pascal VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012//). Pascal VOC 2007 and MS-COCO are similar.
34 |
35 | Download the images into a single folder. Then create a text file where each line contains the name of an image file. For example, here is our initial data layout:
36 | ```
37 | data
38 | └── VOC2012
39 | ├── images
40 | │ └── {image_id}.jpg
41 | └── lists
42 | └── images.txt
43 | ```
44 |
45 | #### Extraction
46 |
47 | We first extract features from images and stores these into files. We then extract eigenvectors from these features. Once we have the eigenvectors, we can perform downstream tasks such as object segmentation and object localization.
48 |
49 | The primary script for this extraction process is `extract.py` in the `extract/` directory. All functions in `extract.py` have helpful docstrings with example usage.
50 |
51 | ##### Step 1: Feature Extraction
52 |
53 | First, we extract features from our images and save them to `.pth` files.
54 |
55 | With regard to models, our repository currently only supports DINO, but other models are easy to add (see the `get_model` function in `extract_utils.py`). The DINO model is downloaded automatically using `torch.hub`.
56 |
57 | Here is an example using `dino_vits16`:
58 |
59 | ```bash
60 | python extract.py extract_features \
61 | --images_list "./data/VOC2012/lists/images.txt" \
62 | --images_root "./data/VOC2012/images" \
63 | --output_dir "./data/VOC2012/features/dino_vits16" \
64 | --model_name dino_vits16 \
65 | --batch_size 1
66 | ```
67 |
68 | ##### Step 2: Eigenvector Computation
69 |
70 | Second, we extract eigenvectors from our features and save them to `.pth` files.
71 |
72 | Here, we extract the top `K=5` eigenvectors of the Laplacian matrix of our features:
73 |
74 | ```bash
75 | python extract.py extract_eigs \
76 | --images_root "./data/VOC2012/images" \
77 | --features_dir "./data/VOC2012/features/dino_vits16" \
78 | --which_matrix "laplacian" \
79 | --output_dir "./data/VOC2012/eigs/laplacian" \
80 | --K 5
81 | ```
82 |
83 | The final data structure after extracting eigenvectors looks like:
84 | ```
85 | data
86 | ├── VOC2012
87 | │ ├── eigs
88 | │ │ └── {outpur_dir_name}
89 | │ │ └── {image_id}.pth
90 | │ ├── features
91 | │ │ └── {model_name}
92 | │ │ └── {image_id}.pth
93 | │ ├── images
94 | │ │ └── {image_id}.jpg
95 | │ └── lists
96 | │ └── images.txt
97 | └── VOC2007
98 | └── ...
99 | ```
100 |
101 | At this point, you are ready to use the eigenvectors for downstream tasks such as object localization, object segmentation, and semantic segmentation.
102 |
103 | #### Object Localization
104 |
105 | First, clone the `dino` repo inside this project root (or symlink it).
106 | ```bash
107 | git clone https://github.com/facebookresearch/dino
108 | ```
109 |
110 | Run the steps above to save your eigenvectors inside a directory, which we will now call `${EIGS_DIR}`. You can then move to the `object-localization` directory and evaluate object localization with:
111 | ```bash
112 | python main.py \
113 | --eigenseg \
114 | --precomputed_eigs_dir ${EIGS_DIR} \
115 | --dataset VOC12 \
116 | --name "example_eigs"
117 | ```
118 |
119 | #### Object Segmentation
120 |
121 | To perform object segmentation (i.e. single-region segmentations), you first extract features and eigenvectors (as described above). You then extract coarse (i.e. patch-level) single-region segmentations from the eigenvectors, and then turn these into high-resolution segmentations using a CRF.
122 |
123 | Below, we will give example commands for the CUB bird dataset (`CUB_200_2011`). To download this dataset, as well as the three other object segmentation datasets used in our paper, you can follow the instructions in [unsupervised-image-segmentation](https://github.com/lukemelas/unsupervised-image-segmentation). Then make sure to specify the `data_root` parameter in the `config/eval.yaml`.
124 |
125 | For example:
126 | ```bash
127 |
128 | # Example dataset
129 | DATASET=CUB_200_2011
130 |
131 | # Features
132 | python extract.py extract_features \
133 | --images_list "./data/object-segmentation/${DATASET}/lists/images.txt" \
134 | --images_root "./data/object-segmentation/${DATASET}/images" \
135 | --output_dir "./data/object-segmentation/${DATASET}/features/dino_vits16" \
136 | --model_name dino_vits16 \
137 | --batch_size 1
138 |
139 | # Eigenvectors
140 | python extract.py extract_eigs \
141 | --images_root "./data/object-segmentation/${DATASET}/images" \
142 | --features_dir "./data/object-segmentation/${DATASET}/features/dino_vits16/" \
143 | --which_matrix "laplacian" \
144 | --output_dir "./data/object-segmentation/${DATASET}/eigs/laplacian_dino_vits16" \
145 | --K 2 \
146 |
147 |
148 | # Extract single-region segmentatiosn
149 | python extract.py extract_single_region_segmentations \
150 | --features_dir "./data/object-segmentation/${DATASET}/features/dino_vits16" \
151 | --eigs_dir "./data/object-segmentation/${DATASET}/eigs/laplacian_dino_vits16" \
152 | --output_dir "./data/object-segmentation/${DATASET}/single_region_segmentation/patches/laplacian_dino_vits16"
153 |
154 | # With CRF
155 | # Optionally, you can also use `--multiprocessing 64` to speed up computation by running on 64 processes
156 | python extract.py extract_crf_segmentations \
157 | --images_list "./data/object-segmentation/${DATASET}/lists/images.txt" \
158 | --images_root "./data/object-segmentation/${DATASET}/images" \
159 | --segmentations_dir "./data/object-segmentation/${DATASET}/single_region_segmentation/patches/laplacian_dino_vits16" \
160 | --output_dir "./data/object-segmentation/${DATASET}/single_region_segmentation/crf/laplacian_dino_vits16" \
161 | --downsample_factor 16 \
162 | --num_classes 2
163 | ```
164 |
165 | After this extraction process, you should have a file with full-resolution segmentations. Then to evaluate on object segmentation, you can move into the `object-segmentation` directory and run `python main.py`. For example:
166 |
167 | ```bash
168 | python main.py predictions.root="./data/object-segmentation" predictions.run="single_region_segmentation/crf/laplacian_dino_vits16"
169 | ```
170 |
171 | By default, this assumes that all four object segmentations are available. To run on a custom dataset or only a subset of these datasets, simply edit `configs/eval.yaml`.
172 |
173 | Also, if you want to visualize your segmentations, you should be able to use `streamlit run extract.py vis_segmentations` (after installing streamlit).
174 |
175 | #### Semantic Segmentation
176 |
177 | For semantic segmentation, we provide full instructions in the `semantic-segmentation` subfolder.
178 |
179 | #### Acknowledgements
180 |
181 | L. M. K. acknowledges the generous support of the Rhodes Trust. C. R. is supported by Innovate UK (project 71653) on behalf of UK Research and Innovation (UKRI) and by the European Research Council (ERC) IDIU-638009. I. L. and A. V. are supported by the VisualAI EPSRC programme grant (EP/T028572/1).
182 |
183 | We would like to acknowledge LOST ([paper](https://arxiv.org/abs/2109.14279) and [code](https://github.com/valeoai/LOST)), whose code we adapt for our object localization experiments. If you are interested in object localization, we suggest checking out their work!
184 |
185 | #### Citation
186 | ```
187 | @inproceedings{
188 | melaskyriazi2022deep,
189 | title={Deep Spectral Methods: A Surprisingly Strong Baseline for Unsupervised Semantic Segmentation and Localization}
190 | author={Luke Melas-Kyriazi and Christian Rupprecht and Iro Laina and Andrea Vedaldi}
191 | year={2022}
192 | booktitle={CVPR}
193 | }
194 | ```
195 |
--------------------------------------------------------------------------------
/object-localization/object_discovery.py:
--------------------------------------------------------------------------------
1 | """
2 | Main functions for object discovery.
3 | Code adapted from LOST: https://github.com/valeoai/LOST
4 | """
5 | from collections import namedtuple
6 | from typing import Optional, Tuple
7 | import torch
8 | import torch.nn.functional as F
9 | import scipy
10 | import scipy.ndimage
11 |
12 | import numpy as np
13 | from datasets import bbox_iou
14 |
15 |
16 | def get_eigenvectors_from_features(feats, which_matrix: str = 'affinity_torch', K=2):
17 | from scipy.sparse.linalg import eigsh
18 |
19 | # Eigenvectors of affinity matrix
20 | if which_matrix == 'affinity_torch':
21 | A = feats @ feats.T
22 | eigenvalues, eigenvectors = torch.eig(A, eigenvectors=True)
23 |
24 | # Eigenvectors of affinity matrix with scipy
25 | elif which_matrix == 'affinity':
26 | A = (feats @ feats.T).cpu().numpy()
27 | eigenvalues, eigenvectors = eigsh(A, which='LM', k=K)
28 | eigenvectors = torch.flip(torch.from_numpy(eigenvectors), dims=(-1,))
29 |
30 | # Eigenvectors of laplacian matrix
31 | elif which_matrix == 'laplacian':
32 | A = (feats @ feats.T).cpu().numpy()
33 | _W_semantic = (A * (A > 0))
34 | _W_semantic = _W_semantic / _W_semantic.max()
35 | diag = _W_semantic @ np.ones(_W_semantic.shape[0])
36 | diag[diag < 1e-12] = 1.0
37 | D = np.diag(diag) # row sum
38 | try:
39 | eigenvalues, eigenvectors = eigsh(D - _W_semantic, k=K, sigma=0, which='LM', M=D)
40 | except:
41 | eigenvalues, eigenvectors = eigsh(D - _W_semantic, k=K, which='SM', M=D)
42 | eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float()
43 |
44 | # Eigenvectors of matting laplacian matrix
45 | elif which_matrix == 'matting_laplacian':
46 |
47 | raise NotImplementedError()
48 |
49 | # # Get sizes
50 | # B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict)
51 | # H_pad_lr, W_pad_lr = H_pad // image_downsample_factor, W_pad // image_downsample_factor
52 |
53 | # # Load image
54 | # image_file = str(Path(images_root) / f'{image_id}.jpg')
55 | # image_lr = Image.open(image_file).resize((W_pad_lr, H_pad_lr), Image.BILINEAR)
56 | # image_lr = np.array(image_lr) / 255.
57 |
58 | # # Get color affinities
59 | # W_lr = utils.knn_affinity(image_lr / 255)
60 |
61 | # # Get semantic affinities
62 | # k_feats_lr = F.interpolate(
63 | # k_feats.T.reshape(1, -1, H_patch, W_patch),
64 | # size=(H_pad_lr, W_pad_lr), mode='bilinear', align_corners=False
65 | # ).reshape(-1, H_pad_lr * W_pad_lr).T
66 | # A_sm_lr = k_feats_lr @ k_feats_lr.T
67 | # W_sm_lr = (A_sm_lr * (A_sm_lr > 0)).cpu().numpy()
68 | # W_sm_lr = W_sm_lr / W_sm_lr.max()
69 |
70 | # # Combine
71 | # W_color = np.array(W_lr.todense().astype(np.float32))
72 | # W_comb = W_sm_lr + W_color * image_color_lambda # combination
73 | # D_comb = utils.get_diagonal(W_comb)
74 |
75 | # # Extract eigenvectors
76 | # try:
77 | # eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM', M=D_comb)
78 | # except:
79 | # eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM', M=D_comb)
80 | # eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float()
81 |
82 | return eigenvectors
83 |
84 |
85 | def get_bbox_from_patch_mask(patch_mask, init_image_size, img_np: Optional[np.array] = None):
86 |
87 | # Sizing
88 | H, W = init_image_size[1:]
89 | T = patch_mask.numel()
90 | if (H // 8) * (W // 8) == T:
91 | P, H_lr, W_lr = (8, H // 8, W // 8)
92 | elif (H // 16) * (W // 16) == T:
93 | P, H_lr, W_lr = (16, H // 16, W // 16)
94 | elif 4 * (H // 16) * (W // 16) == T:
95 | P, H_lr, W_lr = (8, 2 * (H // 16), 2 * (W // 16))
96 | elif 16 * (H // 32) * (W // 32) == T:
97 | P, H_lr, W_lr = (8, 4 * (H // 32), 4 * (W // 32))
98 | else:
99 | raise ValueError(f'{init_image_size=}, {patch_mask.shape=}')
100 |
101 | # Create patch mask
102 | patch_mask = patch_mask.reshape(H_lr, W_lr).cpu().numpy()
103 |
104 | # Possibly reverse mask
105 | # print(np.mean(patch_mask).item())
106 | if 0.5 < np.mean(patch_mask).item() < 1.0:
107 | patch_mask = (1 - patch_mask).astype(np.uint8)
108 | elif np.sum(patch_mask).item() == 0: # nothing detected at all, so cover the entire image
109 | patch_mask = (1 - patch_mask).astype(np.uint8)
110 |
111 | # Get the box corresponding to the largest connected component of the first eigenvector
112 | xmin, ymin, xmax, ymax = get_largest_cc_box(patch_mask)
113 | # pred = [xmin, ymin, xmax, ymax]
114 |
115 | # Rescale to image size
116 | r_xmin, r_xmax = P * xmin, P * xmax
117 | r_ymin, r_ymax = P * ymin, P * ymax
118 |
119 | # Prediction bounding box
120 | pred = [r_xmin, r_ymin, r_xmax, r_ymax]
121 |
122 | # Check not out of image size (used when padding)
123 | pred[2] = min(pred[2], W)
124 | pred[3] = min(pred[3], H)
125 |
126 | return np.asarray(pred)
127 |
128 |
129 | def lost(feats, dims, scales, init_image_size, k_patches=100):
130 | """
131 | Implementation of LOST method.
132 | Inputs
133 | feats: the pixel/patche features of an image
134 | dims: dimension of the map from which the features are used
135 | scales: from image to map scale
136 | init_image_size: size of the image
137 | k_patches: number of k patches retrieved that are compared to the seed at seed expansion
138 | Outputs
139 | pred: box predictions
140 | A: binary affinity matrix
141 | scores: lowest degree scores for all patches
142 | seed: selected patch corresponding to an object
143 | """
144 | # Compute the similarity
145 | A = (feats @ feats.transpose(1, 2)).squeeze()
146 |
147 | # Compute the inverse degree centrality measure per patch
148 | sorted_patches, scores = patch_scoring(A)
149 |
150 | # Select the initial seed
151 | seed = sorted_patches[0]
152 |
153 | # Seed expansion
154 | potentials = sorted_patches[:k_patches]
155 | similars = potentials[A[seed, potentials] > 0.0]
156 | M = torch.sum(A[similars, :], dim=0)
157 |
158 | # Box extraction
159 | pred, _ = detect_box(
160 | M, seed, dims, scales=scales, initial_im_size=init_image_size[1:]
161 | )
162 |
163 | return np.asarray(pred), A, M, scores, seed
164 |
165 |
166 | def patch_scoring(M, threshold=0.):
167 | """
168 | Patch scoring based on the inverse degree.
169 | """
170 | # Cloning important
171 | A = M.clone()
172 |
173 | # Zero diagonal
174 | A.fill_diagonal_(0)
175 |
176 | # Make sure symmetric and non nul
177 | A[A < 0] = 0
178 | C = A + A.t() # NOTE: this was not used. should this be used?
179 |
180 | # Sort pixels by inverse degree
181 | cent = -torch.sum(A > threshold, dim=1).type(torch.float32)
182 | sel = torch.argsort(cent, descending=True)
183 |
184 | return sel, cent
185 |
186 |
187 | def detect_box(A, seed, dims, initial_im_size=None, scales=None):
188 | """
189 | Extract a box corresponding to the seed patch. Among connected components extract from the affinity matrix, select the one corresponding to the seed patch.
190 | """
191 | w_featmap, h_featmap = dims
192 |
193 | correl = A.reshape(w_featmap, h_featmap).float()
194 |
195 | # Compute connected components
196 | labeled_array, num_features = scipy.ndimage.label(correl.cpu().numpy() > 0.0)
197 |
198 | # Find connected component corresponding to the initial seed
199 | cc = labeled_array[np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap))]
200 |
201 | # Should not happen with LOST
202 | if cc == 0:
203 | raise ValueError("The seed is in the background component.")
204 |
205 | # Find box
206 | mask = np.where(labeled_array == cc)
207 |
208 | # Add +1 because excluded max
209 | ymin, ymax = min(mask[0]), max(mask[0]) + 1
210 | xmin, xmax = min(mask[1]), max(mask[1]) + 1
211 |
212 | # Rescale to image size
213 | r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax
214 | r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax
215 |
216 | pred = [r_xmin, r_ymin, r_xmax, r_ymax]
217 |
218 | # Check not out of image size (used when padding)
219 | if initial_im_size:
220 | pred[2] = min(pred[2], initial_im_size[1])
221 | pred[3] = min(pred[3], initial_im_size[0])
222 |
223 | # Coordinate predictions for the feature space
224 | # Axis different then in image space
225 | pred_feats = [ymin, xmin, ymax, xmax]
226 |
227 | return pred, pred_feats
228 |
229 |
230 | def dino_seg(attn, dims, patch_size, head=0):
231 | """
232 | Extraction of boxes based on the DINO segmentation method proposed in https://github.com/facebookresearch/dino.
233 | """
234 | w_featmap, h_featmap = dims
235 | nh = attn.shape[1]
236 | official_th = 0.6
237 |
238 | # We keep only the output patch attention
239 | # Get the attentions corresponding to [CLS] token
240 | attentions = attn[0, :, 0, 1:].reshape(nh, -1)
241 |
242 | # we keep only a certain percentage of the mass
243 | val, idx = torch.sort(attentions)
244 | val /= torch.sum(val, dim=1, keepdim=True)
245 | cumval = torch.cumsum(val, dim=1)
246 | th_attn = cumval > (1 - official_th)
247 | idx2 = torch.argsort(idx)
248 | for h in range(nh):
249 | th_attn[h] = th_attn[h][idx2[h]]
250 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
251 |
252 | # Connected components
253 | labeled_array, num_features = scipy.ndimage.label(th_attn[head].cpu().numpy())
254 |
255 | # Find the biggest component
256 | size_components = [np.sum(labeled_array == c) for c in range(np.max(labeled_array))]
257 |
258 | if len(size_components) > 1:
259 | # Select the biggest component avoiding component 0 corresponding to background
260 | biggest_component = np.argmax(size_components[1:]) + 1
261 | else:
262 | # Cases of a single component
263 | biggest_component = 0
264 |
265 | # Mask corresponding to connected component
266 | mask = np.where(labeled_array == biggest_component)
267 |
268 | # Add +1 because excluded max
269 | ymin, ymax = min(mask[0]), max(mask[0]) + 1
270 | xmin, xmax = min(mask[1]), max(mask[1]) + 1
271 |
272 | # Rescale to image
273 | r_xmin, r_xmax = xmin * patch_size, xmax * patch_size
274 | r_ymin, r_ymax = ymin * patch_size, ymax * patch_size
275 | pred = [r_xmin, r_ymin, r_xmax, r_ymax]
276 |
277 | return pred
278 |
279 |
280 | def get_largest_cc_box(mask: np.array):
281 | from skimage.measure import label as measure_label
282 | labels = measure_label(mask) # get connected components
283 | largest_cc_index = np.argmax(np.bincount(labels.flat)[1:]) + 1
284 | mask = np.where(labels == largest_cc_index)
285 | ymin, ymax = min(mask[0]), max(mask[0]) + 1
286 | xmin, xmax = min(mask[1]), max(mask[1]) + 1
287 | return [xmin, ymin, xmax, ymax]
288 |
--------------------------------------------------------------------------------
/object-segmentation/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc functions, including distributed helpers, mostly from torchvision
3 | """
4 | import time
5 | import datetime
6 | import random
7 | import numpy as np
8 | import torch
9 | import torch.distributed as dist
10 | import torchvision
11 | from dataclasses import dataclass
12 | from collections import defaultdict, deque
13 | from typing import Callable, Optional
14 | from PIL import Image
15 | from accelerate import Accelerator
16 | from omegaconf import DictConfig
17 |
18 |
19 | @dataclass
20 | class TrainState:
21 | epoch: int = 0
22 | step: int = 0
23 | best_val: Optional[float] = None
24 |
25 |
26 | def get_optimizer(cfg: DictConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer:
27 | # Determine the learning rate
28 | if cfg.optimizer.scale_learning_rate_with_batch_size:
29 | lr = accelerator.state.num_processes * cfg.data.loader.batch_size * cfg.optimizer.base_lr
30 | print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format(
31 | ws=accelerator.state.num_processes, bs=cfg.data.loader.batch_size, blr=cfg.lr, lr=lr))
32 | else: # scale base learning rate by batch size
33 | lr = cfg.lr
34 | print('lr = {lr} (absolute learning rate)'.format(lr=lr))
35 | # Construct optimizer
36 | if cfg.optimizer.kind == 'torch':
37 | parameters = [p for p in model.parameters() if p.requires_grad]
38 | optimizer = getattr(torch.optim, cfg.optimizer.cls)(parameters, lr=lr, **cfg.optimizer.kwargs)
39 | elif cfg.optimizer.kind == 'timm':
40 | from timm.optim import create_optimizer_v2
41 | optimizer = create_optimizer_v2(model, lr=lr, **cfg.optimizer.kwargs)
42 | elif cfg.optimizer.kind == 'transformers':
43 | import transformers
44 | parameters = [p for p in model.parameters() if p.requires_grad]
45 | optimizer = getattr(transformers, cfg.optimizer.name)(parameters, lr=lr, **cfg.optimizer.kwargs)
46 | else:
47 | raise NotImplementedError(f'invalid optimizer config: {cfg.optimizer}')
48 | return optimizer
49 |
50 |
51 | def get_scheduler(cfg: DictConfig, optimizer: torch.optim.Optimizer) -> Callable:
52 | if cfg.scheduler.kind == 'torch':
53 | Sch = getattr(torch.optim.lr_scheduler, cfg.scheduler.cls)
54 | scheduler = Sch(optimizer=optimizer, **cfg.scheduler.kwargs)
55 | if cfg.scheduler.warmup:
56 | from warmup_scheduler import GradualWarmupScheduler
57 | scheduler = GradualWarmupScheduler( # wrap scheduler with warmup
58 | optimizer, multiplier=1, total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler)
59 | elif cfg.scheduler.kind == 'timm':
60 | from timm.scheduler import create_scheduler
61 | scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs)
62 | elif cfg.scheduler.kind == 'transformers':
63 | from transformers import get_scheduler
64 | scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs)
65 | else:
66 | raise NotImplementedError(f'invalid scheduler config: {cfg.scheduler}')
67 | return scheduler
68 |
69 |
70 | @torch.no_grad()
71 | def accuracy(output, target, topk=(1,)):
72 | """Computes the accuracy over the k top predictions for the specified values of k"""
73 | # reshape
74 | target = target.reshape(-1)
75 | output = output.reshape(target.size(0), -1)
76 |
77 | maxk = max(topk)
78 | batch_size = target.size(0)
79 |
80 | _, pred = output.topk(maxk, 1, True, True)
81 | pred = pred.t()
82 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
83 |
84 | res = []
85 | for k in topk:
86 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
87 | res.append(correct_k.mul_(100.0 / batch_size))
88 | return res
89 |
90 |
91 | class SmoothedValue(object):
92 | """Track a series of values and provide access to smoothed values over a
93 | window or the global series average.
94 | """
95 |
96 | def __init__(self, window_size=20, fmt=None):
97 | if fmt is None:
98 | fmt = "{median:.4f} ({global_avg:.4f})"
99 | self.deque = deque(maxlen=window_size)
100 | self.total = 0.0
101 | self.count = 0
102 | self.fmt = fmt
103 |
104 | def update(self, value, n=1):
105 | self.deque.append(value)
106 | self.count += n
107 | self.total += value * n
108 |
109 | def synchronize_between_processes(self, device='cuda'):
110 | """
111 | Warning: does not synchronize the deque!
112 | """
113 | if not using_distributed():
114 | return
115 | print(f"device={device}")
116 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device)
117 | dist.barrier()
118 | dist.all_reduce(t)
119 | t = t.tolist()
120 | self.count = int(t[0])
121 | self.total = t[1]
122 |
123 | @property
124 | def median(self):
125 | d = torch.tensor(list(self.deque))
126 | return d.median().item()
127 |
128 | @property
129 | def avg(self):
130 | d = torch.tensor(list(self.deque), dtype=torch.float32)
131 | return d.mean().item()
132 |
133 | @property
134 | def global_avg(self):
135 | return self.total / self.count
136 |
137 | @property
138 | def max(self):
139 | return max(self.deque)
140 |
141 | @property
142 | def value(self):
143 | return self.deque[-1]
144 |
145 | def __str__(self):
146 | return self.fmt.format(
147 | median=self.median,
148 | avg=self.avg,
149 | global_avg=self.global_avg,
150 | max=self.max,
151 | value=self.value)
152 |
153 |
154 | class MetricLogger(object):
155 | def __init__(self, delimiter="\t"):
156 | self.meters = defaultdict(SmoothedValue)
157 | self.delimiter = delimiter
158 |
159 | def update(self, **kwargs):
160 | n = kwargs.pop('n', 1)
161 | for k, v in kwargs.items():
162 | if isinstance(v, torch.Tensor):
163 | v = v.item()
164 | assert isinstance(v, (float, int))
165 | self.meters[k].update(v, n=n)
166 |
167 | def __getattr__(self, attr):
168 | if attr in self.meters:
169 | return self.meters[attr]
170 | if attr in self.__dict__:
171 | return self.__dict__[attr]
172 | raise AttributeError("'{}' object has no attribute '{}'".format(
173 | type(self).__name__, attr))
174 |
175 | def __str__(self):
176 | loss_str = []
177 | for name, meter in self.meters.items():
178 | loss_str.append(
179 | "{}: {}".format(name, str(meter))
180 | )
181 | return self.delimiter.join(loss_str)
182 |
183 | def synchronize_between_processes(self, device='cuda'):
184 | for meter in self.meters.values():
185 | meter.synchronize_between_processes(device=device)
186 |
187 | def add_meter(self, name, meter):
188 | self.meters[name] = meter
189 |
190 | def log_every(self, iterable, print_freq, header=None):
191 | i = 0
192 | if not header:
193 | header = ''
194 | start_time = time.time()
195 | end = time.time()
196 | iter_time = SmoothedValue(fmt='{avg:.4f}')
197 | data_time = SmoothedValue(fmt='{avg:.4f}')
198 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
199 | log_msg = [
200 | header,
201 | '[{0' + space_fmt + '}/{1}]',
202 | 'eta: {eta}',
203 | '{meters}',
204 | 'time: {time}',
205 | 'data: {data}'
206 | ]
207 | if torch.cuda.is_available():
208 | log_msg.append('max mem: {memory:.0f}')
209 | log_msg = self.delimiter.join(log_msg)
210 | MB = 1024.0 * 1024.0
211 | for obj in iterable:
212 | data_time.update(time.time() - end)
213 | yield obj
214 | iter_time.update(time.time() - end)
215 | if i % print_freq == 0 or i == len(iterable) - 1:
216 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
217 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
218 | if torch.cuda.is_available():
219 | print(log_msg.format(
220 | i, len(iterable), eta=eta_string,
221 | meters=str(self),
222 | time=str(iter_time), data=str(data_time),
223 | memory=torch.cuda.max_memory_allocated() / MB))
224 | else:
225 | print(log_msg.format(
226 | i, len(iterable), eta=eta_string,
227 | meters=str(self),
228 | time=str(iter_time), data=str(data_time)))
229 | i += 1
230 | end = time.time()
231 | total_time = time.time() - start_time
232 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
233 | print('{} Total time: {} ({:.4f} s / it)'.format(
234 | header, total_time_str, total_time / len(iterable)))
235 |
236 |
237 | class NormalizeInverse(torchvision.transforms.Normalize):
238 | """
239 | Undoes the normalization and returns the reconstructed images in the input domain.
240 | """
241 |
242 | def __init__(self, mean, std):
243 | mean = torch.as_tensor(mean)
244 | std = torch.as_tensor(std)
245 | std_inv = 1 / (std + 1e-7)
246 | mean_inv = -mean * std_inv
247 | super().__init__(mean=mean_inv, std=std_inv)
248 |
249 | def __call__(self, tensor):
250 | return super().__call__(tensor.clone())
251 |
252 |
253 | def set_requires_grad(module, requires_grad=True):
254 | for p in module.parameters():
255 | p.requires_grad = requires_grad
256 |
257 |
258 | def resume_from_checkpoint(cfg, model, optimizer=None, scheduler=None, model_ema=None):
259 |
260 | # Resume model state dict
261 | checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu')
262 | if 'model' in checkpoint:
263 | state_dict, key = checkpoint['model'], 'model'
264 | else:
265 | state_dict, key = checkpoint, 'N/A'
266 | if any(k.startswith('module.') for k in state_dict.keys()):
267 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
268 | print('Removed "module." from checkpoint state dict')
269 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
270 | print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}')
271 | if len(missing_keys):
272 | print(f' - Missing_keys: {missing_keys}')
273 | if len(unexpected_keys):
274 | print(f' - Unexpected_keys: {unexpected_keys}')
275 | # Resume model ema
276 | if cfg.ema.use_ema:
277 | if checkpoint['model_ema']:
278 | model_ema.load_state_dict(checkpoint['model_ema'])
279 | print('Loaded model ema from checkpoint')
280 | else:
281 | model_ema.load_state_dict(model.parameters())
282 | print('No model ema in checkpoint; loaded current parameters into model')
283 | else:
284 | if 'model_ema' in checkpoint:
285 | print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)')
286 | else:
287 | print('Not using model ema, and no model_ema found in checkpoint.')
288 |
289 | # Resume optimization state
290 | if cfg.checkpoint.resume_training and 'train' in cfg.job_type:
291 | if 'steps' in checkpoint:
292 | checkpoint['step'] = checkpoint['steps']
293 | assert {'optimizer', 'scheduler', 'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys()))
294 | optimizer.load_state_dict(checkpoint['optimizer'])
295 | scheduler.load_state_dict(checkpoint['scheduler'])
296 | epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val']
297 | train_state = TrainState(epoch=epoch, step=step, best_val=best_val)
298 | print(f'Loaded optimizer/scheduler at epoch {epoch} from checkpoint')
299 | elif cfg.checkpoint.resume_optimizer_only:
300 | assert 'optimizer' in set(checkpoint.keys())
301 | optimizer.load_state_dict(checkpoint['optimizer'])
302 | print(f'Loaded optimizer from checkpoint, but did not load scheduler/epoch')
303 | else:
304 | train_state = TrainState()
305 | print('Did not resume training (i.e. optimizer/scheduler/epoch)')
306 |
307 | return train_state
308 |
309 |
310 | def setup_distributed_print(is_master):
311 | """
312 | This function disables printing when not in master process
313 | """
314 | import builtins as __builtin__
315 | builtin_print = __builtin__.print
316 |
317 | def print(*args, **kwargs):
318 | force = kwargs.pop('force', False)
319 | if is_master or force:
320 | builtin_print(*args, **kwargs)
321 |
322 | __builtin__.print = print
323 |
324 |
325 | def using_distributed():
326 | return dist.is_available() and dist.is_initialized()
327 |
328 |
329 | def get_rank():
330 | return dist.get_rank() if using_distributed() else 0
331 |
332 |
333 | def set_seed(seed):
334 | rank = get_rank()
335 | seed = seed + rank
336 | torch.manual_seed(seed)
337 | torch.cuda.manual_seed(seed)
338 | np.random.seed(seed)
339 | random.seed(seed)
340 | torch.backends.cudnn.enabled = True
341 | torch.backends.cudnn.benchmark = True
342 | if using_distributed():
343 | print(f'Seeding node {rank} with seed {seed}', force=True)
344 | else:
345 | print(f'Seeding node {rank} with seed {seed}')
346 |
347 |
348 | def tensor_to_pil(image: torch.Tensor):
349 | assert len(image.shape) and image.shape[0] == 3, f"{image.shape=}"
350 | image = (image.float() * 0.5 + 0.5).clamp(0, 1).detach().cpu().requires_grad_(False)
351 | ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
352 | return Image.fromarray(ndarr)
353 |
354 |
355 | def albumentations_to_torch(transform):
356 | def _transform(img, target):
357 | augmented = transform(image=img, mask=target)
358 | return augmented['image'], augmented['mask']
359 | return _transform
360 |
--------------------------------------------------------------------------------
/semantic-segmentation/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc functions, including distributed helpers, mostly from torchvision
3 | """
4 | import time
5 | import datetime
6 | import random
7 | import numpy as np
8 | import torch
9 | import torch.distributed as dist
10 | import torchvision
11 | from dataclasses import dataclass
12 | from collections import defaultdict, deque
13 | from typing import Callable, Optional
14 | from PIL import Image
15 | from accelerate import Accelerator
16 | from omegaconf import DictConfig
17 |
18 |
19 | @dataclass
20 | class TrainState:
21 | epoch: int = 0
22 | step: int = 0
23 | best_val: Optional[float] = None
24 |
25 |
26 | def get_optimizer(cfg: DictConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer:
27 | # Determine the learning rate
28 | if cfg.optimizer.scale_learning_rate_with_batch_size:
29 | lr = accelerator.state.num_processes * cfg.data.loader.batch_size * cfg.optimizer.base_lr
30 | print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format(
31 | ws=accelerator.state.num_processes, bs=cfg.data.loader.batch_size, blr=cfg.lr, lr=lr))
32 | else: # scale base learning rate by batch size
33 | lr = cfg.lr
34 | print('lr = {lr} (absolute learning rate)'.format(lr=lr))
35 | # Construct optimizer
36 | if cfg.optimizer.kind == 'torch':
37 | parameters = [p for p in model.parameters() if p.requires_grad]
38 | optimizer = getattr(torch.optim, cfg.optimizer.cls)(parameters, lr=lr, **cfg.optimizer.kwargs)
39 | elif cfg.optimizer.kind == 'timm':
40 | from timm.optim import create_optimizer_v2
41 | optimizer = create_optimizer_v2(model, lr=lr, **cfg.optimizer.kwargs)
42 | elif cfg.optimizer.kind == 'transformers':
43 | import transformers
44 | parameters = [p for p in model.parameters() if p.requires_grad]
45 | optimizer = getattr(transformers, cfg.optimizer.name)(parameters, lr=lr, **cfg.optimizer.kwargs)
46 | else:
47 | raise NotImplementedError(f'invalid optimizer config: {cfg.optimizer}')
48 | return optimizer
49 |
50 |
51 | def get_scheduler(cfg: DictConfig, optimizer: torch.optim.Optimizer) -> Callable:
52 | if cfg.scheduler.kind == 'torch':
53 | Sch = getattr(torch.optim.lr_scheduler, cfg.scheduler.cls)
54 | scheduler = Sch(optimizer=optimizer, **cfg.scheduler.kwargs)
55 | if cfg.scheduler.warmup:
56 | from warmup_scheduler import GradualWarmupScheduler
57 | scheduler = GradualWarmupScheduler( # wrap scheduler with warmup
58 | optimizer, multiplier=1, total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler)
59 | elif cfg.scheduler.kind == 'timm':
60 | from timm.scheduler import create_scheduler
61 | scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs)
62 | elif cfg.scheduler.kind == 'transformers':
63 | from transformers import get_scheduler
64 | scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs)
65 | else:
66 | raise NotImplementedError(f'invalid scheduler config: {cfg.scheduler}')
67 | return scheduler
68 |
69 |
70 | @torch.no_grad()
71 | def accuracy(output, target, topk=(1,)):
72 | """Computes the accuracy over the k top predictions for the specified values of k"""
73 | # reshape
74 | target = target.reshape(-1)
75 | output = output.reshape(target.size(0), -1)
76 |
77 | maxk = max(topk)
78 | batch_size = target.size(0)
79 |
80 | _, pred = output.topk(maxk, 1, True, True)
81 | pred = pred.t()
82 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
83 |
84 | res = []
85 | for k in topk:
86 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
87 | res.append(correct_k.mul_(100.0 / batch_size))
88 | return res
89 |
90 |
91 | class SmoothedValue(object):
92 | """Track a series of values and provide access to smoothed values over a
93 | window or the global series average.
94 | """
95 |
96 | def __init__(self, window_size=20, fmt=None):
97 | if fmt is None:
98 | fmt = "{median:.4f} ({global_avg:.4f})"
99 | self.deque = deque(maxlen=window_size)
100 | self.total = 0.0
101 | self.count = 0
102 | self.fmt = fmt
103 |
104 | def update(self, value, n=1):
105 | self.deque.append(value)
106 | self.count += n
107 | self.total += value * n
108 |
109 | def synchronize_between_processes(self, device='cuda'):
110 | """
111 | Warning: does not synchronize the deque!
112 | """
113 | if not using_distributed():
114 | return
115 | print(f"device={device}")
116 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device)
117 | dist.barrier()
118 | dist.all_reduce(t)
119 | t = t.tolist()
120 | self.count = int(t[0])
121 | self.total = t[1]
122 |
123 | @property
124 | def median(self):
125 | d = torch.tensor(list(self.deque))
126 | return d.median().item()
127 |
128 | @property
129 | def avg(self):
130 | d = torch.tensor(list(self.deque), dtype=torch.float32)
131 | return d.mean().item()
132 |
133 | @property
134 | def global_avg(self):
135 | return self.total / self.count
136 |
137 | @property
138 | def max(self):
139 | return max(self.deque)
140 |
141 | @property
142 | def value(self):
143 | return self.deque[-1]
144 |
145 | def __str__(self):
146 | return self.fmt.format(
147 | median=self.median,
148 | avg=self.avg,
149 | global_avg=self.global_avg,
150 | max=self.max,
151 | value=self.value)
152 |
153 |
154 | class MetricLogger(object):
155 | def __init__(self, delimiter="\t"):
156 | self.meters = defaultdict(SmoothedValue)
157 | self.delimiter = delimiter
158 |
159 | def update(self, **kwargs):
160 | n = kwargs.pop('n', 1)
161 | for k, v in kwargs.items():
162 | if isinstance(v, torch.Tensor):
163 | v = v.item()
164 | assert isinstance(v, (float, int))
165 | self.meters[k].update(v, n=n)
166 |
167 | def __getattr__(self, attr):
168 | if attr in self.meters:
169 | return self.meters[attr]
170 | if attr in self.__dict__:
171 | return self.__dict__[attr]
172 | raise AttributeError("'{}' object has no attribute '{}'".format(
173 | type(self).__name__, attr))
174 |
175 | def __str__(self):
176 | loss_str = []
177 | for name, meter in self.meters.items():
178 | loss_str.append(
179 | "{}: {}".format(name, str(meter))
180 | )
181 | return self.delimiter.join(loss_str)
182 |
183 | def synchronize_between_processes(self, device='cuda'):
184 | for meter in self.meters.values():
185 | meter.synchronize_between_processes(device=device)
186 |
187 | def add_meter(self, name, meter):
188 | self.meters[name] = meter
189 |
190 | def log_every(self, iterable, print_freq, header=None):
191 | i = 0
192 | if not header:
193 | header = ''
194 | start_time = time.time()
195 | end = time.time()
196 | iter_time = SmoothedValue(fmt='{avg:.4f}')
197 | data_time = SmoothedValue(fmt='{avg:.4f}')
198 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
199 | log_msg = [
200 | header,
201 | '[{0' + space_fmt + '}/{1}]',
202 | 'eta: {eta}',
203 | '{meters}',
204 | 'time: {time}',
205 | 'data: {data}'
206 | ]
207 | if torch.cuda.is_available():
208 | log_msg.append('max mem: {memory:.0f}')
209 | log_msg = self.delimiter.join(log_msg)
210 | MB = 1024.0 * 1024.0
211 | for obj in iterable:
212 | data_time.update(time.time() - end)
213 | yield obj
214 | iter_time.update(time.time() - end)
215 | if i % print_freq == 0 or i == len(iterable) - 1:
216 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
217 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
218 | if torch.cuda.is_available():
219 | print(log_msg.format(
220 | i, len(iterable), eta=eta_string,
221 | meters=str(self),
222 | time=str(iter_time), data=str(data_time),
223 | memory=torch.cuda.max_memory_allocated() / MB))
224 | else:
225 | print(log_msg.format(
226 | i, len(iterable), eta=eta_string,
227 | meters=str(self),
228 | time=str(iter_time), data=str(data_time)))
229 | i += 1
230 | end = time.time()
231 | total_time = time.time() - start_time
232 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
233 | print('{} Total time: {} ({:.4f} s / it)'.format(
234 | header, total_time_str, total_time / len(iterable)))
235 |
236 |
237 | class NormalizeInverse(torchvision.transforms.Normalize):
238 | """
239 | Undoes the normalization and returns the reconstructed images in the input domain.
240 | """
241 |
242 | def __init__(self, mean, std):
243 | mean = torch.as_tensor(mean)
244 | std = torch.as_tensor(std)
245 | std_inv = 1 / (std + 1e-7)
246 | mean_inv = -mean * std_inv
247 | super().__init__(mean=mean_inv, std=std_inv)
248 |
249 | def __call__(self, tensor):
250 | return super().__call__(tensor.clone())
251 |
252 |
253 | def set_requires_grad(module, requires_grad=True):
254 | for p in module.parameters():
255 | p.requires_grad = requires_grad
256 |
257 |
258 | def resume_from_checkpoint(cfg, model, optimizer=None, scheduler=None, model_ema=None):
259 |
260 | # Resume model state dict
261 | checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu')
262 | if 'model' in checkpoint:
263 | state_dict, key = checkpoint['model'], 'model'
264 | else:
265 | state_dict, key = checkpoint, 'N/A'
266 | if any(k.startswith('module.') for k in state_dict.keys()):
267 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
268 | print('Removed "module." from checkpoint state dict')
269 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
270 | print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}')
271 | if len(missing_keys):
272 | print(f' - Missing_keys: {missing_keys}')
273 | if len(unexpected_keys):
274 | print(f' - Unexpected_keys: {unexpected_keys}')
275 | # Resume model ema
276 | if cfg.ema.use_ema:
277 | if checkpoint['model_ema']:
278 | model_ema.load_state_dict(checkpoint['model_ema'])
279 | print('Loaded model ema from checkpoint')
280 | else:
281 | model_ema.load_state_dict(model.parameters())
282 | print('No model ema in checkpoint; loaded current parameters into model')
283 | else:
284 | if 'model_ema' in checkpoint:
285 | print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)')
286 | else:
287 | print('Not using model ema, and no model_ema found in checkpoint.')
288 |
289 | # Resume optimization state
290 | if cfg.checkpoint.resume_training and 'train' in cfg.job_type:
291 | if 'steps' in checkpoint:
292 | checkpoint['step'] = checkpoint['steps']
293 | assert {'optimizer', 'scheduler', 'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys()))
294 | optimizer.load_state_dict(checkpoint['optimizer'])
295 | scheduler.load_state_dict(checkpoint['scheduler'])
296 | epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val']
297 | train_state = TrainState(epoch=epoch, step=step, best_val=best_val)
298 | print(f'Loaded optimizer/scheduler at epoch {epoch} from checkpoint')
299 | elif cfg.checkpoint.resume_optimizer_only:
300 | assert 'optimizer' in set(checkpoint.keys())
301 | optimizer.load_state_dict(checkpoint['optimizer'])
302 | print(f'Loaded optimizer from checkpoint, but did not load scheduler/epoch')
303 | else:
304 | train_state = TrainState()
305 | print('Did not resume training (i.e. optimizer/scheduler/epoch)')
306 |
307 | return train_state
308 |
309 |
310 | def setup_distributed_print(is_master):
311 | """
312 | This function disables printing when not in master process
313 | """
314 | import builtins as __builtin__
315 | builtin_print = __builtin__.print
316 |
317 | def print(*args, **kwargs):
318 | force = kwargs.pop('force', False)
319 | if is_master or force:
320 | builtin_print(*args, **kwargs)
321 |
322 | __builtin__.print = print
323 |
324 |
325 | def using_distributed():
326 | return dist.is_available() and dist.is_initialized()
327 |
328 |
329 | def get_rank():
330 | return dist.get_rank() if using_distributed() else 0
331 |
332 |
333 | def set_seed(seed):
334 | rank = get_rank()
335 | seed = seed + rank
336 | torch.manual_seed(seed)
337 | torch.cuda.manual_seed(seed)
338 | np.random.seed(seed)
339 | random.seed(seed)
340 | torch.backends.cudnn.enabled = True
341 | torch.backends.cudnn.benchmark = True
342 | if using_distributed():
343 | print(f'Seeding node {rank} with seed {seed}', force=True)
344 | else:
345 | print(f'Seeding node {rank} with seed {seed}')
346 |
347 |
348 | def tensor_to_pil(image: torch.Tensor):
349 | assert len(image.shape) and image.shape[0] == 3, f"{image.shape=}"
350 | image = (image.float() * 0.5 + 0.5).clamp(0, 1).detach().cpu().requires_grad_(False)
351 | ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
352 | return Image.fromarray(ndarr)
353 |
354 |
355 | def albumentations_to_torch(transform):
356 | def _transform(img, target):
357 | augmented = transform(image=img, mask=target)
358 | return augmented['image'], augmented['mask']
359 | return _transform
360 |
--------------------------------------------------------------------------------
/object-localization/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Datasets file. Code adapted from LOST: https://github.com/valeoai/LOST
3 | """
4 | import os
5 | import torch
6 | import json
7 | import torchvision
8 | import numpy as np
9 | import skimage.io
10 |
11 | from PIL import Image
12 | from tqdm import tqdm
13 | from torchvision import transforms as pth_transforms
14 | from traitlets.traitlets import default
15 |
16 |
17 | class ImageDataset:
18 | def __init__(self, image_path, transform):
19 |
20 | self.image_path = image_path
21 | self.transform = transform
22 | self.name = image_path.split("/")[-1]
23 |
24 | # Read the image
25 | with open(image_path, "rb") as f:
26 | img = Image.open(f)
27 | img = img.convert("RGB")
28 |
29 | # Build a dataloader
30 | img = self.transform(img)
31 | self.dataloader = [[img, image_path]]
32 |
33 | def get_image_name(self, *args, **kwargs):
34 | return self.image_path.split("/")[-1].split(".")[0]
35 |
36 | def load_image(self, *args, **kwargs):
37 | return skimage.io.imread(self.image_path)
38 |
39 |
40 | class Dataset:
41 | def __init__(self, dataset_name, dataset_set, remove_hards, transform):
42 | """
43 | Build the dataloader
44 | """
45 |
46 | self.dataset_name = dataset_name
47 | self.set = dataset_set
48 | self.transform = transform
49 |
50 | if dataset_name == "VOC07":
51 | self.root_path = "datasets/VOC2007"
52 | self.year = "2007"
53 | elif dataset_name == "VOC12":
54 | self.root_path = "datasets/VOC2012"
55 | self.year = "2012"
56 | elif dataset_name == "COCO20k":
57 | self.year = "2014"
58 | self.root_path = f"datasets/COCO/images/{dataset_set}{self.year}"
59 | self.sel20k = 'datasets/coco_20k_filenames.txt'
60 | # JSON file constructed based on COCO train2014 gt
61 | self.all_annfile = "datasets/COCO/annotations/instances_train2014.json"
62 | self.annfile = "datasets/instances_train2014_sel20k.json"
63 | if not os.path.exists(self.annfile):
64 | select_coco_20k(self.sel20k, self.all_annfile)
65 | else:
66 | raise ValueError("Unknown dataset.")
67 |
68 | if not os.path.exists(self.root_path):
69 | print(self.root_path)
70 | raise ValueError("Please follow the README to setup the datasets.")
71 |
72 | self.name = f"{self.dataset_name}_{self.set}"
73 |
74 | # Build the dataloader
75 | if "VOC" in dataset_name:
76 | self.dataloader = torchvision.datasets.VOCDetection(
77 | self.root_path,
78 | year=self.year,
79 | image_set=self.set,
80 | transform=self.transform,
81 | download=False,
82 | )
83 | elif "COCO20k" == dataset_name:
84 | self.dataloader = torchvision.datasets.CocoDetection(
85 | self.root_path, annFile=self.annfile, transform=self.transform
86 | )
87 | else:
88 | raise ValueError("Unknown dataset.")
89 |
90 | # Set hards images that are not included
91 | self.remove_hards = remove_hards
92 | self.hards = []
93 | if remove_hards:
94 | self.name += f"-nohards"
95 | self.hards = self.get_hards()
96 | print(f"Nb images discarded {len(self.hards)}")
97 |
98 | def load_image(self, im_name):
99 | """
100 | Load the image corresponding to the im_name
101 | """
102 | if "VOC" in self.dataset_name:
103 | image = skimage.io.imread(f"/datasets_local/VOC{self.year}/JPEGImages/{im_name}")
104 | elif "COCO" in self.dataset_name:
105 | im_path = self.path_20k[self.sel_20k.index(im_name)]
106 | image = skimage.io.imread(f"/datasets_local/COCO/images/{im_path}")
107 | else:
108 | raise ValueError("Unkown dataset.")
109 | return image
110 |
111 | def get_image_name(self, inp):
112 | """
113 | Return the image name
114 | """
115 | if "VOC" in self.dataset_name:
116 | im_name = inp["annotation"]["filename"]
117 | elif "COCO" in self.dataset_name:
118 | im_name = str(inp[0]["image_id"])
119 |
120 | return im_name
121 |
122 | def extract_gt(self, targets, im_name):
123 | if "VOC" in self.dataset_name:
124 | return extract_gt_VOC(targets, remove_hards=self.remove_hards)
125 | elif "COCO" in self.dataset_name:
126 | return extract_gt_COCO(targets, remove_iscrowd=True)
127 | else:
128 | raise ValueError("Unknown dataset")
129 |
130 | def extract_classes(self):
131 | if "VOC" in self.dataset_name:
132 | cls_path = f"classes_{self.set}_{self.year}.txt"
133 | elif "COCO" in self.dataset_name:
134 | cls_path = f"classes_{self.dataset}_{self.set}_{self.year}.txt"
135 |
136 | # Load if exists
137 | if os.path.exists(cls_path):
138 | all_classes = []
139 | with open(cls_path, "r") as f:
140 | for line in f:
141 | all_classes.append(line.strip())
142 | else:
143 | print("Extract all classes from the dataset")
144 | if "VOC" in self.dataset_name:
145 | all_classes = self.extract_classes_VOC()
146 | elif "COCO" in self.dataset_name:
147 | all_classes = self.extract_classes_COCO()
148 |
149 | with open(cls_path, "w") as f:
150 | for s in all_classes:
151 | f.write(str(s) + "\n")
152 |
153 | return all_classes
154 |
155 | def extract_classes_VOC(self):
156 | all_classes = []
157 | for im_id, inp in enumerate(tqdm(self.dataloader)):
158 | objects = inp[1]["annotation"]["object"]
159 |
160 | for o in range(len(objects)):
161 | if objects[o]["name"] not in all_classes:
162 | all_classes.append(objects[o]["name"])
163 |
164 | return all_classes
165 |
166 | def extract_classes_COCO(self):
167 | all_classes = []
168 | for im_id, inp in enumerate(tqdm(self.dataloader)):
169 | objects = inp[1]
170 |
171 | for o in range(len(objects)):
172 | if objects[o]["category_id"] not in all_classes:
173 | all_classes.append(objects[o]["category_id"])
174 |
175 | return all_classes
176 |
177 | def get_hards(self):
178 | hard_path = "datasets/hard_%s_%s_%s.txt" % (self.dataset_name, self.set, self.year)
179 | if os.path.exists(hard_path):
180 | hards = []
181 | with open(hard_path, "r") as f:
182 | for line in f:
183 | hards.append(int(line.strip()))
184 | else:
185 | print("Discover hard images that should be discarded")
186 |
187 | if "VOC" in self.dataset_name:
188 | # set the hards
189 | hards = discard_hard_voc(self.dataloader)
190 |
191 | with open(hard_path, "w") as f:
192 | for s in hards:
193 | f.write(str(s) + "\n")
194 |
195 | return hards
196 |
197 |
198 | def discard_hard_voc(dataloader):
199 | hards = []
200 | for im_id, inp in enumerate(tqdm(dataloader)):
201 | objects = inp[1]["annotation"]["object"]
202 | nb_obj = len(objects)
203 |
204 | hard = np.zeros(nb_obj)
205 | for i, o in enumerate(range(nb_obj)):
206 | hard[i] = (
207 | 1
208 | if (objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1")
209 | else 0
210 | )
211 |
212 | # all images with only truncated or difficult objects
213 | if np.sum(hard) == nb_obj:
214 | hards.append(im_id)
215 | return hards
216 |
217 |
218 | def extract_gt_COCO(targets, remove_iscrowd=True):
219 | objects = targets
220 | nb_obj = len(objects)
221 |
222 | gt_bbxs = []
223 | gt_clss = []
224 | for o in range(nb_obj):
225 | # Remove iscrowd boxes
226 | if remove_iscrowd and objects[o]["iscrowd"] == 1:
227 | continue
228 | gt_cls = objects[o]["category_id"]
229 | gt_clss.append(gt_cls)
230 | bbx = objects[o]["bbox"]
231 | x1y1x2y2 = [bbx[0], bbx[1], bbx[0] + bbx[2], bbx[1] + bbx[3]]
232 | x1y1x2y2 = [int(round(x)) for x in x1y1x2y2]
233 | gt_bbxs.append(x1y1x2y2)
234 |
235 | return np.asarray(gt_bbxs), gt_clss
236 |
237 |
238 | def extract_gt_VOC(targets, remove_hards=False):
239 | objects = targets["annotation"]["object"]
240 | nb_obj = len(objects)
241 |
242 | gt_bbxs = []
243 | gt_clss = []
244 | for o in range(nb_obj):
245 | if remove_hards and (
246 | objects[o]["truncated"] == "1" or objects[o]["difficult"] == "1"
247 | ):
248 | continue
249 | gt_cls = objects[o]["name"]
250 | gt_clss.append(gt_cls)
251 | obj = objects[o]["bndbox"]
252 | x1y1x2y2 = [
253 | int(obj["xmin"]),
254 | int(obj["ymin"]),
255 | int(obj["xmax"]),
256 | int(obj["ymax"]),
257 | ]
258 | # Original annotations are integers in the range [1, W or H]
259 | # Assuming they mean 1-based pixel indices (inclusive),
260 | # a box with annotation (xmin=1, xmax=W) covers the whole image.
261 | # In coordinate space this is represented by (xmin=0, xmax=W)
262 | x1y1x2y2[0] -= 1
263 | x1y1x2y2[1] -= 1
264 | gt_bbxs.append(x1y1x2y2)
265 |
266 | return np.asarray(gt_bbxs), gt_clss
267 |
268 |
269 | def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
270 | # https://github.com/ultralytics/yolov5/blob/develop/utils/general.py
271 | # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
272 | box2 = box2.T
273 |
274 | # Get the coordinates of bounding boxes
275 | if x1y1x2y2: # x1, y1, x2, y2 = box1
276 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
277 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
278 | else: # transform from xywh to xyxy
279 | b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
280 | b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
281 | b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
282 | b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
283 |
284 | # Intersection area
285 | inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (
286 | torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)
287 | ).clamp(0)
288 |
289 | # Union Area
290 | w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
291 | w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
292 | union = w1 * h1 + w2 * h2 - inter + eps
293 |
294 | iou = inter / union
295 | if GIoU or DIoU or CIoU:
296 | cw = torch.max(b1_x2, b2_x2) - torch.min(
297 | b1_x1, b2_x1
298 | ) # convex (smallest enclosing box) width
299 | ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
300 | if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
301 | c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
302 | rho2 = (
303 | (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
304 | (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
305 | ) / 4 # center distance squared
306 | if DIoU:
307 | return iou - rho2 / c2 # DIoU
308 | elif (
309 | CIoU
310 | ): # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
311 | v = (4 / math.pi ** 2) * torch.pow(
312 | torch.atan(w2 / h2) - torch.atan(w1 / h1), 2
313 | )
314 | with torch.no_grad():
315 | alpha = v / (v - iou + (1 + eps))
316 | return iou - (rho2 / c2 + v * alpha) # CIoU
317 | else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
318 | c_area = cw * ch + eps # convex area
319 | return iou - (c_area - union) / c_area # GIoU
320 | else:
321 | return iou # IoU
322 |
323 |
324 | def select_coco_20k(sel_file, all_annotations_file):
325 | print('Building COCO 20k dataset.')
326 |
327 | # load all annotations
328 | with open(all_annotations_file, "r") as f:
329 | train2014 = json.load(f)
330 |
331 | # load selected images
332 | with open(sel_file, "r") as f:
333 | sel_20k = f.readlines()
334 | sel_20k = [s.replace("\n", "") for s in sel_20k]
335 | im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in sel_20k]
336 |
337 | # # OLD
338 | # new_anno = []
339 | # new_images = []
340 | # for i in tqdm(im20k):
341 | # new_anno.extend(
342 | # [a for a in train2014["annotations"] if a["image_id"] == int(i)]
343 | # )
344 | # new_images.extend([a for a in train2014["images"] if a["id"] == int(i)])
345 |
346 | # NEW
347 | from collections import defaultdict
348 | id_to_ann = defaultdict(list) # image_id --> [annotations]
349 | id_to_img = defaultdict(list) # image_id --> [images]
350 | for a in tqdm(train2014["annotations"]):
351 | id_to_ann[a["image_id"]].append(a)
352 | for im in tqdm(train2014["images"]):
353 | id_to_img[im["id"]].append(a)
354 | new_anno = [id_to_ann[int(id)] for id in im20k]
355 | new_images = [id_to_img[int(id)] for id in im20k]
356 | print(len(im20k))
357 | print(len(new_anno))
358 | print(len(new_images))
359 |
360 | train2014_20k = {}
361 | train2014_20k["images"] = new_images
362 | train2014_20k["annotations"] = new_anno
363 | train2014_20k["categories"] = train2014["categories"]
364 |
365 | with open("datasets/instances_train2014_sel20k.json", "w") as outfile:
366 | json.dump(train2014_20k, outfile)
367 |
368 | print('Done.')
369 |
--------------------------------------------------------------------------------
/semantic-segmentation/train.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import os
4 | import sys
5 | from contextlib import nullcontext
6 | from pathlib import Path
7 | from typing import Callable, Iterable, Optional
8 |
9 | import hydra
10 | import numpy as np
11 | import torch
12 | import wandb
13 | from accelerate import Accelerator
14 | from omegaconf import DictConfig, OmegaConf
15 | from PIL import Image
16 | from torch.nn import functional as F
17 | from torch.utils.data import DataLoader
18 | from tqdm import tqdm
19 |
20 | import util as utils
21 | from dataset import get_datasets
22 | from model import get_model
23 |
24 |
25 | @hydra.main(config_path='config', config_name='train')
26 | def main(cfg: DictConfig):
27 |
28 | # Accelerator
29 | accelerator = Accelerator(fp16=cfg.fp16, cpu=cfg.cpu)
30 |
31 | # Logging
32 | utils.setup_distributed_print(accelerator.is_local_main_process)
33 | if cfg.wandb and accelerator.is_local_main_process:
34 | wandb.init(name=cfg.name, job_type=cfg.job_type, config=OmegaConf.to_container(cfg), save_code=True, **cfg.wandb_kwargs)
35 | cfg = DictConfig(wandb.config.as_dict()) # get the config back from wandb for hyperparameter sweeps
36 |
37 | # Configuration
38 | print(OmegaConf.to_yaml(cfg))
39 | print(f'Current working directory: {os.getcwd()}')
40 |
41 | # Set random seed
42 | utils.set_seed(cfg.seed)
43 |
44 | # Create model
45 | model = get_model(**cfg.model)
46 |
47 | # Freeze layers, if desired
48 | if cfg.unfrozen_backbone_layers >= 0:
49 | num_unfrozen = None if (cfg.unfrozen_backbone_layers == 0) else (-cfg.unfrozen_backbone_layers)
50 | for module in list(model.backbone.children())[:num_unfrozen]:
51 | for p in module.parameters():
52 | p.requires_grad_(False)
53 |
54 | print(f'Parameters (total): {sum(p.numel() for p in model.parameters()):_d}')
55 | print(f'Parameters (train): {sum(p.numel() for p in model.parameters() if p.requires_grad):_d}')
56 | print(f'Backbone parameters (total): {sum(p.numel() for p in model.backbone.parameters()):_d}')
57 | print(f'Backbone parameters (train): {sum(p.numel() for p in model.backbone.parameters() if p.requires_grad):_d}')
58 |
59 | # Optimizer and scheduler
60 | optimizer = utils.get_optimizer(cfg, model, accelerator)
61 | scheduler = utils.get_scheduler(cfg, optimizer)
62 |
63 | # Resume from checkpoint and create the initial training state
64 | if cfg.checkpoint.resume:
65 | train_state: utils.TrainState = utils.resume_from_checkpoint(cfg, model, optimizer, scheduler, model_ema=None)
66 | else:
67 | train_state = utils.TrainState() # start training from scratch
68 |
69 | # Data
70 | dataset_train, dataset_val, collate_fn = get_datasets(cfg)
71 | dataloader_train = DataLoader(dataset_train, shuffle=True, drop_last=True,
72 | collate_fn=collate_fn, **cfg.data.loader)
73 | dataloader_val = DataLoader(dataset_val, shuffle=False, drop_last=False,
74 | collate_fn=collate_fn, **{**cfg.data.loader, 'batch_size': 1})
75 | total_batch_size = cfg.data.loader.batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps
76 |
77 | # SyncBatchNorm
78 | if accelerator.num_processes > 1:
79 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
80 |
81 | # Setup
82 | model, optimizer, dataloader_train = accelerator.prepare(model, optimizer, dataloader_train)
83 |
84 | # Exponential moving average of model parameters
85 | if cfg.ema.use_ema:
86 | from torch_ema import ExponentialMovingAverage
87 | model_ema = ExponentialMovingAverage((p for p in model.parameters() if p.requires_grad), decay=cfg.ema.decay)
88 | print('Initialized model EMA')
89 | else:
90 | model_ema = None
91 | print('Not using model EMA')
92 |
93 | # Shared training, evaluation, and visualization args
94 | kwargs = dict(
95 | cfg=cfg,
96 | model=model,
97 | dataloader_train=dataloader_train,
98 | dataloader_val=dataloader_val,
99 | optimizer=optimizer,
100 | scheduler=scheduler,
101 | accelerator=accelerator,
102 | model_ema=model_ema,
103 | train_state=train_state)
104 |
105 | # Evaluation
106 | if cfg.job_type == 'generate':
107 | test_stats = generate(**kwargs)
108 | return 0
109 |
110 | # Evaluation
111 | if cfg.job_type == 'eval':
112 | test_stats = evaluate(**kwargs)
113 | return test_stats['val_loss']
114 |
115 | # Info
116 | print(f'***** Starting training at {datetime.datetime.now()} *****')
117 | print(f' Dataset train size: {len(dataset_train):_}')
118 | print(f' Dataset val size: {len(dataset_val):_}')
119 | print(f' Dataloader train size: {len(dataloader_train):_}')
120 | print(f' Dataloader val size: {len(dataloader_val):_}')
121 | print(f' Batch size per device = {cfg.data.loader.batch_size}')
122 | print(f' Total train batch size (w. parallel, dist & accum) = {total_batch_size}')
123 | print(f' Gradient Accumulation steps = {cfg.gradient_accumulation_steps}')
124 | print(f' Max optimization steps = {cfg.max_train_steps}')
125 | print(f' Max optimization epochs = {cfg.max_train_epochs}')
126 | print(f' Training state = {train_state}')
127 |
128 | # Evaluate masks before training
129 | if cfg.get('eval_masks_before_training', True):
130 | print('Evaluating masks before training...')
131 | if accelerator.is_main_process:
132 | evaluate(**kwargs, evaluate_dataset_pseudolabels=True) # <-- to evaluate the self-training masks
133 | torch.cuda.synchronize()
134 |
135 | # Training loop
136 | while True:
137 |
138 | # Single epoch of training
139 | train_state = train_one_epoch(**kwargs)
140 |
141 | # Save checkpoint on only 1 process
142 | if accelerator.is_local_main_process:
143 | checkpoint_dict = {
144 | 'model': accelerator.unwrap_model(model).state_dict(),
145 | 'optimizer': optimizer.state_dict(),
146 | 'scheduler': scheduler.state_dict(),
147 | 'epoch': train_state.epoch,
148 | 'step': train_state.step,
149 | 'best_val': train_state.best_val,
150 | 'model_ema': model_ema.state_dict() if model_ema else {},
151 | 'cfg': cfg
152 | }
153 | print(f'Saved checkpoint to {str(Path(".").resolve())}')
154 | accelerator.save(checkpoint_dict, 'checkpoint-latest.pth')
155 | if (train_state.epoch > 0) and (train_state.epoch % cfg.checkpoint_every == 0):
156 | accelerator.save(checkpoint_dict, f'checkpoint-{train_state.epoch:04d}.pth')
157 |
158 | # Evaluate
159 | if train_state.epoch % cfg.get('eval_every', 1) == 0:
160 | test_stats = evaluate(**kwargs)
161 | if accelerator.is_local_main_process:
162 | if (train_state.best_val is None) or (test_stats['mIoU'] > train_state.best_val):
163 | train_state.best_val = test_stats['mIoU']
164 | torch.save(checkpoint_dict, 'checkpoint-best.pth')
165 | if cfg.wandb:
166 | wandb.log(test_stats)
167 | wandb.run.summary["best_mIoU"] = train_state.best_val
168 |
169 | # End training
170 | if ((cfg.max_train_steps is not None and train_state.step >= cfg.max_train_steps) or
171 | (cfg.max_train_epochs is not None and train_state.epoch >= cfg.max_train_epochs)):
172 | print(f'Ending training at: {datetime.datetime.now()}')
173 | print(f'Final train state: {train_state}')
174 | sys.exit()
175 |
176 |
177 | def train_one_epoch(
178 | *,
179 | cfg: DictConfig,
180 | model: torch.nn.Module,
181 | dataloader_train: Iterable,
182 | optimizer: torch.optim.Optimizer,
183 | accelerator: Accelerator,
184 | scheduler: Callable,
185 | train_state: utils.TrainState,
186 | model_ema: Optional[object] = None,
187 | **_unused_kwargs
188 | ):
189 |
190 | # Train mode
191 | model.train()
192 | log_header = f'Epoch: [{train_state.epoch}]'
193 | metric_logger = utils.MetricLogger(delimiter=" ")
194 | metric_logger.add_meter('step', utils.SmoothedValue(window_size=1, fmt='{value:.0f}'))
195 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
196 | progress_bar = metric_logger.log_every(dataloader_train, cfg.logging.print_freq, header=log_header)
197 |
198 | # Train
199 | for i, (images, _, pseudolabels, _) in enumerate(progress_bar):
200 | if i >= cfg.get('limit_train_batches', math.inf):
201 | break
202 |
203 | # Forward
204 | output = model(images) # (B, C, H, W)
205 |
206 | # Cross-entropy loss
207 | loss = F.cross_entropy(output, pseudolabels)
208 |
209 | # Measure accuracy
210 | acc1, acc5 = utils.accuracy(output, pseudolabels, topk=(1, 5))
211 |
212 | # Exit if loss is NaN
213 | loss_value = loss.item()
214 | if not math.isfinite(loss_value):
215 | print("Loss is {}, stopping training".format(loss_value))
216 | sys.exit(1)
217 |
218 | # Loss scaling and backward
219 | accelerator.backward(loss)
220 |
221 | # Gradient accumulation, optimizer step, scheduler step
222 | if i % cfg.gradient_accumulation_steps == 0:
223 | optimizer.step()
224 | optimizer.zero_grad()
225 | torch.cuda.synchronize()
226 | if cfg.scheduler.stepwise:
227 | scheduler.step()
228 | train_state.step += 1
229 |
230 | # Model EMA
231 | if model_ema is not None and (train_state.epoch % cfg.ema.update_every) == 0:
232 | model_ema.update((p for p in model.parameters() if p.requires_grad))
233 |
234 | # Logging
235 | log_dict = dict(
236 | lr=optimizer.param_groups[0]["lr"], step=train_state.step,
237 | train_loss=loss_value, sup_loss=sup_loss, con_loss=con_loss,
238 | train_top1=acc1[0], train_top5=acc5[0],
239 | )
240 | metric_logger.update(**log_dict)
241 | if cfg.wandb and accelerator.is_local_main_process:
242 | wandb.log(log_dict)
243 |
244 | # Scheduler
245 | if not cfg.scheduler.stepwise:
246 | scheduler.step()
247 |
248 | # Epoch complete
249 | train_state.epoch += 1
250 |
251 | # Gather stats from all processes
252 | metric_logger.synchronize_between_processes(device=accelerator.device)
253 | print("Averaged stats:", metric_logger)
254 | return train_state
255 |
256 |
257 | @torch.no_grad()
258 | def evaluate(
259 | *,
260 | cfg: DictConfig,
261 | model: torch.nn.Module,
262 | dataloader_val: Iterable,
263 | accelerator: Accelerator,
264 | model_ema: Optional[object] = None,
265 | evaluate_dataset_pseudolabels: bool = False,
266 | **_unused_kwargs
267 | ):
268 |
269 | # To avoid CUDA errors on my machine
270 | torch.backends.cudnn.benchmark = False
271 |
272 | # Eval mode
273 | model.eval()
274 | torch.cuda.synchronize()
275 | eval_context = model_ema.average_parameters if cfg.ema.use_ema else nullcontext
276 |
277 | # Add background class
278 | n_classes = cfg.data.num_classes + 1
279 |
280 | # Iterate
281 | tp = [0] * n_classes
282 | fp = [0] * n_classes
283 | fn = [0] * n_classes
284 |
285 | # Check
286 | assert dataloader_val.batch_size == 1, 'Please use batch_size=1 for val to compute mIoU'
287 |
288 | # Load all pixel embeddings
289 | all_preds = np.zeros((len(dataloader_val) * 500 * 500), dtype=np.float32)
290 | all_gt = np.zeros((len(dataloader_val) * 500 * 500), dtype=np.float32)
291 | offset_ = 0
292 |
293 | # Add all pixels to our arrays
294 | with eval_context():
295 | for (inputs, targets, mask, _) in tqdm(dataloader_val, desc='Concatenating all predictions'):
296 |
297 | # Predict
298 | if evaluate_dataset_pseudolabels:
299 | mask = mask
300 | else:
301 | logits = model(inputs.to(accelerator.device).contiguous()).squeeze(0) # (C, H, W)
302 | mask = torch.argmax(logits, dim=0) # (H, W)
303 |
304 | # Convert
305 | target = targets.numpy().squeeze()
306 | mask = mask.cpu().numpy().squeeze()
307 |
308 | # Check where ground-truth is valid and append valid pixels to the array
309 | valid = (target != 255)
310 | n_valid = np.sum(valid)
311 | all_gt[offset_:offset_+n_valid] = target[valid]
312 |
313 | # Possibly reshape embedding to match gt.
314 | if mask.shape != target.shape:
315 | raise ValueError(f'{mask.shape=} != {target.shape=}')
316 |
317 | # Append the predicted targets in the array
318 | all_preds[offset_:offset_+n_valid, ] = mask[valid]
319 | all_gt[offset_:offset_+n_valid, ] = target[valid]
320 |
321 | # Update offset_
322 | offset_ += n_valid
323 |
324 | # Truncate to the actual number of pixels
325 | all_preds = all_preds[:offset_, ]
326 | all_gt = all_gt[:offset_, ]
327 |
328 | # TP, FP, and FN evaluation
329 | for i_part in range(0, n_classes):
330 | tmp_all_gt = (all_gt == i_part)
331 | tmp_pred = (all_preds == i_part)
332 | tp[i_part] += np.sum(tmp_all_gt & tmp_pred)
333 | fp[i_part] += np.sum(~tmp_all_gt & tmp_pred)
334 | fn[i_part] += np.sum(tmp_all_gt & ~tmp_pred)
335 |
336 | # Calculate Jaccard index
337 | jac = [0] * n_classes
338 | for i_part in range(0, n_classes):
339 | jac[i_part] = float(tp[i_part]) / max(float(tp[i_part] + fp[i_part] + fn[i_part]), 1e-8)
340 |
341 | # Print results
342 | eval_result = dict()
343 | eval_result['jaccards_all_categs'] = jac
344 | eval_result['mIoU'] = np.mean(jac)
345 | print('Evaluation of semantic segmentation ')
346 | print(f'Full eval result: {eval_result}')
347 | print('mIoU is %.2f' % (100*eval_result['mIoU']))
348 | return eval_result
349 |
350 |
351 | @torch.no_grad()
352 | def generate(
353 | *,
354 | cfg: DictConfig,
355 | model: torch.nn.Module,
356 | dataloader_val: Iterable,
357 | accelerator: Accelerator,
358 | model_ema: Optional[object] = None,
359 | **_unused_kwargs
360 | ):
361 |
362 | # To avoid CUDA errors on my machine
363 | torch.backends.cudnn.benchmark = False
364 |
365 | # Eval mode
366 | model.eval()
367 | torch.cuda.synchronize()
368 | eval_context = model_ema.average_parameters if cfg.ema.use_ema else nullcontext
369 |
370 | # Create paths
371 | preds_dir = Path('preds')
372 | gt_dir = Path('gt')
373 | preds_dir.mkdir(exist_ok=True, parents=True)
374 | gt_dir.mkdir(exist_ok=True, parents=True)
375 |
376 | # Generate and save
377 | with eval_context():
378 | for (inputs, targets, _, metadata) in tqdm(dataloader_val, desc='Concatenating all predictions'):
379 | # Predict
380 | logits = model(inputs.to(accelerator.device).contiguous()).squeeze(0) # (C, H, W)
381 | # Convert
382 | preds = torch.argmax(logits, dim=0).cpu().numpy().astype(np.uint8)
383 | gt = targets.squeeze().numpy().astype(np.uint8)
384 | # Save
385 | Image.fromarray(preds).convert('L').save(preds_dir / f"{metadata[0]['id']}.png")
386 | Image.fromarray(gt).convert('L').save(gt_dir / f"{metadata[0]['id']}.png")
387 |
388 | print(f'Saved to {Path(".").absolute()}')
389 |
390 |
391 | if __name__ == '__main__':
392 | main()
393 |
--------------------------------------------------------------------------------
/object-localization/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Main experiment file. Code adapted from LOST: https://github.com/valeoai/LOST
3 | """
4 | import os
5 | import sys
6 | import argparse
7 | import random
8 | import pickle
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from pprint import pprint
14 | from typing import Union
15 | from pathlib import Path
16 | from torchvision import transforms
17 | from tqdm import tqdm
18 | from PIL import Image
19 |
20 | from networks import get_model
21 | from datasets import ImageDataset, Dataset, bbox_iou
22 | from visualizations import visualize_fms, visualize_predictions, visualize_seed_expansion
23 | from object_discovery import lost, detect_box, dino_seg, get_eigenvectors_from_features, get_largest_cc_box, get_bbox_from_patch_mask
24 |
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser("Visualize Self-Attention maps")
28 | parser.add_argument(
29 | "--arch",
30 | default="vit_small",
31 | type=str,
32 | choices=[
33 | "vit_tiny",
34 | "vit_small",
35 | "vit_base",
36 | "resnet50",
37 | "vgg16_imagenet",
38 | "resnet50_imagenet",
39 | ],
40 | help="Model architecture.",
41 | )
42 | parser.add_argument(
43 | "--patch_size", default=16, type=int, help="Patch resolution of the model."
44 | )
45 |
46 | # Use a dataset
47 | parser.add_argument(
48 | "--dataset",
49 | default="VOC07",
50 | type=str,
51 | choices=[None, "VOC07", "VOC12", "COCO20k"],
52 | help="Dataset name.",
53 | )
54 | parser.add_argument(
55 | "--set",
56 | default="trainval",
57 | type=str,
58 | choices=["val", "train", "trainval", "test"],
59 | help="Path of the image to load.",
60 | )
61 | # Or use a single image
62 | parser.add_argument(
63 | "--image_path",
64 | type=str,
65 | default=None,
66 | help="If want to apply only on one image, give file path.",
67 | )
68 |
69 | # Folder used to output visualizations and
70 | parser.add_argument(
71 | "--output_dir", type=str, default="outputs", help="Output directory to store predictions and visualizations."
72 | )
73 |
74 | # Evaluation setup
75 | parser.add_argument("--no_hard", action="store_true", help="Only used in the case of the VOC_all setup (see the paper).")
76 | parser.add_argument("--no_evaluation", action="store_true", help="Compute the evaluation.")
77 | parser.add_argument("--save_predictions", default=True, type=bool, help="Save predicted bouding boxes.")
78 |
79 | # Visualization
80 | parser.add_argument(
81 | "--visualize",
82 | type=str,
83 | choices=["fms", "seed_expansion", "pred", None],
84 | default=None,
85 | help="Select the different type of visualizations.",
86 | )
87 |
88 | # For ResNet dilation
89 | parser.add_argument("--resnet_dilate", type=int, default=2, help="Dilation level of the resnet model.")
90 |
91 | # LOST parameters
92 | parser.add_argument(
93 | "--which_features",
94 | type=str,
95 | default="k",
96 | choices=["k", "q", "v"],
97 | help="Which features to use",
98 | )
99 | parser.add_argument(
100 | "--k_patches",
101 | type=int,
102 | default=100,
103 | help="Number of patches with the lowest degree considered."
104 | )
105 |
106 | # Misc
107 | parser.add_argument("--name", type=str, default=None, help='Experiment name')
108 | parser.add_argument("--skip_if_exists", action='store_true', help='If results dir already exists , exit')
109 |
110 | # Use dino-seg proposed method
111 | parser.add_argument("--dinoseg", action="store_true", help="Apply DINO-seg baseline.")
112 | parser.add_argument("--dinoseg_head", type=int, default=4)
113 |
114 | # Use eigenvalue method
115 | parser.add_argument("--eigenseg", action='store_true', help='Apply eigenvalue method')
116 | parser.add_argument("--precomputed_eigs_dir", default=None, type=str,
117 | help='Apply eigenvalue method with precomputed bboxes')
118 | parser.add_argument("--precomputed_eigs_downsample", default=16, type=str)
119 | parser.add_argument("--which_matrix", choices=['infer', 'affinity', 'laplacian'],
120 | default='infer', help='Which matrix to use for eigenvector calculation')
121 |
122 | # Parse
123 | args = parser.parse_args()
124 |
125 | # Modify
126 | if args.image_path is not None:
127 | args.save_predictions = False
128 | args.no_evaluation = True
129 | args.dataset = None
130 |
131 | return args
132 |
133 |
134 | @torch.no_grad()
135 | def main():
136 |
137 | # Args
138 | args = parse_args()
139 |
140 | # -------------------------------------------------------------------------------------------------------
141 | # Dataset
142 |
143 | # Transform
144 | transform = transforms.Compose([
145 | transforms.ToTensor(),
146 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
147 | ])
148 |
149 | # If an image_path is given, apply the method only to the image
150 | if args.image_path is not None:
151 | dataset = ImageDataset(args.image_path, transform)
152 | else:
153 | dataset = Dataset(args.dataset, args.set, args.no_hard, transform)
154 |
155 | # Naming
156 | if args.name is not None:
157 | exp_name = args.name
158 | elif args.dinoseg:
159 | # Experiment with the baseline DINO-seg
160 | if "vit" not in args.arch:
161 | raise ValueError("DINO-seg can only be applied to tranformer networks.")
162 | exp_name = f"{args.arch}-{args.patch_size}_dinoseg-head{args.dinoseg_head}"
163 | else:
164 | # Experiment with LOST
165 | exp_name = f"LOST-{args.arch}"
166 | if "resnet" in args.arch:
167 | exp_name += f"dilate{args.resnet_dilate}"
168 | elif "vit" in args.arch:
169 | exp_name += f"{args.patch_size}_{args.which_features}"
170 |
171 | # -------------------------------------------------------------------------------------------------------
172 | # Directories
173 | if args.image_path is None:
174 | args.output_dir = os.path.join(args.output_dir, dataset.name)
175 |
176 | # Skip if already exists
177 | exp_dir = Path(args.output_dir) / exp_name
178 | if args.skip_if_exists and exp_dir.is_dir() and len(list(exp_dir.iterdir())) > 0:
179 | print(f'Directory already exists and is not empty: {str(exp_dir)}')
180 | print(f'Exiting...')
181 | sys.exit()
182 | os.makedirs(args.output_dir, exist_ok=True)
183 |
184 | # -------------------------------------------------------------------------------------------------------
185 | # Model
186 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
187 | model = get_model(args.arch, args.patch_size, args.resnet_dilate, device)
188 |
189 | print(f"Running LOST on the dataset {dataset.name} (exp: {exp_name})")
190 | print(f"Args:")
191 | print(pprint(args.__dict__))
192 |
193 | # Visualization
194 | if args.visualize:
195 | vis_folder = f"{args.output_dir}/visualizations/{exp_name}"
196 | os.makedirs(vis_folder, exist_ok=True)
197 |
198 | # -------------------------------------------------------------------------------------------------------
199 | # Loop over images
200 | preds_dict = {}
201 | gt_dict = {}
202 | cnt = 0
203 | corloc = np.zeros(len(dataset.dataloader))
204 |
205 | pbar = tqdm(dataset.dataloader)
206 | for im_id, inp in enumerate(pbar):
207 |
208 | # ------------ IMAGE PROCESSING -------------------------------------------
209 | img = inp[0]
210 | init_image_size = img.shape
211 |
212 | # Get the name of the image
213 | im_name = dataset.get_image_name(inp[1])
214 |
215 | # Pass in case of no gt boxes in the image
216 | if im_name is None:
217 | continue
218 |
219 | # Padding the image with zeros to fit multiple of patch-size
220 | if args.eigenseg:
221 | size_im = (
222 | img.shape[0],
223 | int(np.floor(img.shape[1] / args.patch_size) * args.patch_size),
224 | int(np.floor(img.shape[2] / args.patch_size) * args.patch_size),
225 | )
226 | img = paded = img[:, :size_im[1], :size_im[2]]
227 | else:
228 | size_im = (
229 | img.shape[0],
230 | int(np.ceil(img.shape[1] / args.patch_size) * args.patch_size),
231 | int(np.ceil(img.shape[2] / args.patch_size) * args.patch_size),
232 | )
233 | paded = torch.zeros(size_im)
234 | paded[:, : img.shape[1], : img.shape[2]] = img
235 | img = paded
236 |
237 | # Size for transformers
238 | w_featmap = img.shape[-2] // args.patch_size
239 | h_featmap = img.shape[-1] // args.patch_size
240 |
241 | # ------------ GROUND-TRUTH -------------------------------------------
242 | if not args.no_evaluation:
243 | gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name)
244 |
245 | if gt_bbxs is not None:
246 | # Discard images with no gt annotations
247 | # Happens only in the case of VOC07 and VOC12
248 | if gt_bbxs.shape[0] == 0 and args.no_hard:
249 | continue
250 |
251 | # ------------ EXTRACT FEATURES -------------------------------------------
252 |
253 | # Load precomputer bounding boxes
254 | if args.eigenseg and args.precomputed_eigs_dir is not None:
255 |
256 | # Load
257 | if 'COCO' in dataset.name:
258 | fname = f"COCO_train2014_{int(im_name):012d}.pth"
259 | elif 'VOC' in dataset.name:
260 | fname = im_name.replace('.jpg', '.pth')
261 | precomputed_eigs_file = os.path.join(args.precomputed_eigs_dir, fname)
262 | precomputed_eigs = torch.load(precomputed_eigs_file, map_location='cpu')
263 | eigenvectors = precomputed_eigs['eigenvectors'] # tensor of shape (K, H_lr * W_lr)
264 |
265 | # Get eigenvectors
266 | if args.which_matrix == 'infer': # get the type
267 | which_matrix = Path(args.precomputed_eigs_dir).name.split('_')[0]
268 | else:
269 | which_matrix = args.which_matrix
270 | segment_index = {'matting': 1, 'laplacian': 1, 'affinity': 0}[which_matrix]
271 | patch_mask = (eigenvectors[segment_index] > 0)
272 | pred = get_bbox_from_patch_mask(patch_mask, init_image_size)
273 |
274 | # Extract features from self-supervised model
275 | else:
276 |
277 | # Move to GPU
278 | img = img.cuda(non_blocking=True)
279 |
280 | # ------------ FORWARD PASS -------------------------------------------
281 | if "vit" in args.arch:
282 | # Store the outputs of qkv layer from the last attention layer
283 | feat_out = {}
284 | def hook_fn_forward_qkv(module, input, output):
285 | feat_out["qkv"] = output
286 | model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv)
287 |
288 | # Forward pass in the model
289 | attentions = model.get_last_selfattention(img[None, :, :, :])
290 |
291 | # Scaling factor
292 | scales = [args.patch_size, args.patch_size]
293 |
294 | # Dimensions
295 | nb_im = attentions.shape[0] # Batch size
296 | nh = attentions.shape[1] # Number of heads
297 | nb_tokens = attentions.shape[2] # Number of tokens
298 |
299 | # Baseline: compute DINO segmentation technique proposed in the DINO paper
300 | # and select the biggest component
301 | if args.dinoseg:
302 | pred = dino_seg(attentions, (w_featmap, h_featmap), args.patch_size, head=args.dinoseg_head)
303 | pred = np.asarray(pred)
304 | else:
305 | # Extract the qkv features of the last attention layer
306 | qkv = (
307 | feat_out["qkv"]
308 | .reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
309 | .permute(2, 0, 3, 1, 4)
310 | )
311 | q, k, v = qkv[0], qkv[1], qkv[2]
312 | k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
313 | q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
314 | v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
315 |
316 | # Modality selection
317 | if args.which_features == "k":
318 | feats = k[:, 1:, :]
319 | elif args.which_features == "q":
320 | feats = q[:, 1:, :]
321 | elif args.which_features == "v":
322 | feats = v[:, 1:, :]
323 | elif "resnet" in args.arch:
324 | x = model.forward(img[None, :, :, :])
325 | d, w_featmap, h_featmap = x.shape[1:]
326 | feats = x.reshape((1, d, -1)).transpose(2, 1)
327 | # Apply layernorm
328 | layernorm = nn.LayerNorm(feats.size()[1:]).to(device)
329 | feats = layernorm(feats)
330 | # Scaling factor
331 | scales = [
332 | float(img.shape[1]) / x.shape[2],
333 | float(img.shape[2]) / x.shape[3],
334 | ]
335 | elif "vgg16" in args.arch:
336 | x = model.forward(img[None, :, :, :])
337 | d, w_featmap, h_featmap = x.shape[1:]
338 | feats = x.reshape((1, d, -1)).transpose(2, 1)
339 | # Apply layernorm
340 | layernorm = nn.LayerNorm(feats.size()[1:]).to(device)
341 | feats = layernorm(feats)
342 | # Scaling factor
343 | scales = [
344 | float(img.shape[1]) / x.shape[2],
345 | float(img.shape[2]) / x.shape[3],
346 | ]
347 | else:
348 | raise ValueError("Unknown model.")
349 |
350 | # Sizes
351 | dims_wh = [w_featmap, h_featmap]
352 |
353 | # ------------ Apply LOST -------------------------------------------
354 | if not args.dinoseg:
355 | if args.eigenseg:
356 |
357 | # Get eigenvectors
358 | eigenvectors = get_eigenvectors_from_features(feats, args.which_matrix)
359 |
360 | # Get bounding box
361 | assert ('affinity' in args.which_matrix) ^ ('laplacian' in args.which_matrix)
362 | eig_index = 0 if 'affinity' in args.which_matrix else 1
363 | patch_mask = (eigenvectors[:, eig_index] > 0)
364 | pred = get_bbox_from_patch_mask(patch_mask, init_image_size)
365 |
366 | else:
367 |
368 | pred, A, M, scores, seed = lost(feats, dims_wh, scales, init_image_size, k_patches=args.k_patches)
369 |
370 | # ------------ Visualizations -------------------------------------------
371 | if args.visualize == "fms":
372 | visualize_fms(A.clone().cpu().numpy(), seed, scores, dims_wh, scales, vis_folder, im_name)
373 |
374 | elif args.visualize == "seed_expansion":
375 | image = dataset.load_image(im_name)
376 |
377 | # Before expansion
378 | pred_seed, _ = detect_box(A[seed, :], seed, dims_wh, scales=scales, initial_im_size=init_image_size[1:])
379 | visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims_wh, vis_folder, im_name)
380 |
381 | elif args.visualize == "pred":
382 | image = dataset.load_image(im_name)
383 | visualize_predictions(image, pred, seed, scales, dims_wh, vis_folder, im_name)
384 |
385 | # Save the prediction
386 | preds_dict[im_name] = pred
387 | gt_dict[im_name] = gt_bbxs
388 |
389 | # Evaluation
390 | if args.no_evaluation:
391 | continue
392 |
393 | # Compare prediction to GT boxes
394 | ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(gt_bbxs))
395 |
396 | if torch.any(ious >= 0.5):
397 | corloc[im_id] = 1
398 |
399 | cnt += 1
400 | if cnt % 10 == 0:
401 | pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt} ({int(np.sum(corloc))/cnt * 100:.1f}%)")
402 |
403 | # Save predicted bounding boxes
404 | if args.save_predictions:
405 | folder = f"{args.output_dir}/{exp_name}"
406 | os.makedirs(folder, exist_ok=True)
407 | with open(os.path.join(folder, "preds.pkl"), "wb") as f:
408 | pickle.dump(preds_dict, f)
409 | with open(os.path.join(folder, "gt.pkl"), "wb") as f:
410 | pickle.dump(gt_dict, f)
411 | print(f"Predictions saved to {folder}")
412 |
413 | # Evaluate
414 | if not args.no_evaluation:
415 | print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})")
416 | result_file = os.path.join(folder, 'results.txt')
417 | with open(result_file, 'w') as f:
418 | f.write('corloc,%.1f,,\n'%(100*np.sum(corloc)/cnt))
419 | print('File saved at %s'%result_file)
420 |
421 |
422 | if __name__ == "__main__":
423 | main()
--------------------------------------------------------------------------------
/extract/extract.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from pathlib import Path
3 | from typing import Optional, Tuple
4 |
5 | import cv2
6 | import fire
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from accelerate import Accelerator
11 | from PIL import Image
12 | from scipy.sparse.linalg import eigsh
13 | from sklearn.cluster import KMeans, MiniBatchKMeans
14 | from sklearn.decomposition import PCA
15 | from torchvision.utils import draw_bounding_boxes
16 | from tqdm import tqdm
17 |
18 | import extract_utils as utils
19 |
20 |
21 | def extract_features(
22 | images_list: str,
23 | images_root: Optional[str],
24 | model_name: str,
25 | batch_size: int,
26 | output_dir: str,
27 | which_block: int = -1,
28 | ):
29 | """
30 | Extract features from a list of images.
31 |
32 | Example:
33 | python extract.py extract_features \
34 | --images_list "./data/VOC2012/lists/images.txt" \
35 | --images_root "./data/VOC2012/images" \
36 | --output_dir "./data/VOC2012/features/dino_vits16" \
37 | --model_name dino_vits16 \
38 | --batch_size 1
39 | """
40 |
41 | # Output
42 | utils.make_output_dir(output_dir)
43 |
44 | # Models
45 | model_name = model_name.lower()
46 | model, val_transform, patch_size, num_heads = utils.get_model(model_name)
47 |
48 | # Add hook
49 | if 'dino' in model_name or 'mocov3' in model_name:
50 | feat_out = {}
51 | def hook_fn_forward_qkv(module, input, output):
52 | feat_out["qkv"] = output
53 | model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv)
54 | else:
55 | raise ValueError(model_name)
56 |
57 | # Dataset
58 | filenames = Path(images_list).read_text().splitlines()
59 | dataset = utils.ImagesDataset(filenames=filenames, images_root=images_root, transform=val_transform)
60 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=8)
61 | print(f'Dataset size: {len(dataset)=}')
62 | print(f'Dataloader size: {len(dataloader)=}')
63 |
64 | # Prepare
65 | accelerator = Accelerator(fp16=True, cpu=False)
66 | # model, dataloader = accelerator.prepare(model, dataloader)
67 | model = model.to(accelerator.device)
68 |
69 | # Process
70 | pbar = tqdm(dataloader, desc='Processing')
71 | for i, (images, files, indices) in enumerate(pbar):
72 | output_dict = {}
73 |
74 | # Check if file already exists
75 | id = Path(files[0]).stem
76 | output_file = Path(output_dir) / f'{id}.pth'
77 | if output_file.is_file():
78 | pbar.write(f'Skipping existing file {str(output_file)}')
79 | continue
80 |
81 | # Reshape image
82 | P = patch_size
83 | B, C, H, W = images.shape
84 | H_patch, W_patch = H // P, W // P
85 | H_pad, W_pad = H_patch * P, W_patch * P
86 | T = H_patch * W_patch + 1 # number of tokens, add 1 for [CLS]
87 | # images = F.interpolate(images, size=(H_pad, W_pad), mode='bilinear') # resize image
88 | images = images[:, :, :H_pad, :W_pad]
89 | images = images.to(accelerator.device)
90 |
91 | # Forward and collect features into output dict
92 | if 'dino' in model_name or 'mocov3' in model_name:
93 | # accelerator.unwrap_model(model).get_intermediate_layers(images)[0].squeeze(0)
94 | model.get_intermediate_layers(images)[0].squeeze(0)
95 | # output_dict['out'] = out
96 | output_qkv = feat_out["qkv"].reshape(B, T, 3, num_heads, -1 // num_heads).permute(2, 0, 3, 1, 4)
97 | # output_dict['q'] = output_qkv[0].transpose(1, 2).reshape(B, T, -1)[:, 1:, :]
98 | output_dict['k'] = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :]
99 | # output_dict['v'] = output_qkv[2].transpose(1, 2).reshape(B, T, -1)[:, 1:, :]
100 | else:
101 | raise ValueError(model_name)
102 |
103 | # Metadata
104 | output_dict['indices'] = indices[0]
105 | output_dict['file'] = files[0]
106 | output_dict['id'] = id
107 | output_dict['model_name'] = model_name
108 | output_dict['patch_size'] = patch_size
109 | output_dict['shape'] = (B, C, H, W)
110 | output_dict = {k: (v.detach().cpu() if torch.is_tensor(v) else v) for k, v in output_dict.items()}
111 |
112 | # Save
113 | accelerator.save(output_dict, str(output_file))
114 | accelerator.wait_for_everyone()
115 |
116 | print(f'Saved features to {output_dir}')
117 |
118 |
119 | def _extract_eig(
120 | inp: Tuple[int, str],
121 | K: int,
122 | images_root: str,
123 | output_dir: str,
124 | which_matrix: str = 'laplacian',
125 | which_features: str = 'k',
126 | normalize: bool = True,
127 | lapnorm: bool = True,
128 | which_color_matrix: str = 'knn',
129 | threshold_at_zero: bool = True,
130 | image_downsample_factor: Optional[int] = None,
131 | image_color_lambda: float = 10,
132 | ):
133 | index, features_file = inp
134 |
135 | # Load
136 | data_dict = torch.load(features_file, map_location='cpu')
137 | image_id = data_dict['file'][:-4]
138 |
139 | # Load
140 | output_file = str(Path(output_dir) / f'{image_id}.pth')
141 | if Path(output_file).is_file():
142 | print(f'Skipping existing file {str(output_file)}')
143 | return # skip because already generated
144 |
145 | # Load affinity matrix
146 | feats = data_dict[which_features].squeeze().cuda()
147 | if normalize:
148 | feats = F.normalize(feats, p=2, dim=-1)
149 |
150 | # Eigenvectors of affinity matrix
151 | if which_matrix == 'affinity_torch':
152 | W = feats @ feats.T
153 | if threshold_at_zero:
154 | W = (W * (W > 0))
155 | eigenvalues, eigenvectors = torch.eig(W, eigenvectors=True)
156 | eigenvalues = eigenvalues.cpu()
157 | eigenvectors = eigenvectors.cpu()
158 |
159 | # Eigenvectors of affinity matrix with scipy
160 | elif which_matrix == 'affinity_svd':
161 | USV = torch.linalg.svd(feats, full_matrices=False)
162 | eigenvectors = USV[0][:, :K].T.to('cpu', non_blocking=True)
163 | eigenvalues = USV[1][:K].to('cpu', non_blocking=True)
164 |
165 | # Eigenvectors of affinity matrix with scipy
166 | elif which_matrix == 'affinity':
167 | W = (feats @ feats.T)
168 | if threshold_at_zero:
169 | W = (W * (W > 0))
170 | W = W.cpu().numpy()
171 | eigenvalues, eigenvectors = eigsh(W, which='LM', k=K)
172 | eigenvectors = torch.flip(torch.from_numpy(eigenvectors), dims=(-1,)).T
173 |
174 | # Eigenvectors of matting laplacian matrix
175 | elif which_matrix in ['matting_laplacian', 'laplacian']:
176 |
177 | # Get sizes
178 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict)
179 | if image_downsample_factor is None:
180 | image_downsample_factor = P
181 | H_pad_lr, W_pad_lr = H_pad // image_downsample_factor, W_pad // image_downsample_factor
182 |
183 | # Upscale features to match the resolution
184 | if (H_patch, W_patch) != (H_pad_lr, W_pad_lr):
185 | feats = F.interpolate(
186 | feats.T.reshape(1, -1, H_patch, W_patch),
187 | size=(H_pad_lr, W_pad_lr), mode='bilinear', align_corners=False
188 | ).reshape(-1, H_pad_lr * W_pad_lr).T
189 |
190 | ### Feature affinities
191 | W_feat = (feats @ feats.T)
192 | if threshold_at_zero:
193 | W_feat = (W_feat * (W_feat > 0))
194 | W_feat = W_feat / W_feat.max() # NOTE: If features are normalized, this naturally does nothing
195 | W_feat = W_feat.cpu().numpy()
196 |
197 | ### Color affinities
198 | # If we are fusing with color affinites, then load the image and compute
199 | if image_color_lambda > 0:
200 |
201 | # Load image
202 | image_file = str(Path(images_root) / f'{image_id}.jpg')
203 | image_lr = Image.open(image_file).resize((W_pad_lr, H_pad_lr), Image.BILINEAR)
204 | image_lr = np.array(image_lr) / 255.
205 |
206 | # Color affinities (of type scipy.sparse.csr_matrix)
207 | if which_color_matrix == 'knn':
208 | W_lr = utils.knn_affinity(image_lr)
209 | elif which_color_matrix == 'rw':
210 | W_lr = utils.rw_affinity(image_lr)
211 |
212 | # Convert to dense numpy array
213 | W_color = np.array(W_lr.todense().astype(np.float32))
214 |
215 | else:
216 |
217 | # No color affinity
218 | W_color = 0
219 |
220 | # Combine
221 | W_comb = W_feat + W_color * image_color_lambda # combination
222 | D_comb = np.array(utils.get_diagonal(W_comb).todense()) # is dense or sparse faster? not sure, should check
223 |
224 | # Extract eigenvectors
225 | if lapnorm:
226 | try:
227 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM', M=D_comb)
228 | except:
229 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM', M=D_comb)
230 | else:
231 | try:
232 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM')
233 | except:
234 | eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM')
235 | eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float()
236 |
237 | # Sign ambiguity
238 | for k in range(eigenvectors.shape[0]):
239 | if 0.5 < torch.mean((eigenvectors[k] > 0).float()).item() < 1.0: # reverse segment
240 | eigenvectors[k] = 0 - eigenvectors[k]
241 |
242 | # Save dict
243 | output_dict = {'eigenvalues': eigenvalues, 'eigenvectors': eigenvectors}
244 | torch.save(output_dict, output_file)
245 |
246 |
247 | def extract_eigs(
248 | images_root: str,
249 | features_dir: str,
250 | output_dir: str,
251 | which_matrix: str = 'laplacian',
252 | which_color_matrix: str = 'knn',
253 | which_features: str = 'k',
254 | normalize: bool = True,
255 | threshold_at_zero: bool = True,
256 | lapnorm: bool = True,
257 | K: int = 20,
258 | image_downsample_factor: Optional[int] = None,
259 | image_color_lambda: float = 0.0,
260 | multiprocessing: int = 0
261 | ):
262 | """
263 | Extracts eigenvalues from features.
264 |
265 | Example:
266 | python extract.py extract_eigs \
267 | --images_root "./data/VOC2012/images" \
268 | --features_dir "./data/VOC2012/features/dino_vits16" \
269 | --which_matrix "laplacian" \
270 | --output_dir "./data/VOC2012/eigs/laplacian" \
271 | --K 5
272 | """
273 | utils.make_output_dir(output_dir)
274 | kwargs = dict(K=K, which_matrix=which_matrix, which_features=which_features, which_color_matrix=which_color_matrix,
275 | normalize=normalize, threshold_at_zero=threshold_at_zero, images_root=images_root, output_dir=output_dir,
276 | image_downsample_factor=image_downsample_factor, image_color_lambda=image_color_lambda, lapnorm=lapnorm)
277 | print(kwargs)
278 | fn = partial(_extract_eig, **kwargs)
279 | inputs = list(enumerate(sorted(Path(features_dir).iterdir())))
280 | utils.parallel_process(inputs, fn, multiprocessing)
281 |
282 |
283 | def _extract_multi_region_segmentations(
284 | inp: Tuple[int, Tuple[str, str]],
285 | adaptive: bool,
286 | non_adaptive_num_segments: int,
287 | infer_bg_index: bool,
288 | kmeans_baseline: bool,
289 | output_dir: str,
290 | num_eigenvectors: int,
291 | ):
292 | index, (feature_path, eigs_path) = inp
293 |
294 | # Load
295 | data_dict = torch.load(feature_path, map_location='cpu')
296 | data_dict.update(torch.load(eigs_path, map_location='cpu'))
297 |
298 | # Output file
299 | id = Path(data_dict['id'])
300 | output_file = str(Path(output_dir) / f'{id}.png')
301 | if Path(output_file).is_file():
302 | print(f'Skipping existing file {str(output_file)}')
303 | return # skip because already generated
304 |
305 | # Sizes
306 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict)
307 |
308 | # If adaptive, we use the gaps between eigenvalues to determine the number of
309 | # segments per image. If not, we use non_adaptive_num_segments to get a fixed
310 | # number of segments per image.
311 | if adaptive:
312 | indices_by_gap = np.argsort(np.diff(data_dict['eigenvalues'].numpy()))[::-1]
313 | index_largest_gap = indices_by_gap[indices_by_gap != 0][0] # remove zero and take the biggest
314 | n_clusters = index_largest_gap + 1
315 | # print(f'Number of clusters: {n_clusters}')
316 | else:
317 | n_clusters = non_adaptive_num_segments
318 |
319 | # K-Means
320 | kmeans = KMeans(n_clusters=n_clusters)
321 |
322 | # Compute segments using eigenvector or baseline K-means
323 | if kmeans_baseline:
324 | feats = data_dict['k'].squeeze().numpy()
325 | clusters = kmeans.fit_predict(feats)
326 | else:
327 | eigenvectors = data_dict['eigenvectors'][1:1+num_eigenvectors].numpy() # take non-constant eigenvectors
328 | # import pdb; pdb.set_trace()
329 | clusters = kmeans.fit_predict(eigenvectors.T)
330 |
331 | # Reshape
332 | if clusters.size == H_patch * W_patch: # TODO: better solution might be to pass in patch index
333 | segmap = clusters.reshape(H_patch, W_patch)
334 | elif clusters.size == H_patch * W_patch * 4:
335 | segmap = clusters.reshape(H_patch * 2, W_patch * 2)
336 | else:
337 | raise ValueError()
338 |
339 | # TODO: Improve this step in the pipeline.
340 | # Background detection: we assume that the segment with the most border pixels is the
341 | # background region. We will always make this region equal 0.
342 | if infer_bg_index:
343 | indices, normlized_counts = utils.get_border_fraction(segmap)
344 | bg_index = indices[np.argmax(normlized_counts)].item()
345 | bg_region = (segmap == bg_index)
346 | zero_region = (segmap == 0)
347 | segmap[bg_region] = 0
348 | segmap[zero_region] = bg_index
349 |
350 | # Save dict
351 | Image.fromarray(segmap).convert('L').save(output_file)
352 |
353 |
354 | def extract_multi_region_segmentations(
355 | features_dir: str,
356 | eigs_dir: str,
357 | output_dir: str,
358 | adaptive: bool = False,
359 | non_adaptive_num_segments: int = 4,
360 | infer_bg_index: bool = True,
361 | kmeans_baseline: bool = False,
362 | num_eigenvectors: int = 1_000_000,
363 | multiprocessing: int = 0
364 | ):
365 | """
366 | Example:
367 | python extract.py extract_multi_region_segmentations \
368 | --features_dir "./data/VOC2012/features/dino_vits16" \
369 | --eigs_dir "./data/VOC2012/eigs/laplacian" \
370 | --output_dir "./data/VOC2012/multi_region_segmentation/fixed" \
371 | """
372 | utils.make_output_dir(output_dir)
373 | fn = partial(_extract_multi_region_segmentations, adaptive=adaptive, infer_bg_index=infer_bg_index,
374 | non_adaptive_num_segments=non_adaptive_num_segments, num_eigenvectors=num_eigenvectors,
375 | kmeans_baseline=kmeans_baseline, output_dir=output_dir)
376 | inputs = utils.get_paired_input_files(features_dir, eigs_dir)
377 | utils.parallel_process(inputs, fn, multiprocessing)
378 |
379 |
380 | def _extract_single_region_segmentations(
381 | inp: Tuple[int, Tuple[str, str]],
382 | threshold: float,
383 | output_dir: str,
384 | ):
385 | index, (feature_path, eigs_path) = inp
386 |
387 | # Load
388 | data_dict = torch.load(feature_path, map_location='cpu')
389 | data_dict.update(torch.load(eigs_path, map_location='cpu'))
390 |
391 | # Output file
392 | id = Path(data_dict['id'])
393 | output_file = str(Path(output_dir) / f'{id}.png')
394 | if Path(output_file).is_file():
395 | print(f'Skipping existing file {str(output_file)}')
396 | return # skip because already generated
397 |
398 | # Sizes
399 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict)
400 |
401 | # Eigenvector
402 | eigenvector = data_dict['eigenvectors'][1].numpy() # take smallest non-zero eigenvector
403 | segmap = (eigenvector > threshold).reshape(H_patch, W_patch)
404 |
405 | # Save dict
406 | Image.fromarray(segmap).convert('L').save(output_file)
407 |
408 |
409 | def extract_single_region_segmentations(
410 | features_dir: str,
411 | eigs_dir: str,
412 | output_dir: str,
413 | threshold: float = 0.0,
414 | multiprocessing: int = 0
415 | ):
416 | """
417 | Example:
418 | python extract.py extract_single_region_segmentations \
419 | --features_dir "./data/VOC2012/features/dino_vits16" \
420 | --eigs_dir "./data/VOC2012/eigs/laplacian" \
421 | --output_dir "./data/VOC2012/single_region_segmentation/patches" \
422 | """
423 | utils.make_output_dir(output_dir)
424 | fn = partial(_extract_single_region_segmentations, threshold=threshold, output_dir=output_dir)
425 | inputs = utils.get_paired_input_files(features_dir, eigs_dir)
426 | utils.parallel_process(inputs, fn, multiprocessing)
427 |
428 |
429 | def _extract_bbox(
430 | inp: Tuple[str, str],
431 | num_erode: int,
432 | num_dilate: int,
433 | skip_bg_index: bool,
434 | downsample_factor: Optional[int] = None
435 | ):
436 | index, (feature_path, segmentation_path) = inp
437 |
438 | # Load
439 | data_dict = torch.load(feature_path, map_location='cpu')
440 | segmap = np.array(Image.open(str(segmentation_path)))
441 | image_id = data_dict['id']
442 |
443 | # Sizes
444 | B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict, downsample_factor)
445 |
446 | # Get bounding boxes
447 | outputs = {'bboxes': [], 'bboxes_original_resolution': [], 'segment_indices': [], 'id': image_id,
448 | 'format': "(xmin, ymin, xmax, ymax)"}
449 | for segment_index in sorted(np.unique(segmap).tolist()):
450 | if (not skip_bg_index) or (segment_index > 0): # skip 0, because 0 is the background
451 |
452 | # Erode and dilate mask
453 | binary_mask = (segmap == segment_index)
454 | binary_mask = utils.erode_or_dilate_mask(binary_mask, r=num_erode, erode=True)
455 | binary_mask = utils.erode_or_dilate_mask(binary_mask, r=num_dilate, erode=False)
456 |
457 | # Find box
458 | mask = np.where(binary_mask == 1)
459 | ymin, ymax = min(mask[0]), max(mask[0]) + 1 # add +1 because excluded max
460 | xmin, xmax = min(mask[1]), max(mask[1]) + 1 # add +1 because excluded max
461 | bbox = [xmin, ymin, xmax, ymax]
462 | bbox_resized = [x * P for x in bbox] # rescale to image size
463 | bbox_features = [ymin, xmin, ymax, xmax] # feature space coordinates are different
464 |
465 | # Append
466 | outputs['segment_indices'].append(segment_index)
467 | outputs['bboxes'].append(bbox)
468 | outputs['bboxes_original_resolution'].append(bbox_resized)
469 |
470 | return outputs
471 |
472 |
473 | def extract_bboxes(
474 | features_dir: str,
475 | segmentations_dir: str,
476 | output_file: str,
477 | num_erode: int = 2,
478 | num_dilate: int = 3,
479 | skip_bg_index: bool = True,
480 | downsample_factor: Optional[int] = None,
481 | ):
482 | """
483 | Note: There is no need for multiprocessing here, as it is more convenient to save
484 | the entire output as a single JSON file. Example:
485 | python extract.py extract_bboxes \
486 | --features_dir "./data/VOC2012/features/dino_vits16" \
487 | --segmentations_dir "./data/VOC2012/multi_region_segmentation/fixed" \
488 | --num_erode 2 --num_dilate 5 \
489 | --output_file "./data/VOC2012/multi_region_bboxes/fixed/bboxes_e2_d5.pth" \
490 | """
491 | utils.make_output_dir(str(Path(output_file).parent), check_if_empty=False)
492 | fn = partial(_extract_bbox, num_erode=num_erode, num_dilate=num_dilate, skip_bg_index=skip_bg_index,
493 | downsample_factor=downsample_factor)
494 | inputs = utils.get_paired_input_files(features_dir, segmentations_dir)
495 | all_outputs = [fn(inp) for inp in tqdm(inputs, desc='Extracting bounding boxes')]
496 | torch.save(all_outputs, output_file)
497 | print('Done')
498 |
499 |
500 | def extract_bbox_features(
501 | images_root: str,
502 | bbox_file: str,
503 | model_name: str,
504 | output_file: str,
505 | ):
506 | """
507 | Example:
508 | python extract.py extract_bbox_features \
509 | --model_name dino_vits16 \
510 | --images_root "./data/VOC2012/images" \
511 | --bbox_file "./data/VOC2012/multi_region_bboxes/fixed/bboxes_e2_d5.pth" \
512 | --output_file "./data/VOC2012/features/dino_vits16" \
513 | --output_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_features_e2_d5.pth" \
514 | """
515 |
516 | # Load bounding boxes
517 | bbox_list = torch.load(bbox_file)
518 | total_num_boxes = sum(len(d['bboxes']) for d in bbox_list)
519 | print(f'Loaded bounding box list. There are {total_num_boxes} total bounding boxes.')
520 |
521 | # Models
522 | model_name_lower = model_name.lower()
523 | model, val_transform, patch_size, num_heads = utils.get_model(model_name_lower)
524 | model.eval().to('cuda')
525 |
526 | # Loop over boxes
527 | for bbox_dict in tqdm(bbox_list):
528 | # Get image info
529 | image_id = bbox_dict['id']
530 | bboxes = bbox_dict['bboxes_original_resolution']
531 | # Load image as tensor
532 | image_filename = str(Path(images_root) / f'{image_id}.jpg')
533 | image = val_transform(Image.open(image_filename).convert('RGB')) # (3, H, W)
534 | image = image.unsqueeze(0).to('cuda') # (1, 3, H, W)
535 | features_crops = []
536 | for (xmin, ymin, xmax, ymax) in bboxes:
537 | image_crop = image[:, :, ymin:ymax, xmin:xmax]
538 | features_crop = model(image_crop).squeeze().cpu()
539 | features_crops.append(features_crop)
540 | bbox_dict['features'] = torch.stack(features_crops, dim=0)
541 |
542 | # Save
543 | torch.save(bbox_list, output_file)
544 | print(f'Saved features to {output_file}')
545 |
546 |
547 | def extract_bbox_clusters(
548 | bbox_features_file: str,
549 | output_file: str,
550 | num_clusters: int = 20,
551 | seed: int = 0,
552 | pca_dim: Optional[int] = 0,
553 | ):
554 | """
555 | Example:
556 | python extract.py extract_bbox_clusters \
557 | --bbox_features_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_features_e2_d5.pth" \
558 | --pca_dim 32 --num_clusters 21 --seed 0 \
559 | --output_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_clusters_e2_d5_pca_32.pth" \
560 | """
561 |
562 | # Load bounding boxes
563 | bbox_list = torch.load(bbox_features_file)
564 | total_num_boxes = sum(len(d['bboxes']) for d in bbox_list)
565 | print(f'Loaded bounding box list. There are {total_num_boxes} total bounding boxes with features.')
566 |
567 | # Loop over boxes and stack features with PyTorch, because Numpy is too slow
568 | print(f'Stacking and normalizing features')
569 | all_features = torch.cat([bbox_dict['features'] for bbox_dict in bbox_list], dim=0) # (numBbox, D)
570 | all_features = all_features / torch.norm(all_features, dim=-1, keepdim=True) # (numBbox, D)f
571 | all_features = all_features.numpy()
572 |
573 | # Cluster: PCA
574 | if pca_dim:
575 | pca = PCA(pca_dim)
576 | print(f'Computing PCA with dimension {pca_dim}')
577 | all_features = pca.fit_transform(all_features)
578 |
579 | # Cluster: K-Means
580 | print(f'Computing K-Means clustering with {num_clusters} clusters')
581 | kmeans = MiniBatchKMeans(n_clusters=num_clusters, batch_size=4096, max_iter=5000, random_state=seed)
582 | clusters = kmeans.fit_predict(all_features)
583 |
584 | # Print
585 | _indices, _counts = np.unique(clusters, return_counts=True)
586 | print(f'Cluster indices: {_indices.tolist()}')
587 | print(f'Cluster counts: {_counts.tolist()}')
588 |
589 | # Loop over boxes and add clusters
590 | idx = 0
591 | for bbox_dict in bbox_list:
592 | num_bboxes = len(bbox_dict['bboxes'])
593 | del bbox_dict['features'] # bbox_dict['features'] = bbox_dict['features'].squeeze()
594 | bbox_dict['clusters'] = clusters[idx: idx + num_bboxes]
595 | idx = idx + num_bboxes
596 |
597 | # Save
598 | torch.save(bbox_list, output_file)
599 | print(f'Saved features to {output_file}')
600 |
601 |
602 | def extract_semantic_segmentations(
603 | segmentations_dir: str,
604 | bbox_clusters_file: str,
605 | output_dir: str,
606 | ):
607 | """
608 | Example:
609 | python extract.py extract_semantic_segmentations \
610 | --segmentations_dir "./data/VOC2012/multi_region_segmentation/fixed" \
611 | --bbox_clusters_file "./data/VOC2012/multi_region_bboxes/fixed/bbox_clusters_e2_d5_pca_32.pth" \
612 | --output_dir "./data/VOC2012/semantic_segmentations/patches/fixed/segmaps_e2_d5_pca_32" \
613 | """
614 |
615 | # Load bounding boxes
616 | bbox_list = torch.load(bbox_clusters_file)
617 | total_num_boxes = sum(len(d['bboxes']) for d in bbox_list)
618 | print(f'Loaded bounding box list. There are {total_num_boxes} total bounding boxes with features and clusters.')
619 |
620 | # Output
621 | utils.make_output_dir(output_dir)
622 |
623 | # Loop over boxes
624 | for bbox_dict in tqdm(bbox_list):
625 | # Get image info
626 | image_id = bbox_dict['id']
627 | # Load segmentation as tensor
628 | segmap_path = str(Path(segmentations_dir) / f'{image_id}.png')
629 | segmap = np.array(Image.open(segmap_path))
630 | # Check if the segmap is a binary file with foreground pixels saved as 255 instead of 1
631 | # this will be the case for some of our baselines
632 | if set(np.unique(segmap).tolist()).issubset({0, 255}):
633 | segmap[segmap == 255] = 1
634 | # Semantic map
635 | if not len(bbox_dict['segment_indices']) == len(bbox_dict['clusters'].tolist()):
636 | import pdb
637 | pdb.set_trace()
638 | semantic_map = dict(zip(bbox_dict['segment_indices'], bbox_dict['clusters'].tolist()))
639 | assert 0 not in semantic_map, semantic_map
640 | semantic_map[0] = 0 # background region remains zero
641 | # Perform mapping
642 | semantic_segmap = np.vectorize(semantic_map.__getitem__)(segmap)
643 | # Save
644 | output_file = str(Path(output_dir) / f'{image_id}.png')
645 | Image.fromarray(semantic_segmap.astype(np.uint8)).convert('L').save(output_file)
646 |
647 | print(f'Saved features to {output_dir}')
648 |
649 |
650 | def _extract_crf_segmentations(
651 | inp: Tuple[int, Tuple[str, str]],
652 | images_root: str,
653 | num_classes: int,
654 | output_dir: str,
655 | crf_params: Tuple,
656 | downsample_factor: int = 16,
657 | ):
658 | index, (image_file, segmap_path) = inp
659 |
660 | # Output file
661 | id = Path(image_file).stem
662 | output_file = str(Path(output_dir) / f'{id}.png')
663 | if Path(output_file).is_file():
664 | print(f'Skipping existing file {str(output_file)}')
665 | return # skip because already generated
666 |
667 | # Load image and segmap
668 | image_file = str(Path(images_root) / f'{id}.jpg')
669 | image = np.array(Image.open(image_file).convert('RGB')) # (H_patch, W_patch, 3)
670 | segmap = np.array(Image.open(segmap_path)) # (H_patch, W_patch)
671 |
672 | # Sizes
673 | P = downsample_factor
674 | H, W = image.shape[:2]
675 | H_patch, W_patch = H // P, W // P
676 | H_pad, W_pad = H_patch * P, W_patch * P
677 |
678 | # Resize and expand
679 | segmap_upscaled = cv2.resize(segmap, dsize=(W_pad, H_pad), interpolation=cv2.INTER_NEAREST) # (H_pad, W_pad)
680 | segmap_orig_res = cv2.resize(segmap, dsize=(W, H), interpolation=cv2.INTER_NEAREST) # (H, W)
681 | segmap_orig_res[:H_pad, :W_pad] = segmap_upscaled # replace with the correctly upscaled version, just in case they are different
682 |
683 | # Convert binary
684 | if set(np.unique(segmap_orig_res).tolist()) == {0, 255}:
685 | segmap_orig_res[segmap_orig_res == 255] = 1
686 |
687 | # CRF
688 | import denseCRF # make sure you've installed SimpleCRF
689 | unary_potentials = F.one_hot(torch.from_numpy(segmap_orig_res).long(), num_classes=num_classes)
690 | segmap_crf = denseCRF.densecrf(image, unary_potentials, crf_params) # (H_pad, W_pad)
691 |
692 | # Save
693 | Image.fromarray(segmap_crf).convert('L').save(output_file)
694 |
695 |
696 | def extract_crf_segmentations(
697 | images_list: str,
698 | images_root: str,
699 | segmentations_dir: str,
700 | output_dir: str,
701 | num_classes: int = 21,
702 | downsample_factor: int = 16,
703 | multiprocessing: int = 0,
704 | # CRF parameters
705 | w1 = 10, # weight of bilateral term # default: 10.0,
706 | alpha = 80, # spatial std # default: 80,
707 | beta = 13, # rgb std # default: 13,
708 | w2 = 3, # weight of spatial term # default: 3.0,
709 | gamma = 3, # spatial std # default: 3,
710 | it = 5.0, # iteration # default: 5.0,
711 | ):
712 | """
713 | Applies a CRF to segmentations in order to sharpen them.
714 |
715 | Example:
716 | python extract.py extract_crf_segmentations \
717 | --images_list "./data/VOC2012/lists/images.txt" \
718 | --images_root "./data/VOC2012/images" \
719 | --segmentations_dir "./data/VOC2012/semantic_segmentations/patches/fixed/segmaps_e2_d5_pca_32" \
720 | --output_dir "./data/VOC2012/semantic_segmentations/crf/fixed/segmaps_e2_d5_pca_32" \
721 | """
722 | try:
723 | import denseCRF
724 | except:
725 | raise ImportError(
726 | 'Please install SimpleCRF to compute CRF segmentations:\n'
727 | 'pip3 install SimpleCRF'
728 | )
729 |
730 | utils.make_output_dir(output_dir)
731 | fn = partial(_extract_crf_segmentations, images_root=images_root, num_classes=num_classes, output_dir=output_dir,
732 | crf_params=(w1, alpha, beta, w2, gamma, it), downsample_factor=downsample_factor)
733 | inputs = utils.get_paired_input_files(images_list, segmentations_dir)
734 | print(f'Found {len(inputs)} images and segmaps')
735 | utils.parallel_process(inputs, fn, multiprocessing)
736 |
737 |
738 | def vis_segmentations(
739 | images_list: str,
740 | images_root: str,
741 | segmentations_dir: str,
742 | bbox_file: Optional[str] = None,
743 | ):
744 | """
745 | Example:
746 | streamlit run extract.py vis_segmentations -- \
747 | --images_list "./data/VOC2012/lists/images.txt" \
748 | --images_root "./data/VOC2012/images" \
749 | --segmentations_dir "./data/VOC2012/multi_region_segmentation/fixed"
750 | or alternatively:
751 | --segmentations_dir "./data/VOC2012/semantic_segmentations/crf/fixed/segmaps_e2_d5_pca_32/"
752 | """
753 | # Streamlit setup
754 | import streamlit as st
755 | from matplotlib.cm import get_cmap
756 | from skimage.color import label2rgb
757 | st.set_page_config(layout='wide')
758 |
759 | # Inputs
760 | image_paths = []
761 | segmap_paths = []
762 | images_root = Path(images_root)
763 | segmentations_dir = Path(segmentations_dir)
764 | for image_file in Path(images_list).read_text().splitlines():
765 | segmap_file = f'{Path(image_file).stem}.png'
766 | image_paths.append(images_root / image_file)
767 | segmap_paths.append(segmentations_dir / segmap_file)
768 | print(f'Found {len(image_paths)} image and segmap paths')
769 |
770 | # Load optional bounding boxes
771 | if bbox_file is not None:
772 | bboxes_list = torch.load(bbox_file)
773 |
774 | # Colors
775 | colors = get_cmap('tab20', 21).colors[:, :3]
776 |
777 | # Which index
778 | which_index = st.number_input(label='Which index to view (0 for all)', value=0)
779 |
780 | # Load
781 | total = 0
782 | for i, (image_path, segmap_path) in enumerate(zip(image_paths, segmap_paths)):
783 | if total > 40: break
784 | image_id = image_path.stem
785 |
786 | # Streamlit
787 | cols = []
788 |
789 | # Load
790 | image = np.array(Image.open(image_path).convert('RGB'))
791 | segmap = np.array(Image.open(segmap_path))
792 |
793 | # Convert binary
794 | if set(np.unique(segmap).tolist()) == {0, 255}:
795 | segmap[segmap == 255] = 1
796 |
797 | # Resize
798 | segmap_fullres = cv2.resize(segmap, dsize=image.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
799 |
800 | # Only view images with a specific class
801 | if which_index not in np.unique(segmap):
802 | continue
803 | total += 1
804 |
805 | # Streamlit
806 | cols.append({'image': image, 'caption': image_id})
807 |
808 | # Load optional bounding boxes
809 | bboxes = None
810 | if bbox_file is not None:
811 | bboxes = torch.tensor(bboxes_list[i]['bboxes_original_resolution'])
812 | assert bboxes_list[i]['id'] == image_id, f"{bboxes_list[i]['id']=} but {image_id=}"
813 | image_torch = torch.from_numpy(image).permute(2, 0, 1)
814 | image_with_boxes_torch = draw_bounding_boxes(image_torch, bboxes)
815 | image_with_boxes = image_with_boxes_torch.permute(1, 2, 0).numpy()
816 |
817 | # Streamlit
818 | cols.append({'image': image_with_boxes})
819 |
820 | # Color
821 | segmap_label_indices, segmap_label_counts = np.unique(segmap, return_counts=True)
822 | blank_segmap_overlay = label2rgb(label=segmap_fullres, image=np.full_like(image, 128),
823 | colors=colors[segmap_label_indices[segmap_label_indices != 0]], bg_label=0, alpha=1.0)
824 | image_segmap_overlay = label2rgb(label=segmap_fullres, image=image,
825 | colors=colors[segmap_label_indices[segmap_label_indices != 0]], bg_label=0, alpha=0.45)
826 | segmap_caption = dict(zip(segmap_label_indices.tolist(), (segmap_label_counts).tolist()))
827 |
828 | # Streamlit
829 | cols.append({'image': blank_segmap_overlay, 'caption': segmap_caption})
830 | cols.append({'image': image_segmap_overlay, 'caption': segmap_caption})
831 |
832 | # Display
833 | for d, col in zip(cols, st.columns(len(cols))):
834 | col.image(**d)
835 |
836 |
837 | if __name__ == '__main__':
838 | torch.set_grad_enabled(False)
839 | fire.Fire(dict(
840 | extract_features=extract_features,
841 | extract_eigs=extract_eigs,
842 | extract_multi_region_segmentations=extract_multi_region_segmentations,
843 | extract_bboxes=extract_bboxes,
844 | extract_bbox_features=extract_bbox_features,
845 | extract_bbox_clusters=extract_bbox_clusters,
846 | extract_semantic_segmentations=extract_semantic_segmentations,
847 | extract_crf_segmentations=extract_crf_segmentations,
848 | extract_single_region_segmentations=extract_single_region_segmentations,
849 | vis_segmentations=vis_segmentations,
850 | ))
851 |
--------------------------------------------------------------------------------