├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── app.py
├── configs
├── Base-coco-stuff-164K-171.yaml
├── san_clip_vit_large_res4_coco.yaml
└── san_clip_vit_res4_coco.yaml
├── datasets
├── prepare_ade20k_sem_seg.py
├── prepare_pcontext_sem_seg_459cls.py
├── prepare_pcontext_sem_seg_59cls.py
└── prepare_voc_sem_seg.py
├── docker
├── Dockerfile
└── app.Dockerfile
├── docs
├── .nojekyll
├── README.md
├── arch.png
├── index.html
└── static
│ ├── css
│ ├── bulma-carousel.min.css
│ ├── bulma-slider.min.css
│ ├── bulma.css.map.txt
│ ├── bulma.min.css
│ ├── fontawesome.all.min.css
│ └── index.css
│ ├── images
│ ├── favicon.ico
│ └── grid.jpg
│ ├── js
│ ├── bulma-carousel.js
│ ├── bulma-carousel.min.js
│ ├── bulma-slider.js
│ ├── bulma-slider.min.js
│ ├── fontawesome.all.min.js
│ └── index.js
│ └── pdfs
│ └── sample.pdf
├── install.sh
├── predict.py
├── requirements.txt
├── resources
├── arch.png
├── san_vit_b_16.log
└── san_vit_large_14.log
├── san
├── __init__.py
├── config.py
├── data
│ ├── __init__.py
│ ├── build.py
│ ├── dataset_mappers
│ │ ├── __init__.py
│ │ └── mask_former_semantic_dataset_mapper.py
│ └── datasets
│ │ ├── __init__.py
│ │ ├── register_ade20k_full.py
│ │ ├── register_coco_stuff_164k.py
│ │ ├── register_pcontext.py
│ │ └── register_voc.py
├── model
│ ├── __init__.py
│ ├── attn_helper.py
│ ├── clip_utils
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── utils.py
│ │ └── visual.py
│ ├── criterion.py
│ ├── layers.py
│ ├── matcher.py
│ ├── san.py
│ └── side_adapter
│ │ ├── __init__.py
│ │ ├── side_adapter.py
│ │ └── timm_wrapper.py
├── test_time_augmentation.py
└── utils
│ ├── __init__.py
│ ├── events.py
│ ├── file_io.py
│ └── misc.py
├── train_net.py
└── visualize_json_results.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # output dir
2 | output
3 | instant_test_output
4 | inference_test_output
5 |
6 |
7 | *.json
8 | *.diff
9 | *.jpg
10 | !/projects/DensePose/doc/images/*.jpg
11 | !docs/static/images/*.jpg
12 | # compilation and distribution
13 | __pycache__
14 | _ext
15 | *.pyc
16 | *.pyd
17 | *.so
18 | *.dll
19 | *.egg-info/
20 | build/
21 | dist/
22 | wheels/
23 |
24 | # pytorch/python/numpy formats
25 | *.pth
26 | *.pkl
27 | *.npy
28 | *.ts
29 | model_ts*.txt
30 |
31 | # ipython/jupyter notebooks
32 | *.ipynb
33 | **/.ipynb_checkpoints/
34 |
35 | # Editor temporaries
36 | *.swn
37 | *.swo
38 | *.swp
39 | *~
40 |
41 | # editor settings
42 | .idea
43 | .vscode
44 | _darcs
45 |
46 | # project dirs
47 | /detectron2/model_zoo/configs
48 | /datasets/*
49 | !/datasets/*.*
50 | /projects/*/datasets
51 | /models
52 | /snippet
53 |
54 | .history
55 | .env
56 |
57 | wandb
58 | amlt
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: 22.12.0
4 | hooks:
5 | - id: black
6 | exclude: ^third_party/
7 | - repo: https://github.com/asottile/seed-isort-config
8 | rev: v2.2.0
9 | hooks:
10 | - id: seed-isort-config
11 | - repo: https://github.com/pre-commit/pre-commit-hooks
12 | rev: v4.0.1
13 | hooks:
14 | - id: trailing-whitespace
15 | - id: check-yaml
16 | - id: end-of-file-fixer
17 | - id: requirements-txt-fixer
18 | - id: check-merge-conflict
19 | - id: fix-encoding-pragma
20 | args: ["--remove"]
21 | - id: mixed-line-ending
22 | args: ["--fix=lf"]
23 | # - repo: https://github.com/jumanjihouse/pre-commit-hooks
24 | # rev: 2.1.5
25 | # hooks:
26 | # - id: markdownlint
27 | # args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036", "-t", "allow_different_nesting"]
28 | - repo: https://github.com/myint/docformatter
29 | rev: v1.4
30 | hooks:
31 | - id: docformatter
32 | args: ["--in-place", "--wrap-descriptions", "79"]
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 MendelXu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR2023-Highlight] Side Adapter Network for Open-Vocabulary Semantic Segmentation
2 | # [PAMI] SAN: Side Adapter Network for Open-Vocabulary Semantic Segmentation
3 | [](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-2?p=side-adapter-network-for-open-vocabulary)
4 | [](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-3?p=side-adapter-network-for-open-vocabulary)
5 | [](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-7?p=side-adapter-network-for-open-vocabulary)
6 | [](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-1?p=side-adapter-network-for-open-vocabulary)
7 | [](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-5?p=side-adapter-network-for-open-vocabulary)
8 |
9 | This is the official implementation of our conference paper : "[Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)" (main branch) and journal paper: "[SAN: Side Adapter Network for Open-Vocabulary Semantic Segmentation
10 | ](https://www.computer.org/csdl/journal/tp/2023/12/10238837/1QaEJ3q98aY)" (video branch).
11 |
12 | ## Introduction
13 |
14 | This paper presents a new framework for open-vocabulary semantic segmentation with the pre-trained vision-language model, named Side Adapter Network (SAN). Our approach models the semantic segmentation task as a region recognition problem. A side network is attached to a frozen CLIP model with two branches: one for predicting mask proposals, and the other for predicting attention bias which is applied in the CLIP model to recognize the class of masks. This decoupled design has the benefit CLIP in recognizing the class of mask proposals. Since the attached side network can reuse CLIP features, it can be very light. In addition, the entire network can be trained end-to-end, allowing the side network to be adapted to the frozen CLIP model, which makes the predicted mask proposals CLIP-aware.
15 | Our approach is fast, accurate, and only adds a few additional trainable parameters. We evaluate our approach on multiple semantic segmentation benchmarks. Our method significantly outperforms other counterparts, with up to 18 times fewer trainable parameters and 19 times faster inference speed.
16 | 
17 | ### Tab of Content
18 | - [Demo](#6)
19 | - [Installation](#1)
20 | - [Data Preparation](#2)
21 | - [Usage](#3)
22 | - [Training](#5)
23 | - [Evaluation](#4)
24 | - [Visualization](#7)
25 |
26 | - [FAQ](#8)
27 |
28 |
29 |
30 | ### Demo
31 | - Run the demo app on [🤗HuggingFace](https://huggingface.co/spaces/Mendel192/SAN-Demo). (It is running on a low-spec machine and could be slow)
32 | - Run the demo app with docker.
33 | ```
34 | docker build docker/app.Docker -t san_app
35 | docker run -it --shm-size 4G -p 7860:7860 san_app
36 | ```
37 |
38 |
39 | ### Installation
40 | 1. Clone the repository
41 | ```sh
42 | git clone https://github.com/MendelXu/SAN.git
43 | ```
44 | 2. Navigate to the project directory
45 | ```sh
46 | cd SAN
47 | ```
48 | 3. Install the dependencies
49 | ```sh
50 | bash install.sh
51 | ```
52 | **Hint**: You can run the job in the docker instead of installing dependencies locally.
53 | Run with pre-built docker:
54 | ```
55 | docker run -it --gpus all --shm-size 8G mendelxu/pytorch:d2_nvcr_2008 /bin/bash
56 | ```
57 | or build your docker with provided dockerfile `docker/Dcokerfile`.
58 |
59 |
60 |
61 | ### Data Preparation
62 | See [SimSeg](https://github.com/MendelXu/zsseg.baseline) for reference. The data should be organized like:
63 | ```
64 | datasets/
65 | coco/
66 | ...
67 | train2017/
68 | val2017/
69 | stuffthingmaps_detectron2/
70 | VOC2012/
71 | ...
72 | images_detectron2/
73 | annotations_detectron2/
74 | pcontext/
75 | ...
76 | val/
77 | pcontext_full/
78 | ...
79 | val/
80 | ADEChallengeData2016/
81 | ...
82 | images/
83 | annotations_detectron2/
84 | ADE20K_2021_17_01/
85 | ...
86 | images/
87 | annotations_detectron2/
88 | ```
89 | **Hint** In the code, those datasets are registered with their related dataset names. The relationship is:
90 | ```
91 | coco_2017_*_stuff_sem_seg : COCO Stuff-171
92 | voc_sem_seg_*: Pascal VOC-20
93 | pcontext_sem_seg_*: Pascal Context-59
94 | ade20k_sem_seg_*: ADE-150
95 | pcontext_full_sem_seg_*: Pascal Context-459
96 | ade20k_full_sem_seg_*: ADE-847
97 | ```
98 |
99 |
100 | ### Usage
101 |
102 |
103 | - #### Pretrained Weights
104 |
105 | |Model|Config |Weights|Logs|
106 | |-----|-------|---|---|
107 | |SAN-ViT-B/16|configs/san_clip_vit_res4_coco.yaml |[Huggingface](https://huggingface.co/Mendel192/san/resolve/main/san_vit_b_16.pth) |[Log](resources/san_vit_b_16.log) |
108 | |SAN-ViT-L/14|configs/san_clip_vit_large_res4_coco.yaml |[Huggingface](https://huggingface.co/Mendel192/san/resolve/main/san_vit_large_14.pth) |[Log](resources/san_vit_large_14.log)|
109 |
110 |
111 | - #### Evaluation
112 |
113 |
114 | - evaluate trained model on validation sets of all datasets.
115 | ```sh
116 | python train_net.py --eval-only --config-file --num-gpus OUTPUT_DIR MODEL.WEIGHTS
117 | ```
118 | For example, evaluate our pre-trained model:
119 | ```
120 | # 1. Download SAN (ViT-B/16 CLIP) from https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth.
121 | # 2. put it at `output/model.pth`.
122 | # 3. evaluation
123 | python train_net.py --eval-only --config-file configs/san_clip_vit_res4_coco.yaml --num-gpus 8 OUTPUT_DIR ./output/trained_vit_b16 MODEL.WEIGHTS output/model.pth
124 | ```
125 | - evaluate trained model on validation sets of one dataset.
126 | ```sh
127 | python train_net.py --eval-only --config-file --num-gpus OUTPUT_DIR MODEL.WEIGHTS DATASETS.TEST "('',)"
128 | ```
129 |
130 | - #### Visualization
131 |
132 |
133 |
134 | ```sh
135 | python visualize_json_results.py --input --output --dataset
136 | # example:
137 | # Generate the results.
138 | # python train_net.py --eval-only --config-file configs/san_clip_vit_res4_coco.yaml --num-gpus 1 OUTPUT_DIR ./output/trained_vit_b16 MODEL.WEIGHTS output/san/san_vit_b_16.pth DATASETS.TEST '("pcontext_sem_seg_val",)'
139 | # Visualizing
140 | # python visualize_json_results.py --input output/trained_vit_b16/inference/sem_seg_predictions.json --output output/viz --dataset pcontext_sem_seg_val
141 | ```
142 |
143 |
144 | - #### Training
145 |
146 | ```sh
147 | wandb off
148 | # [Optional] If you want to log the training logs to wandb.
149 | # wandb login
150 | # wandb on
151 | python train_net.py --config-file --num-gpus OUTPUT_DIR WANDB.NAME
152 | ```
153 | **Hint**: We use `<>` to denote the variables you should replace according to your own setting.
154 |
155 | ### FAQ
156 |
157 |
158 |
159 | If you found it is too late to get a response from the author on the github, please e-mail me directly at [shea.mendel] [AT] [gmail.com].
160 |
161 | ### License
162 | Distributed under the MIT License. See LICENSE for more information.
163 |
164 | ### Cite
165 |
166 | If you find it helpful, you can cite our paper in your work.
167 |
168 | ```
169 | @proceedings{xu2023side,
170 | title={Side Adapter Network for Open-Vocabulary Semantic Segmentation},
171 | author={Mengde Xu, Zheng Zhang, Fangyun Wei, Han Hu, Xiang Bai},
172 | journal={CVPR},
173 | year={2023}
174 | }
175 |
176 | @article{xu2023san,
177 | title={SAN: Side adapter network for open-vocabulary semantic segmentation},
178 | author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Hu, Han and Bai, Xiang},
179 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
180 | year={2023},
181 | publisher={IEEE}
182 | }
183 | ```
184 |
185 |
186 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | from predict import Predictor, model_cfg
2 | from PIL import Image
3 | import gradio as gr
4 |
5 | # set a lot of global variables
6 |
7 | predictor = None
8 | vocabulary = ["bat man, woman"]
9 | input_image: Image.Image = None
10 | outputs: dict = None
11 | cur_model_name: str = None
12 |
13 |
14 | def set_vocabulary(text):
15 | global vocabulary
16 | vocabulary = text.split(",")
17 | print("set vocabulary to", vocabulary)
18 |
19 |
20 | def set_input(image):
21 | global input_image
22 | input_image = image
23 | print("set input image to", image)
24 |
25 |
26 | def set_predictor(model_name: str):
27 | global cur_model_name
28 | if cur_model_name == model_name:
29 | return
30 | global predictor
31 | predictor = Predictor(**model_cfg[model_name])
32 | print("set predictor to", model_name)
33 | cur_model_name = model_name
34 |
35 |
36 | set_predictor(list(model_cfg.keys())[0])
37 |
38 |
39 | # for visualization
40 | def visualize(vis_mode):
41 | if outputs is None:
42 | return None
43 | return predictor.visualize(**outputs, mode=vis_mode)
44 |
45 |
46 | def segment_image(vis_mode, voc_mode, model_name):
47 | set_predictor(model_name)
48 | if input_image is None:
49 | return None
50 | global outputs
51 | result = predictor.predict(
52 | input_image, vocabulary=vocabulary, augment_vocabulary=voc_mode
53 | )
54 | outputs = result
55 |
56 | return visualize(vis_mode)
57 |
58 |
59 | def segment_e2e(image, vis_mode):
60 | set_input(image)
61 | return segment_image(vis_mode)
62 |
63 |
64 | # gradio
65 |
66 | with gr.Blocks(
67 | css="""
68 | #submit {background: #3498db; color: white; border: none; padding: 10px 20px; border-radius: 5px;width: 20%;margin: 0 auto; display: block;}
69 |
70 | """
71 | ) as demo:
72 | gr.Markdown(
73 | f"Side Adapter Network for Open-Vocabulary Semantic Segmentation
"
74 | )
75 | gr.Markdown(
76 | """
77 | This is the demo for our conference paper : "[Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)".
78 | """
79 | )
80 | # gr.Image(type="pil", value="./resources/arch.png", shape=(460, 200), elem_id="arch")
81 | gr.Markdown(
82 | """
83 | ---
84 | """
85 | )
86 | with gr.Row():
87 | image = gr.Image(type="pil", elem_id="input_image")
88 | plt = gr.Image(type="pil", elem_id="output_image")
89 |
90 | with gr.Row():
91 | model_name = gr.Dropdown(
92 | list(model_cfg.keys()), label="Model", value="san_vit_b_16"
93 | )
94 | augment_vocabulary = gr.Dropdown(
95 | ["COCO-all", "COCO-stuff"],
96 | label="Vocabulary Expansion",
97 | value="COCO-all",
98 | )
99 | vis_mode = gr.Dropdown(
100 | ["overlay", "mask"], label="Visualization Mode", value="overlay"
101 | )
102 | object_names = gr.Textbox(value=",".join(vocabulary), label="Object Names (Empty inputs will use the vocabulary specified in `Vocabulary Expansion`. Multiple names should be seperated with ,.)", lines=5)
103 |
104 | button = gr.Button("Run", elem_id="submit")
105 | note = gr.Markdown(
106 | """
107 | ---
108 | ### FAQ
109 | - **Q**: What is the `Vocabulary Expansion` option for?
110 | **A**: The vocabulary expansion option is used to expand the vocabulary of the model. The model assign category to each area with `argmax`. When only a vocabulary with few thing classes is provided, it will produce much false postive.
111 | - **Q**: Error: `Unexpected token '<', " 0", "6->1", "12->2", "18->3"]
9 | ATTN_BIAS:
10 | NUM_HEADS: 16
--------------------------------------------------------------------------------
/configs/san_clip_vit_res4_coco.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: Base-coco-stuff-164K-171.yaml
2 | MODEL:
3 | META_ARCHITECTURE: "SAN"
4 | SOLVER:
5 | BACKBONE_MULTIPLIER: 1.0
--------------------------------------------------------------------------------
/datasets/prepare_ade20k_sem_seg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | import os
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import tqdm
9 | from PIL import Image
10 |
11 |
12 | def convert(input, output):
13 | img = np.asarray(Image.open(input))
14 | assert img.dtype == np.uint8
15 | img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1
16 | Image.fromarray(img).save(output)
17 |
18 |
19 | if __name__ == "__main__":
20 | dataset_dir = (
21 | Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016"
22 | )
23 | for name in ["training", "validation"]:
24 | annotation_dir = dataset_dir / "annotations" / name
25 | output_dir = dataset_dir / "annotations_detectron2" / name
26 | output_dir.mkdir(parents=True, exist_ok=True)
27 | for file in tqdm.tqdm(list(annotation_dir.iterdir())):
28 | output_file = output_dir / file.name
29 | convert(file, output_file)
30 |
--------------------------------------------------------------------------------
/datasets/prepare_pcontext_sem_seg_459cls.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: Zhen Zhu(zzhu@hust.edu.cn)
4 | # Generate train & val data.
5 |
6 |
7 | import os
8 | import argparse
9 | import shutil
10 | from PIL import Image
11 | import numpy as np
12 | from tqdm import tqdm
13 | from scipy.io import loadmat
14 |
15 | LABEL_DIR = "label"
16 | IMAGE_DIR = "image"
17 |
18 |
19 | class PascalContextGenerator(object):
20 | def __init__(self, args, image_dir=IMAGE_DIR, label_dir=LABEL_DIR):
21 | self.args = args
22 | self.train_label_dir = os.path.join(self.args.save_dir, "train", label_dir)
23 | self.val_label_dir = os.path.join(self.args.save_dir, "val", label_dir)
24 | if not os.path.exists(self.train_label_dir):
25 | os.makedirs(self.train_label_dir)
26 |
27 | if not os.path.exists(self.val_label_dir):
28 | os.makedirs(self.val_label_dir)
29 |
30 | self.train_image_dir = os.path.join(self.args.save_dir, "train", image_dir)
31 | self.val_image_dir = os.path.join(self.args.save_dir, "val", image_dir)
32 | if not os.path.exists(self.train_image_dir):
33 | os.makedirs(self.train_image_dir)
34 |
35 | if not os.path.exists(self.val_image_dir):
36 | os.makedirs(self.val_image_dir)
37 | self.all_cls = set()
38 |
39 | def _class_to_index(self, mask):
40 | self.all_cls = self.all_cls.union(set(np.unique(mask).tolist()))
41 | # import pdb
42 | # pdb.set_trace()
43 | mask = mask - 1
44 | return mask
45 |
46 | def generate_label(self):
47 | _image_dir = os.path.join(self.args.img_dir, "JPEGImages")
48 | _anno_dir = self.args.anno_dir
49 | annFile = os.path.join(self.args.img_dir, "trainval_merged.json")
50 |
51 | from detail import Detail
52 |
53 | train_detail = Detail(annFile, _image_dir, "train")
54 | train_ids = train_detail.getImgs()
55 |
56 | for img_id in tqdm(train_ids):
57 | mask = loadmat(
58 | os.path.join(_anno_dir, img_id["file_name"].replace(".jpg", ".mat"))
59 | )["LabelMap"]
60 | mask = Image.fromarray(self._class_to_index(mask))
61 | filename = img_id["file_name"]
62 | basename, _ = os.path.splitext(filename)
63 | if filename.endswith(".jpg"):
64 | imgpath = os.path.join(_image_dir, filename)
65 | shutil.copy(imgpath, os.path.join(self.train_image_dir, filename))
66 | mask_png_name = basename + ".tif"
67 | mask.save(os.path.join(self.train_label_dir, mask_png_name))
68 |
69 | val_detail = Detail(annFile, _image_dir, "val")
70 | val_ids = val_detail.getImgs()
71 | for img_id in tqdm(val_ids):
72 | mask = loadmat(
73 | os.path.join(_anno_dir, img_id["file_name"].replace(".jpg", ".mat"))
74 | )["LabelMap"]
75 | mask = Image.fromarray(self._class_to_index(mask))
76 | filename = img_id["file_name"]
77 | basename, _ = os.path.splitext(filename)
78 | if filename.endswith(".jpg"):
79 | imgpath = os.path.join(_image_dir, filename)
80 | shutil.copy(imgpath, os.path.join(self.val_image_dir, filename))
81 | mask_png_name = basename + ".tif"
82 | mask.save(os.path.join(self.val_label_dir, mask_png_name))
83 |
84 |
85 | if __name__ == "__main__":
86 | parser = argparse.ArgumentParser()
87 | parser.add_argument(
88 | "--save_dir",
89 | default=None,
90 | type=str,
91 | dest="save_dir",
92 | help="The directory to save the data.",
93 | )
94 | # ori_root_dir: VOCdevkit/VOC2010
95 | parser.add_argument(
96 | "--img_dir",
97 | default=None,
98 | type=str,
99 | dest="img_dir",
100 | help="The directory of the cityscapes data.",
101 | )
102 | parser.add_argument(
103 | "--anno_dir",
104 | default=None,
105 | type=str,
106 | dest="anno_dir",
107 | help="The directory of the cityscapes data.",
108 | )
109 | args = parser.parse_args()
110 |
111 | pascalcontext_seg_generator = PascalContextGenerator(args)
112 | pascalcontext_seg_generator.generate_label()
113 | print(pascalcontext_seg_generator.all_cls)
114 |
--------------------------------------------------------------------------------
/datasets/prepare_pcontext_sem_seg_59cls.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: Zhen Zhu(zzhu@hust.edu.cn)
4 | # Generate train & val data.
5 |
6 |
7 | import os
8 | import argparse
9 | import shutil
10 | from PIL import Image
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 | LABEL_DIR = "label"
15 | IMAGE_DIR = "image"
16 |
17 |
18 | class PascalContextGenerator(object):
19 | def __init__(self, args, image_dir=IMAGE_DIR, label_dir=LABEL_DIR):
20 | self.args = args
21 | self.train_label_dir = os.path.join(self.args.save_dir, "train", label_dir)
22 | self.val_label_dir = os.path.join(self.args.save_dir, "val", label_dir)
23 | if not os.path.exists(self.train_label_dir):
24 | os.makedirs(self.train_label_dir)
25 |
26 | if not os.path.exists(self.val_label_dir):
27 | os.makedirs(self.val_label_dir)
28 |
29 | self.train_image_dir = os.path.join(self.args.save_dir, "train", image_dir)
30 | self.val_image_dir = os.path.join(self.args.save_dir, "val", image_dir)
31 | if not os.path.exists(self.train_image_dir):
32 | os.makedirs(self.train_image_dir)
33 |
34 | if not os.path.exists(self.val_image_dir):
35 | os.makedirs(self.val_image_dir)
36 |
37 | def _class_to_index(self, mask, _mapping, _key):
38 | # assert the values
39 | values = np.unique(mask)
40 | for i in range(len(values)):
41 | assert values[i] in _mapping
42 | index = np.digitize(mask.ravel(), _mapping, right=True)
43 | mask = _key[index].reshape(mask.shape)
44 | mask = mask - 1
45 | return mask
46 |
47 | def generate_label(self):
48 | _image_dir = os.path.join(self.args.ori_root_dir, "JPEGImages")
49 | annFile = os.path.join(self.args.ori_root_dir, "trainval_merged.json")
50 | _mapping = np.sort(
51 | np.array(
52 | [
53 | 0,
54 | 2,
55 | 259,
56 | 260,
57 | 415,
58 | 324,
59 | 9,
60 | 258,
61 | 144,
62 | 18,
63 | 19,
64 | 22,
65 | 23,
66 | 397,
67 | 25,
68 | 284,
69 | 158,
70 | 159,
71 | 416,
72 | 33,
73 | 162,
74 | 420,
75 | 454,
76 | 295,
77 | 296,
78 | 427,
79 | 44,
80 | 45,
81 | 46,
82 | 308,
83 | 59,
84 | 440,
85 | 445,
86 | 31,
87 | 232,
88 | 65,
89 | 354,
90 | 424,
91 | 68,
92 | 326,
93 | 72,
94 | 458,
95 | 34,
96 | 207,
97 | 80,
98 | 355,
99 | 85,
100 | 347,
101 | 220,
102 | 349,
103 | 360,
104 | 98,
105 | 187,
106 | 104,
107 | 105,
108 | 366,
109 | 189,
110 | 368,
111 | 113,
112 | 115,
113 | ]
114 | )
115 | )
116 | _key = np.array(range(len(_mapping))).astype("uint8")
117 |
118 | from detail import Detail
119 |
120 | train_detail = Detail(annFile, _image_dir, "train")
121 | train_ids = train_detail.getImgs()
122 | for img_id in tqdm(train_ids):
123 | mask = Image.fromarray(
124 | self._class_to_index(
125 | train_detail.getMask(img_id), _mapping=_mapping, _key=_key
126 | )
127 | )
128 | filename = img_id["file_name"]
129 | basename, _ = os.path.splitext(filename)
130 | if filename.endswith(".jpg"):
131 | imgpath = os.path.join(_image_dir, filename)
132 | shutil.copy(imgpath, os.path.join(self.train_image_dir, filename))
133 | mask_png_name = basename + ".png"
134 | mask.save(os.path.join(self.train_label_dir, mask_png_name))
135 |
136 | val_detail = Detail(annFile, _image_dir, "val")
137 | val_ids = val_detail.getImgs()
138 | for img_id in tqdm(val_ids):
139 | mask = Image.fromarray(
140 | self._class_to_index(
141 | val_detail.getMask(img_id), _mapping=_mapping, _key=_key
142 | )
143 | )
144 | filename = img_id["file_name"]
145 | basename, _ = os.path.splitext(filename)
146 | if filename.endswith(".jpg"):
147 | imgpath = os.path.join(_image_dir, filename)
148 | shutil.copy(imgpath, os.path.join(self.val_image_dir, filename))
149 | mask_png_name = basename + ".png"
150 | mask.save(os.path.join(self.val_label_dir, mask_png_name))
151 |
152 |
153 | if __name__ == "__main__":
154 | parser = argparse.ArgumentParser()
155 | parser.add_argument(
156 | "--save_dir",
157 | default=None,
158 | type=str,
159 | dest="save_dir",
160 | help="The directory to save the data.",
161 | )
162 | # ori_root_dir: VOCdevkit/VOC2010
163 | parser.add_argument(
164 | "--ori_root_dir",
165 | default=None,
166 | type=str,
167 | dest="ori_root_dir",
168 | help="The directory of the cityscapes data.",
169 | )
170 |
171 | args = parser.parse_args()
172 |
173 | pascalcontext_seg_generator = PascalContextGenerator(args)
174 | pascalcontext_seg_generator.generate_label()
175 |
--------------------------------------------------------------------------------
/datasets/prepare_voc_sem_seg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import shutil
5 | from functools import partial
6 | from glob import glob
7 |
8 | import mmcv
9 | import numpy as np
10 | from PIL import Image
11 |
12 |
13 | full_clsID_to_trID = {
14 | 0: 255,
15 | 1: 0,
16 | 2: 1,
17 | 3: 2,
18 | 4: 3,
19 | 5: 4,
20 | 6: 5,
21 | 7: 6,
22 | 8: 7,
23 | 9: 8,
24 | 10: 9,
25 | 11: 10,
26 | 12: 11,
27 | 13: 12,
28 | 14: 13,
29 | 15: 14,
30 | 16: 15,
31 | 17: 16,
32 | 18: 17,
33 | 19: 18,
34 | 20: 19,
35 | 255: 255,
36 | }
37 | # 2-cow/motobike, 4-airplane/sofa, 6-cat/tv, 8-train/bottle, 10-
38 | # chair/potted plant
39 | # "aeroplane",
40 | # "bicycle",
41 | # "bird",
42 | # "boat",
43 | # "bottle",
44 | # "bus",
45 | # "car",
46 | # "cat",
47 | # "chair",
48 | # "cow",
49 | # "diningtable",
50 | # "dog",
51 | # "horse",
52 | # "motorbike",
53 | # "person",
54 | # "pottedplant",
55 | # "sheep",
56 | # "sofa",
57 | # "train",
58 | # "tv",
59 |
60 | novel_clsID = [16, 17, 18, 19, 20]
61 | novel_2_clsID = [10, 14]
62 | novel_4_clsID = [1, 10, 14, 18]
63 | novel_6_clsID = [1, 8, 10, 14, 18, 20]
64 | novel_8_clsID = [1, 5, 8, 10, 14, 18, 19, 20]
65 | novel_10_clsID = [1, 5, 8, 9, 10, 14, 16, 18, 19, 20]
66 |
67 | base_clsID = [k for k in full_clsID_to_trID.keys() if k not in novel_clsID + [0, 255]]
68 | base_2_clsID = [k for k in base_clsID if k not in novel_2_clsID]
69 | base_4_clsID = [k for k in base_clsID if k not in novel_4_clsID]
70 | base_6_clsID = [k for k in base_clsID if k not in novel_6_clsID]
71 | base_8_clsID = [k for k in base_clsID if k not in novel_8_clsID]
72 | base_10_clsID = [k for k in base_clsID if k not in novel_10_clsID]
73 |
74 | novel_clsID_to_trID = {k: i for i, k in enumerate(novel_clsID)}
75 | base_clsID_to_trID = {k: i for i, k in enumerate(base_clsID)}
76 |
77 | base_2_clsID_to_trID = {k: i for i, k in enumerate(base_2_clsID)}
78 | base_4_clsID_to_trID = {k: i for i, k in enumerate(base_4_clsID)}
79 | base_6_clsID_to_trID = {k: i for i, k in enumerate(base_6_clsID)}
80 | base_8_clsID_to_trID = {k: i for i, k in enumerate(base_8_clsID)}
81 | base_10_clsID_to_trID = {k: i for i, k in enumerate(base_10_clsID)}
82 |
83 |
84 | def convert_to_trainID(
85 | maskpath, out_mask_dir, is_train, clsID_to_trID=full_clsID_to_trID, suffix=""
86 | ):
87 | mask = np.array(Image.open(maskpath))
88 | mask_copy = np.ones_like(mask, dtype=np.uint8) * 255
89 | for clsID, trID in clsID_to_trID.items():
90 | mask_copy[mask == clsID] = trID
91 | seg_filename = (
92 | osp.join(out_mask_dir, "train" + suffix, osp.basename(maskpath))
93 | if is_train
94 | else osp.join(out_mask_dir, "val" + suffix, osp.basename(maskpath))
95 | )
96 | if len(np.unique(mask_copy)) == 1 and np.unique(mask_copy)[0] == 255:
97 | return
98 | Image.fromarray(mask_copy).save(seg_filename, "PNG")
99 |
100 |
101 | def parse_args():
102 | parser = argparse.ArgumentParser(
103 | description="Convert VOC2021 annotations to mmsegmentation format"
104 | ) # noqa
105 | parser.add_argument("voc_path", help="voc path")
106 | parser.add_argument("-o", "--out_dir", help="output path")
107 | parser.add_argument("--nproc", default=16, type=int, help="number of process")
108 | args = parser.parse_args()
109 | return args
110 |
111 |
112 | def main():
113 | args = parse_args()
114 | voc_path = args.voc_path
115 | nproc = args.nproc
116 | print(full_clsID_to_trID)
117 | print(base_clsID_to_trID)
118 | print(novel_clsID_to_trID)
119 | out_dir = args.out_dir or voc_path
120 | # out_img_dir = osp.join(out_dir, 'images')
121 | out_mask_dir = osp.join(out_dir, "annotations_detectron2")
122 | out_image_dir = osp.join(out_dir, "images_detectron2")
123 | for dir_name in [
124 | "train",
125 | "val",
126 | "train_base",
127 | "train_base_2",
128 | "train_base_4",
129 | "train_base_6",
130 | "train_base_8",
131 | "train_base_10",
132 | "train_novel",
133 | "val_base",
134 | "val_novel",
135 | ]:
136 | os.makedirs(osp.join(out_mask_dir, dir_name), exist_ok=True)
137 | if dir_name in ["train", "val"]:
138 | os.makedirs(osp.join(out_image_dir, dir_name), exist_ok=True)
139 |
140 | train_list = [
141 | osp.join(voc_path, "SegmentationClassAug", f + ".png")
142 | for f in np.loadtxt(osp.join(voc_path, "train.txt"), dtype=np.str).tolist()
143 | ]
144 | test_list = [
145 | osp.join(voc_path, "SegmentationClassAug", f + ".png")
146 | for f in np.loadtxt(osp.join(voc_path, "val.txt"), dtype=np.str).tolist()
147 | ]
148 |
149 | if args.nproc > 1:
150 | mmcv.track_parallel_progress(
151 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
152 | train_list,
153 | nproc=nproc,
154 | )
155 | mmcv.track_parallel_progress(
156 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
157 | test_list,
158 | nproc=nproc,
159 | )
160 | else:
161 | mmcv.track_progress(
162 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
163 | train_list,
164 | )
165 | mmcv.track_progress(
166 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
167 | test_list,
168 | )
169 |
170 | print("Done!")
171 |
172 |
173 | if __name__ == "__main__":
174 | main()
175 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:22.08-py3
2 |
3 | RUN pip install cython scipy shapely timm h5py submitit scikit-image wandb setuptools numpy Pillow pycocotools~=2.0.4 fvcore tabulate tqdm ftfy regex opencv-python open_clip_torch cityscapesscripts tensorboard
4 | RUN pip install 'git+https://github.com/facebookresearch/detectron2.git'
5 | RUN pip install opencv-python-headless==4.5.5.64
6 |
--------------------------------------------------------------------------------
/docker/app.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:22.08-py3
2 |
3 | RUN pip install 'git+https://github.com/facebookresearch/detectron2.git'
4 | RUN pip install cython scipy shapely timm h5py submitit scikit-image wandb setuptools numpy Pillow pycocotools~=2.0.4 fvcore tabulate tqdm ftfy regex open_clip_torch cityscapesscripts tensorboard gradio
5 |
6 | RUN useradd -m -u 1000 user
7 | # Switch to the "user" user
8 | USER user
9 | # Set home to the user's home directory
10 | ENV HOME=/home/user \
11 | PATH=/home/user/.local/bin:$PATH
12 |
13 | # Set the working directory to the user's home directory
14 | WORKDIR $HOME
15 | RUN git clone https://github.com/MendelXu/SAN app
16 |
17 | WORKDIR $HOME/app
18 | ENV GRADIO_SERVER_NAME=0.0.0.0
19 | EXPOSE 7860
20 | RUN echo "gradio app.py">>run.sh
21 | CMD ["script","-c","sh run.sh","/dev/null"]
22 |
--------------------------------------------------------------------------------
/docs/.nojekyll:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Academic Project Page Template
2 | This is an academic paper project page template.
3 |
4 |
5 | Example project pages built using this template are:
6 | - https://www.vision.huji.ac.il/deepsim/
7 | - https://www.vision.huji.ac.il/3d_ads/
8 | - https://www.vision.huji.ac.il/ssrl_ad/
9 | - https://www.vision.huji.ac.il/conffusion/
10 |
11 |
12 | ## Start using the template
13 | To start using the template click on `Use this Template`.
14 |
15 | The template uses html for controlling the content and css for controlling the style.
16 | To edit the websites contents edit the `index.html` file. It contains different HTML "building blocks", use whichever ones you need and comment out the rest.
17 |
18 | ## Components
19 | - Teaser video
20 | - Images Carousel
21 | - Youtube embedding
22 | - Video Carousel
23 | - PDF Poster
24 | - Bibtex citation
25 |
26 | ## Tips:
27 | - The `index.html` file contains comments instructing you what to replace, you should follow these comments.
28 | - The `meta` tags in the `index.html` file are used to provide metadata about your paper
29 | (e.g. helping search engine index the website, showing a preview image when sharing the website, etc.)
30 | - The resolution of images and videos can usually be around 1920-2048, there rarely a need for better resolution that take longer to load.
31 | - All the images and videos you use should be compressed to allow for fast loading of the website (and thus better indexing by search engines). For images, you can use [TinyPNG](https://tinypng.com), for videos you can need to find the tradeoff between size and quality.
32 | - When using large video files (larger than 10MB), it's better to use youtube for hosting the video as serving the video from the website can take time.
33 | - Using a tracker can help you analyze the traffic and see where users came from. [statcounter](https://statcounter.com) is a free, easy to use tracker that takes under 5 minutes to set up.
34 | - This project page can also be made into a github pages website.
35 | - Replace the favicon to one of your choosing (the default one is of the Hebrew University).
36 | - Suggestions, improvements and comments are welcome, simply open an issue or contact me. You can find my contact information at [https://pages.cs.huji.ac.il/eliahu-horwitz/](https://pages.cs.huji.ac.il/eliahu-horwitz/)
37 |
38 | ## Acknowledgments
39 | Parts of this project page were adopted from the [Nerfies](https://nerfies.github.io/) page.
40 |
41 | ## Website License
42 | 
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.
--------------------------------------------------------------------------------
/docs/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/arch.png
--------------------------------------------------------------------------------
/docs/static/css/bulma-carousel.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10}
--------------------------------------------------------------------------------
/docs/static/css/bulma-slider.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff}
--------------------------------------------------------------------------------
/docs/static/css/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Noto Sans', sans-serif;
3 | }
4 |
5 |
6 | .footer .icon-link {
7 | font-size: 25px;
8 | color: #000;
9 | }
10 |
11 | .link-block a {
12 | margin-top: 5px;
13 | margin-bottom: 5px;
14 | }
15 |
16 | .dnerf {
17 | font-variant: small-caps;
18 | }
19 |
20 |
21 | .teaser .hero-body {
22 | padding-top: 0;
23 | padding-bottom: 3rem;
24 | }
25 |
26 | .teaser {
27 | font-family: 'Google Sans', sans-serif;
28 | }
29 |
30 |
31 | .publication-title {
32 | }
33 |
34 | .publication-banner {
35 | max-height: parent;
36 |
37 | }
38 |
39 | .publication-banner video {
40 | position: relative;
41 | left: auto;
42 | top: auto;
43 | transform: none;
44 | object-fit: fit;
45 | }
46 |
47 | .publication-header .hero-body {
48 | }
49 |
50 | .publication-title {
51 | font-family: 'Google Sans', sans-serif;
52 | }
53 |
54 | .publication-authors {
55 | font-family: 'Google Sans', sans-serif;
56 | }
57 |
58 | .publication-venue {
59 | color: #555;
60 | width: fit-content;
61 | font-weight: bold;
62 | }
63 |
64 | .publication-awards {
65 | color: #ff3860;
66 | width: fit-content;
67 | font-weight: bolder;
68 | }
69 |
70 | .publication-authors {
71 | }
72 |
73 | .publication-authors a {
74 | color: hsl(204, 86%, 53%) !important;
75 | }
76 |
77 | .publication-authors a:hover {
78 | text-decoration: underline;
79 | }
80 |
81 | .author-block {
82 | display: inline-block;
83 | }
84 |
85 | .publication-banner img {
86 | }
87 |
88 | .publication-authors {
89 | /*color: #4286f4;*/
90 | }
91 |
92 | .publication-video {
93 | position: relative;
94 | width: 100%;
95 | height: 0;
96 | padding-bottom: 56.25%;
97 |
98 | overflow: hidden;
99 | border-radius: 10px !important;
100 | }
101 |
102 | .publication-video iframe {
103 | position: absolute;
104 | top: 0;
105 | left: 0;
106 | width: 100%;
107 | height: 100%;
108 | }
109 |
110 | .publication-body img {
111 | }
112 |
113 | .results-carousel {
114 | overflow: hidden;
115 | }
116 |
117 | .results-carousel .item {
118 | margin: 5px;
119 | overflow: hidden;
120 | padding: 20px;
121 | font-size: 0;
122 | }
123 |
124 | .results-carousel video {
125 | margin: 0;
126 | }
127 |
128 | .slider-pagination .slider-page {
129 | background: #000000;
130 | }
131 |
132 | .eql-cntrb {
133 | font-size: smaller;
134 | }
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/docs/static/images/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/static/images/favicon.ico
--------------------------------------------------------------------------------
/docs/static/images/grid.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/static/images/grid.jpg
--------------------------------------------------------------------------------
/docs/static/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/static/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 |
4 | $(document).ready(function() {
5 | // Check for click events on the navbar burger icon
6 |
7 | var options = {
8 | slidesToScroll: 1,
9 | slidesToShow: 1,
10 | loop: true,
11 | infinite: true,
12 | autoplay: true,
13 | autoplaySpeed: 5000,
14 | }
15 |
16 | // Initialize all div with carousel class
17 | var carousels = bulmaCarousel.attach('.carousel', options);
18 |
19 | bulmaSlider.attach();
20 |
21 | })
22 |
--------------------------------------------------------------------------------
/docs/static/pdfs/sample.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/static/pdfs/sample.pdf
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | python -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
2 | python -m pip install -r requirements.txt
3 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6'
4 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | import numpy as np
4 |
5 | try:
6 | # ignore ShapelyDeprecationWarning from fvcore
7 | import warnings
8 |
9 | from shapely.errors import ShapelyDeprecationWarning
10 |
11 | warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)
12 | except:
13 | pass
14 | import os
15 |
16 | import huggingface_hub
17 | import torch
18 | from detectron2.checkpoint import DetectionCheckpointer
19 | from detectron2.config import get_cfg
20 | from detectron2.data import MetadataCatalog
21 | from detectron2.engine import DefaultTrainer
22 | from detectron2.projects.deeplab import add_deeplab_config
23 | from detectron2.utils.visualizer import Visualizer, random_color
24 | from huggingface_hub import hf_hub_download
25 | from PIL import Image
26 |
27 | from san import add_san_config
28 | from san.data.datasets.register_coco_stuff_164k import COCO_CATEGORIES
29 |
30 | model_cfg = {
31 | "san_vit_b_16": {
32 | "config_file": "configs/san_clip_vit_res4_coco.yaml",
33 | "model_path": "huggingface:san_vit_b_16.pth",
34 | },
35 | "san_vit_large_16": {
36 | "config_file": "configs/san_clip_vit_large_res4_coco.yaml",
37 | "model_path": "huggingface:san_vit_large_14.pth",
38 | },
39 | }
40 |
41 |
42 | def download_model(model_path: str):
43 | """
44 | Download the model from huggingface hub.
45 | Args:
46 | model_path (str): the model path
47 | Returns:
48 | str: the downloaded model path
49 | """
50 | if "HF_TOKEN" in os.environ:
51 | huggingface_hub.login(token=os.environ["HF_TOKEN"])
52 | model_path = model_path.split(":")[1]
53 | model_path = hf_hub_download("Mendel192/san", filename=model_path)
54 | return model_path
55 |
56 |
57 | def setup(config_file: str, device=None):
58 | """
59 | Create configs and perform basic setups.
60 | """
61 | cfg = get_cfg()
62 | # for poly lr schedule
63 | add_deeplab_config(cfg)
64 | add_san_config(cfg)
65 | cfg.merge_from_file(config_file)
66 | cfg.MODEL.DEVICE = device or "cuda" if torch.cuda.is_available() else "cpu"
67 | cfg.freeze()
68 | return cfg
69 |
70 |
71 | class Predictor(object):
72 | def __init__(self, config_file: str, model_path: str):
73 | """
74 | Args:
75 | config_file (str): the config file path
76 | model_path (str): the model path
77 | """
78 | cfg = setup(config_file)
79 | self.model = DefaultTrainer.build_model(cfg)
80 | if model_path.startswith("huggingface:"):
81 | model_path = download_model(model_path)
82 | print("Loading model from: ", model_path)
83 | DetectionCheckpointer(self.model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
84 | model_path
85 | )
86 | print("Loaded model from: ", model_path)
87 | self.model.eval()
88 | if torch.cuda.is_available():
89 | self.device = torch.device("cuda")
90 | self.model = self.model.cuda()
91 |
92 | def predict(
93 | self,
94 | image_data_or_path: Union[Image.Image, str],
95 | vocabulary: List[str] = [],
96 | augment_vocabulary: Union[str,bool] = True,
97 | output_file: str = None,
98 | ) -> Union[dict, None]:
99 | """
100 | Predict the segmentation result.
101 | Args:
102 | image_data_or_path (Union[Image.Image, str]): the input image or the image path
103 | vocabulary (List[str]): the vocabulary used for the segmentation
104 | augment_vocabulary (bool): whether to augment the vocabulary
105 | output_file (str): the output file path
106 | Returns:
107 | Union[dict, None]: the segmentation result
108 | """
109 | if isinstance(image_data_or_path, str):
110 | image_data = Image.open(image_data_or_path)
111 | else:
112 | image_data = image_data_or_path
113 | w, h = image_data.size
114 | image_tensor: torch.Tensor = self._preprocess(image_data)
115 | vocabulary = list(set([v.lower().strip() for v in vocabulary]))
116 | # remove invalid vocabulary
117 | vocabulary = [v for v in vocabulary if v != ""]
118 | print("vocabulary:", vocabulary)
119 | ori_vocabulary = vocabulary
120 |
121 | if isinstance(augment_vocabulary,str):
122 | vocabulary = self.augment_vocabulary(vocabulary, augment_vocabulary)
123 | else:
124 | vocabulary = self._merge_vocabulary(vocabulary)
125 | if len(ori_vocabulary) == 0:
126 | ori_vocabulary = vocabulary
127 | with torch.no_grad():
128 | result = self.model(
129 | [
130 | {
131 | "image": image_tensor,
132 | "height": h,
133 | "width": w,
134 | "vocabulary": vocabulary,
135 | }
136 | ]
137 | )[0]["sem_seg"]
138 | seg_map = self._postprocess(result, ori_vocabulary)
139 | if output_file:
140 | self.visualize(image_data, seg_map, ori_vocabulary, output_file)
141 | return
142 | return {
143 | "image": image_data,
144 | "sem_seg": seg_map,
145 | "vocabulary": ori_vocabulary,
146 | }
147 |
148 | def visualize(
149 | self,
150 | image: Image.Image,
151 | sem_seg: np.ndarray,
152 | vocabulary: List[str],
153 | output_file: str = None,
154 | mode: str = "overlay",
155 | ) -> Union[Image.Image, None]:
156 | """
157 | Visualize the segmentation result.
158 | Args:
159 | image (Image.Image): the input image
160 | sem_seg (np.ndarray): the segmentation result
161 | vocabulary (List[str]): the vocabulary used for the segmentation
162 | output_file (str): the output file path
163 | mode (str): the visualization mode, can be "overlay" or "mask"
164 | Returns:
165 | Image.Image: the visualization result. If output_file is not None, return None.
166 | """
167 | # add temporary metadata
168 | # set numpy seed to make sure the colors are the same
169 | np.random.seed(0)
170 | colors = [random_color(rgb=True, maximum=255) for _ in range(len(vocabulary))]
171 | MetadataCatalog.get("_temp").set(stuff_classes=vocabulary, stuff_colors=colors)
172 | metadata = MetadataCatalog.get("_temp")
173 | if mode == "overlay":
174 | v = Visualizer(image, metadata)
175 | v = v.draw_sem_seg(sem_seg, area_threshold=0).get_image()
176 | v = Image.fromarray(v)
177 | else:
178 | v = np.zeros((image.size[1], image.size[0], 3), dtype=np.uint8)
179 | labels, areas = np.unique(sem_seg, return_counts=True)
180 | sorted_idxs = np.argsort(-areas).tolist()
181 | labels = labels[sorted_idxs]
182 | for label in filter(lambda l: l < len(metadata.stuff_classes), labels):
183 | v[sem_seg == label] = metadata.stuff_colors[label]
184 | v = Image.fromarray(v)
185 | # remove temporary metadata
186 | MetadataCatalog.remove("_temp")
187 | if output_file is None:
188 | return v
189 | v.save(output_file)
190 | print(f"saved to {output_file}")
191 |
192 | def _merge_vocabulary(self, vocabulary: List[str]) -> List[str]:
193 | default_voc = [c["name"] for c in COCO_CATEGORIES]
194 | return vocabulary + [c for c in default_voc if c not in vocabulary]
195 |
196 | def augment_vocabulary(
197 | self, vocabulary: List[str], aug_set: str = "COCO-all"
198 | ) -> List[str]:
199 | default_voc = [c["name"] for c in COCO_CATEGORIES]
200 | stuff_voc = [
201 | c["name"]
202 | for c in COCO_CATEGORIES
203 | if "isthing" not in c or c["isthing"] == 0
204 | ]
205 | if aug_set == "COCO-all":
206 | return vocabulary + [c for c in default_voc if c not in vocabulary]
207 | elif aug_set == "COCO-stuff":
208 | return vocabulary + [c for c in stuff_voc if c not in vocabulary]
209 | else:
210 | return vocabulary
211 |
212 | def _preprocess(self, image: Image.Image) -> torch.Tensor:
213 | """
214 | Preprocess the input image.
215 | Args:
216 | image (Image.Image): the input image
217 | Returns:
218 | torch.Tensor: the preprocessed image
219 | """
220 | image = image.convert("RGB")
221 | # resize short side to 640
222 | w, h = image.size
223 | if w < h:
224 | image = image.resize((640, int(h * 640 / w)))
225 | else:
226 | image = image.resize((int(w * 640 / h), 640))
227 | image = torch.from_numpy(np.asarray(image)).float()
228 | image = image.permute(2, 0, 1)
229 | return image
230 |
231 | def _postprocess(
232 | self, result: torch.Tensor, ori_vocabulary: List[str]
233 | ) -> np.ndarray:
234 | """
235 | Postprocess the segmentation result.
236 | Args:
237 | result (torch.Tensor): the segmentation result
238 | ori_vocabulary (List[str]): the original vocabulary used for the segmentation
239 | Returns:
240 | np.ndarray: the postprocessed segmentation result
241 | """
242 | result = result.argmax(dim=0).cpu().numpy() # (H, W)
243 | if len(ori_vocabulary) == 0:
244 | return result
245 | result[result >= len(ori_vocabulary)] = len(ori_vocabulary)
246 | return result
247 |
248 |
249 | def pre_download():
250 | """pre downlaod model from huggingface and open_clip to avoid network issue."""
251 | for model_name, model_info in model_cfg.items():
252 | download_model(model_info["model_path"])
253 | cfg = setup(model_info["config_file"])
254 | DefaultTrainer.build_model(cfg)
255 |
256 |
257 | if __name__ == "__main__":
258 | from argparse import ArgumentParser
259 |
260 | parser = ArgumentParser()
261 | parser.add_argument(
262 | "--config-file", type=str, required=True, help="path to config file"
263 | )
264 | parser.add_argument(
265 | "--model-path", type=str, required=True, help="path to model file"
266 | )
267 | parser.add_argument(
268 | "--img-path", type=str, required=True, help="path to image file."
269 | )
270 | parser.add_argument("--aug-vocab", action="store_true", help="augment vocabulary.")
271 | parser.add_argument(
272 | "--vocab",
273 | type=str,
274 | default="",
275 | help="list of category name. seperated with ,.",
276 | )
277 | parser.add_argument(
278 | "--output-file", type=str, default=None, help="path to output file."
279 | )
280 | args = parser.parse_args()
281 | predictor = Predictor(config_file=args.config_file, model_path=args.model_path)
282 | predictor.predict(
283 | args.img_path,
284 | args.vocab.split(","),
285 | args.aug_vocab,
286 | output_file=args.output_file,
287 | )
288 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cython
2 | scipy
3 | shapely
4 | timm
5 | h5py
6 | submitit
7 | scikit-image
8 | wandb
9 | setuptools
10 | numpy==1.22.4
11 | Pillow==9.3.0
12 | pycocotools~=2.0.4
13 | fvcore
14 | tabulate
15 | tqdm
16 | ftfy
17 | regex
18 | opencv-python
19 | open_clip_torch==2.16.0
20 | mmcv==1.3.14
--------------------------------------------------------------------------------
/resources/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/resources/arch.png
--------------------------------------------------------------------------------
/san/__init__.py:
--------------------------------------------------------------------------------
1 | from . import data # register all new datasets
2 | from . import model
3 | from . import utils
4 |
5 | # config
6 | from .config import add_san_config
7 |
8 | # dataset loading
9 | from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
10 | MaskFormerSemanticDatasetMapper,
11 | )
12 |
13 | # models
14 | from .test_time_augmentation import SemanticSegmentorWithTTA
15 |
--------------------------------------------------------------------------------
/san/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from detectron2.config import CfgNode as CN
3 |
4 |
5 | def add_san_config(cfg):
6 | # copied from maskformer2
7 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
8 | # Color augmentation
9 | cfg.INPUT.COLOR_AUG_SSD = False
10 | # We retry random cropping until no single category in semantic segmentation GT occupies more
11 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
12 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
13 | # Pad image and segmentation GT in dataset mapper.
14 | cfg.INPUT.SIZE_DIVISIBILITY = -1
15 |
16 | # solver config
17 | # optimizer
18 | # weight decay on embedding
19 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
20 | cfg.SOLVER.WEIGHT_DECAY_EMBED_GROUP = [
21 | "absolute_pos_embed",
22 | "positional_embedding",
23 | "pos_embed",
24 | "query_embed",
25 | "relative_position_bias_table",
26 | ]
27 | cfg.SOLVER.OPTIMIZER = "ADAMW"
28 | cfg.SOLVER.BACKBONE_MULTIPLIER = 1.0
29 | cfg.SOLVER.CLIP_MULTIPLIER = 1.0
30 | cfg.SOLVER.TEST_IMS_PER_BATCH = 1
31 |
32 | # san
33 | cfg.MODEL.SAN = CN()
34 | cfg.MODEL.SAN.NO_OBJECT_WEIGHT = 0.1
35 | cfg.MODEL.SAN.CLASS_WEIGHT = 2.0
36 | cfg.MODEL.SAN.DICE_WEIGHT = 5.0
37 | cfg.MODEL.SAN.MASK_WEIGHT = 5.0
38 | cfg.MODEL.SAN.TRAIN_NUM_POINTS = 112 * 112
39 | cfg.MODEL.SAN.NUM_CLASSES = 171
40 | cfg.MODEL.SAN.OVERSAMPLE_RATIO = 3.0
41 | cfg.MODEL.SAN.IMPORTANCE_SAMPLE_RATIO = 0.75
42 | cfg.MODEL.SAN.CLIP_MODEL_NAME = "ViT-B/16"
43 | cfg.MODEL.SAN.CLIP_PRETRAINED_NAME = "openai"
44 | cfg.MODEL.SAN.CLIP_TEMPLATE_SET = "vild"
45 | cfg.MODEL.SAN.FEATURE_LAST_LAYER_IDX = 9
46 | cfg.MODEL.SAN.CLIP_FROZEN_EXCLUDE = ["positional_embedding"]
47 | cfg.MODEL.SAN.CLIP_DEEPER_FROZEN_EXCLUDE = []
48 | cfg.MODEL.SAN.REC_CROSS_ATTN = False
49 | cfg.MODEL.SAN.REC_DOWNSAMPLE_METHOD = "max"
50 | cfg.MODEL.SAN.SOS_TOKEN_FORMAT = "cls_token"
51 | cfg.MODEL.SAN.SIZE_DIVISIBILITY = 32
52 | cfg.MODEL.SAN.ASYMETRIC_INPUT = True
53 | cfg.MODEL.SAN.CLIP_RESOLUTION = 0.5
54 |
55 | cfg.MODEL.SAN.SEM_SEG_POSTPROCESS_BEFORE_INFERENCE = True
56 | # side adapter
57 | cfg.MODEL.SIDE_ADAPTER = CN()
58 | cfg.MODEL.SIDE_ADAPTER.NAME = "RegionwiseSideAdapterNetwork"
59 | cfg.MODEL.SIDE_ADAPTER.VIT_NAME = "vit_w240n6d8_patch16"
60 | cfg.MODEL.SIDE_ADAPTER.PRETRAINED = False
61 | cfg.MODEL.SIDE_ADAPTER.IMAGE_SIZE = 640
62 | cfg.MODEL.SIDE_ADAPTER.DROP_PATH_RATE = 0.0
63 | cfg.MODEL.SIDE_ADAPTER.NUM_QUERIES = 100
64 | cfg.MODEL.SIDE_ADAPTER.FUSION_TYPE = "add"
65 | cfg.MODEL.SIDE_ADAPTER.FUSION_MAP = ["0->0", "3->1", "6->2", "9->3"]
66 | cfg.MODEL.SIDE_ADAPTER.DEEP_SUPERVISION_IDXS = [7, 8]
67 |
68 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS = CN()
69 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_HEADS = 12
70 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_LAYERS = 1
71 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.EMBED_CHANNELS = 256
72 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_CHANNELS = 256
73 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_NUM_LAYERS = 3
74 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.RESCALE_ATTN_BIAS = True
75 |
76 | # wandb
77 | cfg.WANDB = CN()
78 | cfg.WANDB.PROJECT = "san"
79 | cfg.WANDB.NAME = None
80 | # use flash attention
81 | cfg.MODEL.FLASH = False
82 |
--------------------------------------------------------------------------------
/san/data/__init__.py:
--------------------------------------------------------------------------------
1 | from . import datasets
2 | from .build import build_detection_train_loader, build_detection_test_loader
3 |
--------------------------------------------------------------------------------
/san/data/dataset_mappers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/san/data/dataset_mappers/__init__.py
--------------------------------------------------------------------------------
/san/data/dataset_mappers/mask_former_semantic_dataset_mapper.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | from torch.nn import functional as F
7 |
8 | from detectron2.config import configurable
9 | from detectron2.data import MetadataCatalog
10 | from detectron2.data import detection_utils as utils
11 | from detectron2.data import transforms as T
12 | from detectron2.projects.point_rend import ColorAugSSDTransform
13 | from detectron2.structures import BitMasks, Instances
14 |
15 | __all__ = ["MaskFormerSemanticDatasetMapper"]
16 |
17 |
18 | class MaskFormerSemanticDatasetMapper:
19 | """
20 | A callable which takes a dataset dict in Detectron2 Dataset format,
21 | and map it into a format used by MaskFormer for semantic segmentation.
22 |
23 | The callable currently does the following:
24 |
25 | 1. Read the image from "file_name"
26 | 2. Applies geometric transforms to the image and annotation
27 | 3. Find and applies suitable cropping to the image and annotation
28 | 4. Prepare image and annotation to Tensors
29 | """
30 |
31 | @configurable
32 | def __init__(
33 | self,
34 | is_train=True,
35 | *,
36 | augmentations,
37 | image_format,
38 | ignore_label,
39 | size_divisibility,
40 | ):
41 | """
42 | NOTE: this interface is experimental.
43 | Args:
44 | is_train: for training or inference
45 | augmentations: a list of augmentations or deterministic transforms to apply
46 | image_format: an image format supported by :func:`detection_utils.read_image`.
47 | ignore_label: the label that is ignored to evaluation
48 | size_divisibility: pad image size to be divisible by this value
49 | """
50 | self.is_train = is_train
51 | self.tfm_gens = augmentations
52 | self.img_format = image_format
53 | self.ignore_label = ignore_label
54 | self.size_divisibility = size_divisibility
55 |
56 | logger = logging.getLogger(__name__)
57 | mode = "training" if is_train else "inference"
58 | logger.info(
59 | f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}"
60 | )
61 |
62 | @classmethod
63 | def from_config(cls, cfg, is_train=True):
64 | # Build augmentation
65 | augs = [
66 | T.ResizeShortestEdge(
67 | cfg.INPUT.MIN_SIZE_TRAIN,
68 | cfg.INPUT.MAX_SIZE_TRAIN,
69 | cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
70 | )
71 | ]
72 | if cfg.INPUT.CROP.ENABLED:
73 | augs.append(
74 | T.RandomCrop_CategoryAreaConstraint(
75 | cfg.INPUT.CROP.TYPE,
76 | cfg.INPUT.CROP.SIZE,
77 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
78 | cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
79 | )
80 | )
81 | if cfg.INPUT.COLOR_AUG_SSD:
82 | augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
83 | augs.append(T.RandomFlip())
84 |
85 | # Assume always applies to the training set.
86 | dataset_names = cfg.DATASETS.TRAIN
87 | meta = MetadataCatalog.get(dataset_names[0])
88 | ignore_label = meta.ignore_label
89 |
90 | ret = {
91 | "is_train": is_train,
92 | "augmentations": augs,
93 | "image_format": cfg.INPUT.FORMAT,
94 | "ignore_label": ignore_label,
95 | "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
96 | }
97 | return ret
98 |
99 | def __call__(self, dataset_dict):
100 | """
101 | Args:
102 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
103 |
104 | Returns:
105 | dict: a format that builtin models in detectron2 accept
106 | """
107 | assert (
108 | self.is_train
109 | ), "MaskFormerSemanticDatasetMapper should only be used for training!"
110 |
111 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
112 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
113 | utils.check_image_size(dataset_dict, image)
114 |
115 | if "sem_seg_file_name" in dataset_dict:
116 | # PyTorch transformation not implemented for uint16, so converting it to double first
117 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype(
118 | "double"
119 | )
120 | else:
121 | sem_seg_gt = None
122 |
123 | if sem_seg_gt is None:
124 | raise ValueError(
125 | "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
126 | dataset_dict["file_name"]
127 | )
128 | )
129 |
130 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
131 | aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
132 | image = aug_input.image
133 | sem_seg_gt = aug_input.sem_seg
134 |
135 | # Pad image and segmentation label here!
136 | image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
137 | if sem_seg_gt is not None:
138 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
139 |
140 | if self.size_divisibility > 0:
141 | image_size = (image.shape[-2], image.shape[-1])
142 | padding_size = [
143 | 0,
144 | self.size_divisibility - image_size[1],
145 | 0,
146 | self.size_divisibility - image_size[0],
147 | ]
148 | image = F.pad(image, padding_size, value=128).contiguous()
149 | if sem_seg_gt is not None:
150 | sem_seg_gt = F.pad(
151 | sem_seg_gt, padding_size, value=self.ignore_label
152 | ).contiguous()
153 |
154 | image_shape = (image.shape[-2], image.shape[-1]) # h, w
155 |
156 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
157 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
158 | # Therefore it's important to use torch.Tensor.
159 | dataset_dict["image"] = image
160 |
161 | if sem_seg_gt is not None:
162 | dataset_dict["sem_seg"] = sem_seg_gt.long()
163 |
164 | if "annotations" in dataset_dict:
165 | raise ValueError(
166 | "Semantic segmentation dataset should not have 'annotations'."
167 | )
168 |
169 | # Prepare per-category binary masks
170 | if sem_seg_gt is not None:
171 | sem_seg_gt = sem_seg_gt.numpy()
172 | instances = Instances(image_shape)
173 | classes = np.unique(sem_seg_gt)
174 | # remove ignored region
175 | classes = classes[classes != self.ignore_label]
176 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
177 |
178 | masks = []
179 | for class_id in classes:
180 | masks.append(sem_seg_gt == class_id)
181 |
182 | if len(masks) == 0:
183 | # Some image does not have annotation (all ignored)
184 | instances.gt_masks = torch.zeros(
185 | (0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])
186 | )
187 | else:
188 | masks = BitMasks(
189 | torch.stack(
190 | [
191 | torch.from_numpy(np.ascontiguousarray(x.copy()))
192 | for x in masks
193 | ]
194 | )
195 | )
196 | instances.gt_masks = masks.tensor
197 |
198 | dataset_dict["instances"] = instances
199 |
200 | return dataset_dict
201 |
--------------------------------------------------------------------------------
/san/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | register_ade20k_full,
3 | register_coco_stuff_164k,
4 | register_pcontext,
5 | register_voc,
6 | )
7 |
--------------------------------------------------------------------------------
/san/data/datasets/register_coco_stuff_164k.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from detectron2.data import DatasetCatalog, MetadataCatalog
4 | from detectron2.data.datasets import load_sem_seg
5 |
6 | COCO_CATEGORIES = [
7 | {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
8 | {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
9 | {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
10 | {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
11 | {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
12 | {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
13 | {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
14 | {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
15 | {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
16 | {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
17 | {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
18 | {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
19 | {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
20 | {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
21 | {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
22 | {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
23 | {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
24 | {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
25 | {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
26 | {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
27 | {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
28 | {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
29 | {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
30 | {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
31 | {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
32 | {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
33 | {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
34 | {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
35 | {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
36 | {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
37 | {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
38 | {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
39 | {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
40 | {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
41 | {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
42 | {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
43 | {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
44 | {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
45 | {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
46 | {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
47 | {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
48 | {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
49 | {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
50 | {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
51 | {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
52 | {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
53 | {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
54 | {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
55 | {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
56 | {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
57 | {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
58 | {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
59 | {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
60 | {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
61 | {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
62 | {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
63 | {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
64 | {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
65 | {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
66 | {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
67 | {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
68 | {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
69 | {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
70 | {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
71 | {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
72 | {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
73 | {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
74 | {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
75 | {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
76 | {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
77 | {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
78 | {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
79 | {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
80 | {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
81 | {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
82 | {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
83 | {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
84 | {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
85 | {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
86 | {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
87 | {"id": 92, "name": "banner", "supercategory": "textile"},
88 | {"id": 93, "name": "blanket", "supercategory": "textile"},
89 | {"id": 94, "name": "branch", "supercategory": "plant"},
90 | {"id": 95, "name": "bridge", "supercategory": "building"},
91 | {"id": 96, "name": "building-other", "supercategory": "building"},
92 | {"id": 97, "name": "bush", "supercategory": "plant"},
93 | {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
94 | {"id": 99, "name": "cage", "supercategory": "structural"},
95 | {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
96 | {"id": 101, "name": "carpet", "supercategory": "floor"},
97 | {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
98 | {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
99 | {"id": 104, "name": "cloth", "supercategory": "textile"},
100 | {"id": 105, "name": "clothes", "supercategory": "textile"},
101 | {"id": 106, "name": "clouds", "supercategory": "sky"},
102 | {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
103 | {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
104 | {"id": 109, "name": "curtain", "supercategory": "textile"},
105 | {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
106 | {"id": 111, "name": "dirt", "supercategory": "ground"},
107 | {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
108 | {"id": 113, "name": "fence", "supercategory": "structural"},
109 | {"id": 114, "name": "floor-marble", "supercategory": "floor"},
110 | {"id": 115, "name": "floor-other", "supercategory": "floor"},
111 | {"id": 116, "name": "floor-stone", "supercategory": "floor"},
112 | {"id": 117, "name": "floor-tile", "supercategory": "floor"},
113 | {"id": 118, "name": "floor-wood", "supercategory": "floor"},
114 | {"id": 119, "name": "flower", "supercategory": "plant"},
115 | {"id": 120, "name": "fog", "supercategory": "water"},
116 | {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
117 | {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
118 | {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
119 | {"id": 124, "name": "grass", "supercategory": "plant"},
120 | {"id": 125, "name": "gravel", "supercategory": "ground"},
121 | {"id": 126, "name": "ground-other", "supercategory": "ground"},
122 | {"id": 127, "name": "hill", "supercategory": "solid"},
123 | {"id": 128, "name": "house", "supercategory": "building"},
124 | {"id": 129, "name": "leaves", "supercategory": "plant"},
125 | {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
126 | {"id": 131, "name": "mat", "supercategory": "textile"},
127 | {"id": 132, "name": "metal", "supercategory": "raw-material"},
128 | {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
129 | {"id": 134, "name": "moss", "supercategory": "plant"},
130 | {"id": 135, "name": "mountain", "supercategory": "solid"},
131 | {"id": 136, "name": "mud", "supercategory": "ground"},
132 | {"id": 137, "name": "napkin", "supercategory": "textile"},
133 | {"id": 138, "name": "net", "supercategory": "structural"},
134 | {"id": 139, "name": "paper", "supercategory": "raw-material"},
135 | {"id": 140, "name": "pavement", "supercategory": "ground"},
136 | {"id": 141, "name": "pillow", "supercategory": "textile"},
137 | {"id": 142, "name": "plant-other", "supercategory": "plant"},
138 | {"id": 143, "name": "plastic", "supercategory": "raw-material"},
139 | {"id": 144, "name": "platform", "supercategory": "ground"},
140 | {"id": 145, "name": "playingfield", "supercategory": "ground"},
141 | {"id": 146, "name": "railing", "supercategory": "structural"},
142 | {"id": 147, "name": "railroad", "supercategory": "ground"},
143 | {"id": 148, "name": "river", "supercategory": "water"},
144 | {"id": 149, "name": "road", "supercategory": "ground"},
145 | {"id": 150, "name": "rock", "supercategory": "solid"},
146 | {"id": 151, "name": "roof", "supercategory": "building"},
147 | {"id": 152, "name": "rug", "supercategory": "textile"},
148 | {"id": 153, "name": "salad", "supercategory": "food-stuff"},
149 | {"id": 154, "name": "sand", "supercategory": "ground"},
150 | {"id": 155, "name": "sea", "supercategory": "water"},
151 | {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
152 | {"id": 157, "name": "sky-other", "supercategory": "sky"},
153 | {"id": 158, "name": "skyscraper", "supercategory": "building"},
154 | {"id": 159, "name": "snow", "supercategory": "ground"},
155 | {"id": 160, "name": "solid-other", "supercategory": "solid"},
156 | {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
157 | {"id": 162, "name": "stone", "supercategory": "solid"},
158 | {"id": 163, "name": "straw", "supercategory": "plant"},
159 | {"id": 164, "name": "structural-other", "supercategory": "structural"},
160 | {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
161 | {"id": 166, "name": "tent", "supercategory": "building"},
162 | {"id": 167, "name": "textile-other", "supercategory": "textile"},
163 | {"id": 168, "name": "towel", "supercategory": "textile"},
164 | {"id": 169, "name": "tree", "supercategory": "plant"},
165 | {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
166 | {"id": 171, "name": "wall-brick", "supercategory": "wall"},
167 | {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
168 | {"id": 173, "name": "wall-other", "supercategory": "wall"},
169 | {"id": 174, "name": "wall-panel", "supercategory": "wall"},
170 | {"id": 175, "name": "wall-stone", "supercategory": "wall"},
171 | {"id": 176, "name": "wall-tile", "supercategory": "wall"},
172 | {"id": 177, "name": "wall-wood", "supercategory": "wall"},
173 | {"id": 178, "name": "water-other", "supercategory": "water"},
174 | {"id": 179, "name": "waterdrops", "supercategory": "water"},
175 | {"id": 180, "name": "window-blind", "supercategory": "window"},
176 | {"id": 181, "name": "window-other", "supercategory": "window"},
177 | {"id": 182, "name": "wood", "supercategory": "solid"},
178 | ]
179 |
180 |
181 | def _get_coco_stuff_meta():
182 | # Id 0 is reserved for ignore_label, we change ignore_label for 0
183 | # to 255 in our pre-processing.
184 | stuff_ids = [k["id"] for k in COCO_CATEGORIES]
185 | assert len(stuff_ids) == 171, len(stuff_ids)
186 |
187 | # For semantic segmentation, this mapping maps from contiguous stuff id
188 | # (in [0, 91], used in models) to ids in the dataset (used for processing results)
189 | stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
190 | stuff_classes = [k["name"] for k in COCO_CATEGORIES]
191 |
192 | ret = {
193 | "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
194 | "stuff_classes": stuff_classes,
195 | }
196 | return ret
197 |
198 |
199 | def register_all_coco_stuff_164k(root):
200 | root = os.path.join(root, "coco")
201 | meta = _get_coco_stuff_meta()
202 |
203 | for name, image_dirname, sem_seg_dirname in [
204 | ("train", "train2017", "stuffthingmaps_detectron2/train2017"),
205 | ("test", "val2017", "stuffthingmaps_detectron2/val2017"),
206 | ]:
207 | image_dir = os.path.join(root, image_dirname)
208 | gt_dir = os.path.join(root, sem_seg_dirname)
209 | all_name = f"coco_2017_{name}_stuff_sem_seg"
210 | DatasetCatalog.register(
211 | all_name,
212 | lambda x=image_dir, y=gt_dir: load_sem_seg(
213 | y, x, gt_ext="png", image_ext="jpg"
214 | ),
215 | )
216 | MetadataCatalog.get(all_name).set(
217 | image_root=image_dir,
218 | sem_seg_root=gt_dir,
219 | evaluator_type="sem_seg",
220 | ignore_label=255,
221 | **meta,
222 | )
223 |
224 |
225 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
226 | register_all_coco_stuff_164k(_root)
227 |
--------------------------------------------------------------------------------
/san/data/datasets/register_pcontext.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from detectron2.data import DatasetCatalog, MetadataCatalog
4 | from detectron2.data.datasets import load_sem_seg
5 |
6 | PCONTEXT_SEM_SEG_CATEGORIES = [
7 | "aeroplane",
8 | "bag",
9 | "bed",
10 | "bedclothes",
11 | "bench",
12 | "bicycle",
13 | "bird",
14 | "boat",
15 | "book",
16 | "bottle",
17 | "building",
18 | "bus",
19 | "cabinet",
20 | "car",
21 | "cat",
22 | "ceiling",
23 | "chair",
24 | "cloth",
25 | "computer",
26 | "cow",
27 | "cup",
28 | "curtain",
29 | "dog",
30 | "door",
31 | "fence",
32 | "floor",
33 | "flower",
34 | "food",
35 | "grass",
36 | "ground",
37 | "horse",
38 | "keyboard",
39 | "light",
40 | "motorbike",
41 | "mountain",
42 | "mouse",
43 | "person",
44 | "plate",
45 | "platform",
46 | "pottedplant",
47 | "road",
48 | "rock",
49 | "sheep",
50 | "shelves",
51 | "sidewalk",
52 | "sign",
53 | "sky",
54 | "snow",
55 | "sofa",
56 | "diningtable",
57 | "track",
58 | "train",
59 | "tree",
60 | "truck",
61 | "tvmonitor",
62 | "wall",
63 | "water",
64 | "window",
65 | "wood",
66 | ]
67 |
68 | PCONTEXT_FULL_SEM_SEG_CATEGORIES = [
69 | "accordion",
70 | "aeroplane",
71 | "air conditioner",
72 | "antenna",
73 | "artillery",
74 | "ashtray",
75 | "atrium",
76 | "baby carriage",
77 | "bag",
78 | "ball",
79 | "balloon",
80 | "bamboo weaving",
81 | "barrel",
82 | "baseball bat",
83 | "basket",
84 | "basketball backboard",
85 | "bathtub",
86 | "bed",
87 | "bedclothes",
88 | "beer",
89 | "bell",
90 | "bench",
91 | "bicycle",
92 | "binoculars",
93 | "bird",
94 | "bird cage",
95 | "bird feeder",
96 | "bird nest",
97 | "blackboard",
98 | "board",
99 | "boat",
100 | "bone",
101 | "book",
102 | "bottle",
103 | "bottle opener",
104 | "bowl",
105 | "box",
106 | "bracelet",
107 | "brick",
108 | "bridge",
109 | "broom",
110 | "brush",
111 | "bucket",
112 | "building",
113 | "bus",
114 | "cabinet",
115 | "cabinet door",
116 | "cage",
117 | "cake",
118 | "calculator",
119 | "calendar",
120 | "camel",
121 | "camera",
122 | "camera lens",
123 | "can",
124 | "candle",
125 | "candle holder",
126 | "cap",
127 | "car",
128 | "card",
129 | "cart",
130 | "case",
131 | "casette recorder",
132 | "cash register",
133 | "cat",
134 | "cd",
135 | "cd player",
136 | "ceiling",
137 | "cell phone",
138 | "cello",
139 | "chain",
140 | "chair",
141 | "chessboard",
142 | "chicken",
143 | "chopstick",
144 | "clip",
145 | "clippers",
146 | "clock",
147 | "closet",
148 | "cloth",
149 | "clothes tree",
150 | "coffee",
151 | "coffee machine",
152 | "comb",
153 | "computer",
154 | "concrete",
155 | "cone",
156 | "container",
157 | "control booth",
158 | "controller",
159 | "cooker",
160 | "copying machine",
161 | "coral",
162 | "cork",
163 | "corkscrew",
164 | "counter",
165 | "court",
166 | "cow",
167 | "crabstick",
168 | "crane",
169 | "crate",
170 | "cross",
171 | "crutch",
172 | "cup",
173 | "curtain",
174 | "cushion",
175 | "cutting board",
176 | "dais",
177 | "disc",
178 | "disc case",
179 | "dishwasher",
180 | "dock",
181 | "dog",
182 | "dolphin",
183 | "door",
184 | "drainer",
185 | "dray",
186 | "drink dispenser",
187 | "drinking machine",
188 | "drop",
189 | "drug",
190 | "drum",
191 | "drum kit",
192 | "duck",
193 | "dumbbell",
194 | "earphone",
195 | "earrings",
196 | "egg",
197 | "electric fan",
198 | "electric iron",
199 | "electric pot",
200 | "electric saw",
201 | "electronic keyboard",
202 | "engine",
203 | "envelope",
204 | "equipment",
205 | "escalator",
206 | "exhibition booth",
207 | "extinguisher",
208 | "eyeglass",
209 | "fan",
210 | "faucet",
211 | "fax machine",
212 | "fence",
213 | "ferris wheel",
214 | "fire extinguisher",
215 | "fire hydrant",
216 | "fire place",
217 | "fish",
218 | "fish tank",
219 | "fishbowl",
220 | "fishing net",
221 | "fishing pole",
222 | "flag",
223 | "flagstaff",
224 | "flame",
225 | "flashlight",
226 | "floor",
227 | "flower",
228 | "fly",
229 | "foam",
230 | "food",
231 | "footbridge",
232 | "forceps",
233 | "fork",
234 | "forklift",
235 | "fountain",
236 | "fox",
237 | "frame",
238 | "fridge",
239 | "frog",
240 | "fruit",
241 | "funnel",
242 | "furnace",
243 | "game controller",
244 | "game machine",
245 | "gas cylinder",
246 | "gas hood",
247 | "gas stove",
248 | "gift box",
249 | "glass",
250 | "glass marble",
251 | "globe",
252 | "glove",
253 | "goal",
254 | "grandstand",
255 | "grass",
256 | "gravestone",
257 | "ground",
258 | "guardrail",
259 | "guitar",
260 | "gun",
261 | "hammer",
262 | "hand cart",
263 | "handle",
264 | "handrail",
265 | "hanger",
266 | "hard disk drive",
267 | "hat",
268 | "hay",
269 | "headphone",
270 | "heater",
271 | "helicopter",
272 | "helmet",
273 | "holder",
274 | "hook",
275 | "horse",
276 | "horse-drawn carriage",
277 | "hot-air balloon",
278 | "hydrovalve",
279 | "ice",
280 | "inflator pump",
281 | "ipod",
282 | "iron",
283 | "ironing board",
284 | "jar",
285 | "kart",
286 | "kettle",
287 | "key",
288 | "keyboard",
289 | "kitchen range",
290 | "kite",
291 | "knife",
292 | "knife block",
293 | "ladder",
294 | "ladder truck",
295 | "ladle",
296 | "laptop",
297 | "leaves",
298 | "lid",
299 | "life buoy",
300 | "light",
301 | "light bulb",
302 | "lighter",
303 | "line",
304 | "lion",
305 | "lobster",
306 | "lock",
307 | "machine",
308 | "mailbox",
309 | "mannequin",
310 | "map",
311 | "mask",
312 | "mat",
313 | "match book",
314 | "mattress",
315 | "menu",
316 | "metal",
317 | "meter box",
318 | "microphone",
319 | "microwave",
320 | "mirror",
321 | "missile",
322 | "model",
323 | "money",
324 | "monkey",
325 | "mop",
326 | "motorbike",
327 | "mountain",
328 | "mouse",
329 | "mouse pad",
330 | "musical instrument",
331 | "napkin",
332 | "net",
333 | "newspaper",
334 | "oar",
335 | "ornament",
336 | "outlet",
337 | "oven",
338 | "oxygen bottle",
339 | "pack",
340 | "pan",
341 | "paper",
342 | "paper box",
343 | "paper cutter",
344 | "parachute",
345 | "parasol",
346 | "parterre",
347 | "patio",
348 | "pelage",
349 | "pen",
350 | "pen container",
351 | "pencil",
352 | "person",
353 | "photo",
354 | "piano",
355 | "picture",
356 | "pig",
357 | "pillar",
358 | "pillow",
359 | "pipe",
360 | "pitcher",
361 | "plant",
362 | "plastic",
363 | "plate",
364 | "platform",
365 | "player",
366 | "playground",
367 | "pliers",
368 | "plume",
369 | "poker",
370 | "poker chip",
371 | "pole",
372 | "pool table",
373 | "postcard",
374 | "poster",
375 | "pot",
376 | "pottedplant",
377 | "printer",
378 | "projector",
379 | "pumpkin",
380 | "rabbit",
381 | "racket",
382 | "radiator",
383 | "radio",
384 | "rail",
385 | "rake",
386 | "ramp",
387 | "range hood",
388 | "receiver",
389 | "recorder",
390 | "recreational machines",
391 | "remote control",
392 | "road",
393 | "robot",
394 | "rock",
395 | "rocket",
396 | "rocking horse",
397 | "rope",
398 | "rug",
399 | "ruler",
400 | "runway",
401 | "saddle",
402 | "sand",
403 | "saw",
404 | "scale",
405 | "scanner",
406 | "scissors",
407 | "scoop",
408 | "screen",
409 | "screwdriver",
410 | "sculpture",
411 | "scythe",
412 | "sewer",
413 | "sewing machine",
414 | "shed",
415 | "sheep",
416 | "shell",
417 | "shelves",
418 | "shoe",
419 | "shopping cart",
420 | "shovel",
421 | "sidecar",
422 | "sidewalk",
423 | "sign",
424 | "signal light",
425 | "sink",
426 | "skateboard",
427 | "ski",
428 | "sky",
429 | "sled",
430 | "slippers",
431 | "smoke",
432 | "snail",
433 | "snake",
434 | "snow",
435 | "snowmobiles",
436 | "sofa",
437 | "spanner",
438 | "spatula",
439 | "speaker",
440 | "speed bump",
441 | "spice container",
442 | "spoon",
443 | "sprayer",
444 | "squirrel",
445 | "stage",
446 | "stair",
447 | "stapler",
448 | "stick",
449 | "sticky note",
450 | "stone",
451 | "stool",
452 | "stove",
453 | "straw",
454 | "stretcher",
455 | "sun",
456 | "sunglass",
457 | "sunshade",
458 | "surveillance camera",
459 | "swan",
460 | "sweeper",
461 | "swim ring",
462 | "swimming pool",
463 | "swing",
464 | "switch",
465 | "table",
466 | "tableware",
467 | "tank",
468 | "tap",
469 | "tape",
470 | "tarp",
471 | "telephone",
472 | "telephone booth",
473 | "tent",
474 | "tire",
475 | "toaster",
476 | "toilet",
477 | "tong",
478 | "tool",
479 | "toothbrush",
480 | "towel",
481 | "toy",
482 | "toy car",
483 | "track",
484 | "train",
485 | "trampoline",
486 | "trash bin",
487 | "tray",
488 | "tree",
489 | "tricycle",
490 | "tripod",
491 | "trophy",
492 | "truck",
493 | "tube",
494 | "turtle",
495 | "tvmonitor",
496 | "tweezers",
497 | "typewriter",
498 | "umbrella",
499 | "unknown",
500 | "vacuum cleaner",
501 | "vending machine",
502 | "video camera",
503 | "video game console",
504 | "video player",
505 | "video tape",
506 | "violin",
507 | "wakeboard",
508 | "wall",
509 | "wallet",
510 | "wardrobe",
511 | "washing machine",
512 | "watch",
513 | "water",
514 | "water dispenser",
515 | "water pipe",
516 | "water skate board",
517 | "watermelon",
518 | "whale",
519 | "wharf",
520 | "wheel",
521 | "wheelchair",
522 | "window",
523 | "window blinds",
524 | "wineglass",
525 | "wire",
526 | "wood",
527 | "wool",
528 | ]
529 |
530 |
531 | def register_all_pcontext_59(root):
532 | root = os.path.join(root, "pcontext")
533 | for name, dirname in [("train", "train"), ("val", "val")]:
534 | image_dir = os.path.join(root, dirname, "image")
535 | gt_dir = os.path.join(root, dirname, "label")
536 | name = f"pcontext_sem_seg_{name}"
537 | DatasetCatalog.register(
538 | name,
539 | lambda x=image_dir, y=gt_dir: load_sem_seg(
540 | y, x, gt_ext="png", image_ext="jpg"
541 | ),
542 | )
543 | MetadataCatalog.get(name).set(
544 | stuff_classes=PCONTEXT_SEM_SEG_CATEGORIES[:],
545 | image_root=image_dir,
546 | sem_seg_root=gt_dir,
547 | evaluator_type="sem_seg",
548 | ignore_label=255,
549 | )
550 |
551 |
552 | def register_all_pcontext_full(root):
553 | root = os.path.join(root, "pcontext_full")
554 | for name, dirname in [("train", "train"), ("val", "val")]:
555 | image_dir = os.path.join(root, dirname, "image")
556 | gt_dir = os.path.join(root, dirname, "label")
557 | name = f"pcontext_full_sem_seg_{name}"
558 | DatasetCatalog.register(
559 | name,
560 | lambda x=image_dir, y=gt_dir: load_sem_seg(
561 | y, x, gt_ext="tif", image_ext="jpg"
562 | ),
563 | )
564 | MetadataCatalog.get(name).set(
565 | stuff_classes=PCONTEXT_FULL_SEM_SEG_CATEGORIES[:],
566 | image_root=image_dir,
567 | sem_seg_root=gt_dir,
568 | evaluator_type="sem_seg",
569 | ignore_label=65535,
570 | )
571 |
572 |
573 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
574 | register_all_pcontext_59(_root)
575 | register_all_pcontext_full(_root)
576 |
--------------------------------------------------------------------------------
/san/data/datasets/register_voc.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from detectron2.data import DatasetCatalog, MetadataCatalog
4 | from detectron2.data.datasets import load_sem_seg
5 |
6 | CLASS_NAMES = (
7 | "aeroplane",
8 | "bicycle",
9 | "bird",
10 | "boat",
11 | "bottle",
12 | "bus",
13 | "car",
14 | "cat",
15 | "chair",
16 | "cow",
17 | "diningtable",
18 | "dog",
19 | "horse",
20 | "motorbike",
21 | "person",
22 | "pottedplant",
23 | "sheep",
24 | "sofa",
25 | "train",
26 | "tv",
27 | )
28 |
29 |
30 | def _get_voc_meta(cat_list):
31 | ret = {
32 | "stuff_classes": cat_list,
33 | }
34 | return ret
35 |
36 |
37 | def register_all_voc_11k(root):
38 | root = os.path.join(root, "VOC2012")
39 | meta = _get_voc_meta(CLASS_NAMES)
40 |
41 | for name, image_dirname, sem_seg_dirname in [
42 | ("train", "JPEGImages", "annotations_detectron2/train"),
43 | ("val", "JPEGImages", "annotations_detectron2/val"),
44 | ]:
45 | image_dir = os.path.join(root, image_dirname)
46 | gt_dir = os.path.join(root, sem_seg_dirname)
47 | all_name = f"voc_sem_seg_{name}"
48 | DatasetCatalog.register(
49 | all_name,
50 | lambda x=image_dir, y=gt_dir: load_sem_seg(
51 | y, x, gt_ext="png", image_ext="jpg"
52 | ),
53 | )
54 | MetadataCatalog.get(all_name).set(
55 | image_root=image_dir,
56 | sem_seg_root=gt_dir,
57 | evaluator_type="sem_seg",
58 | ignore_label=255,
59 | **meta,
60 | )
61 |
62 |
63 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
64 | register_all_voc_11k(_root)
65 |
--------------------------------------------------------------------------------
/san/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .side_adapter import *
2 | from .san import SAN
3 |
--------------------------------------------------------------------------------
/san/model/clip_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import get_predefined_templates
2 | from .visual import FeatureExtractor, RecWithAttnbiasHead
3 | from .classifier import PredefinedOvClassifier, LearnableBgOvClassifier
4 |
--------------------------------------------------------------------------------
/san/model/clip_utils/classifier.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from torch.nn import functional as F
3 | import torch
4 | from detectron2.utils.registry import Registry
5 | from open_clip.model import CLIP
6 | from torch import nn
7 | from .utils import get_labelset_from_dataset
8 | from open_clip import tokenizer
9 |
10 |
11 | class PredefinedOvClassifier(nn.Module):
12 | def __init__(
13 | self,
14 | clip_model: CLIP,
15 | cache_feature: bool = True,
16 | templates: List[str] = ["a photo of {}"],
17 | ):
18 | # copy the clip model to this module
19 | super().__init__()
20 | for name, child in clip_model.named_children():
21 | if "visual" not in name:
22 | self.add_module(name, child)
23 | for name, param in clip_model.named_parameters(recurse=False):
24 | self.register_parameter(name, param)
25 | for name, buffer in clip_model.named_buffers(recurse=False):
26 | self.register_buffer(name, buffer)
27 | self.templates = templates
28 | self._freeze()
29 |
30 | self.cache_feature = cache_feature
31 | if self.cache_feature:
32 | self.cache = {}
33 |
34 | def forward(self, category_names: List[str]):
35 | text_embed_bucket = []
36 | for template in self.templates:
37 | noun_tokens = tokenizer.tokenize(
38 | [template.format(noun) for noun in category_names]
39 | )
40 | text_inputs = noun_tokens.to(self.text_projection.data.device)
41 | text_embed = self.encode_text(text_inputs, normalize=True)
42 | text_embed_bucket.append(text_embed)
43 | text_embed = torch.stack(text_embed_bucket).mean(dim=0)
44 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
45 | return text_embed
46 |
47 | @torch.no_grad()
48 | def encode_text(self, text, normalize: bool = False):
49 | cast_dtype = self.transformer.get_cast_dtype()
50 |
51 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
52 |
53 | x = x + self.positional_embedding.to(cast_dtype)
54 | x = x.permute(1, 0, 2) # NLD -> LND
55 | x = self.transformer(x, attn_mask=self.attn_mask)
56 | x = x.permute(1, 0, 2) # LND -> NLD
57 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
58 | # take features from the eot embedding (eot_token is the highest number in each sequence)
59 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
60 | return F.normalize(x, dim=-1) if normalize else x
61 |
62 | def get_classifier_by_vocabulary(self, vocabulary: List[str]):
63 | if self.cache_feature:
64 | new_words = [word for word in vocabulary if word not in self.cache]
65 | if len(new_words) > 0:
66 | cat_embeddings = self(new_words)
67 | self.cache.update(dict(zip(new_words, cat_embeddings)))
68 | cat_embeddings = torch.stack([self.cache[word] for word in vocabulary])
69 | else:
70 | cat_embeddings = self(vocabulary)
71 | return cat_embeddings
72 |
73 | def get_classifier_by_dataset_name(self, dataset_name: str):
74 | if self.cache_feature:
75 | if dataset_name not in self.cache:
76 | category_names = get_labelset_from_dataset(dataset_name)
77 | cat_embeddings = self(category_names)
78 | self.cache[dataset_name] = cat_embeddings
79 | cat_embeddings = self.cache[dataset_name]
80 | else:
81 | category_names = get_labelset_from_dataset(dataset_name)
82 | cat_embeddings = self(category_names)
83 | return cat_embeddings
84 |
85 | def _freeze(self):
86 | for param in self.parameters():
87 | param.requires_grad = False
88 |
89 | def train(self, mode=True):
90 | super().train(False)
91 |
92 |
93 | class LearnableBgOvClassifier(PredefinedOvClassifier):
94 | def __init__(
95 | self,
96 | clip_model: CLIP,
97 | cache_feature: bool = True,
98 | templates: List[str] = ["a photo of {}"],
99 | ):
100 | super().__init__(clip_model, cache_feature, templates)
101 | self.bg_embed = nn.Parameter(torch.randn(1, self.text_projection.shape[0]))
102 | nn.init.normal_(
103 | self.bg_embed,
104 | std=self.bg_embed.shape[1] ** -0.5,
105 | )
106 |
107 | def get_classifier_by_vocabulary(self, vocabulary: List[str]):
108 | cat_embedding = super().get_classifier_by_vocabulary(vocabulary)
109 | cat_embedding = torch.cat([cat_embedding, self.bg_embed], dim=0)
110 | cat_embedding = F.normalize(cat_embedding, p=2, dim=-1)
111 | return cat_embedding
112 |
113 | def get_classifier_by_dataset_name(self, dataset_name: str):
114 | cat_embedding = super().get_classifier_by_dataset_name(dataset_name)
115 | cat_embedding = torch.cat([cat_embedding, self.bg_embed], dim=0)
116 | cat_embedding = F.normalize(cat_embedding, p=2, dim=-1)
117 | return cat_embedding
118 |
--------------------------------------------------------------------------------
/san/model/clip_utils/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from detectron2.data import MetadataCatalog
3 |
4 |
5 | PREDEFINED_LABELSETS = {}
6 |
7 | PREDEFINED_TEMPLATES = {
8 | "imagenet": [
9 | "a bad photo of a {}.",
10 | "a photo of many {}.",
11 | "a sculpture of a {}.",
12 | "a photo of the hard to see {}.",
13 | "a low resolution photo of the {}.",
14 | "a rendering of a {}.",
15 | "graffiti of a {}.",
16 | "a bad photo of the {}.",
17 | "a cropped photo of the {}.",
18 | "a tattoo of a {}.",
19 | "the embroidered {}.",
20 | "a photo of a hard to see {}.",
21 | "a bright photo of a {}.",
22 | "a photo of a clean {}.",
23 | "a photo of a dirty {}.",
24 | "a dark photo of the {}.",
25 | "a drawing of a {}.",
26 | "a photo of my {}.",
27 | "the plastic {}.",
28 | "a photo of the cool {}.",
29 | "a close-up photo of a {}.",
30 | "a black and white photo of the {}.",
31 | "a painting of the {}.",
32 | "a painting of a {}.",
33 | "a pixelated photo of the {}.",
34 | "a sculpture of the {}.",
35 | "a bright photo of the {}.",
36 | "a cropped photo of a {}.",
37 | "a plastic {}.",
38 | "a photo of the dirty {}.",
39 | "a jpeg corrupted photo of a {}.",
40 | "a blurry photo of the {}.",
41 | "a photo of the {}.",
42 | "a good photo of the {}.",
43 | "a rendering of the {}.",
44 | "a {} in a video game.",
45 | "a photo of one {}.",
46 | "a doodle of a {}.",
47 | "a close-up photo of the {}.",
48 | "a photo of a {}.",
49 | "the origami {}.",
50 | "the {} in a video game.",
51 | "a sketch of a {}.",
52 | "a doodle of the {}.",
53 | "a origami {}.",
54 | "a low resolution photo of a {}.",
55 | "the toy {}.",
56 | "a rendition of the {}.",
57 | "a photo of the clean {}.",
58 | "a photo of a large {}.",
59 | "a rendition of a {}.",
60 | "a photo of a nice {}.",
61 | "a photo of a weird {}.",
62 | "a blurry photo of a {}.",
63 | "a cartoon {}.",
64 | "art of a {}.",
65 | "a sketch of the {}.",
66 | "a embroidered {}.",
67 | "a pixelated photo of a {}.",
68 | "itap of the {}.",
69 | "a jpeg corrupted photo of the {}.",
70 | "a good photo of a {}.",
71 | "a plushie {}.",
72 | "a photo of the nice {}.",
73 | "a photo of the small {}.",
74 | "a photo of the weird {}.",
75 | "the cartoon {}.",
76 | "art of the {}.",
77 | "a drawing of the {}.",
78 | "a photo of the large {}.",
79 | "a black and white photo of a {}.",
80 | "the plushie {}.",
81 | "a dark photo of a {}.",
82 | "itap of a {}.",
83 | "graffiti of the {}.",
84 | "a toy {}.",
85 | "itap of my {}.",
86 | "a photo of a cool {}.",
87 | "a photo of a small {}.",
88 | "a tattoo of the {}.",
89 | ],
90 | "vild": [
91 | "a photo of a {}.",
92 | "This is a photo of a {}",
93 | "There is a {} in the scene",
94 | "There is the {} in the scene",
95 | "a photo of a {} in the scene",
96 | "a photo of a small {}.",
97 | "a photo of a medium {}.",
98 | "a photo of a large {}.",
99 | "This is a photo of a small {}.",
100 | "This is a photo of a medium {}.",
101 | "This is a photo of a large {}.",
102 | "There is a small {} in the scene.",
103 | "There is a medium {} in the scene.",
104 | "There is a large {} in the scene.",
105 | ],
106 | }
107 |
108 |
109 | def get_labelset_from_dataset(dataset_name: str) -> List[str]:
110 | if dataset_name not in PREDEFINED_LABELSETS:
111 | try:
112 | labelset = [
113 | c.strip() for c in MetadataCatalog.get(dataset_name).stuff_classes
114 | ]
115 | except:
116 | labelset = [
117 | c.strip() for c in MetadataCatalog.get(dataset_name).thing_classes
118 | ]
119 | else:
120 | labelset = PREDEFINED_LABELSETS[dataset_name]
121 | return labelset
122 |
123 |
124 | def get_predefined_templates(template_set_name: str) -> List[str]:
125 | if template_set_name not in PREDEFINED_TEMPLATES:
126 | raise ValueError(f"Template set {template_set_name} not found")
127 | return PREDEFINED_TEMPLATES[template_set_name]
128 |
--------------------------------------------------------------------------------
/san/model/clip_utils/visual.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from open_clip.transformer import VisionTransformer
6 | from detectron2.layers import ShapeSpec
7 | from ..attn_helper import cross_attn_layer, downsample2d, resize_pos_embed2d
8 |
9 |
10 | class ClipOutput(dict):
11 | def __init__(self, spacial_shape, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | self.spacial_shape = spacial_shape
14 |
15 | def save(self, idx: int, clip_feat: torch.Tensor):
16 | l, n, c = clip_feat.shape
17 | self[idx] = (
18 | clip_feat[1:].permute(1, 2, 0).reshape(n, c, *self.spacial_shape)
19 | ) # n, c, h, w
20 | self[f"{idx}_cls_token"] = clip_feat[0:1] # 1, n, c
21 |
22 |
23 | class FeatureExtractor(nn.Module):
24 | def __init__(
25 | self,
26 | visual_encoder: VisionTransformer,
27 | last_layer_idx: int = -1,
28 | frozen_exclude=[],
29 | ):
30 | super().__init__()
31 | self.output_tokens = visual_encoder.output_tokens
32 | self.image_size = visual_encoder.image_size
33 | self.patch_size = visual_encoder.patch_size
34 | self.grid_size = visual_encoder.grid_size
35 | self.num_features = visual_encoder.ln_pre.normalized_shape[0]
36 |
37 | self.input_patchnorm = visual_encoder.input_patchnorm
38 | self.patchnorm_pre_ln = visual_encoder.patchnorm_pre_ln
39 | self.conv1 = visual_encoder.conv1
40 |
41 | # class embeddings and positional embeddings
42 | self.class_embedding = visual_encoder.class_embedding
43 | self.positional_embedding = visual_encoder.positional_embedding
44 | # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
45 | self.patch_dropout = visual_encoder.patch_dropout
46 | self.ln_pre = visual_encoder.ln_pre
47 | if last_layer_idx == -1:
48 | self.resblocks = visual_encoder.transformer.resblocks
49 | self.last_output_idx = len(self.resblocks) + 1
50 | else:
51 | self.resblocks = visual_encoder.transformer.resblocks[:last_layer_idx]
52 | self.last_output_idx = last_layer_idx + 1
53 | #
54 | self.frozen_exclude = frozen_exclude
55 | self._freeze(self.frozen_exclude)
56 |
57 | def forward(self, x: torch.Tensor):
58 | if self.input_patchnorm:
59 | raise NotImplementedError("input_patchnorm is not implemented yet.")
60 | else:
61 | x = self.conv1(x) # shape = [*, width, grid, grid]
62 | _, _, h, w = x.shape
63 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
64 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
65 |
66 | # class embeddings and positional embeddings
67 | x = torch.cat(
68 | [
69 | self.class_embedding.to(x.dtype)
70 | + torch.zeros(
71 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
72 | ),
73 | x,
74 | ],
75 | dim=1,
76 | ) # shape = [*, grid ** 2 + 1, width]
77 | pos_embed = self.positional_embedding.to(x.dtype)
78 | pos_embed = resize_pos_embed2d(pos_embed[None, ...], self.grid_size, (h, w))[0]
79 | x = x + pos_embed
80 |
81 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
82 | x = self.patch_dropout(x)
83 | x = self.ln_pre(x)
84 | x = x.permute(1, 0, 2) # NLD -> LND
85 |
86 | outputs = ClipOutput(spacial_shape=(h, w))
87 | outputs.save(0, x)
88 | for i, resblock in enumerate(self.resblocks, start=1):
89 | x = resblock(x)
90 | outputs.save(i, x)
91 | return outputs
92 |
93 | def _freeze(self, frozen_exclude):
94 | if "all" in frozen_exclude:
95 | return
96 | for name, param in self.named_parameters():
97 | if not any([exclude in name for exclude in frozen_exclude]):
98 | param.requires_grad = False
99 |
100 | @property
101 | def output_shapes(self):
102 | return {
103 | i: ShapeSpec(channels=self.num_features)
104 | for i in range(self.last_output_idx)
105 | }
106 |
107 | @property
108 | def size_divisibility(self):
109 | return self.patch_size[0]
110 |
111 |
112 | class RecWithAttnbiasHead(nn.Module):
113 | def __init__(
114 | self,
115 | visual_encoder: VisionTransformer,
116 | first_layer_idx: int = 0,
117 | frozen_exclude: List[str] = [],
118 | sos_token_format: str = "cls_token",
119 | sos_token_num: int = 1,
120 | cross_attn: bool = True,
121 | downsample_method: str = "bilinear",
122 | ):
123 | super().__init__()
124 | self.output_tokens = visual_encoder.output_tokens
125 | self.output_dim = visual_encoder.output_dim
126 | self.first_layer_idx = first_layer_idx
127 | self.cross_attn = cross_attn
128 | self.downsample_method = downsample_method
129 |
130 | if first_layer_idx < 0:
131 | raise NotImplementedError("first_layer_idx < 0 is not implemented yet.")
132 | self.resblocks = visual_encoder.transformer.resblocks[first_layer_idx:]
133 | self.global_average_pool = visual_encoder.global_average_pool
134 | self.attn_pool = visual_encoder.attn_pool
135 | assert (
136 | self.attn_pool is None
137 | ), "recognition with attn_pool is not implemented yet."
138 | assert (
139 | not self.global_average_pool
140 | ), "recognition with global_average_pool is not implemented yet."
141 | self.ln_post = visual_encoder.ln_post
142 | self.proj = visual_encoder.proj
143 |
144 | self.sos_token_format = sos_token_format
145 | self.sos_token_num = sos_token_num
146 | self.frozen_exclude = frozen_exclude
147 |
148 | if sos_token_format in ["learnable_token", "pos_embedding"]:
149 | self.sos_token = nn.Parameter(
150 | torch.randn(sos_token_num, 1, self.proj.shape[0])
151 | )
152 | nn.init.normal_(self.sos_token, std=0.02)
153 | self.frozen_exclude.append("sos_token")
154 | self._freeze(self.frozen_exclude)
155 |
156 | def _freeze(self, frozen_exclude):
157 | if "all" in frozen_exclude:
158 | return
159 | for name, param in self.named_parameters():
160 | if not any([exclude in name for exclude in frozen_exclude]):
161 | param.requires_grad = False
162 |
163 | def forward(self, features, attn_bias, normalize: bool = False):
164 | # construct clip shadow features.
165 | cls_token = features[f"{self.first_layer_idx}_cls_token"] # 1,n,c
166 | pix_feat = features[self.first_layer_idx] # n,c,h,w
167 | n, c, h, w = pix_feat.shape
168 | x = torch.cat(
169 | [cls_token, pix_feat.reshape(n, c, -1).permute(2, 0, 1)]
170 | ) # 1+l,n,c
171 |
172 | # construct sos token.
173 | if self.sos_token_format == "cls_token":
174 | sos_token = cls_token.repeat(self.sos_token_num, 1, 1)
175 | elif self.sos_token_format == "learnable_token":
176 | sos_token = self.sos_token.expand(-1, n, -1)
177 | elif self.sos_token_format == "pos_embedding":
178 | sos_token = self.sos_token.expand(-1, n, -1) + cls_token
179 |
180 | # construct attn biases.
181 | attn_biases = self._build_attn_biases(attn_bias, target_shape=(h, w))
182 | if self.cross_attn:
183 | for i, resblock in enumerate(self.resblocks):
184 | if self.cross_attn:
185 | sos_token = cross_attn_layer(
186 | resblock,
187 | sos_token,
188 | x[1:,],
189 | attn_biases[i],
190 | )
191 | if i < len(self.resblocks) - 1:
192 | x = resblock(x)
193 | else:
194 | x = torch.cat([sos_token, x], dim=0)
195 | for i, resblock in enumerate(self.resblocks):
196 | x = resblock(x, attn_mask=attn_biases[i])
197 | sos_token = x[: self.sos_token_num]
198 |
199 | sos_token = sos_token.permute(1, 0, 2) # LND -> NLD
200 |
201 | sos_token = self.ln_post(sos_token)
202 |
203 | if self.proj is not None:
204 | sos_token = sos_token @ self.proj
205 | if normalize:
206 | sos_token = F.normalize(sos_token, dim=-1)
207 | return sos_token
208 |
209 | def _build_attn_biases(self, attn_biases, target_shape):
210 | formatted_attn_biases = []
211 | for attn_bias in attn_biases:
212 | # convert it to proper format: N*num_head,L,L
213 | # attn_bias: [N, num_head/1, num_sos,H,W]
214 | n, num_head, num_sos, h, w = attn_bias.shape
215 | # reshape and downsample
216 | attn_bias = downsample2d(
217 | attn_bias.reshape(n, num_head * num_sos, h, w),
218 | target_shape,
219 | method=self.downsample_method,
220 | )
221 | attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape)
222 | true_num_head = self.resblocks[0].attn.num_heads
223 | assert (
224 | num_head == 1 or num_head == true_num_head
225 | ), f"num_head={num_head} is not supported."
226 | if num_head == 1:
227 | attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1)
228 | attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1)
229 | L = attn_bias.shape[-1]
230 | if self.cross_attn:
231 | # [n*num_head, num_sos, L]
232 | formatted_attn_biases.append(attn_bias)
233 | else:
234 | # [n*num_head, num_sos+1+L, num_sos+1+L]
235 | new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L, num_sos + 1 + L)
236 | new_attn_bias[:, :num_sos] = -100
237 | new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0
238 | new_attn_bias[:num_sos, num_sos] = -100
239 | new_attn_bias = (
240 | new_attn_bias[None, ...].expand(n * true_num_head, -1, -1).clone()
241 | )
242 | new_attn_bias[..., :num_sos, -L:] = attn_bias
243 | formatted_attn_biases.append(new_attn_bias)
244 |
245 | if len(formatted_attn_biases) == 1:
246 | formatted_attn_biases = [formatted_attn_biases[0] for _ in self.resblocks]
247 | return formatted_attn_biases
248 |
--------------------------------------------------------------------------------
/san/model/criterion.py:
--------------------------------------------------------------------------------
1 | # Copied from maskformer2 by Bowen Cheng
2 | import logging
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from detectron2.utils.comm import get_world_size
9 | from detectron2.projects.point_rend.point_features import (
10 | get_uncertain_point_coords_with_randomness,
11 | point_sample,
12 | )
13 |
14 | from san.utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
15 |
16 |
17 | def dice_loss(
18 | inputs: torch.Tensor,
19 | targets: torch.Tensor,
20 | num_masks: float,
21 | ):
22 | """
23 | Compute the DICE loss, similar to generalized IOU for masks
24 | Args:
25 | inputs: A float tensor of arbitrary shape.
26 | The predictions for each example.
27 | targets: A float tensor with the same shape as inputs. Stores the binary
28 | classification label for each element in inputs
29 | (0 for the negative class and 1 for the positive class).
30 | """
31 | inputs = inputs.sigmoid()
32 | inputs = inputs.flatten(1)
33 | numerator = 2 * (inputs * targets).sum(-1)
34 | denominator = inputs.sum(-1) + targets.sum(-1)
35 | loss = 1 - (numerator + 1) / (denominator + 1)
36 | return loss.sum() / num_masks
37 |
38 |
39 | dice_loss_jit = torch.jit.script(dice_loss) # type: torch.jit.ScriptModule
40 |
41 |
42 | def sigmoid_ce_loss(
43 | inputs: torch.Tensor,
44 | targets: torch.Tensor,
45 | num_masks: float,
46 | ):
47 | """
48 | Args:
49 | inputs: A float tensor of arbitrary shape.
50 | The predictions for each example.
51 | targets: A float tensor with the same shape as inputs. Stores the binary
52 | classification label for each element in inputs
53 | (0 for the negative class and 1 for the positive class).
54 | Returns:
55 | Loss tensor
56 | """
57 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
58 |
59 | return loss.mean(1).sum() / num_masks
60 |
61 |
62 | sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss) # type: torch.jit.ScriptModule
63 |
64 |
65 | def calculate_uncertainty(logits):
66 | """
67 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
68 | foreground class in `classes`.
69 | Args:
70 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
71 | class-agnostic, where R is the total number of predicted masks in all images and C is
72 | the number of foreground classes. The values are logits.
73 | Returns:
74 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
75 | the most uncertain locations having the highest uncertainty score.
76 | """
77 | assert logits.shape[1] == 1
78 | gt_class_logits = logits.clone()
79 | return -(torch.abs(gt_class_logits))
80 |
81 |
82 | class SetCriterion(nn.Module):
83 | """This class computes the loss for DETR.
84 | The process happens in two steps:
85 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
86 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
87 | """
88 |
89 | def __init__(
90 | self,
91 | num_classes,
92 | matcher,
93 | weight_dict,
94 | eos_coef,
95 | losses,
96 | num_points,
97 | oversample_ratio,
98 | importance_sample_ratio,
99 | ):
100 | """Create the criterion.
101 | Parameters:
102 | num_classes: number of object categories, omitting the special no-object category
103 | matcher: module able to compute a matching between targets and proposals
104 | weight_dict: dict containing as key the names of the losses and as values their relative weight.
105 | eos_coef: relative classification weight applied to the no-object category
106 | losses: list of all the losses to be applied. See get_loss for list of available losses.
107 | """
108 | super().__init__()
109 | self.num_classes = num_classes
110 | self.matcher = matcher
111 | self.weight_dict = weight_dict
112 | self.eos_coef = eos_coef
113 | self.losses = losses
114 | empty_weight = torch.ones(self.num_classes + 1)
115 | empty_weight[-1] = self.eos_coef
116 | self.register_buffer("empty_weight", empty_weight)
117 |
118 | # pointwise mask loss parameters
119 | self.num_points = num_points
120 | self.oversample_ratio = oversample_ratio
121 | self.importance_sample_ratio = importance_sample_ratio
122 |
123 | def loss_labels(self, outputs, targets, indices, num_masks):
124 | """Classification loss (NLL)
125 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
126 | """
127 | assert "pred_logits" in outputs
128 | src_logits = outputs["pred_logits"].float()
129 |
130 | idx = self._get_src_permutation_idx(indices)
131 | target_classes_o = torch.cat(
132 | [t["labels"][J] for t, (_, J) in zip(targets, indices)]
133 | )
134 | target_classes = torch.full(
135 | src_logits.shape[:2],
136 | self.num_classes,
137 | dtype=torch.int64,
138 | device=src_logits.device,
139 | )
140 | target_classes[idx] = target_classes_o
141 |
142 | loss_ce = F.cross_entropy(
143 | src_logits.transpose(1, 2), target_classes, self.empty_weight
144 | )
145 | losses = {"loss_ce": loss_ce}
146 | return losses
147 |
148 | def loss_masks(self, outputs, targets, indices, num_masks):
149 | """Compute the losses related to the masks: the focal loss and the dice loss.
150 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
151 | """
152 | assert "pred_masks" in outputs
153 |
154 | src_idx = self._get_src_permutation_idx(indices)
155 | tgt_idx = self._get_tgt_permutation_idx(indices)
156 | src_masks = outputs["pred_masks"]
157 | src_masks = src_masks[src_idx]
158 | masks = [t["masks"] for t in targets]
159 | # TODO use valid to mask invalid areas due to padding in loss
160 | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
161 | target_masks = target_masks.to(src_masks)
162 | target_masks = target_masks[tgt_idx]
163 |
164 | # No need to upsample predictions as we are using normalized coordinates :)
165 | # N x 1 x H x W
166 | src_masks = src_masks[:, None]
167 | target_masks = target_masks[:, None]
168 |
169 | with torch.no_grad():
170 | # sample point_coords
171 | point_coords = get_uncertain_point_coords_with_randomness(
172 | src_masks,
173 | lambda logits: calculate_uncertainty(logits),
174 | self.num_points,
175 | self.oversample_ratio,
176 | self.importance_sample_ratio,
177 | )
178 | # get gt labels
179 | point_labels = point_sample(
180 | target_masks,
181 | point_coords,
182 | align_corners=False,
183 | ).squeeze(1)
184 |
185 | point_logits = point_sample(
186 | src_masks,
187 | point_coords,
188 | align_corners=False,
189 | ).squeeze(1)
190 |
191 | losses = {
192 | "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
193 | "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks),
194 | }
195 |
196 | del src_masks
197 | del target_masks
198 | return losses
199 |
200 | def _get_src_permutation_idx(self, indices):
201 | # permute predictions following indices
202 | batch_idx = torch.cat(
203 | [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
204 | )
205 | src_idx = torch.cat([src for (src, _) in indices])
206 | return batch_idx, src_idx
207 |
208 | def _get_tgt_permutation_idx(self, indices):
209 | # permute targets following indices
210 | batch_idx = torch.cat(
211 | [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
212 | )
213 | tgt_idx = torch.cat([tgt for (_, tgt) in indices])
214 | return batch_idx, tgt_idx
215 |
216 | def get_loss(self, loss, outputs, targets, indices, num_masks):
217 | loss_map = {
218 | "labels": self.loss_labels,
219 | "masks": self.loss_masks,
220 | }
221 | assert loss in loss_map, f"do you really want to compute {loss} loss?"
222 | return loss_map[loss](outputs, targets, indices, num_masks)
223 |
224 | def forward(self, outputs, targets):
225 | """This performs the loss computation.
226 | Parameters:
227 | outputs: dict of tensors, see the output specification of the model for the format
228 | targets: list of dicts, such that len(targets) == batch_size.
229 | The expected keys in each dict depends on the losses applied, see each loss' doc
230 | """
231 | outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
232 |
233 | # Retrieve the matching between the outputs of the last layer and the targets
234 | indices = self.matcher(outputs_without_aux, targets)
235 |
236 | # Compute the average number of target boxes accross all nodes, for normalization purposes
237 | num_masks = sum(len(t["labels"]) for t in targets)
238 | num_masks = torch.as_tensor(
239 | [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
240 | )
241 | if is_dist_avail_and_initialized():
242 | torch.distributed.all_reduce(num_masks)
243 | num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
244 |
245 | # Compute all the requested losses
246 | losses = {}
247 | for loss in self.losses:
248 | losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
249 |
250 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
251 | if "aux_outputs" in outputs:
252 | for i, aux_outputs in enumerate(outputs["aux_outputs"]):
253 | indices = self.matcher(aux_outputs, targets)
254 | for loss in self.losses:
255 | l_dict = self.get_loss(
256 | loss, aux_outputs, targets, indices, num_masks
257 | )
258 | l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
259 | losses.update(l_dict)
260 |
261 | return losses
262 |
263 | def __repr__(self):
264 | head = "Criterion " + self.__class__.__name__
265 | body = [
266 | "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
267 | "losses: {}".format(self.losses),
268 | "weight_dict: {}".format(self.weight_dict),
269 | "num_classes: {}".format(self.num_classes),
270 | "eos_coef: {}".format(self.eos_coef),
271 | "num_points: {}".format(self.num_points),
272 | "oversample_ratio: {}".format(self.oversample_ratio),
273 | "importance_sample_ratio: {}".format(self.importance_sample_ratio),
274 | ]
275 | _repr_indent = 4
276 | lines = [head] + [" " * _repr_indent + line for line in body]
277 | return "\n".join(lines)
278 |
--------------------------------------------------------------------------------
/san/model/layers.py:
--------------------------------------------------------------------------------
1 | import fvcore.nn.weight_init as weight_init
2 | import torch
3 | from detectron2.layers import CNNBlockBase, Conv2d
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class LayerNorm(nn.Module):
9 | """
10 | A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
11 | variance normalization over the channel dimension for inputs that have shape
12 | (batch_size, channels, height, width).
13 | https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
14 | """
15 |
16 | def __init__(self, normalized_shape, eps=1e-6):
17 | super().__init__()
18 | self.weight = nn.Parameter(torch.ones(normalized_shape))
19 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
20 | self.eps = eps
21 | self.normalized_shape = (normalized_shape,)
22 |
23 | def forward(self, x: torch.Tensor):
24 | u = x.mean(1, keepdim=True)
25 | s = (x - u).pow(2).mean(1, keepdim=True)
26 | x = (x - u) / torch.sqrt(s + self.eps)
27 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
28 | return x
29 |
30 |
31 | class MLP(nn.Module):
32 | """Very simple multi-layer perceptron (also called FFN)"""
33 |
34 | def __init__(
35 | self, input_dim, hidden_dim, output_dim, num_layers, affine_func=nn.Linear
36 | ):
37 | super().__init__()
38 | self.num_layers = num_layers
39 | h = [hidden_dim] * (num_layers - 1)
40 | self.layers = nn.ModuleList(
41 | affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
42 | )
43 |
44 | def forward(self, x: torch.Tensor):
45 | for i, layer in enumerate(self.layers):
46 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
47 | return x
48 |
49 |
50 | class AddFusion(CNNBlockBase):
51 | def __init__(self, in_channels, out_channels):
52 | super().__init__(in_channels, out_channels, 1)
53 | self.input_proj = nn.Sequential(
54 | LayerNorm(in_channels),
55 | Conv2d(
56 | in_channels,
57 | out_channels,
58 | kernel_size=1,
59 | ),
60 | )
61 | weight_init.c2_xavier_fill(self.input_proj[-1])
62 |
63 | def forward(self, x: torch.Tensor, y: torch.Tensor, spatial_shape: tuple):
64 | # x: [N,L,C] y: [N,C,H,W]
65 | y = (
66 | F.interpolate(
67 | self.input_proj(y.contiguous()),
68 | size=spatial_shape,
69 | mode="bilinear",
70 | align_corners=False,
71 | )
72 | .permute(0, 2, 3, 1)
73 | .reshape(x.shape)
74 | )
75 | x = x + y
76 | return x
77 |
78 |
79 | def build_fusion_layer(fusion_type: str, in_channels: int, out_channels: int):
80 | if fusion_type == "add":
81 | return AddFusion(in_channels, out_channels)
82 | else:
83 | raise ValueError("Unknown fusion type: {}".format(fusion_type))
84 |
--------------------------------------------------------------------------------
/san/model/matcher.py:
--------------------------------------------------------------------------------
1 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
2 | # Copied from maskformer2
3 | """
4 | Modules to compute the matching cost and solve the corresponding LSAP.
5 | """
6 | import torch
7 | import torch.nn.functional as F
8 | from scipy.optimize import linear_sum_assignment
9 | from torch import nn
10 | from torch.cuda.amp import autocast
11 |
12 | from detectron2.projects.point_rend.point_features import point_sample
13 |
14 |
15 | def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
16 | """
17 | Compute the DICE loss, similar to generalized IOU for masks
18 | Args:
19 | inputs: A float tensor of arbitrary shape.
20 | The predictions for each example.
21 | targets: A float tensor with the same shape as inputs. Stores the binary
22 | classification label for each element in inputs
23 | (0 for the negative class and 1 for the positive class).
24 | """
25 | inputs = inputs.sigmoid()
26 | inputs = inputs.flatten(1)
27 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
28 | denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
29 | loss = 1 - (numerator + 1) / (denominator + 1)
30 | return loss
31 |
32 |
33 | batch_dice_loss_jit = torch.jit.script(batch_dice_loss) # type: torch.jit.ScriptModule
34 |
35 |
36 | def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
37 | """
38 | Args:
39 | inputs: A float tensor of arbitrary shape.
40 | The predictions for each example.
41 | targets: A float tensor with the same shape as inputs. Stores the binary
42 | classification label for each element in inputs
43 | (0 for the negative class and 1 for the positive class).
44 | Returns:
45 | Loss tensor
46 | """
47 | hw = inputs.shape[1]
48 |
49 | pos = F.binary_cross_entropy_with_logits(
50 | inputs, torch.ones_like(inputs), reduction="none"
51 | )
52 | neg = F.binary_cross_entropy_with_logits(
53 | inputs, torch.zeros_like(inputs), reduction="none"
54 | )
55 |
56 | loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
57 | "nc,mc->nm", neg, (1 - targets)
58 | )
59 |
60 | return loss / hw
61 |
62 |
63 | batch_sigmoid_ce_loss_jit = torch.jit.script(
64 | batch_sigmoid_ce_loss
65 | ) # type: torch.jit.ScriptModule
66 |
67 |
68 | class HungarianMatcher(nn.Module):
69 | """This class computes an assignment between the targets and the predictions of the network
70 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
71 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
72 | while the others are un-matched (and thus treated as non-objects).
73 | """
74 |
75 | def __init__(
76 | self,
77 | cost_class: float = 1,
78 | cost_mask: float = 1,
79 | cost_dice: float = 1,
80 | num_points: int = 0,
81 | ):
82 | """Creates the matcher
83 | Params:
84 | cost_class: This is the relative weight of the classification error in the matching cost
85 | cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
86 | cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
87 | """
88 | super().__init__()
89 | self.cost_class = cost_class
90 | self.cost_mask = cost_mask
91 | self.cost_dice = cost_dice
92 |
93 | assert (
94 | cost_class != 0 or cost_mask != 0 or cost_dice != 0
95 | ), "all costs cant be 0"
96 |
97 | self.num_points = num_points
98 |
99 | @torch.no_grad()
100 | def memory_efficient_forward(self, outputs, targets):
101 | """More memory-friendly matching"""
102 | bs, num_queries = outputs["pred_logits"].shape[:2]
103 |
104 | indices = []
105 |
106 | # Iterate through batch size
107 | for b in range(bs):
108 | out_prob = outputs["pred_logits"][b].softmax(
109 | -1
110 | ) # [num_queries, num_classes]
111 | tgt_ids = targets[b]["labels"]
112 |
113 | # Compute the classification cost. Contrary to the loss, we don't use the NLL,
114 | # but approximate it in 1 - proba[target class].
115 | # The 1 is a constant that doesn't change the matching, it can be ommitted.
116 | cost_class = -out_prob[:, tgt_ids]
117 |
118 | out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred]
119 | # gt masks are already padded when preparing target
120 | tgt_mask = targets[b]["masks"].to(out_mask)
121 |
122 | out_mask = out_mask[:, None]
123 | tgt_mask = tgt_mask[:, None]
124 | # all masks share the same set of points for efficient matching!
125 | point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
126 | # get gt labels
127 | tgt_mask = point_sample(
128 | tgt_mask,
129 | point_coords.repeat(tgt_mask.shape[0], 1, 1),
130 | align_corners=False,
131 | ).squeeze(1)
132 |
133 | out_mask = point_sample(
134 | out_mask,
135 | point_coords.repeat(out_mask.shape[0], 1, 1),
136 | align_corners=False,
137 | ).squeeze(1)
138 |
139 | with autocast(enabled=False):
140 | out_mask = out_mask.float()
141 | tgt_mask = tgt_mask.float()
142 | # Compute the focal loss between masks
143 | cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask)
144 |
145 | # Compute the dice loss betwen masks
146 | cost_dice = batch_dice_loss_jit(out_mask, tgt_mask)
147 |
148 | # Final cost matrix
149 | C = (
150 | self.cost_mask * cost_mask
151 | + self.cost_class * cost_class
152 | + self.cost_dice * cost_dice
153 | )
154 | C = C.reshape(num_queries, -1).cpu()
155 |
156 | indices.append(linear_sum_assignment(C))
157 |
158 | return [
159 | (
160 | torch.as_tensor(i, dtype=torch.int64),
161 | torch.as_tensor(j, dtype=torch.int64),
162 | )
163 | for i, j in indices
164 | ]
165 |
166 | @torch.no_grad()
167 | def forward(self, outputs, targets):
168 | """Performs the matching
169 | Params:
170 | outputs: This is a dict that contains at least these entries:
171 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
172 | "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
173 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
174 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
175 | objects in the target) containing the class labels
176 | "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
177 | Returns:
178 | A list of size batch_size, containing tuples of (index_i, index_j) where:
179 | - index_i is the indices of the selected predictions (in order)
180 | - index_j is the indices of the corresponding selected targets (in order)
181 | For each batch element, it holds:
182 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
183 | """
184 | return self.memory_efficient_forward(outputs, targets)
185 |
186 | def __repr__(self, _repr_indent=4):
187 | head = "Matcher " + self.__class__.__name__
188 | body = [
189 | "cost_class: {}".format(self.cost_class),
190 | "cost_mask: {}".format(self.cost_mask),
191 | "cost_dice: {}".format(self.cost_dice),
192 | ]
193 | lines = [head] + [" " * _repr_indent + line for line in body]
194 | return "\n".join(lines)
195 |
--------------------------------------------------------------------------------
/san/model/san.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import open_clip
4 | import torch
5 | from detectron2.config import configurable
6 | from detectron2.modeling import META_ARCH_REGISTRY
7 | from detectron2.modeling.postprocessing import sem_seg_postprocess
8 | from detectron2.structures import ImageList
9 | from detectron2.utils.memory import retry_if_cuda_oom
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from .clip_utils import (
14 | FeatureExtractor,
15 | LearnableBgOvClassifier,
16 | PredefinedOvClassifier,
17 | RecWithAttnbiasHead,
18 | get_predefined_templates,
19 | )
20 | from .criterion import SetCriterion
21 | from .matcher import HungarianMatcher
22 | from .side_adapter import build_side_adapter_network
23 |
24 |
25 | @META_ARCH_REGISTRY.register()
26 | class SAN(nn.Module):
27 | @configurable
28 | def __init__(
29 | self,
30 | *,
31 | clip_visual_extractor: nn.Module,
32 | clip_rec_head: nn.Module,
33 | side_adapter_network: nn.Module,
34 | ov_classifier: PredefinedOvClassifier,
35 | criterion: SetCriterion,
36 | size_divisibility: int,
37 | asymetric_input: bool = True,
38 | clip_resolution: float = 0.5,
39 | pixel_mean: List[float] = [0.48145466, 0.4578275, 0.40821073],
40 | pixel_std: List[float] = [0.26862954, 0.26130258, 0.27577711],
41 | sem_seg_postprocess_before_inference: bool = False,
42 | ):
43 | super().__init__()
44 | self.asymetric_input = asymetric_input
45 | self.clip_resolution = clip_resolution
46 | self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
47 | self.size_divisibility = size_divisibility
48 | self.criterion = criterion
49 |
50 | self.side_adapter_network = side_adapter_network
51 | self.clip_visual_extractor = clip_visual_extractor
52 | self.clip_rec_head = clip_rec_head
53 | self.ov_classifier = ov_classifier
54 | self.register_buffer(
55 | "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
56 | )
57 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
58 |
59 | @classmethod
60 | def from_config(cls, cfg):
61 | ## copied from maskformer2
62 | # Loss parameters
63 | no_object_weight = cfg.MODEL.SAN.NO_OBJECT_WEIGHT
64 | # loss weights
65 | class_weight = cfg.MODEL.SAN.CLASS_WEIGHT
66 | dice_weight = cfg.MODEL.SAN.DICE_WEIGHT
67 | mask_weight = cfg.MODEL.SAN.MASK_WEIGHT
68 |
69 | # building criterion
70 | matcher = HungarianMatcher(
71 | cost_class=class_weight,
72 | cost_mask=mask_weight,
73 | cost_dice=dice_weight,
74 | num_points=cfg.MODEL.SAN.TRAIN_NUM_POINTS,
75 | )
76 |
77 | weight_dict = {
78 | "loss_ce": class_weight,
79 | "loss_mask": mask_weight,
80 | "loss_dice": dice_weight,
81 | }
82 | aux_weight_dict = {}
83 | for i in range(len(cfg.MODEL.SIDE_ADAPTER.DEEP_SUPERVISION_IDXS) - 1):
84 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
85 | weight_dict.update(aux_weight_dict)
86 | losses = ["labels", "masks"]
87 |
88 | criterion = SetCriterion(
89 | num_classes=cfg.MODEL.SAN.NUM_CLASSES,
90 | matcher=matcher,
91 | weight_dict=weight_dict,
92 | eos_coef=no_object_weight,
93 | losses=losses,
94 | num_points=cfg.MODEL.SAN.TRAIN_NUM_POINTS,
95 | oversample_ratio=cfg.MODEL.SAN.OVERSAMPLE_RATIO,
96 | importance_sample_ratio=cfg.MODEL.SAN.IMPORTANCE_SAMPLE_RATIO,
97 | )
98 | ## end of copy
99 |
100 | model, _, preprocess = open_clip.create_model_and_transforms(
101 | cfg.MODEL.SAN.CLIP_MODEL_NAME,
102 | pretrained=cfg.MODEL.SAN.CLIP_PRETRAINED_NAME,
103 | )
104 | ov_classifier = LearnableBgOvClassifier(
105 | model, templates=get_predefined_templates(cfg.MODEL.SAN.CLIP_TEMPLATE_SET)
106 | )
107 |
108 | clip_visual_extractor = FeatureExtractor(
109 | model.visual,
110 | last_layer_idx=cfg.MODEL.SAN.FEATURE_LAST_LAYER_IDX,
111 | frozen_exclude=cfg.MODEL.SAN.CLIP_FROZEN_EXCLUDE,
112 | )
113 | clip_rec_head = RecWithAttnbiasHead(
114 | model.visual,
115 | first_layer_idx=cfg.MODEL.SAN.FEATURE_LAST_LAYER_IDX,
116 | frozen_exclude=cfg.MODEL.SAN.CLIP_DEEPER_FROZEN_EXCLUDE,
117 | cross_attn=cfg.MODEL.SAN.REC_CROSS_ATTN,
118 | sos_token_format=cfg.MODEL.SAN.SOS_TOKEN_FORMAT,
119 | sos_token_num=cfg.MODEL.SIDE_ADAPTER.NUM_QUERIES,
120 | downsample_method=cfg.MODEL.SAN.REC_DOWNSAMPLE_METHOD,
121 | )
122 |
123 | pixel_mean, pixel_std = (
124 | preprocess.transforms[-1].mean,
125 | preprocess.transforms[-1].std,
126 | )
127 | pixel_mean = [255.0 * x for x in pixel_mean]
128 | pixel_std = [255.0 * x for x in pixel_std]
129 |
130 | return {
131 | "clip_visual_extractor": clip_visual_extractor,
132 | "clip_rec_head": clip_rec_head,
133 | "side_adapter_network": build_side_adapter_network(
134 | cfg, clip_visual_extractor.output_shapes
135 | ),
136 | "ov_classifier": ov_classifier,
137 | "criterion": criterion,
138 | "size_divisibility": cfg.MODEL.SAN.SIZE_DIVISIBILITY,
139 | "asymetric_input": cfg.MODEL.SAN.ASYMETRIC_INPUT,
140 | "clip_resolution": cfg.MODEL.SAN.CLIP_RESOLUTION,
141 | "sem_seg_postprocess_before_inference": cfg.MODEL.SAN.SEM_SEG_POSTPROCESS_BEFORE_INFERENCE,
142 | "pixel_mean": pixel_mean,
143 | "pixel_std": pixel_std,
144 | }
145 |
146 | def forward(self, batched_inputs):
147 | # get classifier weight for each dataset
148 | # !! Could be computed once and saved. It will run only once per dataset.
149 | if "vocabulary" in batched_inputs[0]:
150 | ov_classifier_weight = (
151 | self.ov_classifier.logit_scale.exp()
152 | * self.ov_classifier.get_classifier_by_vocabulary(
153 | batched_inputs[0]["vocabulary"]
154 | )
155 | )
156 | else:
157 | dataset_names = [x["meta"]["dataset_name"] for x in batched_inputs]
158 | assert (
159 | len(list(set(dataset_names))) == 1
160 | ), "All images in a batch must be from the same dataset."
161 | ov_classifier_weight = (
162 | self.ov_classifier.logit_scale.exp()
163 | * self.ov_classifier.get_classifier_by_dataset_name(dataset_names[0])
164 | ) # C+1,ndim
165 | images = [x["image"].to(self.device) for x in batched_inputs]
166 | images = [(x - self.pixel_mean) / self.pixel_std for x in images]
167 | images = ImageList.from_tensors(images, self.size_divisibility)
168 | clip_input = images.tensor
169 | if self.asymetric_input:
170 | clip_input = F.interpolate(
171 | clip_input, scale_factor=self.clip_resolution, mode="bilinear"
172 | )
173 | clip_image_features = self.clip_visual_extractor(clip_input)
174 | mask_preds, attn_biases = self.side_adapter_network(
175 | images.tensor, clip_image_features
176 | )
177 | # !! Could be optimized to run in parallel.
178 | mask_embs = [
179 | self.clip_rec_head(clip_image_features, attn_bias, normalize=True)
180 | for attn_bias in attn_biases
181 | ] # [B,N,C]
182 | mask_logits = [
183 | torch.einsum("bqc,nc->bqn", mask_emb, ov_classifier_weight)
184 | for mask_emb in mask_embs
185 | ]
186 | if self.training:
187 | if "instances" in batched_inputs[0]:
188 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
189 | targets = self.prepare_targets(gt_instances, images)
190 | else:
191 | targets = None
192 | outputs = {
193 | "pred_logits": mask_logits[-1],
194 | "pred_masks": mask_preds[-1],
195 | "aux_outputs": [
196 | {
197 | "pred_logits": aux_pred_logits,
198 | "pred_masks": aux_pred_masks,
199 | }
200 | for aux_pred_logits, aux_pred_masks in zip(
201 | mask_logits[:-1], mask_preds[:-1]
202 | )
203 | ],
204 | }
205 | # bipartite matching-based loss
206 | losses = self.criterion(outputs, targets)
207 |
208 | for k in list(losses.keys()):
209 | if k in self.criterion.weight_dict:
210 | losses[k] *= self.criterion.weight_dict[k]
211 | else:
212 | # remove this loss if not specified in `weight_dict`
213 | losses.pop(k)
214 | return losses
215 | else:
216 | mask_preds = mask_preds[-1]
217 | mask_logits = mask_logits[-1]
218 | # torch.cuda.empty_cache()
219 | # Inference
220 | mask_preds = F.interpolate(
221 | mask_preds,
222 | size=(images.tensor.shape[-2], images.tensor.shape[-1]),
223 | mode="bilinear",
224 | align_corners=False,
225 | )
226 | processed_results = []
227 | for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
228 | mask_logits, mask_preds, batched_inputs, images.image_sizes
229 | ):
230 | height = input_per_image.get("height", image_size[0])
231 | width = input_per_image.get("width", image_size[1])
232 | processed_results.append({})
233 |
234 | if self.sem_seg_postprocess_before_inference:
235 | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
236 | mask_pred_result, image_size, height, width
237 | )
238 | mask_cls_result = mask_cls_result.to(mask_pred_result)
239 | r = retry_if_cuda_oom(self.semantic_inference)(
240 | mask_cls_result, mask_pred_result
241 | )
242 | if not self.sem_seg_postprocess_before_inference:
243 | r = retry_if_cuda_oom(sem_seg_postprocess)(
244 | r, image_size, height, width
245 | )
246 | processed_results[-1]["sem_seg"] = r
247 | return processed_results
248 |
249 | def prepare_targets(self, targets, images):
250 | h_pad, w_pad = images.tensor.shape[-2:]
251 | new_targets = []
252 | for targets_per_image in targets:
253 | # pad gt
254 | gt_masks = targets_per_image.gt_masks
255 | padded_masks = torch.zeros(
256 | (gt_masks.shape[0], h_pad, w_pad),
257 | dtype=gt_masks.dtype,
258 | device=gt_masks.device,
259 | )
260 | padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
261 | new_targets.append(
262 | {
263 | "labels": targets_per_image.gt_classes,
264 | "masks": padded_masks,
265 | }
266 | )
267 | return new_targets
268 |
269 | def semantic_inference(self, mask_cls, mask_pred):
270 | mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
271 | mask_pred = mask_pred.sigmoid()
272 | semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
273 | return semseg
274 |
275 | @property
276 | def device(self):
277 | return self.pixel_mean.device
278 |
--------------------------------------------------------------------------------
/san/model/side_adapter/__init__.py:
--------------------------------------------------------------------------------
1 | from . import timm_wrapper
2 | from .side_adapter import build_side_adapter_network
3 |
--------------------------------------------------------------------------------
/san/model/side_adapter/side_adapter.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import partial
3 | from typing import Dict, List, Tuple
4 |
5 | import torch
6 | from detectron2.config import configurable
7 | from detectron2.layers import ShapeSpec
8 | from detectron2.utils.logger import log_first_n
9 | from detectron2.utils.registry import Registry
10 | from timm import create_model
11 | from timm.models.vision_transformer import VisionTransformer
12 | from torch import nn
13 | from torch.nn import functional as F
14 |
15 | from ..layers import MLP, build_fusion_layer
16 | from .timm_wrapper import PatchEmbed
17 |
18 | SIDE_ADAPTER_REGISTRY = Registry("SIDE_ADAPTER")
19 | SIDE_ADAPTER_REGISTRY.__doc__ = """
20 | Registry for side adapter.
21 | """
22 |
23 |
24 | def build_side_adapter_network(cfg, input_shape):
25 | name = cfg.MODEL.SIDE_ADAPTER.NAME
26 | return SIDE_ADAPTER_REGISTRY.get(name)(cfg, input_shape)
27 |
28 |
29 | class MLPMaskDecoder(nn.Module):
30 | def __init__(
31 | self,
32 | *,
33 | in_channels: int,
34 | total_heads: int = 1,
35 | total_layers: int = 1,
36 | embed_channels: int = 256,
37 | mlp_channels: int = 256,
38 | mlp_num_layers: int = 3,
39 | rescale_attn_bias: bool = False,
40 | ):
41 | super().__init__()
42 | self.total_heads = total_heads
43 | self.total_layers = total_layers
44 |
45 | dense_affine_func = partial(nn.Conv2d, kernel_size=1)
46 | # Query Branch
47 | self.query_mlp = MLP(in_channels, mlp_channels, embed_channels, mlp_num_layers)
48 | # Pixel Branch
49 | self.pix_mlp = MLP(
50 | in_channels,
51 | mlp_channels,
52 | embed_channels,
53 | mlp_num_layers,
54 | affine_func=dense_affine_func,
55 | )
56 | # Attention Bias Branch
57 | self.attn_mlp = MLP(
58 | in_channels,
59 | mlp_channels,
60 | embed_channels * self.total_heads * self.total_layers,
61 | mlp_num_layers,
62 | affine_func=dense_affine_func,
63 | )
64 | if rescale_attn_bias:
65 | self.bias_scaling = nn.Linear(1, 1)
66 | else:
67 | self.bias_scaling = nn.Identity()
68 |
69 | def forward(
70 | self, query: torch.Tensor, x: torch.Tensor
71 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
72 | # query: [B,N,C]
73 | # x: [B,C,H,W]
74 | query = self.query_mlp(query)
75 | pix = self.pix_mlp(x)
76 | b, c, h, w = pix.shape
77 | # preidict mask
78 | mask_preds = torch.einsum("bqc,bchw->bqhw", query, pix)
79 | # generate attn bias
80 | attn = self.attn_mlp(x)
81 | attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w)
82 | attn_bias = torch.einsum("bqc,blnchw->blnqhw", query, attn)
83 | attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1)
84 | attn_bias = attn_bias.chunk(self.total_layers, dim=1)
85 | attn_bias = [attn.squeeze(1) for attn in attn_bias]
86 | return mask_preds, attn_bias
87 |
88 |
89 | @SIDE_ADAPTER_REGISTRY.register()
90 | class RegionwiseSideAdapterNetwork(nn.Module):
91 | @configurable
92 | def __init__(
93 | self,
94 | vit_model: VisionTransformer,
95 | fusion_layers: nn.ModuleList,
96 | mask_decoder: nn.Module,
97 | num_queries: int,
98 | fusion_map: Dict[int, int],
99 | deep_supervision_idxs: List[int],
100 | ):
101 | super().__init__()
102 | # remove cls token
103 | if vit_model.cls_token is not None:
104 | vit_model.pos_embed = nn.Parameter(vit_model.pos_embed[:, 1:, ...])
105 | del vit_model.cls_token
106 | vit_model.cls_token = None
107 | # delete out norm
108 | del vit_model.norm
109 | vit_model.norm = nn.Identity()
110 | self.vit_model = vit_model
111 |
112 | self.num_queries = num_queries
113 | self.num_features = vit_model.num_features
114 | # add query token
115 | self.query_embed = nn.Parameter(torch.zeros(1, num_queries, self.num_features))
116 | self.query_pos_embed = nn.Parameter(
117 | torch.zeros(1, num_queries, self.num_features)
118 | )
119 | nn.init.normal_(self.query_embed, std=0.02)
120 | nn.init.normal_(self.query_pos_embed, std=0.02)
121 | self.fusion_layers = fusion_layers
122 | self.fusion_map = fusion_map
123 | self.mask_decoder = mask_decoder
124 | # for training
125 | self.deep_supervision_idxs = deep_supervision_idxs
126 |
127 | @classmethod
128 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
129 | vit = create_model(
130 | cfg.MODEL.SIDE_ADAPTER.VIT_NAME,
131 | cfg.MODEL.SIDE_ADAPTER.PRETRAINED,
132 | img_size=cfg.MODEL.SIDE_ADAPTER.IMAGE_SIZE,
133 | drop_path_rate=cfg.MODEL.SIDE_ADAPTER.DROP_PATH_RATE,
134 | fc_norm=False,
135 | num_classes=0,
136 | embed_layer=PatchEmbed,
137 | )
138 | # ["0->0","3->1","6->2","9->3"]
139 | fusion_map: List[str] = cfg.MODEL.SIDE_ADAPTER.FUSION_MAP
140 |
141 | x2side_map = {int(j): int(i) for i, j in [x.split("->") for x in fusion_map]}
142 | # build fusion layers
143 | fusion_type: str = cfg.MODEL.SIDE_ADAPTER.FUSION_TYPE
144 | fusion_layers = nn.ModuleDict(
145 | {
146 | f"layer_{tgt_idx}": build_fusion_layer(
147 | fusion_type, input_shape[src_idx].channels, vit.num_features
148 | )
149 | for tgt_idx, src_idx in x2side_map.items()
150 | }
151 | )
152 | # build mask decoder
153 | return {
154 | "vit_model": vit,
155 | "num_queries": cfg.MODEL.SIDE_ADAPTER.NUM_QUERIES,
156 | "fusion_layers": fusion_layers,
157 | "fusion_map": x2side_map,
158 | "mask_decoder": MLPMaskDecoder(
159 | in_channels=vit.num_features,
160 | total_heads=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_HEADS,
161 | total_layers=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_LAYERS,
162 | embed_channels=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.EMBED_CHANNELS,
163 | mlp_channels=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_CHANNELS,
164 | mlp_num_layers=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_NUM_LAYERS,
165 | rescale_attn_bias=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.RESCALE_ATTN_BIAS,
166 | ),
167 | "deep_supervision_idxs": cfg.MODEL.SIDE_ADAPTER.DEEP_SUPERVISION_IDXS,
168 | }
169 |
170 | def forward(
171 | self, image: torch.Tensor, clip_features: List[torch.Tensor]
172 | ) -> Dict[str, List[torch.Tensor]]:
173 | features = self.forward_features(image, clip_features)
174 | return self.decode_masks(features)
175 |
176 | def decode_masks(
177 | self, features: List[Dict[str, torch.Tensor]]
178 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
179 | if not self.training:
180 | features = [features[-1]]
181 | mask_preds = []
182 | attn_biases = []
183 | for feature in features:
184 | mask_pred, attn_bias = self.mask_decoder(**feature)
185 | mask_preds.append(mask_pred)
186 | attn_biases.append(attn_bias)
187 | return mask_preds, attn_biases
188 |
189 | def forward_features(
190 | self, image: torch.Tensor, clip_features: List[torch.Tensor]
191 | ) -> List[Dict[str, torch.Tensor]]:
192 | x, (h, w) = self.vit_model.patch_embed(image)
193 | L = x.shape[1] # token length
194 | pos_embed = self.vit_model.pos_embed
195 | ori_h, ori_w = self.vit_model.patch_embed.grid_size
196 | if pos_embed.shape[1] != L:
197 | pos_embed = (
198 | F.interpolate(
199 | pos_embed.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
200 | size=[h, w],
201 | mode="bicubic",
202 | align_corners=False,
203 | )
204 | .flatten(2)
205 | .permute(0, 2, 1)
206 | )
207 | pos_embed = torch.cat(
208 | [self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed], dim=1
209 | )
210 | x = torch.cat(
211 | [self.query_embed.expand(x.shape[0], -1, -1), x],
212 | dim=1,
213 | ) # B, Q+L, C
214 | x = x + pos_embed
215 | x = self.vit_model.norm_pre(x)
216 | x = self.fuse(0, x, clip_features, (h, w))
217 | outs = []
218 | for i, blk in enumerate(self.vit_model.blocks, start=1):
219 | x = blk(x)
220 | x = self.fuse(i, x, clip_features, (h, w))
221 | if i in self.deep_supervision_idxs:
222 | outs.append(
223 | {
224 | "query": x[:, :-L, ...],
225 | "x": x[:, -L:, ...]
226 | .permute(0, 2, 1)
227 | .reshape(x.shape[0], x.shape[-1], h, w),
228 | }
229 | )
230 |
231 | if i < len(self.vit_model.blocks):
232 | x = x + pos_embed
233 |
234 | return outs
235 |
236 | def fuse(
237 | self,
238 | block_idx: int,
239 | x: torch.Tensor,
240 | clip_features: List[torch.Tensor],
241 | spatial_shape: Tuple[int, int],
242 | ) -> torch.Tensor:
243 | if block_idx in self.fusion_map:
244 | src_idx = self.fusion_map[block_idx]
245 | L = spatial_shape[0] * spatial_shape[1]
246 | x = torch.cat(
247 | [
248 | x[:, :-L, ...],
249 | self.fusion_layers[f"layer_{block_idx}"](
250 | x[:, -L:, ...], clip_features[src_idx], spatial_shape
251 | ),
252 | ],
253 | dim=1,
254 | )
255 | log_first_n(
256 | logging.INFO,
257 | f"fuse clip {src_idx} to {block_idx}",
258 | len(self.fusion_map),
259 | )
260 | return x
261 |
--------------------------------------------------------------------------------
/san/model/side_adapter/timm_wrapper.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from torch import nn
3 | from timm.models.vision_transformer import _create_vision_transformer
4 | from timm.models import register_model
5 | from timm.models.layers import to_2tuple
6 |
7 |
8 | class PatchEmbed(nn.Module):
9 | """2D Image to Patch Embedding. Modify the original implementation to allow return the 2D patch size."""
10 |
11 | def __init__(
12 | self,
13 | img_size=224,
14 | patch_size=16,
15 | in_chans=3,
16 | embed_dim=768,
17 | norm_layer=None,
18 | flatten=True,
19 | bias=True,
20 | **kwargs
21 | ):
22 | super().__init__()
23 | if len(kwargs)>0:
24 | warnings.warn(f"Unused kwargs are provided:{kwargs}.")
25 | img_size = to_2tuple(img_size)
26 | patch_size = to_2tuple(patch_size)
27 | self.img_size = img_size
28 | self.patch_size = patch_size
29 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
30 | self.num_patches = self.grid_size[0] * self.grid_size[1]
31 | self.flatten = flatten
32 |
33 | self.proj = nn.Conv2d(
34 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
35 | )
36 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
37 |
38 | def forward(self, x):
39 | x = self.proj(x)
40 | h, w = x.shape[-2:]
41 | if self.flatten:
42 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
43 | x = self.norm(x)
44 | return x, (h, w)
45 |
46 |
47 | @register_model
48 | def vit_w144n6d8_patch16(pretrained=False, **kwargs):
49 | assert not pretrained
50 | model_kwargs = dict(patch_size=16, embed_dim=144, depth=8, num_heads=6, **kwargs)
51 | model = _create_vision_transformer(
52 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs
53 | )
54 | return model
55 |
56 |
57 | @register_model
58 | def vit_w192n6d8_patch16(pretrained=False, **kwargs):
59 | assert not pretrained
60 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=8, num_heads=6, **kwargs)
61 | model = _create_vision_transformer(
62 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs
63 | )
64 | return model
65 |
66 |
67 | @register_model
68 | def vit_w240n6d8_patch16(pretrained=False, **kwargs):
69 | assert not pretrained
70 | model_kwargs = dict(patch_size=16, embed_dim=240, depth=8, num_heads=6, **kwargs)
71 | model = _create_vision_transformer(
72 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs
73 | )
74 | return model
75 |
76 |
77 | @register_model
78 | def vit_w288n6d8_patch16(pretrained=False, **kwargs):
79 | assert not pretrained
80 | model_kwargs = dict(patch_size=16, embed_dim=288, depth=8, num_heads=6, **kwargs)
81 | model = _create_vision_transformer(
82 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs
83 | )
84 | return model
85 |
--------------------------------------------------------------------------------
/san/test_time_augmentation.py:
--------------------------------------------------------------------------------
1 | # Copied from mask2former repo.
2 | import copy
3 | import logging
4 | from itertools import count
5 |
6 | import numpy as np
7 | import torch
8 | from fvcore.transforms import HFlipTransform
9 | from torch import nn
10 | from torch.nn.parallel import DistributedDataParallel
11 |
12 | from detectron2.data.detection_utils import read_image
13 | from detectron2.modeling import DatasetMapperTTA
14 |
15 |
16 | __all__ = [
17 | "SemanticSegmentorWithTTA",
18 | ]
19 |
20 |
21 | class SemanticSegmentorWithTTA(nn.Module):
22 | """
23 | A SemanticSegmentor with test-time augmentation enabled.
24 | Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
25 | """
26 |
27 | def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
28 | """
29 | Args:
30 | cfg (CfgNode):
31 | model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
32 | tta_mapper (callable): takes a dataset dict and returns a list of
33 | augmented versions of the dataset dict. Defaults to
34 | `DatasetMapperTTA(cfg)`.
35 | batch_size (int): batch the augmented images into this batch size for inference.
36 | """
37 | super().__init__()
38 | if isinstance(model, DistributedDataParallel):
39 | model = model.module
40 | self.cfg = cfg.clone()
41 |
42 | self.model = model
43 |
44 | if tta_mapper is None:
45 | tta_mapper = DatasetMapperTTA(cfg)
46 | self.tta_mapper = tta_mapper
47 | self.batch_size = batch_size
48 |
49 | def __call__(self, batched_inputs):
50 | """
51 | Same input/output format as :meth:`SemanticSegmentor.forward`
52 | """
53 |
54 | def _maybe_read_image(dataset_dict):
55 | ret = copy.copy(dataset_dict)
56 | if "image" not in ret:
57 | image = read_image(ret.pop("file_name"), self.model.input_format)
58 | image = torch.from_numpy(
59 | np.ascontiguousarray(image.transpose(2, 0, 1))
60 | ) # CHW
61 | ret["image"] = image
62 | if "height" not in ret and "width" not in ret:
63 | ret["height"] = image.shape[1]
64 | ret["width"] = image.shape[2]
65 | return ret
66 |
67 | processed_results = []
68 | for x in batched_inputs:
69 | result = self._inference_one_image(_maybe_read_image(x))
70 | processed_results.append(result)
71 | return processed_results
72 |
73 | def _inference_one_image(self, input):
74 | """
75 | Args:
76 | input (dict): one dataset dict with "image" field being a CHW tensor
77 | Returns:
78 | dict: one output dict
79 | """
80 | orig_shape = (input["height"], input["width"])
81 | augmented_inputs, tfms = self._get_augmented_inputs(input)
82 |
83 | final_predictions = None
84 | count_predictions = 0
85 | for input, tfm in zip(augmented_inputs, tfms):
86 | count_predictions += 1
87 | with torch.no_grad():
88 | if final_predictions is None:
89 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
90 | final_predictions = (
91 | self.model([input])[0].pop("sem_seg").flip(dims=[2])
92 | )
93 | else:
94 | final_predictions = self.model([input])[0].pop("sem_seg")
95 | else:
96 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
97 | final_predictions += (
98 | self.model([input])[0].pop("sem_seg").flip(dims=[2])
99 | )
100 | else:
101 | final_predictions += self.model([input])[0].pop("sem_seg")
102 |
103 | final_predictions = final_predictions / count_predictions
104 | return {"sem_seg": final_predictions}
105 |
106 | def _get_augmented_inputs(self, input):
107 | augmented_inputs = self.tta_mapper(input)
108 | tfms = [x.pop("transforms") for x in augmented_inputs]
109 | return augmented_inputs, tfms
110 |
--------------------------------------------------------------------------------
/san/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .events import WandbWriter, setup_wandb
2 | from . import file_io
3 |
--------------------------------------------------------------------------------
/san/utils/events.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | from detectron2.utils import comm
4 | from detectron2.utils.events import EventWriter, get_event_storage
5 |
6 |
7 | def setup_wandb(cfg, args):
8 | if comm.is_main_process():
9 | init_args = {
10 | k.lower(): v
11 | for k, v in cfg.WANDB.items()
12 | if isinstance(k, str) and k not in ["config", "name"]
13 | }
14 | # only include most related part to avoid too big table
15 | # TODO: add configurable params to select which part of `cfg` should be saved in config
16 | if "config_exclude_keys" in init_args:
17 | init_args["config"] = cfg
18 | init_args["config"]["cfg_file"] = args.config_file
19 | else:
20 | init_args["config"] = {
21 | "model": cfg.MODEL,
22 | "solver": cfg.SOLVER,
23 | "cfg_file": args.config_file,
24 | }
25 | if ("name" not in init_args) or (init_args["name"] is None):
26 | init_args["name"] = os.path.basename(args.config_file)
27 | wandb.init(**init_args)
28 |
29 |
30 | class BaseRule(object):
31 | def __call__(self, target):
32 | return target
33 |
34 |
35 | class IsIn(BaseRule):
36 | def __init__(self, keyword: str):
37 | self.keyword = keyword
38 |
39 | def __call__(self, target):
40 | return self.keyword in target
41 |
42 |
43 | class Prefix(BaseRule):
44 | def __init__(self, keyword: str):
45 | self.keyword = keyword
46 |
47 | def __call__(self, target):
48 | return "/".join([self.keyword, target])
49 |
50 |
51 | class WandbWriter(EventWriter):
52 | """
53 | Write all scalars to a tensorboard file.
54 | """
55 |
56 | def __init__(self):
57 | """
58 | Args:
59 | log_dir (str): the directory to save the output events
60 | kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
61 | """
62 | self._last_write = -1
63 | self._group_rules = [
64 | (IsIn("/"), BaseRule()),
65 | (IsIn("loss"), Prefix("train")),
66 | ]
67 |
68 | def write(self):
69 | storage = get_event_storage()
70 |
71 | def _group_name(scalar_name):
72 | for rule, op in self._group_rules:
73 | if rule(scalar_name):
74 | return op(scalar_name)
75 | return scalar_name
76 |
77 | stats = {
78 | _group_name(name): scalars[0]
79 | for name, scalars in storage.latest().items()
80 | if scalars[1] > self._last_write
81 | }
82 | if len(stats) > 0:
83 | self._last_write = max([v[1] for k, v in storage.latest().items()])
84 |
85 | # storage.put_{image,histogram} is only meant to be used by
86 | # tensorboard writer. So we access its internal fields directly from here.
87 | if len(storage._vis_data) >= 1:
88 | stats["image"] = [
89 | wandb.Image(img, caption=img_name)
90 | for img_name, img, step_num in storage._vis_data
91 | ]
92 | # Storage stores all image data and rely on this writer to clear them.
93 | # As a result it assumes only one writer will use its image data.
94 | # An alternative design is to let storage store limited recent
95 | # data (e.g. only the most recent image) that all writers can access.
96 | # In that case a writer may not see all image data if its period is long.
97 | storage.clear_images()
98 |
99 | if len(storage._histograms) >= 1:
100 |
101 | def create_bar(tag, bucket_limits, bucket_counts, **kwargs):
102 | data = [
103 | [label, val] for (label, val) in zip(bucket_limits, bucket_counts)
104 | ]
105 | table = wandb.Table(data=data, columns=["label", "value"])
106 | return wandb.plot.bar(table, "label", "value", title=tag)
107 |
108 | stats["hist"] = [create_bar(**params) for params in storage._histograms]
109 |
110 | storage.clear_histograms()
111 |
112 | if len(stats) == 0:
113 | return
114 | wandb.log(stats, step=storage.iter)
115 |
116 | def close(self):
117 | wandb.finish()
118 |
--------------------------------------------------------------------------------
/san/utils/file_io.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import numpy as np
4 | import warnings
5 | from iopath.common.file_io import PathHandler
6 | from detectron2.utils.file_io import PathManager
7 | from zipfile import ZipFile
8 | from detectron2.utils.logger import log_first_n
9 | from io import BytesIO
10 | import multiprocessing
11 |
12 | __zip_file_pool__ = {
13 | # (path, mode): ZipFile
14 | }
15 |
16 |
17 | def find_zip_parent(path: str, mode: str = "r", is_dir=False):
18 | """Find the best match zipfile from the end"""
19 | if mode[-1] == "b":
20 | mode = mode[:-1]
21 | if is_dir:
22 | par_path = path
23 | else:
24 | par_path = os.path.dirname(path)
25 | visited = [par_path]
26 | # count = 0
27 | while par_path:
28 | # count += 1
29 | if ((par_path, mode) in __zip_file_pool__) and (
30 | __zip_file_pool__[(par_path, mode)].fp is not None
31 | ):
32 | # zip file is still open
33 | zip_file = __zip_file_pool__[(par_path, mode)]
34 | for path in visited[:-1]:
35 | __zip_file_pool__[(path, mode)] = zip_file
36 | return zip_file
37 | elif os.path.isfile(par_path + ".zip"):
38 | log_first_n(logging.INFO, "Open zip file {}.".format(par_path), n=1)
39 | zip_file = ZipFile(par_path + ".zip", mode=mode)
40 | for path in visited:
41 | __zip_file_pool__[(path, mode)] = zip_file
42 | # return par_path, zip_file, count
43 | return zip_file
44 |
45 | par_path = os.path.sep.join(par_path.split(os.path.sep)[:-1])
46 | visited.append(par_path)
47 | # return None, None, count
48 | return None
49 |
50 |
51 | class ZipFileHandler(PathHandler):
52 | """
53 | Load data from zipfile and return a file-like object
54 | """
55 |
56 | PREFIX = "zip://"
57 |
58 | def _get_supported_prefixes(self):
59 | return [self.PREFIX]
60 |
61 | def _get_local_path(self, path, **kwargs):
62 | name = path[len(self.PREFIX) :]
63 |
64 | return name
65 |
66 | def _open(self, path: str, mode: str = "r", buffering=-1, **kwargs):
67 | """Open a file and return a file object.
68 | Args:
69 | path (str): _description_
70 | mode (str, optional): file open mode. Defaults to "r".
71 |
72 | Returns:
73 | ByteIO: file-like object
74 | """
75 |
76 | path = self._get_local_path(path)
77 | zip_file: ZipFile = find_zip_parent(path, mode)
78 | if zip_file is None:
79 | warnings.warn(
80 | "No zipfile contains {}, falling back to naive PathHandler".format(
81 | path
82 | ),
83 | )
84 | return PathManager.open(path, mode, buffering, **kwargs)
85 | assert mode in [
86 | # "r",
87 | "rb",
88 | ], "Writing to ZipFile object is not thread safe. Only read mode is supported for now." # Need to deal with write mode carefully, maybe we will change it later.
89 | filename = os.path.join(
90 | zip_file.filelist[0].filename,
91 | path[len(os.path.splitext(zip_file.filename)[0]) + 1 :],
92 | )
93 | log_first_n(
94 | logging.INFO,
95 | "[Example] load file {} from zip file {}.".format(
96 | filename, zip_file.filename
97 | ),
98 | n=1,
99 | )
100 | assert (
101 | buffering == -1
102 | ), f"{self.__class__.__name__} does not support the `buffering` argument"
103 | # Use zipfile.Path in Python 3.10
104 |
105 | if mode[-1] == "b":
106 | mode = mode[:-1]
107 |
108 | return BytesIO(
109 | zip_file.read(filename)
110 | ) # If any errors occur, check whether a ZipFile object is called by multiple threads/processes at the same time.
111 |
112 | def _ls(self, path: str, **kwargs):
113 | """
114 | List the contents of the directory at the provided URI.
115 |
116 | Args:
117 | path (str): A URI supported by this PathHandler
118 |
119 | Returns:
120 | List[str]: list of contents in given path
121 | """
122 | path = self._get_local_path(path)
123 | zip_file: ZipFile = find_zip_parent(path, is_dir=True)
124 | assert zip_file is not None, "No zipfile contains {}".format(path)
125 | file_names = zip_file.namelist()
126 | in_archive_path = os.path.join(
127 | file_names[0], path[len(os.path.splitext(zip_file.filename)[0]) + 1 :]
128 | ).rstrip(os.path.sep)
129 | file_names = [
130 | f[len(in_archive_path) + 1 :]
131 | for f in file_names
132 | if f.startswith(in_archive_path)
133 | ]
134 | # must be closed to avoid thread safety issues
135 | zip_file.close()
136 | return file_names
137 |
138 |
139 | PathManager.register_handler(ZipFileHandler())
140 |
--------------------------------------------------------------------------------
/san/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
3 | """
4 | Misc functions, including distributed helpers.
5 | Mostly copy-paste from torchvision references.
6 | """
7 | from typing import List, Optional
8 | from functools import reduce
9 | import torch
10 | import torch.distributed as dist
11 | import torchvision
12 | from torch import Tensor
13 |
14 |
15 | def _max_by_axis(the_list):
16 | # type: (List[List[int]]) -> List[int]
17 | maxes = the_list[0]
18 | for sublist in the_list[1:]:
19 | for index, item in enumerate(sublist):
20 | maxes[index] = max(maxes[index], item)
21 | return maxes
22 |
23 |
24 | class NestedTensor(object):
25 | def __init__(self, tensors, mask: Optional[Tensor]):
26 | self.tensors = tensors
27 | self.mask = mask
28 |
29 | def to(self, device):
30 | # type: (Device) -> NestedTensor # noqa
31 | cast_tensor = self.tensors.to(device)
32 | mask = self.mask
33 | if mask is not None:
34 | assert mask is not None
35 | cast_mask = mask.to(device)
36 | else:
37 | cast_mask = None
38 | return NestedTensor(cast_tensor, cast_mask)
39 |
40 | def decompose(self):
41 | return self.tensors, self.mask
42 |
43 | def __repr__(self):
44 | return str(self.tensors)
45 |
46 |
47 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
48 | # TODO make this more general
49 | if tensor_list[0].ndim == 3:
50 | if torchvision._is_tracing():
51 | # nested_tensor_from_tensor_list() does not export well to ONNX
52 | # call _onnx_nested_tensor_from_tensor_list() instead
53 | return _onnx_nested_tensor_from_tensor_list(tensor_list)
54 |
55 | # TODO make it support different-sized images
56 | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
57 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
58 | batch_shape = [len(tensor_list)] + max_size
59 | b, c, h, w = batch_shape
60 | dtype = tensor_list[0].dtype
61 | device = tensor_list[0].device
62 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
63 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
64 | for img, pad_img, m in zip(tensor_list, tensor, mask):
65 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
66 | m[: img.shape[1], : img.shape[2]] = False
67 | else:
68 | raise ValueError("not supported")
69 | return NestedTensor(tensor, mask)
70 |
71 |
72 | # _onnx_nested_tensor_from_tensor_list() is an implementation of
73 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
74 | @torch.jit.unused
75 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
76 | max_size = []
77 | for i in range(tensor_list[0].dim()):
78 | max_size_i = torch.max(
79 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
80 | ).to(torch.int64)
81 | max_size.append(max_size_i)
82 | max_size = tuple(max_size)
83 |
84 | # work around for
85 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
86 | # m[: img.shape[1], :img.shape[2]] = False
87 | # which is not yet supported in onnx
88 | padded_imgs = []
89 | padded_masks = []
90 | for img in tensor_list:
91 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
92 | padded_img = torch.nn.functional.pad(
93 | img, (0, padding[2], 0, padding[1], 0, padding[0])
94 | )
95 | padded_imgs.append(padded_img)
96 |
97 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
98 | padded_mask = torch.nn.functional.pad(
99 | m, (0, padding[2], 0, padding[1]), "constant", 1
100 | )
101 | padded_masks.append(padded_mask.to(torch.bool))
102 |
103 | tensor = torch.stack(padded_imgs)
104 | mask = torch.stack(padded_masks)
105 |
106 | return NestedTensor(tensor, mask=mask)
107 |
108 |
109 | def is_dist_avail_and_initialized():
110 | if not dist.is_available():
111 | return False
112 | if not dist.is_initialized():
113 | return False
114 | return True
115 |
116 |
117 | def get_module_by_name(module, access_string):
118 | names = access_string.split(sep=".")
119 | return reduce(getattr, names, module)
120 |
--------------------------------------------------------------------------------
/train_net.py:
--------------------------------------------------------------------------------
1 | try:
2 | # ignore ShapelyDeprecationWarning from fvcore
3 | import warnings
4 |
5 | from shapely.errors import ShapelyDeprecationWarning
6 |
7 | warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)
8 | except:
9 | pass
10 | import copy
11 | import itertools
12 | import logging
13 | import os
14 | from collections import OrderedDict, defaultdict
15 | from typing import Any, Dict, List, Set
16 |
17 | import detectron2.utils.comm as comm
18 | import torch
19 |
20 | warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.parallel")
21 | from detectron2.checkpoint import DetectionCheckpointer
22 | from detectron2.config import get_cfg
23 | from detectron2.data import MetadataCatalog
24 | from detectron2.engine import (
25 | DefaultTrainer,
26 | default_argument_parser,
27 | default_setup,
28 | launch,
29 | )
30 | from detectron2.evaluation import (
31 | CityscapesSemSegEvaluator,
32 | DatasetEvaluators,
33 | SemSegEvaluator,
34 | verify_results,
35 | )
36 | from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
37 | from detectron2.solver.build import maybe_add_gradient_clipping
38 | from detectron2.utils.logger import setup_logger
39 | from tabulate import tabulate
40 |
41 | from san import (
42 | MaskFormerSemanticDatasetMapper,
43 | SemanticSegmentorWithTTA,
44 | add_san_config,
45 | )
46 | from san.data import build_detection_test_loader, build_detection_train_loader
47 | from san.utils import WandbWriter, setup_wandb
48 |
49 |
50 | class Trainer(DefaultTrainer):
51 | def build_writers(self):
52 | writers = super().build_writers()
53 | # use wandb writer instead.
54 | writers[-1] = WandbWriter()
55 | return writers
56 |
57 | @classmethod
58 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
59 | if output_folder is None:
60 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
61 | evaluator_list = []
62 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
63 | # semantic segmentation
64 | if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
65 | evaluator_list.append(
66 | SemSegEvaluator(
67 | dataset_name,
68 | distributed=True,
69 | output_dir=output_folder,
70 | )
71 | )
72 |
73 | if evaluator_type == "cityscapes_sem_seg":
74 | assert (
75 | torch.cuda.device_count() > comm.get_rank()
76 | ), "CityscapesEvaluator currently do not work with multiple machines."
77 | return CityscapesSemSegEvaluator(dataset_name)
78 |
79 | if len(evaluator_list) == 0:
80 | raise NotImplementedError(
81 | "no Evaluator for the dataset {} with the type {}".format(
82 | dataset_name, evaluator_type
83 | )
84 | )
85 | elif len(evaluator_list) == 1:
86 | return evaluator_list[0]
87 | return DatasetEvaluators(evaluator_list)
88 |
89 | @classmethod
90 | def build_train_loader(cls, cfg):
91 | # resue maskformer dataset mapper
92 | if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
93 | mapper = MaskFormerSemanticDatasetMapper(cfg, True)
94 | return build_detection_train_loader(cfg, mapper=mapper)
95 | else:
96 | mapper = None
97 | return build_detection_train_loader(cfg, mapper=mapper)
98 |
99 | @classmethod
100 | def build_test_loader(cls, cfg, dataset_name):
101 | # Add dataset meta info.
102 | return build_detection_test_loader(cfg, dataset_name)
103 |
104 | @classmethod
105 | def build_lr_scheduler(cls, cfg, optimizer):
106 | # use poly scheduler
107 | return build_lr_scheduler(cfg, optimizer)
108 |
109 | @classmethod
110 | def build_optimizer(cls, cfg, model):
111 | weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
112 | weight_decay_embed_group = cfg.SOLVER.WEIGHT_DECAY_EMBED_GROUP
113 | weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
114 |
115 | defaults = {}
116 | defaults["lr"] = cfg.SOLVER.BASE_LR
117 | defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
118 |
119 | norm_module_types = (
120 | torch.nn.BatchNorm1d,
121 | torch.nn.BatchNorm2d,
122 | torch.nn.BatchNorm3d,
123 | torch.nn.SyncBatchNorm,
124 | # NaiveSyncBatchNorm inherits from BatchNorm2d
125 | torch.nn.GroupNorm,
126 | torch.nn.InstanceNorm1d,
127 | torch.nn.InstanceNorm2d,
128 | torch.nn.InstanceNorm3d,
129 | torch.nn.LayerNorm,
130 | torch.nn.LocalResponseNorm,
131 | )
132 |
133 | params: List[Dict[str, Any]] = []
134 | memo: Set[torch.nn.parameter.Parameter] = set()
135 | for module_name, module in model.named_modules():
136 | for module_param_name, value in module.named_parameters(recurse=False):
137 | if not value.requires_grad:
138 | continue
139 | # Avoid duplicating parameters
140 | if value in memo:
141 | continue
142 | memo.add(value)
143 |
144 | hyperparams = copy.copy(defaults)
145 | hyperparams["param_name"] = ".".join([module_name, module_param_name])
146 | if "side_adapter_network" in module_name:
147 | hyperparams["lr"] = (
148 | hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
149 | )
150 | # scale clip lr
151 | if "clip" in module_name:
152 | hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.CLIP_MULTIPLIER
153 | if any([x in module_param_name for x in weight_decay_embed_group]):
154 | hyperparams["weight_decay"] = weight_decay_embed
155 | if isinstance(module, norm_module_types):
156 | hyperparams["weight_decay"] = weight_decay_norm
157 | if isinstance(module, torch.nn.Embedding):
158 | hyperparams["weight_decay"] = weight_decay_embed
159 | params.append({"params": [value], **hyperparams})
160 |
161 | def maybe_add_full_model_gradient_clipping(optim):
162 | # detectron2 doesn't have full model gradient clipping now
163 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
164 | enable = (
165 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED
166 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
167 | and clip_norm_val > 0.0
168 | )
169 |
170 | class FullModelGradientClippingOptimizer(optim):
171 | def step(self, closure=None):
172 | all_params = itertools.chain(
173 | *[x["params"] for x in self.param_groups]
174 | )
175 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
176 | super().step(closure=closure)
177 |
178 | return FullModelGradientClippingOptimizer if enable else optim
179 |
180 | optimizer_type = cfg.SOLVER.OPTIMIZER
181 | if optimizer_type == "SGD":
182 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
183 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
184 | )
185 | elif optimizer_type == "ADAMW":
186 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
187 | params, cfg.SOLVER.BASE_LR
188 | )
189 | else:
190 | raise NotImplementedError(f"no optimizer type {optimizer_type}")
191 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
192 | optimizer = maybe_add_gradient_clipping(cfg, optimizer)
193 | # display the lr and wd of each param group in a table
194 | optim_info = defaultdict(list)
195 | total_params_size = 0
196 | for group in optimizer.param_groups:
197 | optim_info["Param Name"].append(group["param_name"])
198 | optim_info["Param Shape"].append(
199 | "X".join([str(x) for x in list(group["params"][0].shape)])
200 | )
201 | total_params_size += group["params"][0].numel()
202 | optim_info["Lr"].append(group["lr"])
203 | optim_info["Wd"].append(group["weight_decay"])
204 | # Counting the number of parameters
205 | optim_info["Param Name"].append("Total")
206 | optim_info["Param Shape"].append("{:.2f}M".format(total_params_size / 1e6))
207 | optim_info["Lr"].append("-")
208 | optim_info["Wd"].append("-")
209 | table = tabulate(
210 | list(zip(*optim_info.values())),
211 | headers=optim_info.keys(),
212 | tablefmt="grid",
213 | floatfmt=".2e",
214 | stralign="center",
215 | numalign="center",
216 | )
217 | logger = logging.getLogger("san")
218 | logger.info("Optimizer Info:\n{}\n".format(table))
219 | return optimizer
220 |
221 | @classmethod
222 | def test_with_TTA(cls, cfg, model):
223 | logger = logging.getLogger("detectron2.trainer")
224 | # In the end of training, run an evaluation with TTA.
225 | logger.info("Running inference with test-time augmentation ...")
226 | model = SemanticSegmentorWithTTA(cfg, model)
227 | evaluators = [
228 | cls.build_evaluator(
229 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
230 | )
231 | for name in cfg.DATASETS.TEST
232 | ]
233 | res = cls.test(cfg, model, evaluators)
234 | res = OrderedDict({k + "_TTA": v for k, v in res.items()})
235 | return res
236 |
237 |
238 | def setup(args):
239 | """
240 | Create configs and perform basic setups.
241 | """
242 | cfg = get_cfg()
243 | # for poly lr schedule
244 | add_deeplab_config(cfg)
245 | add_san_config(cfg)
246 | cfg.merge_from_file(args.config_file)
247 | cfg.merge_from_list(args.opts)
248 | cfg.freeze()
249 | default_setup(cfg, args)
250 | if not args.eval_only:
251 | setup_wandb(cfg, args)
252 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="san")
253 | return cfg
254 |
255 |
256 | def main(args):
257 | cfg = setup(args)
258 |
259 | if args.eval_only:
260 | model = Trainer.build_model(cfg)
261 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
262 | cfg.MODEL.WEIGHTS, resume=args.resume
263 | )
264 | res = Trainer.test(cfg, model)
265 | if cfg.TEST.AUG.ENABLED:
266 | res.update(Trainer.test_with_TTA(cfg, model))
267 | if comm.is_main_process():
268 | verify_results(cfg, res)
269 | return res
270 |
271 | trainer = Trainer(cfg)
272 |
273 | trainer.resume_or_load(resume=args.resume)
274 | return trainer.train()
275 |
276 |
277 | if __name__ == "__main__":
278 | args = default_argument_parser().parse_args()
279 | print("Command Line Args:", args)
280 | launch(
281 | main,
282 | args.num_gpus,
283 | num_machines=args.num_machines,
284 | machine_rank=args.machine_rank,
285 | dist_url=args.dist_url,
286 | args=(args,),
287 | )
288 |
--------------------------------------------------------------------------------
/visualize_json_results.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | import argparse
5 | import json
6 | import numpy as np
7 | import os
8 | from collections import defaultdict
9 | import cv2
10 | import tqdm
11 | from fvcore.common.file_io import PathManager
12 | from PIL import Image
13 | from detectron2.data import DatasetCatalog, MetadataCatalog
14 | from detectron2.structures import Boxes, BoxMode, Instances
15 | from detectron2.utils.logger import setup_logger
16 | from detectron2.utils.visualizer import Visualizer, GenericMask
17 | import sys
18 |
19 | import san
20 |
21 |
22 | def create_instances(predictions, image_size, ignore_label=255):
23 | ret = Instances(image_size)
24 |
25 | labels = np.asarray(
26 | [dataset_id_map(predictions[i]["category_id"]) for i in range(len(predictions))]
27 | )
28 | ret.pred_classes = labels
29 | ret.pred_masks = [
30 | GenericMask(predictions[i]["segmentation"], *image_size)
31 | for i in range(len(predictions))
32 | ]
33 | # convert instance to sem_seg map
34 | sem_seg = np.ones(image_size[:2], dtype=np.uint16) * ignore_label
35 | for mask, label in zip(ret.pred_masks, ret.pred_classes):
36 | sem_seg[mask.mask == 1] = label
37 | return sem_seg
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser(
42 | description="A script that visualizes the json predictions from COCO or LVIS dataset."
43 | )
44 | parser.add_argument(
45 | "--input", required=True, help="JSON file produced by the model"
46 | )
47 | parser.add_argument("--output", required=True, help="output directory")
48 | parser.add_argument(
49 | "--dataset", help="name of the dataset", default="coco_2017_val"
50 | )
51 | parser.add_argument(
52 | "--conf-threshold", default=0.5, type=float, help="confidence threshold"
53 | )
54 | args = parser.parse_args()
55 |
56 | logger = setup_logger()
57 |
58 | with PathManager.open(args.input, "r") as f:
59 | predictions = json.load(f)
60 |
61 | pred_by_image = defaultdict(list)
62 | for p in predictions:
63 |
64 | pred_by_image[p["file_name"]].append(p)
65 |
66 | dicts = list(DatasetCatalog.get(args.dataset))
67 | metadata = MetadataCatalog.get(args.dataset)
68 | if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
69 |
70 | def dataset_id_map(ds_id):
71 | return metadata.thing_dataset_id_to_contiguous_id[ds_id]
72 |
73 | elif "lvis" in args.dataset:
74 | # LVIS results are in the same format as COCO results, but have a different
75 | # mapping from dataset category id to contiguous category id in [0, #categories - 1]
76 | def dataset_id_map(ds_id):
77 | return ds_id - 1
78 |
79 | elif "sem_seg" in args.dataset:
80 |
81 | def dataset_id_map(ds_id):
82 | return ds_id
83 |
84 | else:
85 | raise ValueError("Unsupported dataset: {}".format(args.dataset))
86 |
87 | os.makedirs(args.output, exist_ok=True)
88 |
89 | for dic in tqdm.tqdm(dicts):
90 | img = cv2.imread(dic["file_name"], cv2.IMREAD_COLOR)[:, :, ::-1]
91 | basename = os.path.basename(dic["file_name"])
92 | if dic["file_name"] in pred_by_image:
93 | pred = create_instances(
94 | pred_by_image[dic["file_name"]],
95 | img.shape[:2],
96 | ignore_label=metadata.ignore_label,
97 | )
98 |
99 | vis = Visualizer(img, metadata)
100 | vis_pred = vis.draw_sem_seg(pred).get_image()
101 | # import pdb
102 | # pdb.set_trace()
103 | vis = Visualizer(img, metadata)
104 | with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
105 | sem_seg = Image.open(f)
106 | sem_seg = np.asarray(sem_seg, dtype="uint16")
107 | vis_gt = vis.draw_sem_seg(sem_seg).get_image()
108 | # reisze pred and gt to the same height
109 | ratio = vis_gt.shape[0] / 512
110 | tgt_w = int(vis_pred.shape[1] / ratio)
111 | vis_pred = cv2.resize(vis_pred, (tgt_w,512))
112 | vis_gt = cv2.resize(vis_gt, (tgt_w,512))
113 | img = cv2.resize(img, (tgt_w,512))
114 | # build grid view
115 | blank_int = 255 * np.ones((vis_gt.shape[0], 10, 3), dtype=np.uint8)
116 | concat = np.concatenate(
117 | (img, blank_int, vis_pred, blank_int, vis_gt), axis=1
118 | )
119 | cv2.imwrite(os.path.join(args.output, basename), concat[:, :, ::-1])
--------------------------------------------------------------------------------