├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── app.py ├── configs ├── Base-coco-stuff-164K-171.yaml ├── san_clip_vit_large_res4_coco.yaml └── san_clip_vit_res4_coco.yaml ├── datasets ├── prepare_ade20k_sem_seg.py ├── prepare_pcontext_sem_seg_459cls.py ├── prepare_pcontext_sem_seg_59cls.py └── prepare_voc_sem_seg.py ├── docker ├── Dockerfile └── app.Dockerfile ├── docs ├── .nojekyll ├── README.md ├── arch.png ├── index.html └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ ├── images │ ├── favicon.ico │ └── grid.jpg │ ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js │ └── pdfs │ └── sample.pdf ├── install.sh ├── predict.py ├── requirements.txt ├── resources ├── arch.png ├── san_vit_b_16.log └── san_vit_large_14.log ├── san ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ ├── build.py │ ├── dataset_mappers │ │ ├── __init__.py │ │ └── mask_former_semantic_dataset_mapper.py │ └── datasets │ │ ├── __init__.py │ │ ├── register_ade20k_full.py │ │ ├── register_coco_stuff_164k.py │ │ ├── register_pcontext.py │ │ └── register_voc.py ├── model │ ├── __init__.py │ ├── attn_helper.py │ ├── clip_utils │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── utils.py │ │ └── visual.py │ ├── criterion.py │ ├── layers.py │ ├── matcher.py │ ├── san.py │ └── side_adapter │ │ ├── __init__.py │ │ ├── side_adapter.py │ │ └── timm_wrapper.py ├── test_time_augmentation.py └── utils │ ├── __init__.py │ ├── events.py │ ├── file_io.py │ └── misc.py ├── train_net.py └── visualize_json_results.py /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output 3 | instant_test_output 4 | inference_test_output 5 | 6 | 7 | *.json 8 | *.diff 9 | *.jpg 10 | !/projects/DensePose/doc/images/*.jpg 11 | !docs/static/images/*.jpg 12 | # compilation and distribution 13 | __pycache__ 14 | _ext 15 | *.pyc 16 | *.pyd 17 | *.so 18 | *.dll 19 | *.egg-info/ 20 | build/ 21 | dist/ 22 | wheels/ 23 | 24 | # pytorch/python/numpy formats 25 | *.pth 26 | *.pkl 27 | *.npy 28 | *.ts 29 | model_ts*.txt 30 | 31 | # ipython/jupyter notebooks 32 | *.ipynb 33 | **/.ipynb_checkpoints/ 34 | 35 | # Editor temporaries 36 | *.swn 37 | *.swo 38 | *.swp 39 | *~ 40 | 41 | # editor settings 42 | .idea 43 | .vscode 44 | _darcs 45 | 46 | # project dirs 47 | /detectron2/model_zoo/configs 48 | /datasets/* 49 | !/datasets/*.* 50 | /projects/*/datasets 51 | /models 52 | /snippet 53 | 54 | .history 55 | .env 56 | 57 | wandb 58 | amlt -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 22.12.0 4 | hooks: 5 | - id: black 6 | exclude: ^third_party/ 7 | - repo: https://github.com/asottile/seed-isort-config 8 | rev: v2.2.0 9 | hooks: 10 | - id: seed-isort-config 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v4.0.1 13 | hooks: 14 | - id: trailing-whitespace 15 | - id: check-yaml 16 | - id: end-of-file-fixer 17 | - id: requirements-txt-fixer 18 | - id: check-merge-conflict 19 | - id: fix-encoding-pragma 20 | args: ["--remove"] 21 | - id: mixed-line-ending 22 | args: ["--fix=lf"] 23 | # - repo: https://github.com/jumanjihouse/pre-commit-hooks 24 | # rev: 2.1.5 25 | # hooks: 26 | # - id: markdownlint 27 | # args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036", "-t", "allow_different_nesting"] 28 | - repo: https://github.com/myint/docformatter 29 | rev: v1.4 30 | hooks: 31 | - id: docformatter 32 | args: ["--in-place", "--wrap-descriptions", "79"] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MendelXu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR2023-Highlight] Side Adapter Network for Open-Vocabulary Semantic Segmentation 2 | # [PAMI] SAN: Side Adapter Network for Open-Vocabulary Semantic Segmentation 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/side-adapter-network-for-open-vocabulary/open-vocabulary-semantic-segmentation-on-2)](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-2?p=side-adapter-network-for-open-vocabulary) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/side-adapter-network-for-open-vocabulary/open-vocabulary-semantic-segmentation-on-3)](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-3?p=side-adapter-network-for-open-vocabulary) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/side-adapter-network-for-open-vocabulary/open-vocabulary-semantic-segmentation-on-7)](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-7?p=side-adapter-network-for-open-vocabulary) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/side-adapter-network-for-open-vocabulary/open-vocabulary-semantic-segmentation-on-1)](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-1?p=side-adapter-network-for-open-vocabulary) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/side-adapter-network-for-open-vocabulary/open-vocabulary-semantic-segmentation-on-5)](https://paperswithcode.com/sota/open-vocabulary-semantic-segmentation-on-5?p=side-adapter-network-for-open-vocabulary) 8 | 9 | This is the official implementation of our conference paper : "[Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)" (main branch) and journal paper: "[SAN: Side Adapter Network for Open-Vocabulary Semantic Segmentation 10 | ](https://www.computer.org/csdl/journal/tp/2023/12/10238837/1QaEJ3q98aY)" (video branch). 11 | 12 | ## Introduction 13 | 14 | This paper presents a new framework for open-vocabulary semantic segmentation with the pre-trained vision-language model, named Side Adapter Network (SAN). Our approach models the semantic segmentation task as a region recognition problem. A side network is attached to a frozen CLIP model with two branches: one for predicting mask proposals, and the other for predicting attention bias which is applied in the CLIP model to recognize the class of masks. This decoupled design has the benefit CLIP in recognizing the class of mask proposals. Since the attached side network can reuse CLIP features, it can be very light. In addition, the entire network can be trained end-to-end, allowing the side network to be adapted to the frozen CLIP model, which makes the predicted mask proposals CLIP-aware. 15 | Our approach is fast, accurate, and only adds a few additional trainable parameters. We evaluate our approach on multiple semantic segmentation benchmarks. Our method significantly outperforms other counterparts, with up to 18 times fewer trainable parameters and 19 times faster inference speed. 16 | ![](resources/arch.png) 17 | ### Tab of Content 18 | - [Demo](#6) 19 | - [Installation](#1) 20 | - [Data Preparation](#2) 21 | - [Usage](#3) 22 | - [Training](#5) 23 | - [Evaluation](#4) 24 | - [Visualization](#7) 25 | 26 | - [FAQ](#8) 27 | 28 | 29 | 30 | ### Demo 31 | - Run the demo app on [🤗HuggingFace](https://huggingface.co/spaces/Mendel192/SAN-Demo). (It is running on a low-spec machine and could be slow) 32 | - Run the demo app with docker. 33 | ``` 34 | docker build docker/app.Docker -t san_app 35 | docker run -it --shm-size 4G -p 7860:7860 san_app 36 | ``` 37 | 38 | 39 | ### Installation 40 | 1. Clone the repository 41 | ```sh 42 | git clone https://github.com/MendelXu/SAN.git 43 | ``` 44 | 2. Navigate to the project directory 45 | ```sh 46 | cd SAN 47 | ``` 48 | 3. Install the dependencies 49 | ```sh 50 | bash install.sh 51 | ``` 52 | **Hint**: You can run the job in the docker instead of installing dependencies locally. 53 | Run with pre-built docker: 54 | ``` 55 | docker run -it --gpus all --shm-size 8G mendelxu/pytorch:d2_nvcr_2008 /bin/bash 56 | ``` 57 | or build your docker with provided dockerfile `docker/Dcokerfile`. 58 | 59 | 60 | 61 | ### Data Preparation 62 | See [SimSeg](https://github.com/MendelXu/zsseg.baseline) for reference. The data should be organized like: 63 | ``` 64 | datasets/ 65 | coco/ 66 | ... 67 | train2017/ 68 | val2017/ 69 | stuffthingmaps_detectron2/ 70 | VOC2012/ 71 | ... 72 | images_detectron2/ 73 | annotations_detectron2/ 74 | pcontext/ 75 | ... 76 | val/ 77 | pcontext_full/ 78 | ... 79 | val/ 80 | ADEChallengeData2016/ 81 | ... 82 | images/ 83 | annotations_detectron2/ 84 | ADE20K_2021_17_01/ 85 | ... 86 | images/ 87 | annotations_detectron2/ 88 | ``` 89 | **Hint** In the code, those datasets are registered with their related dataset names. The relationship is: 90 | ``` 91 | coco_2017_*_stuff_sem_seg : COCO Stuff-171 92 | voc_sem_seg_*: Pascal VOC-20 93 | pcontext_sem_seg_*: Pascal Context-59 94 | ade20k_sem_seg_*: ADE-150 95 | pcontext_full_sem_seg_*: Pascal Context-459 96 | ade20k_full_sem_seg_*: ADE-847 97 | ``` 98 | 99 | 100 | ### Usage 101 | 102 | 103 | - #### Pretrained Weights 104 | 105 | |Model|Config |Weights|Logs| 106 | |-----|-------|---|---| 107 | |SAN-ViT-B/16|configs/san_clip_vit_res4_coco.yaml |[Huggingface](https://huggingface.co/Mendel192/san/resolve/main/san_vit_b_16.pth) |[Log](resources/san_vit_b_16.log) | 108 | |SAN-ViT-L/14|configs/san_clip_vit_large_res4_coco.yaml |[Huggingface](https://huggingface.co/Mendel192/san/resolve/main/san_vit_large_14.pth) |[Log](resources/san_vit_large_14.log)| 109 | 110 | 111 | - #### Evaluation 112 | 113 | 114 | - evaluate trained model on validation sets of all datasets. 115 | ```sh 116 | python train_net.py --eval-only --config-file --num-gpus OUTPUT_DIR MODEL.WEIGHTS 117 | ``` 118 | For example, evaluate our pre-trained model: 119 | ``` 120 | # 1. Download SAN (ViT-B/16 CLIP) from https://huggingface.co/Mendel192/san/blob/main/san_vit_b_16.pth. 121 | # 2. put it at `output/model.pth`. 122 | # 3. evaluation 123 | python train_net.py --eval-only --config-file configs/san_clip_vit_res4_coco.yaml --num-gpus 8 OUTPUT_DIR ./output/trained_vit_b16 MODEL.WEIGHTS output/model.pth 124 | ``` 125 | - evaluate trained model on validation sets of one dataset. 126 | ```sh 127 | python train_net.py --eval-only --config-file --num-gpus OUTPUT_DIR MODEL.WEIGHTS DATASETS.TEST "('',)" 128 | ``` 129 | 130 | - #### Visualization 131 | 132 | 133 | 134 | ```sh 135 | python visualize_json_results.py --input --output --dataset 136 | # example: 137 | # Generate the results. 138 | # python train_net.py --eval-only --config-file configs/san_clip_vit_res4_coco.yaml --num-gpus 1 OUTPUT_DIR ./output/trained_vit_b16 MODEL.WEIGHTS output/san/san_vit_b_16.pth DATASETS.TEST '("pcontext_sem_seg_val",)' 139 | # Visualizing 140 | # python visualize_json_results.py --input output/trained_vit_b16/inference/sem_seg_predictions.json --output output/viz --dataset pcontext_sem_seg_val 141 | ``` 142 | 143 | 144 | - #### Training 145 | 146 | ```sh 147 | wandb off 148 | # [Optional] If you want to log the training logs to wandb. 149 | # wandb login 150 | # wandb on 151 | python train_net.py --config-file --num-gpus OUTPUT_DIR WANDB.NAME 152 | ``` 153 | **Hint**: We use `<>` to denote the variables you should replace according to your own setting. 154 | 155 | ### FAQ 156 | 157 | 158 | 159 | If you found it is too late to get a response from the author on the github, please e-mail me directly at [shea.mendel] [AT] [gmail.com]. 160 | 161 | ### License 162 | Distributed under the MIT License. See LICENSE for more information. 163 | 164 | ### Cite 165 | 166 | If you find it helpful, you can cite our paper in your work. 167 | 168 | ``` 169 | @proceedings{xu2023side, 170 | title={Side Adapter Network for Open-Vocabulary Semantic Segmentation}, 171 | author={Mengde Xu, Zheng Zhang, Fangyun Wei, Han Hu, Xiang Bai}, 172 | journal={CVPR}, 173 | year={2023} 174 | } 175 | 176 | @article{xu2023san, 177 | title={SAN: Side adapter network for open-vocabulary semantic segmentation}, 178 | author={Xu, Mengde and Zhang, Zheng and Wei, Fangyun and Hu, Han and Bai, Xiang}, 179 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 180 | year={2023}, 181 | publisher={IEEE} 182 | } 183 | ``` 184 | 185 | 186 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from predict import Predictor, model_cfg 2 | from PIL import Image 3 | import gradio as gr 4 | 5 | # set a lot of global variables 6 | 7 | predictor = None 8 | vocabulary = ["bat man, woman"] 9 | input_image: Image.Image = None 10 | outputs: dict = None 11 | cur_model_name: str = None 12 | 13 | 14 | def set_vocabulary(text): 15 | global vocabulary 16 | vocabulary = text.split(",") 17 | print("set vocabulary to", vocabulary) 18 | 19 | 20 | def set_input(image): 21 | global input_image 22 | input_image = image 23 | print("set input image to", image) 24 | 25 | 26 | def set_predictor(model_name: str): 27 | global cur_model_name 28 | if cur_model_name == model_name: 29 | return 30 | global predictor 31 | predictor = Predictor(**model_cfg[model_name]) 32 | print("set predictor to", model_name) 33 | cur_model_name = model_name 34 | 35 | 36 | set_predictor(list(model_cfg.keys())[0]) 37 | 38 | 39 | # for visualization 40 | def visualize(vis_mode): 41 | if outputs is None: 42 | return None 43 | return predictor.visualize(**outputs, mode=vis_mode) 44 | 45 | 46 | def segment_image(vis_mode, voc_mode, model_name): 47 | set_predictor(model_name) 48 | if input_image is None: 49 | return None 50 | global outputs 51 | result = predictor.predict( 52 | input_image, vocabulary=vocabulary, augment_vocabulary=voc_mode 53 | ) 54 | outputs = result 55 | 56 | return visualize(vis_mode) 57 | 58 | 59 | def segment_e2e(image, vis_mode): 60 | set_input(image) 61 | return segment_image(vis_mode) 62 | 63 | 64 | # gradio 65 | 66 | with gr.Blocks( 67 | css=""" 68 | #submit {background: #3498db; color: white; border: none; padding: 10px 20px; border-radius: 5px;width: 20%;margin: 0 auto; display: block;} 69 | 70 | """ 71 | ) as demo: 72 | gr.Markdown( 73 | f"

Side Adapter Network for Open-Vocabulary Semantic Segmentation

" 74 | ) 75 | gr.Markdown( 76 | """ 77 | This is the demo for our conference paper : "[Side Adapter Network for Open-Vocabulary Semantic Segmentation](https://arxiv.org/abs/2302.12242)". 78 | """ 79 | ) 80 | # gr.Image(type="pil", value="./resources/arch.png", shape=(460, 200), elem_id="arch") 81 | gr.Markdown( 82 | """ 83 | --- 84 | """ 85 | ) 86 | with gr.Row(): 87 | image = gr.Image(type="pil", elem_id="input_image") 88 | plt = gr.Image(type="pil", elem_id="output_image") 89 | 90 | with gr.Row(): 91 | model_name = gr.Dropdown( 92 | list(model_cfg.keys()), label="Model", value="san_vit_b_16" 93 | ) 94 | augment_vocabulary = gr.Dropdown( 95 | ["COCO-all", "COCO-stuff"], 96 | label="Vocabulary Expansion", 97 | value="COCO-all", 98 | ) 99 | vis_mode = gr.Dropdown( 100 | ["overlay", "mask"], label="Visualization Mode", value="overlay" 101 | ) 102 | object_names = gr.Textbox(value=",".join(vocabulary), label="Object Names (Empty inputs will use the vocabulary specified in `Vocabulary Expansion`. Multiple names should be seperated with ,.)", lines=5) 103 | 104 | button = gr.Button("Run", elem_id="submit") 105 | note = gr.Markdown( 106 | """ 107 | --- 108 | ### FAQ 109 | - **Q**: What is the `Vocabulary Expansion` option for? 110 | **A**: The vocabulary expansion option is used to expand the vocabulary of the model. The model assign category to each area with `argmax`. When only a vocabulary with few thing classes is provided, it will produce much false postive. 111 | - **Q**: Error: `Unexpected token '<', " 0", "6->1", "12->2", "18->3"] 9 | ATTN_BIAS: 10 | NUM_HEADS: 16 -------------------------------------------------------------------------------- /configs/san_clip_vit_res4_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-coco-stuff-164K-171.yaml 2 | MODEL: 3 | META_ARCHITECTURE: "SAN" 4 | SOLVER: 5 | BACKBONE_MULTIPLIER: 1.0 -------------------------------------------------------------------------------- /datasets/prepare_ade20k_sem_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | import os 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import tqdm 9 | from PIL import Image 10 | 11 | 12 | def convert(input, output): 13 | img = np.asarray(Image.open(input)) 14 | assert img.dtype == np.uint8 15 | img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 16 | Image.fromarray(img).save(output) 17 | 18 | 19 | if __name__ == "__main__": 20 | dataset_dir = ( 21 | Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016" 22 | ) 23 | for name in ["training", "validation"]: 24 | annotation_dir = dataset_dir / "annotations" / name 25 | output_dir = dataset_dir / "annotations_detectron2" / name 26 | output_dir.mkdir(parents=True, exist_ok=True) 27 | for file in tqdm.tqdm(list(annotation_dir.iterdir())): 28 | output_file = output_dir / file.name 29 | convert(file, output_file) 30 | -------------------------------------------------------------------------------- /datasets/prepare_pcontext_sem_seg_459cls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Zhen Zhu(zzhu@hust.edu.cn) 4 | # Generate train & val data. 5 | 6 | 7 | import os 8 | import argparse 9 | import shutil 10 | from PIL import Image 11 | import numpy as np 12 | from tqdm import tqdm 13 | from scipy.io import loadmat 14 | 15 | LABEL_DIR = "label" 16 | IMAGE_DIR = "image" 17 | 18 | 19 | class PascalContextGenerator(object): 20 | def __init__(self, args, image_dir=IMAGE_DIR, label_dir=LABEL_DIR): 21 | self.args = args 22 | self.train_label_dir = os.path.join(self.args.save_dir, "train", label_dir) 23 | self.val_label_dir = os.path.join(self.args.save_dir, "val", label_dir) 24 | if not os.path.exists(self.train_label_dir): 25 | os.makedirs(self.train_label_dir) 26 | 27 | if not os.path.exists(self.val_label_dir): 28 | os.makedirs(self.val_label_dir) 29 | 30 | self.train_image_dir = os.path.join(self.args.save_dir, "train", image_dir) 31 | self.val_image_dir = os.path.join(self.args.save_dir, "val", image_dir) 32 | if not os.path.exists(self.train_image_dir): 33 | os.makedirs(self.train_image_dir) 34 | 35 | if not os.path.exists(self.val_image_dir): 36 | os.makedirs(self.val_image_dir) 37 | self.all_cls = set() 38 | 39 | def _class_to_index(self, mask): 40 | self.all_cls = self.all_cls.union(set(np.unique(mask).tolist())) 41 | # import pdb 42 | # pdb.set_trace() 43 | mask = mask - 1 44 | return mask 45 | 46 | def generate_label(self): 47 | _image_dir = os.path.join(self.args.img_dir, "JPEGImages") 48 | _anno_dir = self.args.anno_dir 49 | annFile = os.path.join(self.args.img_dir, "trainval_merged.json") 50 | 51 | from detail import Detail 52 | 53 | train_detail = Detail(annFile, _image_dir, "train") 54 | train_ids = train_detail.getImgs() 55 | 56 | for img_id in tqdm(train_ids): 57 | mask = loadmat( 58 | os.path.join(_anno_dir, img_id["file_name"].replace(".jpg", ".mat")) 59 | )["LabelMap"] 60 | mask = Image.fromarray(self._class_to_index(mask)) 61 | filename = img_id["file_name"] 62 | basename, _ = os.path.splitext(filename) 63 | if filename.endswith(".jpg"): 64 | imgpath = os.path.join(_image_dir, filename) 65 | shutil.copy(imgpath, os.path.join(self.train_image_dir, filename)) 66 | mask_png_name = basename + ".tif" 67 | mask.save(os.path.join(self.train_label_dir, mask_png_name)) 68 | 69 | val_detail = Detail(annFile, _image_dir, "val") 70 | val_ids = val_detail.getImgs() 71 | for img_id in tqdm(val_ids): 72 | mask = loadmat( 73 | os.path.join(_anno_dir, img_id["file_name"].replace(".jpg", ".mat")) 74 | )["LabelMap"] 75 | mask = Image.fromarray(self._class_to_index(mask)) 76 | filename = img_id["file_name"] 77 | basename, _ = os.path.splitext(filename) 78 | if filename.endswith(".jpg"): 79 | imgpath = os.path.join(_image_dir, filename) 80 | shutil.copy(imgpath, os.path.join(self.val_image_dir, filename)) 81 | mask_png_name = basename + ".tif" 82 | mask.save(os.path.join(self.val_label_dir, mask_png_name)) 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument( 88 | "--save_dir", 89 | default=None, 90 | type=str, 91 | dest="save_dir", 92 | help="The directory to save the data.", 93 | ) 94 | # ori_root_dir: VOCdevkit/VOC2010 95 | parser.add_argument( 96 | "--img_dir", 97 | default=None, 98 | type=str, 99 | dest="img_dir", 100 | help="The directory of the cityscapes data.", 101 | ) 102 | parser.add_argument( 103 | "--anno_dir", 104 | default=None, 105 | type=str, 106 | dest="anno_dir", 107 | help="The directory of the cityscapes data.", 108 | ) 109 | args = parser.parse_args() 110 | 111 | pascalcontext_seg_generator = PascalContextGenerator(args) 112 | pascalcontext_seg_generator.generate_label() 113 | print(pascalcontext_seg_generator.all_cls) 114 | -------------------------------------------------------------------------------- /datasets/prepare_pcontext_sem_seg_59cls.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Zhen Zhu(zzhu@hust.edu.cn) 4 | # Generate train & val data. 5 | 6 | 7 | import os 8 | import argparse 9 | import shutil 10 | from PIL import Image 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | LABEL_DIR = "label" 15 | IMAGE_DIR = "image" 16 | 17 | 18 | class PascalContextGenerator(object): 19 | def __init__(self, args, image_dir=IMAGE_DIR, label_dir=LABEL_DIR): 20 | self.args = args 21 | self.train_label_dir = os.path.join(self.args.save_dir, "train", label_dir) 22 | self.val_label_dir = os.path.join(self.args.save_dir, "val", label_dir) 23 | if not os.path.exists(self.train_label_dir): 24 | os.makedirs(self.train_label_dir) 25 | 26 | if not os.path.exists(self.val_label_dir): 27 | os.makedirs(self.val_label_dir) 28 | 29 | self.train_image_dir = os.path.join(self.args.save_dir, "train", image_dir) 30 | self.val_image_dir = os.path.join(self.args.save_dir, "val", image_dir) 31 | if not os.path.exists(self.train_image_dir): 32 | os.makedirs(self.train_image_dir) 33 | 34 | if not os.path.exists(self.val_image_dir): 35 | os.makedirs(self.val_image_dir) 36 | 37 | def _class_to_index(self, mask, _mapping, _key): 38 | # assert the values 39 | values = np.unique(mask) 40 | for i in range(len(values)): 41 | assert values[i] in _mapping 42 | index = np.digitize(mask.ravel(), _mapping, right=True) 43 | mask = _key[index].reshape(mask.shape) 44 | mask = mask - 1 45 | return mask 46 | 47 | def generate_label(self): 48 | _image_dir = os.path.join(self.args.ori_root_dir, "JPEGImages") 49 | annFile = os.path.join(self.args.ori_root_dir, "trainval_merged.json") 50 | _mapping = np.sort( 51 | np.array( 52 | [ 53 | 0, 54 | 2, 55 | 259, 56 | 260, 57 | 415, 58 | 324, 59 | 9, 60 | 258, 61 | 144, 62 | 18, 63 | 19, 64 | 22, 65 | 23, 66 | 397, 67 | 25, 68 | 284, 69 | 158, 70 | 159, 71 | 416, 72 | 33, 73 | 162, 74 | 420, 75 | 454, 76 | 295, 77 | 296, 78 | 427, 79 | 44, 80 | 45, 81 | 46, 82 | 308, 83 | 59, 84 | 440, 85 | 445, 86 | 31, 87 | 232, 88 | 65, 89 | 354, 90 | 424, 91 | 68, 92 | 326, 93 | 72, 94 | 458, 95 | 34, 96 | 207, 97 | 80, 98 | 355, 99 | 85, 100 | 347, 101 | 220, 102 | 349, 103 | 360, 104 | 98, 105 | 187, 106 | 104, 107 | 105, 108 | 366, 109 | 189, 110 | 368, 111 | 113, 112 | 115, 113 | ] 114 | ) 115 | ) 116 | _key = np.array(range(len(_mapping))).astype("uint8") 117 | 118 | from detail import Detail 119 | 120 | train_detail = Detail(annFile, _image_dir, "train") 121 | train_ids = train_detail.getImgs() 122 | for img_id in tqdm(train_ids): 123 | mask = Image.fromarray( 124 | self._class_to_index( 125 | train_detail.getMask(img_id), _mapping=_mapping, _key=_key 126 | ) 127 | ) 128 | filename = img_id["file_name"] 129 | basename, _ = os.path.splitext(filename) 130 | if filename.endswith(".jpg"): 131 | imgpath = os.path.join(_image_dir, filename) 132 | shutil.copy(imgpath, os.path.join(self.train_image_dir, filename)) 133 | mask_png_name = basename + ".png" 134 | mask.save(os.path.join(self.train_label_dir, mask_png_name)) 135 | 136 | val_detail = Detail(annFile, _image_dir, "val") 137 | val_ids = val_detail.getImgs() 138 | for img_id in tqdm(val_ids): 139 | mask = Image.fromarray( 140 | self._class_to_index( 141 | val_detail.getMask(img_id), _mapping=_mapping, _key=_key 142 | ) 143 | ) 144 | filename = img_id["file_name"] 145 | basename, _ = os.path.splitext(filename) 146 | if filename.endswith(".jpg"): 147 | imgpath = os.path.join(_image_dir, filename) 148 | shutil.copy(imgpath, os.path.join(self.val_image_dir, filename)) 149 | mask_png_name = basename + ".png" 150 | mask.save(os.path.join(self.val_label_dir, mask_png_name)) 151 | 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument( 156 | "--save_dir", 157 | default=None, 158 | type=str, 159 | dest="save_dir", 160 | help="The directory to save the data.", 161 | ) 162 | # ori_root_dir: VOCdevkit/VOC2010 163 | parser.add_argument( 164 | "--ori_root_dir", 165 | default=None, 166 | type=str, 167 | dest="ori_root_dir", 168 | help="The directory of the cityscapes data.", 169 | ) 170 | 171 | args = parser.parse_args() 172 | 173 | pascalcontext_seg_generator = PascalContextGenerator(args) 174 | pascalcontext_seg_generator.generate_label() 175 | -------------------------------------------------------------------------------- /datasets/prepare_voc_sem_seg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import shutil 5 | from functools import partial 6 | from glob import glob 7 | 8 | import mmcv 9 | import numpy as np 10 | from PIL import Image 11 | 12 | 13 | full_clsID_to_trID = { 14 | 0: 255, 15 | 1: 0, 16 | 2: 1, 17 | 3: 2, 18 | 4: 3, 19 | 5: 4, 20 | 6: 5, 21 | 7: 6, 22 | 8: 7, 23 | 9: 8, 24 | 10: 9, 25 | 11: 10, 26 | 12: 11, 27 | 13: 12, 28 | 14: 13, 29 | 15: 14, 30 | 16: 15, 31 | 17: 16, 32 | 18: 17, 33 | 19: 18, 34 | 20: 19, 35 | 255: 255, 36 | } 37 | # 2-cow/motobike, 4-airplane/sofa, 6-cat/tv, 8-train/bottle, 10- 38 | # chair/potted plant 39 | # "aeroplane", 40 | # "bicycle", 41 | # "bird", 42 | # "boat", 43 | # "bottle", 44 | # "bus", 45 | # "car", 46 | # "cat", 47 | # "chair", 48 | # "cow", 49 | # "diningtable", 50 | # "dog", 51 | # "horse", 52 | # "motorbike", 53 | # "person", 54 | # "pottedplant", 55 | # "sheep", 56 | # "sofa", 57 | # "train", 58 | # "tv", 59 | 60 | novel_clsID = [16, 17, 18, 19, 20] 61 | novel_2_clsID = [10, 14] 62 | novel_4_clsID = [1, 10, 14, 18] 63 | novel_6_clsID = [1, 8, 10, 14, 18, 20] 64 | novel_8_clsID = [1, 5, 8, 10, 14, 18, 19, 20] 65 | novel_10_clsID = [1, 5, 8, 9, 10, 14, 16, 18, 19, 20] 66 | 67 | base_clsID = [k for k in full_clsID_to_trID.keys() if k not in novel_clsID + [0, 255]] 68 | base_2_clsID = [k for k in base_clsID if k not in novel_2_clsID] 69 | base_4_clsID = [k for k in base_clsID if k not in novel_4_clsID] 70 | base_6_clsID = [k for k in base_clsID if k not in novel_6_clsID] 71 | base_8_clsID = [k for k in base_clsID if k not in novel_8_clsID] 72 | base_10_clsID = [k for k in base_clsID if k not in novel_10_clsID] 73 | 74 | novel_clsID_to_trID = {k: i for i, k in enumerate(novel_clsID)} 75 | base_clsID_to_trID = {k: i for i, k in enumerate(base_clsID)} 76 | 77 | base_2_clsID_to_trID = {k: i for i, k in enumerate(base_2_clsID)} 78 | base_4_clsID_to_trID = {k: i for i, k in enumerate(base_4_clsID)} 79 | base_6_clsID_to_trID = {k: i for i, k in enumerate(base_6_clsID)} 80 | base_8_clsID_to_trID = {k: i for i, k in enumerate(base_8_clsID)} 81 | base_10_clsID_to_trID = {k: i for i, k in enumerate(base_10_clsID)} 82 | 83 | 84 | def convert_to_trainID( 85 | maskpath, out_mask_dir, is_train, clsID_to_trID=full_clsID_to_trID, suffix="" 86 | ): 87 | mask = np.array(Image.open(maskpath)) 88 | mask_copy = np.ones_like(mask, dtype=np.uint8) * 255 89 | for clsID, trID in clsID_to_trID.items(): 90 | mask_copy[mask == clsID] = trID 91 | seg_filename = ( 92 | osp.join(out_mask_dir, "train" + suffix, osp.basename(maskpath)) 93 | if is_train 94 | else osp.join(out_mask_dir, "val" + suffix, osp.basename(maskpath)) 95 | ) 96 | if len(np.unique(mask_copy)) == 1 and np.unique(mask_copy)[0] == 255: 97 | return 98 | Image.fromarray(mask_copy).save(seg_filename, "PNG") 99 | 100 | 101 | def parse_args(): 102 | parser = argparse.ArgumentParser( 103 | description="Convert VOC2021 annotations to mmsegmentation format" 104 | ) # noqa 105 | parser.add_argument("voc_path", help="voc path") 106 | parser.add_argument("-o", "--out_dir", help="output path") 107 | parser.add_argument("--nproc", default=16, type=int, help="number of process") 108 | args = parser.parse_args() 109 | return args 110 | 111 | 112 | def main(): 113 | args = parse_args() 114 | voc_path = args.voc_path 115 | nproc = args.nproc 116 | print(full_clsID_to_trID) 117 | print(base_clsID_to_trID) 118 | print(novel_clsID_to_trID) 119 | out_dir = args.out_dir or voc_path 120 | # out_img_dir = osp.join(out_dir, 'images') 121 | out_mask_dir = osp.join(out_dir, "annotations_detectron2") 122 | out_image_dir = osp.join(out_dir, "images_detectron2") 123 | for dir_name in [ 124 | "train", 125 | "val", 126 | "train_base", 127 | "train_base_2", 128 | "train_base_4", 129 | "train_base_6", 130 | "train_base_8", 131 | "train_base_10", 132 | "train_novel", 133 | "val_base", 134 | "val_novel", 135 | ]: 136 | os.makedirs(osp.join(out_mask_dir, dir_name), exist_ok=True) 137 | if dir_name in ["train", "val"]: 138 | os.makedirs(osp.join(out_image_dir, dir_name), exist_ok=True) 139 | 140 | train_list = [ 141 | osp.join(voc_path, "SegmentationClassAug", f + ".png") 142 | for f in np.loadtxt(osp.join(voc_path, "train.txt"), dtype=np.str).tolist() 143 | ] 144 | test_list = [ 145 | osp.join(voc_path, "SegmentationClassAug", f + ".png") 146 | for f in np.loadtxt(osp.join(voc_path, "val.txt"), dtype=np.str).tolist() 147 | ] 148 | 149 | if args.nproc > 1: 150 | mmcv.track_parallel_progress( 151 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 152 | train_list, 153 | nproc=nproc, 154 | ) 155 | mmcv.track_parallel_progress( 156 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 157 | test_list, 158 | nproc=nproc, 159 | ) 160 | else: 161 | mmcv.track_progress( 162 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 163 | train_list, 164 | ) 165 | mmcv.track_progress( 166 | partial(convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 167 | test_list, 168 | ) 169 | 170 | print("Done!") 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.08-py3 2 | 3 | RUN pip install cython scipy shapely timm h5py submitit scikit-image wandb setuptools numpy Pillow pycocotools~=2.0.4 fvcore tabulate tqdm ftfy regex opencv-python open_clip_torch cityscapesscripts tensorboard 4 | RUN pip install 'git+https://github.com/facebookresearch/detectron2.git' 5 | RUN pip install opencv-python-headless==4.5.5.64 6 | -------------------------------------------------------------------------------- /docker/app.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.08-py3 2 | 3 | RUN pip install 'git+https://github.com/facebookresearch/detectron2.git' 4 | RUN pip install cython scipy shapely timm h5py submitit scikit-image wandb setuptools numpy Pillow pycocotools~=2.0.4 fvcore tabulate tqdm ftfy regex open_clip_torch cityscapesscripts tensorboard gradio 5 | 6 | RUN useradd -m -u 1000 user 7 | # Switch to the "user" user 8 | USER user 9 | # Set home to the user's home directory 10 | ENV HOME=/home/user \ 11 | PATH=/home/user/.local/bin:$PATH 12 | 13 | # Set the working directory to the user's home directory 14 | WORKDIR $HOME 15 | RUN git clone https://github.com/MendelXu/SAN app 16 | 17 | WORKDIR $HOME/app 18 | ENV GRADIO_SERVER_NAME=0.0.0.0 19 | EXPOSE 7860 20 | RUN echo "gradio app.py">>run.sh 21 | CMD ["script","-c","sh run.sh","/dev/null"] 22 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Academic Project Page Template 2 | This is an academic paper project page template. 3 | 4 | 5 | Example project pages built using this template are: 6 | - https://www.vision.huji.ac.il/deepsim/ 7 | - https://www.vision.huji.ac.il/3d_ads/ 8 | - https://www.vision.huji.ac.il/ssrl_ad/ 9 | - https://www.vision.huji.ac.il/conffusion/ 10 | 11 | 12 | ## Start using the template 13 | To start using the template click on `Use this Template`. 14 | 15 | The template uses html for controlling the content and css for controlling the style. 16 | To edit the websites contents edit the `index.html` file. It contains different HTML "building blocks", use whichever ones you need and comment out the rest. 17 | 18 | ## Components 19 | - Teaser video 20 | - Images Carousel 21 | - Youtube embedding 22 | - Video Carousel 23 | - PDF Poster 24 | - Bibtex citation 25 | 26 | ## Tips: 27 | - The `index.html` file contains comments instructing you what to replace, you should follow these comments. 28 | - The `meta` tags in the `index.html` file are used to provide metadata about your paper 29 | (e.g. helping search engine index the website, showing a preview image when sharing the website, etc.) 30 | - The resolution of images and videos can usually be around 1920-2048, there rarely a need for better resolution that take longer to load. 31 | - All the images and videos you use should be compressed to allow for fast loading of the website (and thus better indexing by search engines). For images, you can use [TinyPNG](https://tinypng.com), for videos you can need to find the tradeoff between size and quality. 32 | - When using large video files (larger than 10MB), it's better to use youtube for hosting the video as serving the video from the website can take time. 33 | - Using a tracker can help you analyze the traffic and see where users came from. [statcounter](https://statcounter.com) is a free, easy to use tracker that takes under 5 minutes to set up. 34 | - This project page can also be made into a github pages website. 35 | - Replace the favicon to one of your choosing (the default one is of the Hebrew University). 36 | - Suggestions, improvements and comments are welcome, simply open an issue or contact me. You can find my contact information at [https://pages.cs.huji.ac.il/eliahu-horwitz/](https://pages.cs.huji.ac.il/eliahu-horwitz/) 37 | 38 | ## Acknowledgments 39 | Parts of this project page were adopted from the [Nerfies](https://nerfies.github.io/) page. 40 | 41 | ## Website License 42 | Creative Commons License
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License. -------------------------------------------------------------------------------- /docs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/arch.png -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | padding: 20px; 121 | font-size: 0; 122 | } 123 | 124 | .results-carousel video { 125 | margin: 0; 126 | } 127 | 128 | .slider-pagination .slider-page { 129 | background: #000000; 130 | } 131 | 132 | .eql-cntrb { 133 | font-size: smaller; 134 | } 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /docs/static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/static/images/favicon.ico -------------------------------------------------------------------------------- /docs/static/images/grid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/static/images/grid.jpg -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | 7 | var options = { 8 | slidesToScroll: 1, 9 | slidesToShow: 1, 10 | loop: true, 11 | infinite: true, 12 | autoplay: true, 13 | autoplaySpeed: 5000, 14 | } 15 | 16 | // Initialize all div with carousel class 17 | var carousels = bulmaCarousel.attach('.carousel', options); 18 | 19 | bulmaSlider.attach(); 20 | 21 | }) 22 | -------------------------------------------------------------------------------- /docs/static/pdfs/sample.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/docs/static/pdfs/sample.pdf -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | python -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 2 | python -m pip install -r requirements.txt 3 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6' 4 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import numpy as np 4 | 5 | try: 6 | # ignore ShapelyDeprecationWarning from fvcore 7 | import warnings 8 | 9 | from shapely.errors import ShapelyDeprecationWarning 10 | 11 | warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) 12 | except: 13 | pass 14 | import os 15 | 16 | import huggingface_hub 17 | import torch 18 | from detectron2.checkpoint import DetectionCheckpointer 19 | from detectron2.config import get_cfg 20 | from detectron2.data import MetadataCatalog 21 | from detectron2.engine import DefaultTrainer 22 | from detectron2.projects.deeplab import add_deeplab_config 23 | from detectron2.utils.visualizer import Visualizer, random_color 24 | from huggingface_hub import hf_hub_download 25 | from PIL import Image 26 | 27 | from san import add_san_config 28 | from san.data.datasets.register_coco_stuff_164k import COCO_CATEGORIES 29 | 30 | model_cfg = { 31 | "san_vit_b_16": { 32 | "config_file": "configs/san_clip_vit_res4_coco.yaml", 33 | "model_path": "huggingface:san_vit_b_16.pth", 34 | }, 35 | "san_vit_large_16": { 36 | "config_file": "configs/san_clip_vit_large_res4_coco.yaml", 37 | "model_path": "huggingface:san_vit_large_14.pth", 38 | }, 39 | } 40 | 41 | 42 | def download_model(model_path: str): 43 | """ 44 | Download the model from huggingface hub. 45 | Args: 46 | model_path (str): the model path 47 | Returns: 48 | str: the downloaded model path 49 | """ 50 | if "HF_TOKEN" in os.environ: 51 | huggingface_hub.login(token=os.environ["HF_TOKEN"]) 52 | model_path = model_path.split(":")[1] 53 | model_path = hf_hub_download("Mendel192/san", filename=model_path) 54 | return model_path 55 | 56 | 57 | def setup(config_file: str, device=None): 58 | """ 59 | Create configs and perform basic setups. 60 | """ 61 | cfg = get_cfg() 62 | # for poly lr schedule 63 | add_deeplab_config(cfg) 64 | add_san_config(cfg) 65 | cfg.merge_from_file(config_file) 66 | cfg.MODEL.DEVICE = device or "cuda" if torch.cuda.is_available() else "cpu" 67 | cfg.freeze() 68 | return cfg 69 | 70 | 71 | class Predictor(object): 72 | def __init__(self, config_file: str, model_path: str): 73 | """ 74 | Args: 75 | config_file (str): the config file path 76 | model_path (str): the model path 77 | """ 78 | cfg = setup(config_file) 79 | self.model = DefaultTrainer.build_model(cfg) 80 | if model_path.startswith("huggingface:"): 81 | model_path = download_model(model_path) 82 | print("Loading model from: ", model_path) 83 | DetectionCheckpointer(self.model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 84 | model_path 85 | ) 86 | print("Loaded model from: ", model_path) 87 | self.model.eval() 88 | if torch.cuda.is_available(): 89 | self.device = torch.device("cuda") 90 | self.model = self.model.cuda() 91 | 92 | def predict( 93 | self, 94 | image_data_or_path: Union[Image.Image, str], 95 | vocabulary: List[str] = [], 96 | augment_vocabulary: Union[str,bool] = True, 97 | output_file: str = None, 98 | ) -> Union[dict, None]: 99 | """ 100 | Predict the segmentation result. 101 | Args: 102 | image_data_or_path (Union[Image.Image, str]): the input image or the image path 103 | vocabulary (List[str]): the vocabulary used for the segmentation 104 | augment_vocabulary (bool): whether to augment the vocabulary 105 | output_file (str): the output file path 106 | Returns: 107 | Union[dict, None]: the segmentation result 108 | """ 109 | if isinstance(image_data_or_path, str): 110 | image_data = Image.open(image_data_or_path) 111 | else: 112 | image_data = image_data_or_path 113 | w, h = image_data.size 114 | image_tensor: torch.Tensor = self._preprocess(image_data) 115 | vocabulary = list(set([v.lower().strip() for v in vocabulary])) 116 | # remove invalid vocabulary 117 | vocabulary = [v for v in vocabulary if v != ""] 118 | print("vocabulary:", vocabulary) 119 | ori_vocabulary = vocabulary 120 | 121 | if isinstance(augment_vocabulary,str): 122 | vocabulary = self.augment_vocabulary(vocabulary, augment_vocabulary) 123 | else: 124 | vocabulary = self._merge_vocabulary(vocabulary) 125 | if len(ori_vocabulary) == 0: 126 | ori_vocabulary = vocabulary 127 | with torch.no_grad(): 128 | result = self.model( 129 | [ 130 | { 131 | "image": image_tensor, 132 | "height": h, 133 | "width": w, 134 | "vocabulary": vocabulary, 135 | } 136 | ] 137 | )[0]["sem_seg"] 138 | seg_map = self._postprocess(result, ori_vocabulary) 139 | if output_file: 140 | self.visualize(image_data, seg_map, ori_vocabulary, output_file) 141 | return 142 | return { 143 | "image": image_data, 144 | "sem_seg": seg_map, 145 | "vocabulary": ori_vocabulary, 146 | } 147 | 148 | def visualize( 149 | self, 150 | image: Image.Image, 151 | sem_seg: np.ndarray, 152 | vocabulary: List[str], 153 | output_file: str = None, 154 | mode: str = "overlay", 155 | ) -> Union[Image.Image, None]: 156 | """ 157 | Visualize the segmentation result. 158 | Args: 159 | image (Image.Image): the input image 160 | sem_seg (np.ndarray): the segmentation result 161 | vocabulary (List[str]): the vocabulary used for the segmentation 162 | output_file (str): the output file path 163 | mode (str): the visualization mode, can be "overlay" or "mask" 164 | Returns: 165 | Image.Image: the visualization result. If output_file is not None, return None. 166 | """ 167 | # add temporary metadata 168 | # set numpy seed to make sure the colors are the same 169 | np.random.seed(0) 170 | colors = [random_color(rgb=True, maximum=255) for _ in range(len(vocabulary))] 171 | MetadataCatalog.get("_temp").set(stuff_classes=vocabulary, stuff_colors=colors) 172 | metadata = MetadataCatalog.get("_temp") 173 | if mode == "overlay": 174 | v = Visualizer(image, metadata) 175 | v = v.draw_sem_seg(sem_seg, area_threshold=0).get_image() 176 | v = Image.fromarray(v) 177 | else: 178 | v = np.zeros((image.size[1], image.size[0], 3), dtype=np.uint8) 179 | labels, areas = np.unique(sem_seg, return_counts=True) 180 | sorted_idxs = np.argsort(-areas).tolist() 181 | labels = labels[sorted_idxs] 182 | for label in filter(lambda l: l < len(metadata.stuff_classes), labels): 183 | v[sem_seg == label] = metadata.stuff_colors[label] 184 | v = Image.fromarray(v) 185 | # remove temporary metadata 186 | MetadataCatalog.remove("_temp") 187 | if output_file is None: 188 | return v 189 | v.save(output_file) 190 | print(f"saved to {output_file}") 191 | 192 | def _merge_vocabulary(self, vocabulary: List[str]) -> List[str]: 193 | default_voc = [c["name"] for c in COCO_CATEGORIES] 194 | return vocabulary + [c for c in default_voc if c not in vocabulary] 195 | 196 | def augment_vocabulary( 197 | self, vocabulary: List[str], aug_set: str = "COCO-all" 198 | ) -> List[str]: 199 | default_voc = [c["name"] for c in COCO_CATEGORIES] 200 | stuff_voc = [ 201 | c["name"] 202 | for c in COCO_CATEGORIES 203 | if "isthing" not in c or c["isthing"] == 0 204 | ] 205 | if aug_set == "COCO-all": 206 | return vocabulary + [c for c in default_voc if c not in vocabulary] 207 | elif aug_set == "COCO-stuff": 208 | return vocabulary + [c for c in stuff_voc if c not in vocabulary] 209 | else: 210 | return vocabulary 211 | 212 | def _preprocess(self, image: Image.Image) -> torch.Tensor: 213 | """ 214 | Preprocess the input image. 215 | Args: 216 | image (Image.Image): the input image 217 | Returns: 218 | torch.Tensor: the preprocessed image 219 | """ 220 | image = image.convert("RGB") 221 | # resize short side to 640 222 | w, h = image.size 223 | if w < h: 224 | image = image.resize((640, int(h * 640 / w))) 225 | else: 226 | image = image.resize((int(w * 640 / h), 640)) 227 | image = torch.from_numpy(np.asarray(image)).float() 228 | image = image.permute(2, 0, 1) 229 | return image 230 | 231 | def _postprocess( 232 | self, result: torch.Tensor, ori_vocabulary: List[str] 233 | ) -> np.ndarray: 234 | """ 235 | Postprocess the segmentation result. 236 | Args: 237 | result (torch.Tensor): the segmentation result 238 | ori_vocabulary (List[str]): the original vocabulary used for the segmentation 239 | Returns: 240 | np.ndarray: the postprocessed segmentation result 241 | """ 242 | result = result.argmax(dim=0).cpu().numpy() # (H, W) 243 | if len(ori_vocabulary) == 0: 244 | return result 245 | result[result >= len(ori_vocabulary)] = len(ori_vocabulary) 246 | return result 247 | 248 | 249 | def pre_download(): 250 | """pre downlaod model from huggingface and open_clip to avoid network issue.""" 251 | for model_name, model_info in model_cfg.items(): 252 | download_model(model_info["model_path"]) 253 | cfg = setup(model_info["config_file"]) 254 | DefaultTrainer.build_model(cfg) 255 | 256 | 257 | if __name__ == "__main__": 258 | from argparse import ArgumentParser 259 | 260 | parser = ArgumentParser() 261 | parser.add_argument( 262 | "--config-file", type=str, required=True, help="path to config file" 263 | ) 264 | parser.add_argument( 265 | "--model-path", type=str, required=True, help="path to model file" 266 | ) 267 | parser.add_argument( 268 | "--img-path", type=str, required=True, help="path to image file." 269 | ) 270 | parser.add_argument("--aug-vocab", action="store_true", help="augment vocabulary.") 271 | parser.add_argument( 272 | "--vocab", 273 | type=str, 274 | default="", 275 | help="list of category name. seperated with ,.", 276 | ) 277 | parser.add_argument( 278 | "--output-file", type=str, default=None, help="path to output file." 279 | ) 280 | args = parser.parse_args() 281 | predictor = Predictor(config_file=args.config_file, model_path=args.model_path) 282 | predictor.predict( 283 | args.img_path, 284 | args.vocab.split(","), 285 | args.aug_vocab, 286 | output_file=args.output_file, 287 | ) 288 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | scipy 3 | shapely 4 | timm 5 | h5py 6 | submitit 7 | scikit-image 8 | wandb 9 | setuptools 10 | numpy==1.22.4 11 | Pillow==9.3.0 12 | pycocotools~=2.0.4 13 | fvcore 14 | tabulate 15 | tqdm 16 | ftfy 17 | regex 18 | opencv-python 19 | open_clip_torch==2.16.0 20 | mmcv==1.3.14 -------------------------------------------------------------------------------- /resources/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/resources/arch.png -------------------------------------------------------------------------------- /san/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data # register all new datasets 2 | from . import model 3 | from . import utils 4 | 5 | # config 6 | from .config import add_san_config 7 | 8 | # dataset loading 9 | from .data.dataset_mappers.mask_former_semantic_dataset_mapper import ( 10 | MaskFormerSemanticDatasetMapper, 11 | ) 12 | 13 | # models 14 | from .test_time_augmentation import SemanticSegmentorWithTTA 15 | -------------------------------------------------------------------------------- /san/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from detectron2.config import CfgNode as CN 3 | 4 | 5 | def add_san_config(cfg): 6 | # copied from maskformer2 7 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" 8 | # Color augmentation 9 | cfg.INPUT.COLOR_AUG_SSD = False 10 | # We retry random cropping until no single category in semantic segmentation GT occupies more 11 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 12 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 13 | # Pad image and segmentation GT in dataset mapper. 14 | cfg.INPUT.SIZE_DIVISIBILITY = -1 15 | 16 | # solver config 17 | # optimizer 18 | # weight decay on embedding 19 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 20 | cfg.SOLVER.WEIGHT_DECAY_EMBED_GROUP = [ 21 | "absolute_pos_embed", 22 | "positional_embedding", 23 | "pos_embed", 24 | "query_embed", 25 | "relative_position_bias_table", 26 | ] 27 | cfg.SOLVER.OPTIMIZER = "ADAMW" 28 | cfg.SOLVER.BACKBONE_MULTIPLIER = 1.0 29 | cfg.SOLVER.CLIP_MULTIPLIER = 1.0 30 | cfg.SOLVER.TEST_IMS_PER_BATCH = 1 31 | 32 | # san 33 | cfg.MODEL.SAN = CN() 34 | cfg.MODEL.SAN.NO_OBJECT_WEIGHT = 0.1 35 | cfg.MODEL.SAN.CLASS_WEIGHT = 2.0 36 | cfg.MODEL.SAN.DICE_WEIGHT = 5.0 37 | cfg.MODEL.SAN.MASK_WEIGHT = 5.0 38 | cfg.MODEL.SAN.TRAIN_NUM_POINTS = 112 * 112 39 | cfg.MODEL.SAN.NUM_CLASSES = 171 40 | cfg.MODEL.SAN.OVERSAMPLE_RATIO = 3.0 41 | cfg.MODEL.SAN.IMPORTANCE_SAMPLE_RATIO = 0.75 42 | cfg.MODEL.SAN.CLIP_MODEL_NAME = "ViT-B/16" 43 | cfg.MODEL.SAN.CLIP_PRETRAINED_NAME = "openai" 44 | cfg.MODEL.SAN.CLIP_TEMPLATE_SET = "vild" 45 | cfg.MODEL.SAN.FEATURE_LAST_LAYER_IDX = 9 46 | cfg.MODEL.SAN.CLIP_FROZEN_EXCLUDE = ["positional_embedding"] 47 | cfg.MODEL.SAN.CLIP_DEEPER_FROZEN_EXCLUDE = [] 48 | cfg.MODEL.SAN.REC_CROSS_ATTN = False 49 | cfg.MODEL.SAN.REC_DOWNSAMPLE_METHOD = "max" 50 | cfg.MODEL.SAN.SOS_TOKEN_FORMAT = "cls_token" 51 | cfg.MODEL.SAN.SIZE_DIVISIBILITY = 32 52 | cfg.MODEL.SAN.ASYMETRIC_INPUT = True 53 | cfg.MODEL.SAN.CLIP_RESOLUTION = 0.5 54 | 55 | cfg.MODEL.SAN.SEM_SEG_POSTPROCESS_BEFORE_INFERENCE = True 56 | # side adapter 57 | cfg.MODEL.SIDE_ADAPTER = CN() 58 | cfg.MODEL.SIDE_ADAPTER.NAME = "RegionwiseSideAdapterNetwork" 59 | cfg.MODEL.SIDE_ADAPTER.VIT_NAME = "vit_w240n6d8_patch16" 60 | cfg.MODEL.SIDE_ADAPTER.PRETRAINED = False 61 | cfg.MODEL.SIDE_ADAPTER.IMAGE_SIZE = 640 62 | cfg.MODEL.SIDE_ADAPTER.DROP_PATH_RATE = 0.0 63 | cfg.MODEL.SIDE_ADAPTER.NUM_QUERIES = 100 64 | cfg.MODEL.SIDE_ADAPTER.FUSION_TYPE = "add" 65 | cfg.MODEL.SIDE_ADAPTER.FUSION_MAP = ["0->0", "3->1", "6->2", "9->3"] 66 | cfg.MODEL.SIDE_ADAPTER.DEEP_SUPERVISION_IDXS = [7, 8] 67 | 68 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS = CN() 69 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_HEADS = 12 70 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_LAYERS = 1 71 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.EMBED_CHANNELS = 256 72 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_CHANNELS = 256 73 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_NUM_LAYERS = 3 74 | cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.RESCALE_ATTN_BIAS = True 75 | 76 | # wandb 77 | cfg.WANDB = CN() 78 | cfg.WANDB.PROJECT = "san" 79 | cfg.WANDB.NAME = None 80 | # use flash attention 81 | cfg.MODEL.FLASH = False 82 | -------------------------------------------------------------------------------- /san/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from .build import build_detection_train_loader, build_detection_test_loader 3 | -------------------------------------------------------------------------------- /san/data/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MendelXu/SAN/e19262d1a23992fc2cef934d6756b89ff50fe5c3/san/data/dataset_mappers/__init__.py -------------------------------------------------------------------------------- /san/data/dataset_mappers/mask_former_semantic_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | from detectron2.config import configurable 9 | from detectron2.data import MetadataCatalog 10 | from detectron2.data import detection_utils as utils 11 | from detectron2.data import transforms as T 12 | from detectron2.projects.point_rend import ColorAugSSDTransform 13 | from detectron2.structures import BitMasks, Instances 14 | 15 | __all__ = ["MaskFormerSemanticDatasetMapper"] 16 | 17 | 18 | class MaskFormerSemanticDatasetMapper: 19 | """ 20 | A callable which takes a dataset dict in Detectron2 Dataset format, 21 | and map it into a format used by MaskFormer for semantic segmentation. 22 | 23 | The callable currently does the following: 24 | 25 | 1. Read the image from "file_name" 26 | 2. Applies geometric transforms to the image and annotation 27 | 3. Find and applies suitable cropping to the image and annotation 28 | 4. Prepare image and annotation to Tensors 29 | """ 30 | 31 | @configurable 32 | def __init__( 33 | self, 34 | is_train=True, 35 | *, 36 | augmentations, 37 | image_format, 38 | ignore_label, 39 | size_divisibility, 40 | ): 41 | """ 42 | NOTE: this interface is experimental. 43 | Args: 44 | is_train: for training or inference 45 | augmentations: a list of augmentations or deterministic transforms to apply 46 | image_format: an image format supported by :func:`detection_utils.read_image`. 47 | ignore_label: the label that is ignored to evaluation 48 | size_divisibility: pad image size to be divisible by this value 49 | """ 50 | self.is_train = is_train 51 | self.tfm_gens = augmentations 52 | self.img_format = image_format 53 | self.ignore_label = ignore_label 54 | self.size_divisibility = size_divisibility 55 | 56 | logger = logging.getLogger(__name__) 57 | mode = "training" if is_train else "inference" 58 | logger.info( 59 | f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}" 60 | ) 61 | 62 | @classmethod 63 | def from_config(cls, cfg, is_train=True): 64 | # Build augmentation 65 | augs = [ 66 | T.ResizeShortestEdge( 67 | cfg.INPUT.MIN_SIZE_TRAIN, 68 | cfg.INPUT.MAX_SIZE_TRAIN, 69 | cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING, 70 | ) 71 | ] 72 | if cfg.INPUT.CROP.ENABLED: 73 | augs.append( 74 | T.RandomCrop_CategoryAreaConstraint( 75 | cfg.INPUT.CROP.TYPE, 76 | cfg.INPUT.CROP.SIZE, 77 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA, 78 | cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 79 | ) 80 | ) 81 | if cfg.INPUT.COLOR_AUG_SSD: 82 | augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT)) 83 | augs.append(T.RandomFlip()) 84 | 85 | # Assume always applies to the training set. 86 | dataset_names = cfg.DATASETS.TRAIN 87 | meta = MetadataCatalog.get(dataset_names[0]) 88 | ignore_label = meta.ignore_label 89 | 90 | ret = { 91 | "is_train": is_train, 92 | "augmentations": augs, 93 | "image_format": cfg.INPUT.FORMAT, 94 | "ignore_label": ignore_label, 95 | "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY, 96 | } 97 | return ret 98 | 99 | def __call__(self, dataset_dict): 100 | """ 101 | Args: 102 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 103 | 104 | Returns: 105 | dict: a format that builtin models in detectron2 accept 106 | """ 107 | assert ( 108 | self.is_train 109 | ), "MaskFormerSemanticDatasetMapper should only be used for training!" 110 | 111 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 112 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 113 | utils.check_image_size(dataset_dict, image) 114 | 115 | if "sem_seg_file_name" in dataset_dict: 116 | # PyTorch transformation not implemented for uint16, so converting it to double first 117 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype( 118 | "double" 119 | ) 120 | else: 121 | sem_seg_gt = None 122 | 123 | if sem_seg_gt is None: 124 | raise ValueError( 125 | "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( 126 | dataset_dict["file_name"] 127 | ) 128 | ) 129 | 130 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 131 | aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) 132 | image = aug_input.image 133 | sem_seg_gt = aug_input.sem_seg 134 | 135 | # Pad image and segmentation label here! 136 | image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 137 | if sem_seg_gt is not None: 138 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) 139 | 140 | if self.size_divisibility > 0: 141 | image_size = (image.shape[-2], image.shape[-1]) 142 | padding_size = [ 143 | 0, 144 | self.size_divisibility - image_size[1], 145 | 0, 146 | self.size_divisibility - image_size[0], 147 | ] 148 | image = F.pad(image, padding_size, value=128).contiguous() 149 | if sem_seg_gt is not None: 150 | sem_seg_gt = F.pad( 151 | sem_seg_gt, padding_size, value=self.ignore_label 152 | ).contiguous() 153 | 154 | image_shape = (image.shape[-2], image.shape[-1]) # h, w 155 | 156 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 157 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 158 | # Therefore it's important to use torch.Tensor. 159 | dataset_dict["image"] = image 160 | 161 | if sem_seg_gt is not None: 162 | dataset_dict["sem_seg"] = sem_seg_gt.long() 163 | 164 | if "annotations" in dataset_dict: 165 | raise ValueError( 166 | "Semantic segmentation dataset should not have 'annotations'." 167 | ) 168 | 169 | # Prepare per-category binary masks 170 | if sem_seg_gt is not None: 171 | sem_seg_gt = sem_seg_gt.numpy() 172 | instances = Instances(image_shape) 173 | classes = np.unique(sem_seg_gt) 174 | # remove ignored region 175 | classes = classes[classes != self.ignore_label] 176 | instances.gt_classes = torch.tensor(classes, dtype=torch.int64) 177 | 178 | masks = [] 179 | for class_id in classes: 180 | masks.append(sem_seg_gt == class_id) 181 | 182 | if len(masks) == 0: 183 | # Some image does not have annotation (all ignored) 184 | instances.gt_masks = torch.zeros( 185 | (0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]) 186 | ) 187 | else: 188 | masks = BitMasks( 189 | torch.stack( 190 | [ 191 | torch.from_numpy(np.ascontiguousarray(x.copy())) 192 | for x in masks 193 | ] 194 | ) 195 | ) 196 | instances.gt_masks = masks.tensor 197 | 198 | dataset_dict["instances"] = instances 199 | 200 | return dataset_dict 201 | -------------------------------------------------------------------------------- /san/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | register_ade20k_full, 3 | register_coco_stuff_164k, 4 | register_pcontext, 5 | register_voc, 6 | ) 7 | -------------------------------------------------------------------------------- /san/data/datasets/register_coco_stuff_164k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from detectron2.data import DatasetCatalog, MetadataCatalog 4 | from detectron2.data.datasets import load_sem_seg 5 | 6 | COCO_CATEGORIES = [ 7 | {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, 8 | {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, 9 | {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, 10 | {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, 11 | {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, 12 | {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, 13 | {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, 14 | {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, 15 | {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, 16 | {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, 17 | {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, 18 | {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, 19 | {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, 20 | {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, 21 | {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, 22 | {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, 23 | {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, 24 | {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, 25 | {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, 26 | {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, 27 | {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, 28 | {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, 29 | {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, 30 | {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, 31 | {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, 32 | {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, 33 | {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, 34 | {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, 35 | {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, 36 | {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, 37 | {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, 38 | {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, 39 | {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, 40 | {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, 41 | {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, 42 | {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, 43 | {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, 44 | {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, 45 | {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, 46 | {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, 47 | {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, 48 | {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, 49 | {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, 50 | {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, 51 | {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, 52 | {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, 53 | {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, 54 | {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, 55 | {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, 56 | {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, 57 | {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, 58 | {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, 59 | {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, 60 | {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, 61 | {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, 62 | {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, 63 | {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, 64 | {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, 65 | {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, 66 | {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, 67 | {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, 68 | {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, 69 | {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, 70 | {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, 71 | {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, 72 | {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, 73 | {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, 74 | {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, 75 | {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, 76 | {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, 77 | {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, 78 | {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, 79 | {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, 80 | {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, 81 | {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, 82 | {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, 83 | {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, 84 | {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, 85 | {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, 86 | {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, 87 | {"id": 92, "name": "banner", "supercategory": "textile"}, 88 | {"id": 93, "name": "blanket", "supercategory": "textile"}, 89 | {"id": 94, "name": "branch", "supercategory": "plant"}, 90 | {"id": 95, "name": "bridge", "supercategory": "building"}, 91 | {"id": 96, "name": "building-other", "supercategory": "building"}, 92 | {"id": 97, "name": "bush", "supercategory": "plant"}, 93 | {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"}, 94 | {"id": 99, "name": "cage", "supercategory": "structural"}, 95 | {"id": 100, "name": "cardboard", "supercategory": "raw-material"}, 96 | {"id": 101, "name": "carpet", "supercategory": "floor"}, 97 | {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"}, 98 | {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"}, 99 | {"id": 104, "name": "cloth", "supercategory": "textile"}, 100 | {"id": 105, "name": "clothes", "supercategory": "textile"}, 101 | {"id": 106, "name": "clouds", "supercategory": "sky"}, 102 | {"id": 107, "name": "counter", "supercategory": "furniture-stuff"}, 103 | {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"}, 104 | {"id": 109, "name": "curtain", "supercategory": "textile"}, 105 | {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"}, 106 | {"id": 111, "name": "dirt", "supercategory": "ground"}, 107 | {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"}, 108 | {"id": 113, "name": "fence", "supercategory": "structural"}, 109 | {"id": 114, "name": "floor-marble", "supercategory": "floor"}, 110 | {"id": 115, "name": "floor-other", "supercategory": "floor"}, 111 | {"id": 116, "name": "floor-stone", "supercategory": "floor"}, 112 | {"id": 117, "name": "floor-tile", "supercategory": "floor"}, 113 | {"id": 118, "name": "floor-wood", "supercategory": "floor"}, 114 | {"id": 119, "name": "flower", "supercategory": "plant"}, 115 | {"id": 120, "name": "fog", "supercategory": "water"}, 116 | {"id": 121, "name": "food-other", "supercategory": "food-stuff"}, 117 | {"id": 122, "name": "fruit", "supercategory": "food-stuff"}, 118 | {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"}, 119 | {"id": 124, "name": "grass", "supercategory": "plant"}, 120 | {"id": 125, "name": "gravel", "supercategory": "ground"}, 121 | {"id": 126, "name": "ground-other", "supercategory": "ground"}, 122 | {"id": 127, "name": "hill", "supercategory": "solid"}, 123 | {"id": 128, "name": "house", "supercategory": "building"}, 124 | {"id": 129, "name": "leaves", "supercategory": "plant"}, 125 | {"id": 130, "name": "light", "supercategory": "furniture-stuff"}, 126 | {"id": 131, "name": "mat", "supercategory": "textile"}, 127 | {"id": 132, "name": "metal", "supercategory": "raw-material"}, 128 | {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"}, 129 | {"id": 134, "name": "moss", "supercategory": "plant"}, 130 | {"id": 135, "name": "mountain", "supercategory": "solid"}, 131 | {"id": 136, "name": "mud", "supercategory": "ground"}, 132 | {"id": 137, "name": "napkin", "supercategory": "textile"}, 133 | {"id": 138, "name": "net", "supercategory": "structural"}, 134 | {"id": 139, "name": "paper", "supercategory": "raw-material"}, 135 | {"id": 140, "name": "pavement", "supercategory": "ground"}, 136 | {"id": 141, "name": "pillow", "supercategory": "textile"}, 137 | {"id": 142, "name": "plant-other", "supercategory": "plant"}, 138 | {"id": 143, "name": "plastic", "supercategory": "raw-material"}, 139 | {"id": 144, "name": "platform", "supercategory": "ground"}, 140 | {"id": 145, "name": "playingfield", "supercategory": "ground"}, 141 | {"id": 146, "name": "railing", "supercategory": "structural"}, 142 | {"id": 147, "name": "railroad", "supercategory": "ground"}, 143 | {"id": 148, "name": "river", "supercategory": "water"}, 144 | {"id": 149, "name": "road", "supercategory": "ground"}, 145 | {"id": 150, "name": "rock", "supercategory": "solid"}, 146 | {"id": 151, "name": "roof", "supercategory": "building"}, 147 | {"id": 152, "name": "rug", "supercategory": "textile"}, 148 | {"id": 153, "name": "salad", "supercategory": "food-stuff"}, 149 | {"id": 154, "name": "sand", "supercategory": "ground"}, 150 | {"id": 155, "name": "sea", "supercategory": "water"}, 151 | {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"}, 152 | {"id": 157, "name": "sky-other", "supercategory": "sky"}, 153 | {"id": 158, "name": "skyscraper", "supercategory": "building"}, 154 | {"id": 159, "name": "snow", "supercategory": "ground"}, 155 | {"id": 160, "name": "solid-other", "supercategory": "solid"}, 156 | {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"}, 157 | {"id": 162, "name": "stone", "supercategory": "solid"}, 158 | {"id": 163, "name": "straw", "supercategory": "plant"}, 159 | {"id": 164, "name": "structural-other", "supercategory": "structural"}, 160 | {"id": 165, "name": "table", "supercategory": "furniture-stuff"}, 161 | {"id": 166, "name": "tent", "supercategory": "building"}, 162 | {"id": 167, "name": "textile-other", "supercategory": "textile"}, 163 | {"id": 168, "name": "towel", "supercategory": "textile"}, 164 | {"id": 169, "name": "tree", "supercategory": "plant"}, 165 | {"id": 170, "name": "vegetable", "supercategory": "food-stuff"}, 166 | {"id": 171, "name": "wall-brick", "supercategory": "wall"}, 167 | {"id": 172, "name": "wall-concrete", "supercategory": "wall"}, 168 | {"id": 173, "name": "wall-other", "supercategory": "wall"}, 169 | {"id": 174, "name": "wall-panel", "supercategory": "wall"}, 170 | {"id": 175, "name": "wall-stone", "supercategory": "wall"}, 171 | {"id": 176, "name": "wall-tile", "supercategory": "wall"}, 172 | {"id": 177, "name": "wall-wood", "supercategory": "wall"}, 173 | {"id": 178, "name": "water-other", "supercategory": "water"}, 174 | {"id": 179, "name": "waterdrops", "supercategory": "water"}, 175 | {"id": 180, "name": "window-blind", "supercategory": "window"}, 176 | {"id": 181, "name": "window-other", "supercategory": "window"}, 177 | {"id": 182, "name": "wood", "supercategory": "solid"}, 178 | ] 179 | 180 | 181 | def _get_coco_stuff_meta(): 182 | # Id 0 is reserved for ignore_label, we change ignore_label for 0 183 | # to 255 in our pre-processing. 184 | stuff_ids = [k["id"] for k in COCO_CATEGORIES] 185 | assert len(stuff_ids) == 171, len(stuff_ids) 186 | 187 | # For semantic segmentation, this mapping maps from contiguous stuff id 188 | # (in [0, 91], used in models) to ids in the dataset (used for processing results) 189 | stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} 190 | stuff_classes = [k["name"] for k in COCO_CATEGORIES] 191 | 192 | ret = { 193 | "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, 194 | "stuff_classes": stuff_classes, 195 | } 196 | return ret 197 | 198 | 199 | def register_all_coco_stuff_164k(root): 200 | root = os.path.join(root, "coco") 201 | meta = _get_coco_stuff_meta() 202 | 203 | for name, image_dirname, sem_seg_dirname in [ 204 | ("train", "train2017", "stuffthingmaps_detectron2/train2017"), 205 | ("test", "val2017", "stuffthingmaps_detectron2/val2017"), 206 | ]: 207 | image_dir = os.path.join(root, image_dirname) 208 | gt_dir = os.path.join(root, sem_seg_dirname) 209 | all_name = f"coco_2017_{name}_stuff_sem_seg" 210 | DatasetCatalog.register( 211 | all_name, 212 | lambda x=image_dir, y=gt_dir: load_sem_seg( 213 | y, x, gt_ext="png", image_ext="jpg" 214 | ), 215 | ) 216 | MetadataCatalog.get(all_name).set( 217 | image_root=image_dir, 218 | sem_seg_root=gt_dir, 219 | evaluator_type="sem_seg", 220 | ignore_label=255, 221 | **meta, 222 | ) 223 | 224 | 225 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 226 | register_all_coco_stuff_164k(_root) 227 | -------------------------------------------------------------------------------- /san/data/datasets/register_pcontext.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from detectron2.data import DatasetCatalog, MetadataCatalog 4 | from detectron2.data.datasets import load_sem_seg 5 | 6 | PCONTEXT_SEM_SEG_CATEGORIES = [ 7 | "aeroplane", 8 | "bag", 9 | "bed", 10 | "bedclothes", 11 | "bench", 12 | "bicycle", 13 | "bird", 14 | "boat", 15 | "book", 16 | "bottle", 17 | "building", 18 | "bus", 19 | "cabinet", 20 | "car", 21 | "cat", 22 | "ceiling", 23 | "chair", 24 | "cloth", 25 | "computer", 26 | "cow", 27 | "cup", 28 | "curtain", 29 | "dog", 30 | "door", 31 | "fence", 32 | "floor", 33 | "flower", 34 | "food", 35 | "grass", 36 | "ground", 37 | "horse", 38 | "keyboard", 39 | "light", 40 | "motorbike", 41 | "mountain", 42 | "mouse", 43 | "person", 44 | "plate", 45 | "platform", 46 | "pottedplant", 47 | "road", 48 | "rock", 49 | "sheep", 50 | "shelves", 51 | "sidewalk", 52 | "sign", 53 | "sky", 54 | "snow", 55 | "sofa", 56 | "diningtable", 57 | "track", 58 | "train", 59 | "tree", 60 | "truck", 61 | "tvmonitor", 62 | "wall", 63 | "water", 64 | "window", 65 | "wood", 66 | ] 67 | 68 | PCONTEXT_FULL_SEM_SEG_CATEGORIES = [ 69 | "accordion", 70 | "aeroplane", 71 | "air conditioner", 72 | "antenna", 73 | "artillery", 74 | "ashtray", 75 | "atrium", 76 | "baby carriage", 77 | "bag", 78 | "ball", 79 | "balloon", 80 | "bamboo weaving", 81 | "barrel", 82 | "baseball bat", 83 | "basket", 84 | "basketball backboard", 85 | "bathtub", 86 | "bed", 87 | "bedclothes", 88 | "beer", 89 | "bell", 90 | "bench", 91 | "bicycle", 92 | "binoculars", 93 | "bird", 94 | "bird cage", 95 | "bird feeder", 96 | "bird nest", 97 | "blackboard", 98 | "board", 99 | "boat", 100 | "bone", 101 | "book", 102 | "bottle", 103 | "bottle opener", 104 | "bowl", 105 | "box", 106 | "bracelet", 107 | "brick", 108 | "bridge", 109 | "broom", 110 | "brush", 111 | "bucket", 112 | "building", 113 | "bus", 114 | "cabinet", 115 | "cabinet door", 116 | "cage", 117 | "cake", 118 | "calculator", 119 | "calendar", 120 | "camel", 121 | "camera", 122 | "camera lens", 123 | "can", 124 | "candle", 125 | "candle holder", 126 | "cap", 127 | "car", 128 | "card", 129 | "cart", 130 | "case", 131 | "casette recorder", 132 | "cash register", 133 | "cat", 134 | "cd", 135 | "cd player", 136 | "ceiling", 137 | "cell phone", 138 | "cello", 139 | "chain", 140 | "chair", 141 | "chessboard", 142 | "chicken", 143 | "chopstick", 144 | "clip", 145 | "clippers", 146 | "clock", 147 | "closet", 148 | "cloth", 149 | "clothes tree", 150 | "coffee", 151 | "coffee machine", 152 | "comb", 153 | "computer", 154 | "concrete", 155 | "cone", 156 | "container", 157 | "control booth", 158 | "controller", 159 | "cooker", 160 | "copying machine", 161 | "coral", 162 | "cork", 163 | "corkscrew", 164 | "counter", 165 | "court", 166 | "cow", 167 | "crabstick", 168 | "crane", 169 | "crate", 170 | "cross", 171 | "crutch", 172 | "cup", 173 | "curtain", 174 | "cushion", 175 | "cutting board", 176 | "dais", 177 | "disc", 178 | "disc case", 179 | "dishwasher", 180 | "dock", 181 | "dog", 182 | "dolphin", 183 | "door", 184 | "drainer", 185 | "dray", 186 | "drink dispenser", 187 | "drinking machine", 188 | "drop", 189 | "drug", 190 | "drum", 191 | "drum kit", 192 | "duck", 193 | "dumbbell", 194 | "earphone", 195 | "earrings", 196 | "egg", 197 | "electric fan", 198 | "electric iron", 199 | "electric pot", 200 | "electric saw", 201 | "electronic keyboard", 202 | "engine", 203 | "envelope", 204 | "equipment", 205 | "escalator", 206 | "exhibition booth", 207 | "extinguisher", 208 | "eyeglass", 209 | "fan", 210 | "faucet", 211 | "fax machine", 212 | "fence", 213 | "ferris wheel", 214 | "fire extinguisher", 215 | "fire hydrant", 216 | "fire place", 217 | "fish", 218 | "fish tank", 219 | "fishbowl", 220 | "fishing net", 221 | "fishing pole", 222 | "flag", 223 | "flagstaff", 224 | "flame", 225 | "flashlight", 226 | "floor", 227 | "flower", 228 | "fly", 229 | "foam", 230 | "food", 231 | "footbridge", 232 | "forceps", 233 | "fork", 234 | "forklift", 235 | "fountain", 236 | "fox", 237 | "frame", 238 | "fridge", 239 | "frog", 240 | "fruit", 241 | "funnel", 242 | "furnace", 243 | "game controller", 244 | "game machine", 245 | "gas cylinder", 246 | "gas hood", 247 | "gas stove", 248 | "gift box", 249 | "glass", 250 | "glass marble", 251 | "globe", 252 | "glove", 253 | "goal", 254 | "grandstand", 255 | "grass", 256 | "gravestone", 257 | "ground", 258 | "guardrail", 259 | "guitar", 260 | "gun", 261 | "hammer", 262 | "hand cart", 263 | "handle", 264 | "handrail", 265 | "hanger", 266 | "hard disk drive", 267 | "hat", 268 | "hay", 269 | "headphone", 270 | "heater", 271 | "helicopter", 272 | "helmet", 273 | "holder", 274 | "hook", 275 | "horse", 276 | "horse-drawn carriage", 277 | "hot-air balloon", 278 | "hydrovalve", 279 | "ice", 280 | "inflator pump", 281 | "ipod", 282 | "iron", 283 | "ironing board", 284 | "jar", 285 | "kart", 286 | "kettle", 287 | "key", 288 | "keyboard", 289 | "kitchen range", 290 | "kite", 291 | "knife", 292 | "knife block", 293 | "ladder", 294 | "ladder truck", 295 | "ladle", 296 | "laptop", 297 | "leaves", 298 | "lid", 299 | "life buoy", 300 | "light", 301 | "light bulb", 302 | "lighter", 303 | "line", 304 | "lion", 305 | "lobster", 306 | "lock", 307 | "machine", 308 | "mailbox", 309 | "mannequin", 310 | "map", 311 | "mask", 312 | "mat", 313 | "match book", 314 | "mattress", 315 | "menu", 316 | "metal", 317 | "meter box", 318 | "microphone", 319 | "microwave", 320 | "mirror", 321 | "missile", 322 | "model", 323 | "money", 324 | "monkey", 325 | "mop", 326 | "motorbike", 327 | "mountain", 328 | "mouse", 329 | "mouse pad", 330 | "musical instrument", 331 | "napkin", 332 | "net", 333 | "newspaper", 334 | "oar", 335 | "ornament", 336 | "outlet", 337 | "oven", 338 | "oxygen bottle", 339 | "pack", 340 | "pan", 341 | "paper", 342 | "paper box", 343 | "paper cutter", 344 | "parachute", 345 | "parasol", 346 | "parterre", 347 | "patio", 348 | "pelage", 349 | "pen", 350 | "pen container", 351 | "pencil", 352 | "person", 353 | "photo", 354 | "piano", 355 | "picture", 356 | "pig", 357 | "pillar", 358 | "pillow", 359 | "pipe", 360 | "pitcher", 361 | "plant", 362 | "plastic", 363 | "plate", 364 | "platform", 365 | "player", 366 | "playground", 367 | "pliers", 368 | "plume", 369 | "poker", 370 | "poker chip", 371 | "pole", 372 | "pool table", 373 | "postcard", 374 | "poster", 375 | "pot", 376 | "pottedplant", 377 | "printer", 378 | "projector", 379 | "pumpkin", 380 | "rabbit", 381 | "racket", 382 | "radiator", 383 | "radio", 384 | "rail", 385 | "rake", 386 | "ramp", 387 | "range hood", 388 | "receiver", 389 | "recorder", 390 | "recreational machines", 391 | "remote control", 392 | "road", 393 | "robot", 394 | "rock", 395 | "rocket", 396 | "rocking horse", 397 | "rope", 398 | "rug", 399 | "ruler", 400 | "runway", 401 | "saddle", 402 | "sand", 403 | "saw", 404 | "scale", 405 | "scanner", 406 | "scissors", 407 | "scoop", 408 | "screen", 409 | "screwdriver", 410 | "sculpture", 411 | "scythe", 412 | "sewer", 413 | "sewing machine", 414 | "shed", 415 | "sheep", 416 | "shell", 417 | "shelves", 418 | "shoe", 419 | "shopping cart", 420 | "shovel", 421 | "sidecar", 422 | "sidewalk", 423 | "sign", 424 | "signal light", 425 | "sink", 426 | "skateboard", 427 | "ski", 428 | "sky", 429 | "sled", 430 | "slippers", 431 | "smoke", 432 | "snail", 433 | "snake", 434 | "snow", 435 | "snowmobiles", 436 | "sofa", 437 | "spanner", 438 | "spatula", 439 | "speaker", 440 | "speed bump", 441 | "spice container", 442 | "spoon", 443 | "sprayer", 444 | "squirrel", 445 | "stage", 446 | "stair", 447 | "stapler", 448 | "stick", 449 | "sticky note", 450 | "stone", 451 | "stool", 452 | "stove", 453 | "straw", 454 | "stretcher", 455 | "sun", 456 | "sunglass", 457 | "sunshade", 458 | "surveillance camera", 459 | "swan", 460 | "sweeper", 461 | "swim ring", 462 | "swimming pool", 463 | "swing", 464 | "switch", 465 | "table", 466 | "tableware", 467 | "tank", 468 | "tap", 469 | "tape", 470 | "tarp", 471 | "telephone", 472 | "telephone booth", 473 | "tent", 474 | "tire", 475 | "toaster", 476 | "toilet", 477 | "tong", 478 | "tool", 479 | "toothbrush", 480 | "towel", 481 | "toy", 482 | "toy car", 483 | "track", 484 | "train", 485 | "trampoline", 486 | "trash bin", 487 | "tray", 488 | "tree", 489 | "tricycle", 490 | "tripod", 491 | "trophy", 492 | "truck", 493 | "tube", 494 | "turtle", 495 | "tvmonitor", 496 | "tweezers", 497 | "typewriter", 498 | "umbrella", 499 | "unknown", 500 | "vacuum cleaner", 501 | "vending machine", 502 | "video camera", 503 | "video game console", 504 | "video player", 505 | "video tape", 506 | "violin", 507 | "wakeboard", 508 | "wall", 509 | "wallet", 510 | "wardrobe", 511 | "washing machine", 512 | "watch", 513 | "water", 514 | "water dispenser", 515 | "water pipe", 516 | "water skate board", 517 | "watermelon", 518 | "whale", 519 | "wharf", 520 | "wheel", 521 | "wheelchair", 522 | "window", 523 | "window blinds", 524 | "wineglass", 525 | "wire", 526 | "wood", 527 | "wool", 528 | ] 529 | 530 | 531 | def register_all_pcontext_59(root): 532 | root = os.path.join(root, "pcontext") 533 | for name, dirname in [("train", "train"), ("val", "val")]: 534 | image_dir = os.path.join(root, dirname, "image") 535 | gt_dir = os.path.join(root, dirname, "label") 536 | name = f"pcontext_sem_seg_{name}" 537 | DatasetCatalog.register( 538 | name, 539 | lambda x=image_dir, y=gt_dir: load_sem_seg( 540 | y, x, gt_ext="png", image_ext="jpg" 541 | ), 542 | ) 543 | MetadataCatalog.get(name).set( 544 | stuff_classes=PCONTEXT_SEM_SEG_CATEGORIES[:], 545 | image_root=image_dir, 546 | sem_seg_root=gt_dir, 547 | evaluator_type="sem_seg", 548 | ignore_label=255, 549 | ) 550 | 551 | 552 | def register_all_pcontext_full(root): 553 | root = os.path.join(root, "pcontext_full") 554 | for name, dirname in [("train", "train"), ("val", "val")]: 555 | image_dir = os.path.join(root, dirname, "image") 556 | gt_dir = os.path.join(root, dirname, "label") 557 | name = f"pcontext_full_sem_seg_{name}" 558 | DatasetCatalog.register( 559 | name, 560 | lambda x=image_dir, y=gt_dir: load_sem_seg( 561 | y, x, gt_ext="tif", image_ext="jpg" 562 | ), 563 | ) 564 | MetadataCatalog.get(name).set( 565 | stuff_classes=PCONTEXT_FULL_SEM_SEG_CATEGORIES[:], 566 | image_root=image_dir, 567 | sem_seg_root=gt_dir, 568 | evaluator_type="sem_seg", 569 | ignore_label=65535, 570 | ) 571 | 572 | 573 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 574 | register_all_pcontext_59(_root) 575 | register_all_pcontext_full(_root) 576 | -------------------------------------------------------------------------------- /san/data/datasets/register_voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from detectron2.data import DatasetCatalog, MetadataCatalog 4 | from detectron2.data.datasets import load_sem_seg 5 | 6 | CLASS_NAMES = ( 7 | "aeroplane", 8 | "bicycle", 9 | "bird", 10 | "boat", 11 | "bottle", 12 | "bus", 13 | "car", 14 | "cat", 15 | "chair", 16 | "cow", 17 | "diningtable", 18 | "dog", 19 | "horse", 20 | "motorbike", 21 | "person", 22 | "pottedplant", 23 | "sheep", 24 | "sofa", 25 | "train", 26 | "tv", 27 | ) 28 | 29 | 30 | def _get_voc_meta(cat_list): 31 | ret = { 32 | "stuff_classes": cat_list, 33 | } 34 | return ret 35 | 36 | 37 | def register_all_voc_11k(root): 38 | root = os.path.join(root, "VOC2012") 39 | meta = _get_voc_meta(CLASS_NAMES) 40 | 41 | for name, image_dirname, sem_seg_dirname in [ 42 | ("train", "JPEGImages", "annotations_detectron2/train"), 43 | ("val", "JPEGImages", "annotations_detectron2/val"), 44 | ]: 45 | image_dir = os.path.join(root, image_dirname) 46 | gt_dir = os.path.join(root, sem_seg_dirname) 47 | all_name = f"voc_sem_seg_{name}" 48 | DatasetCatalog.register( 49 | all_name, 50 | lambda x=image_dir, y=gt_dir: load_sem_seg( 51 | y, x, gt_ext="png", image_ext="jpg" 52 | ), 53 | ) 54 | MetadataCatalog.get(all_name).set( 55 | image_root=image_dir, 56 | sem_seg_root=gt_dir, 57 | evaluator_type="sem_seg", 58 | ignore_label=255, 59 | **meta, 60 | ) 61 | 62 | 63 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 64 | register_all_voc_11k(_root) 65 | -------------------------------------------------------------------------------- /san/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .side_adapter import * 2 | from .san import SAN 3 | -------------------------------------------------------------------------------- /san/model/clip_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_predefined_templates 2 | from .visual import FeatureExtractor, RecWithAttnbiasHead 3 | from .classifier import PredefinedOvClassifier, LearnableBgOvClassifier 4 | -------------------------------------------------------------------------------- /san/model/clip_utils/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from torch.nn import functional as F 3 | import torch 4 | from detectron2.utils.registry import Registry 5 | from open_clip.model import CLIP 6 | from torch import nn 7 | from .utils import get_labelset_from_dataset 8 | from open_clip import tokenizer 9 | 10 | 11 | class PredefinedOvClassifier(nn.Module): 12 | def __init__( 13 | self, 14 | clip_model: CLIP, 15 | cache_feature: bool = True, 16 | templates: List[str] = ["a photo of {}"], 17 | ): 18 | # copy the clip model to this module 19 | super().__init__() 20 | for name, child in clip_model.named_children(): 21 | if "visual" not in name: 22 | self.add_module(name, child) 23 | for name, param in clip_model.named_parameters(recurse=False): 24 | self.register_parameter(name, param) 25 | for name, buffer in clip_model.named_buffers(recurse=False): 26 | self.register_buffer(name, buffer) 27 | self.templates = templates 28 | self._freeze() 29 | 30 | self.cache_feature = cache_feature 31 | if self.cache_feature: 32 | self.cache = {} 33 | 34 | def forward(self, category_names: List[str]): 35 | text_embed_bucket = [] 36 | for template in self.templates: 37 | noun_tokens = tokenizer.tokenize( 38 | [template.format(noun) for noun in category_names] 39 | ) 40 | text_inputs = noun_tokens.to(self.text_projection.data.device) 41 | text_embed = self.encode_text(text_inputs, normalize=True) 42 | text_embed_bucket.append(text_embed) 43 | text_embed = torch.stack(text_embed_bucket).mean(dim=0) 44 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 45 | return text_embed 46 | 47 | @torch.no_grad() 48 | def encode_text(self, text, normalize: bool = False): 49 | cast_dtype = self.transformer.get_cast_dtype() 50 | 51 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 52 | 53 | x = x + self.positional_embedding.to(cast_dtype) 54 | x = x.permute(1, 0, 2) # NLD -> LND 55 | x = self.transformer(x, attn_mask=self.attn_mask) 56 | x = x.permute(1, 0, 2) # LND -> NLD 57 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 58 | # take features from the eot embedding (eot_token is the highest number in each sequence) 59 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 60 | return F.normalize(x, dim=-1) if normalize else x 61 | 62 | def get_classifier_by_vocabulary(self, vocabulary: List[str]): 63 | if self.cache_feature: 64 | new_words = [word for word in vocabulary if word not in self.cache] 65 | if len(new_words) > 0: 66 | cat_embeddings = self(new_words) 67 | self.cache.update(dict(zip(new_words, cat_embeddings))) 68 | cat_embeddings = torch.stack([self.cache[word] for word in vocabulary]) 69 | else: 70 | cat_embeddings = self(vocabulary) 71 | return cat_embeddings 72 | 73 | def get_classifier_by_dataset_name(self, dataset_name: str): 74 | if self.cache_feature: 75 | if dataset_name not in self.cache: 76 | category_names = get_labelset_from_dataset(dataset_name) 77 | cat_embeddings = self(category_names) 78 | self.cache[dataset_name] = cat_embeddings 79 | cat_embeddings = self.cache[dataset_name] 80 | else: 81 | category_names = get_labelset_from_dataset(dataset_name) 82 | cat_embeddings = self(category_names) 83 | return cat_embeddings 84 | 85 | def _freeze(self): 86 | for param in self.parameters(): 87 | param.requires_grad = False 88 | 89 | def train(self, mode=True): 90 | super().train(False) 91 | 92 | 93 | class LearnableBgOvClassifier(PredefinedOvClassifier): 94 | def __init__( 95 | self, 96 | clip_model: CLIP, 97 | cache_feature: bool = True, 98 | templates: List[str] = ["a photo of {}"], 99 | ): 100 | super().__init__(clip_model, cache_feature, templates) 101 | self.bg_embed = nn.Parameter(torch.randn(1, self.text_projection.shape[0])) 102 | nn.init.normal_( 103 | self.bg_embed, 104 | std=self.bg_embed.shape[1] ** -0.5, 105 | ) 106 | 107 | def get_classifier_by_vocabulary(self, vocabulary: List[str]): 108 | cat_embedding = super().get_classifier_by_vocabulary(vocabulary) 109 | cat_embedding = torch.cat([cat_embedding, self.bg_embed], dim=0) 110 | cat_embedding = F.normalize(cat_embedding, p=2, dim=-1) 111 | return cat_embedding 112 | 113 | def get_classifier_by_dataset_name(self, dataset_name: str): 114 | cat_embedding = super().get_classifier_by_dataset_name(dataset_name) 115 | cat_embedding = torch.cat([cat_embedding, self.bg_embed], dim=0) 116 | cat_embedding = F.normalize(cat_embedding, p=2, dim=-1) 117 | return cat_embedding 118 | -------------------------------------------------------------------------------- /san/model/clip_utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from detectron2.data import MetadataCatalog 3 | 4 | 5 | PREDEFINED_LABELSETS = {} 6 | 7 | PREDEFINED_TEMPLATES = { 8 | "imagenet": [ 9 | "a bad photo of a {}.", 10 | "a photo of many {}.", 11 | "a sculpture of a {}.", 12 | "a photo of the hard to see {}.", 13 | "a low resolution photo of the {}.", 14 | "a rendering of a {}.", 15 | "graffiti of a {}.", 16 | "a bad photo of the {}.", 17 | "a cropped photo of the {}.", 18 | "a tattoo of a {}.", 19 | "the embroidered {}.", 20 | "a photo of a hard to see {}.", 21 | "a bright photo of a {}.", 22 | "a photo of a clean {}.", 23 | "a photo of a dirty {}.", 24 | "a dark photo of the {}.", 25 | "a drawing of a {}.", 26 | "a photo of my {}.", 27 | "the plastic {}.", 28 | "a photo of the cool {}.", 29 | "a close-up photo of a {}.", 30 | "a black and white photo of the {}.", 31 | "a painting of the {}.", 32 | "a painting of a {}.", 33 | "a pixelated photo of the {}.", 34 | "a sculpture of the {}.", 35 | "a bright photo of the {}.", 36 | "a cropped photo of a {}.", 37 | "a plastic {}.", 38 | "a photo of the dirty {}.", 39 | "a jpeg corrupted photo of a {}.", 40 | "a blurry photo of the {}.", 41 | "a photo of the {}.", 42 | "a good photo of the {}.", 43 | "a rendering of the {}.", 44 | "a {} in a video game.", 45 | "a photo of one {}.", 46 | "a doodle of a {}.", 47 | "a close-up photo of the {}.", 48 | "a photo of a {}.", 49 | "the origami {}.", 50 | "the {} in a video game.", 51 | "a sketch of a {}.", 52 | "a doodle of the {}.", 53 | "a origami {}.", 54 | "a low resolution photo of a {}.", 55 | "the toy {}.", 56 | "a rendition of the {}.", 57 | "a photo of the clean {}.", 58 | "a photo of a large {}.", 59 | "a rendition of a {}.", 60 | "a photo of a nice {}.", 61 | "a photo of a weird {}.", 62 | "a blurry photo of a {}.", 63 | "a cartoon {}.", 64 | "art of a {}.", 65 | "a sketch of the {}.", 66 | "a embroidered {}.", 67 | "a pixelated photo of a {}.", 68 | "itap of the {}.", 69 | "a jpeg corrupted photo of the {}.", 70 | "a good photo of a {}.", 71 | "a plushie {}.", 72 | "a photo of the nice {}.", 73 | "a photo of the small {}.", 74 | "a photo of the weird {}.", 75 | "the cartoon {}.", 76 | "art of the {}.", 77 | "a drawing of the {}.", 78 | "a photo of the large {}.", 79 | "a black and white photo of a {}.", 80 | "the plushie {}.", 81 | "a dark photo of a {}.", 82 | "itap of a {}.", 83 | "graffiti of the {}.", 84 | "a toy {}.", 85 | "itap of my {}.", 86 | "a photo of a cool {}.", 87 | "a photo of a small {}.", 88 | "a tattoo of the {}.", 89 | ], 90 | "vild": [ 91 | "a photo of a {}.", 92 | "This is a photo of a {}", 93 | "There is a {} in the scene", 94 | "There is the {} in the scene", 95 | "a photo of a {} in the scene", 96 | "a photo of a small {}.", 97 | "a photo of a medium {}.", 98 | "a photo of a large {}.", 99 | "This is a photo of a small {}.", 100 | "This is a photo of a medium {}.", 101 | "This is a photo of a large {}.", 102 | "There is a small {} in the scene.", 103 | "There is a medium {} in the scene.", 104 | "There is a large {} in the scene.", 105 | ], 106 | } 107 | 108 | 109 | def get_labelset_from_dataset(dataset_name: str) -> List[str]: 110 | if dataset_name not in PREDEFINED_LABELSETS: 111 | try: 112 | labelset = [ 113 | c.strip() for c in MetadataCatalog.get(dataset_name).stuff_classes 114 | ] 115 | except: 116 | labelset = [ 117 | c.strip() for c in MetadataCatalog.get(dataset_name).thing_classes 118 | ] 119 | else: 120 | labelset = PREDEFINED_LABELSETS[dataset_name] 121 | return labelset 122 | 123 | 124 | def get_predefined_templates(template_set_name: str) -> List[str]: 125 | if template_set_name not in PREDEFINED_TEMPLATES: 126 | raise ValueError(f"Template set {template_set_name} not found") 127 | return PREDEFINED_TEMPLATES[template_set_name] 128 | -------------------------------------------------------------------------------- /san/model/clip_utils/visual.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from open_clip.transformer import VisionTransformer 6 | from detectron2.layers import ShapeSpec 7 | from ..attn_helper import cross_attn_layer, downsample2d, resize_pos_embed2d 8 | 9 | 10 | class ClipOutput(dict): 11 | def __init__(self, spacial_shape, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.spacial_shape = spacial_shape 14 | 15 | def save(self, idx: int, clip_feat: torch.Tensor): 16 | l, n, c = clip_feat.shape 17 | self[idx] = ( 18 | clip_feat[1:].permute(1, 2, 0).reshape(n, c, *self.spacial_shape) 19 | ) # n, c, h, w 20 | self[f"{idx}_cls_token"] = clip_feat[0:1] # 1, n, c 21 | 22 | 23 | class FeatureExtractor(nn.Module): 24 | def __init__( 25 | self, 26 | visual_encoder: VisionTransformer, 27 | last_layer_idx: int = -1, 28 | frozen_exclude=[], 29 | ): 30 | super().__init__() 31 | self.output_tokens = visual_encoder.output_tokens 32 | self.image_size = visual_encoder.image_size 33 | self.patch_size = visual_encoder.patch_size 34 | self.grid_size = visual_encoder.grid_size 35 | self.num_features = visual_encoder.ln_pre.normalized_shape[0] 36 | 37 | self.input_patchnorm = visual_encoder.input_patchnorm 38 | self.patchnorm_pre_ln = visual_encoder.patchnorm_pre_ln 39 | self.conv1 = visual_encoder.conv1 40 | 41 | # class embeddings and positional embeddings 42 | self.class_embedding = visual_encoder.class_embedding 43 | self.positional_embedding = visual_encoder.positional_embedding 44 | # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn 45 | self.patch_dropout = visual_encoder.patch_dropout 46 | self.ln_pre = visual_encoder.ln_pre 47 | if last_layer_idx == -1: 48 | self.resblocks = visual_encoder.transformer.resblocks 49 | self.last_output_idx = len(self.resblocks) + 1 50 | else: 51 | self.resblocks = visual_encoder.transformer.resblocks[:last_layer_idx] 52 | self.last_output_idx = last_layer_idx + 1 53 | # 54 | self.frozen_exclude = frozen_exclude 55 | self._freeze(self.frozen_exclude) 56 | 57 | def forward(self, x: torch.Tensor): 58 | if self.input_patchnorm: 59 | raise NotImplementedError("input_patchnorm is not implemented yet.") 60 | else: 61 | x = self.conv1(x) # shape = [*, width, grid, grid] 62 | _, _, h, w = x.shape 63 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 64 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 65 | 66 | # class embeddings and positional embeddings 67 | x = torch.cat( 68 | [ 69 | self.class_embedding.to(x.dtype) 70 | + torch.zeros( 71 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device 72 | ), 73 | x, 74 | ], 75 | dim=1, 76 | ) # shape = [*, grid ** 2 + 1, width] 77 | pos_embed = self.positional_embedding.to(x.dtype) 78 | pos_embed = resize_pos_embed2d(pos_embed[None, ...], self.grid_size, (h, w))[0] 79 | x = x + pos_embed 80 | 81 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 82 | x = self.patch_dropout(x) 83 | x = self.ln_pre(x) 84 | x = x.permute(1, 0, 2) # NLD -> LND 85 | 86 | outputs = ClipOutput(spacial_shape=(h, w)) 87 | outputs.save(0, x) 88 | for i, resblock in enumerate(self.resblocks, start=1): 89 | x = resblock(x) 90 | outputs.save(i, x) 91 | return outputs 92 | 93 | def _freeze(self, frozen_exclude): 94 | if "all" in frozen_exclude: 95 | return 96 | for name, param in self.named_parameters(): 97 | if not any([exclude in name for exclude in frozen_exclude]): 98 | param.requires_grad = False 99 | 100 | @property 101 | def output_shapes(self): 102 | return { 103 | i: ShapeSpec(channels=self.num_features) 104 | for i in range(self.last_output_idx) 105 | } 106 | 107 | @property 108 | def size_divisibility(self): 109 | return self.patch_size[0] 110 | 111 | 112 | class RecWithAttnbiasHead(nn.Module): 113 | def __init__( 114 | self, 115 | visual_encoder: VisionTransformer, 116 | first_layer_idx: int = 0, 117 | frozen_exclude: List[str] = [], 118 | sos_token_format: str = "cls_token", 119 | sos_token_num: int = 1, 120 | cross_attn: bool = True, 121 | downsample_method: str = "bilinear", 122 | ): 123 | super().__init__() 124 | self.output_tokens = visual_encoder.output_tokens 125 | self.output_dim = visual_encoder.output_dim 126 | self.first_layer_idx = first_layer_idx 127 | self.cross_attn = cross_attn 128 | self.downsample_method = downsample_method 129 | 130 | if first_layer_idx < 0: 131 | raise NotImplementedError("first_layer_idx < 0 is not implemented yet.") 132 | self.resblocks = visual_encoder.transformer.resblocks[first_layer_idx:] 133 | self.global_average_pool = visual_encoder.global_average_pool 134 | self.attn_pool = visual_encoder.attn_pool 135 | assert ( 136 | self.attn_pool is None 137 | ), "recognition with attn_pool is not implemented yet." 138 | assert ( 139 | not self.global_average_pool 140 | ), "recognition with global_average_pool is not implemented yet." 141 | self.ln_post = visual_encoder.ln_post 142 | self.proj = visual_encoder.proj 143 | 144 | self.sos_token_format = sos_token_format 145 | self.sos_token_num = sos_token_num 146 | self.frozen_exclude = frozen_exclude 147 | 148 | if sos_token_format in ["learnable_token", "pos_embedding"]: 149 | self.sos_token = nn.Parameter( 150 | torch.randn(sos_token_num, 1, self.proj.shape[0]) 151 | ) 152 | nn.init.normal_(self.sos_token, std=0.02) 153 | self.frozen_exclude.append("sos_token") 154 | self._freeze(self.frozen_exclude) 155 | 156 | def _freeze(self, frozen_exclude): 157 | if "all" in frozen_exclude: 158 | return 159 | for name, param in self.named_parameters(): 160 | if not any([exclude in name for exclude in frozen_exclude]): 161 | param.requires_grad = False 162 | 163 | def forward(self, features, attn_bias, normalize: bool = False): 164 | # construct clip shadow features. 165 | cls_token = features[f"{self.first_layer_idx}_cls_token"] # 1,n,c 166 | pix_feat = features[self.first_layer_idx] # n,c,h,w 167 | n, c, h, w = pix_feat.shape 168 | x = torch.cat( 169 | [cls_token, pix_feat.reshape(n, c, -1).permute(2, 0, 1)] 170 | ) # 1+l,n,c 171 | 172 | # construct sos token. 173 | if self.sos_token_format == "cls_token": 174 | sos_token = cls_token.repeat(self.sos_token_num, 1, 1) 175 | elif self.sos_token_format == "learnable_token": 176 | sos_token = self.sos_token.expand(-1, n, -1) 177 | elif self.sos_token_format == "pos_embedding": 178 | sos_token = self.sos_token.expand(-1, n, -1) + cls_token 179 | 180 | # construct attn biases. 181 | attn_biases = self._build_attn_biases(attn_bias, target_shape=(h, w)) 182 | if self.cross_attn: 183 | for i, resblock in enumerate(self.resblocks): 184 | if self.cross_attn: 185 | sos_token = cross_attn_layer( 186 | resblock, 187 | sos_token, 188 | x[1:,], 189 | attn_biases[i], 190 | ) 191 | if i < len(self.resblocks) - 1: 192 | x = resblock(x) 193 | else: 194 | x = torch.cat([sos_token, x], dim=0) 195 | for i, resblock in enumerate(self.resblocks): 196 | x = resblock(x, attn_mask=attn_biases[i]) 197 | sos_token = x[: self.sos_token_num] 198 | 199 | sos_token = sos_token.permute(1, 0, 2) # LND -> NLD 200 | 201 | sos_token = self.ln_post(sos_token) 202 | 203 | if self.proj is not None: 204 | sos_token = sos_token @ self.proj 205 | if normalize: 206 | sos_token = F.normalize(sos_token, dim=-1) 207 | return sos_token 208 | 209 | def _build_attn_biases(self, attn_biases, target_shape): 210 | formatted_attn_biases = [] 211 | for attn_bias in attn_biases: 212 | # convert it to proper format: N*num_head,L,L 213 | # attn_bias: [N, num_head/1, num_sos,H,W] 214 | n, num_head, num_sos, h, w = attn_bias.shape 215 | # reshape and downsample 216 | attn_bias = downsample2d( 217 | attn_bias.reshape(n, num_head * num_sos, h, w), 218 | target_shape, 219 | method=self.downsample_method, 220 | ) 221 | attn_bias = attn_bias.reshape(n, num_head, num_sos, *target_shape) 222 | true_num_head = self.resblocks[0].attn.num_heads 223 | assert ( 224 | num_head == 1 or num_head == true_num_head 225 | ), f"num_head={num_head} is not supported." 226 | if num_head == 1: 227 | attn_bias = attn_bias.repeat(1, true_num_head, 1, 1, 1) 228 | attn_bias = attn_bias.reshape(n * true_num_head, num_sos, -1) 229 | L = attn_bias.shape[-1] 230 | if self.cross_attn: 231 | # [n*num_head, num_sos, L] 232 | formatted_attn_biases.append(attn_bias) 233 | else: 234 | # [n*num_head, num_sos+1+L, num_sos+1+L] 235 | new_attn_bias = attn_bias.new_zeros(num_sos + 1 + L, num_sos + 1 + L) 236 | new_attn_bias[:, :num_sos] = -100 237 | new_attn_bias[torch.arange(num_sos), torch.arange(num_sos)] = 0 238 | new_attn_bias[:num_sos, num_sos] = -100 239 | new_attn_bias = ( 240 | new_attn_bias[None, ...].expand(n * true_num_head, -1, -1).clone() 241 | ) 242 | new_attn_bias[..., :num_sos, -L:] = attn_bias 243 | formatted_attn_biases.append(new_attn_bias) 244 | 245 | if len(formatted_attn_biases) == 1: 246 | formatted_attn_biases = [formatted_attn_biases[0] for _ in self.resblocks] 247 | return formatted_attn_biases 248 | -------------------------------------------------------------------------------- /san/model/criterion.py: -------------------------------------------------------------------------------- 1 | # Copied from maskformer2 by Bowen Cheng 2 | import logging 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from detectron2.utils.comm import get_world_size 9 | from detectron2.projects.point_rend.point_features import ( 10 | get_uncertain_point_coords_with_randomness, 11 | point_sample, 12 | ) 13 | 14 | from san.utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list 15 | 16 | 17 | def dice_loss( 18 | inputs: torch.Tensor, 19 | targets: torch.Tensor, 20 | num_masks: float, 21 | ): 22 | """ 23 | Compute the DICE loss, similar to generalized IOU for masks 24 | Args: 25 | inputs: A float tensor of arbitrary shape. 26 | The predictions for each example. 27 | targets: A float tensor with the same shape as inputs. Stores the binary 28 | classification label for each element in inputs 29 | (0 for the negative class and 1 for the positive class). 30 | """ 31 | inputs = inputs.sigmoid() 32 | inputs = inputs.flatten(1) 33 | numerator = 2 * (inputs * targets).sum(-1) 34 | denominator = inputs.sum(-1) + targets.sum(-1) 35 | loss = 1 - (numerator + 1) / (denominator + 1) 36 | return loss.sum() / num_masks 37 | 38 | 39 | dice_loss_jit = torch.jit.script(dice_loss) # type: torch.jit.ScriptModule 40 | 41 | 42 | def sigmoid_ce_loss( 43 | inputs: torch.Tensor, 44 | targets: torch.Tensor, 45 | num_masks: float, 46 | ): 47 | """ 48 | Args: 49 | inputs: A float tensor of arbitrary shape. 50 | The predictions for each example. 51 | targets: A float tensor with the same shape as inputs. Stores the binary 52 | classification label for each element in inputs 53 | (0 for the negative class and 1 for the positive class). 54 | Returns: 55 | Loss tensor 56 | """ 57 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 58 | 59 | return loss.mean(1).sum() / num_masks 60 | 61 | 62 | sigmoid_ce_loss_jit = torch.jit.script(sigmoid_ce_loss) # type: torch.jit.ScriptModule 63 | 64 | 65 | def calculate_uncertainty(logits): 66 | """ 67 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 68 | foreground class in `classes`. 69 | Args: 70 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or 71 | class-agnostic, where R is the total number of predicted masks in all images and C is 72 | the number of foreground classes. The values are logits. 73 | Returns: 74 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 75 | the most uncertain locations having the highest uncertainty score. 76 | """ 77 | assert logits.shape[1] == 1 78 | gt_class_logits = logits.clone() 79 | return -(torch.abs(gt_class_logits)) 80 | 81 | 82 | class SetCriterion(nn.Module): 83 | """This class computes the loss for DETR. 84 | The process happens in two steps: 85 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 86 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 87 | """ 88 | 89 | def __init__( 90 | self, 91 | num_classes, 92 | matcher, 93 | weight_dict, 94 | eos_coef, 95 | losses, 96 | num_points, 97 | oversample_ratio, 98 | importance_sample_ratio, 99 | ): 100 | """Create the criterion. 101 | Parameters: 102 | num_classes: number of object categories, omitting the special no-object category 103 | matcher: module able to compute a matching between targets and proposals 104 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 105 | eos_coef: relative classification weight applied to the no-object category 106 | losses: list of all the losses to be applied. See get_loss for list of available losses. 107 | """ 108 | super().__init__() 109 | self.num_classes = num_classes 110 | self.matcher = matcher 111 | self.weight_dict = weight_dict 112 | self.eos_coef = eos_coef 113 | self.losses = losses 114 | empty_weight = torch.ones(self.num_classes + 1) 115 | empty_weight[-1] = self.eos_coef 116 | self.register_buffer("empty_weight", empty_weight) 117 | 118 | # pointwise mask loss parameters 119 | self.num_points = num_points 120 | self.oversample_ratio = oversample_ratio 121 | self.importance_sample_ratio = importance_sample_ratio 122 | 123 | def loss_labels(self, outputs, targets, indices, num_masks): 124 | """Classification loss (NLL) 125 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 126 | """ 127 | assert "pred_logits" in outputs 128 | src_logits = outputs["pred_logits"].float() 129 | 130 | idx = self._get_src_permutation_idx(indices) 131 | target_classes_o = torch.cat( 132 | [t["labels"][J] for t, (_, J) in zip(targets, indices)] 133 | ) 134 | target_classes = torch.full( 135 | src_logits.shape[:2], 136 | self.num_classes, 137 | dtype=torch.int64, 138 | device=src_logits.device, 139 | ) 140 | target_classes[idx] = target_classes_o 141 | 142 | loss_ce = F.cross_entropy( 143 | src_logits.transpose(1, 2), target_classes, self.empty_weight 144 | ) 145 | losses = {"loss_ce": loss_ce} 146 | return losses 147 | 148 | def loss_masks(self, outputs, targets, indices, num_masks): 149 | """Compute the losses related to the masks: the focal loss and the dice loss. 150 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 151 | """ 152 | assert "pred_masks" in outputs 153 | 154 | src_idx = self._get_src_permutation_idx(indices) 155 | tgt_idx = self._get_tgt_permutation_idx(indices) 156 | src_masks = outputs["pred_masks"] 157 | src_masks = src_masks[src_idx] 158 | masks = [t["masks"] for t in targets] 159 | # TODO use valid to mask invalid areas due to padding in loss 160 | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() 161 | target_masks = target_masks.to(src_masks) 162 | target_masks = target_masks[tgt_idx] 163 | 164 | # No need to upsample predictions as we are using normalized coordinates :) 165 | # N x 1 x H x W 166 | src_masks = src_masks[:, None] 167 | target_masks = target_masks[:, None] 168 | 169 | with torch.no_grad(): 170 | # sample point_coords 171 | point_coords = get_uncertain_point_coords_with_randomness( 172 | src_masks, 173 | lambda logits: calculate_uncertainty(logits), 174 | self.num_points, 175 | self.oversample_ratio, 176 | self.importance_sample_ratio, 177 | ) 178 | # get gt labels 179 | point_labels = point_sample( 180 | target_masks, 181 | point_coords, 182 | align_corners=False, 183 | ).squeeze(1) 184 | 185 | point_logits = point_sample( 186 | src_masks, 187 | point_coords, 188 | align_corners=False, 189 | ).squeeze(1) 190 | 191 | losses = { 192 | "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks), 193 | "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks), 194 | } 195 | 196 | del src_masks 197 | del target_masks 198 | return losses 199 | 200 | def _get_src_permutation_idx(self, indices): 201 | # permute predictions following indices 202 | batch_idx = torch.cat( 203 | [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] 204 | ) 205 | src_idx = torch.cat([src for (src, _) in indices]) 206 | return batch_idx, src_idx 207 | 208 | def _get_tgt_permutation_idx(self, indices): 209 | # permute targets following indices 210 | batch_idx = torch.cat( 211 | [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] 212 | ) 213 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 214 | return batch_idx, tgt_idx 215 | 216 | def get_loss(self, loss, outputs, targets, indices, num_masks): 217 | loss_map = { 218 | "labels": self.loss_labels, 219 | "masks": self.loss_masks, 220 | } 221 | assert loss in loss_map, f"do you really want to compute {loss} loss?" 222 | return loss_map[loss](outputs, targets, indices, num_masks) 223 | 224 | def forward(self, outputs, targets): 225 | """This performs the loss computation. 226 | Parameters: 227 | outputs: dict of tensors, see the output specification of the model for the format 228 | targets: list of dicts, such that len(targets) == batch_size. 229 | The expected keys in each dict depends on the losses applied, see each loss' doc 230 | """ 231 | outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} 232 | 233 | # Retrieve the matching between the outputs of the last layer and the targets 234 | indices = self.matcher(outputs_without_aux, targets) 235 | 236 | # Compute the average number of target boxes accross all nodes, for normalization purposes 237 | num_masks = sum(len(t["labels"]) for t in targets) 238 | num_masks = torch.as_tensor( 239 | [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device 240 | ) 241 | if is_dist_avail_and_initialized(): 242 | torch.distributed.all_reduce(num_masks) 243 | num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() 244 | 245 | # Compute all the requested losses 246 | losses = {} 247 | for loss in self.losses: 248 | losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) 249 | 250 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 251 | if "aux_outputs" in outputs: 252 | for i, aux_outputs in enumerate(outputs["aux_outputs"]): 253 | indices = self.matcher(aux_outputs, targets) 254 | for loss in self.losses: 255 | l_dict = self.get_loss( 256 | loss, aux_outputs, targets, indices, num_masks 257 | ) 258 | l_dict = {k + f"_{i}": v for k, v in l_dict.items()} 259 | losses.update(l_dict) 260 | 261 | return losses 262 | 263 | def __repr__(self): 264 | head = "Criterion " + self.__class__.__name__ 265 | body = [ 266 | "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), 267 | "losses: {}".format(self.losses), 268 | "weight_dict: {}".format(self.weight_dict), 269 | "num_classes: {}".format(self.num_classes), 270 | "eos_coef: {}".format(self.eos_coef), 271 | "num_points: {}".format(self.num_points), 272 | "oversample_ratio: {}".format(self.oversample_ratio), 273 | "importance_sample_ratio: {}".format(self.importance_sample_ratio), 274 | ] 275 | _repr_indent = 4 276 | lines = [head] + [" " * _repr_indent + line for line in body] 277 | return "\n".join(lines) 278 | -------------------------------------------------------------------------------- /san/model/layers.py: -------------------------------------------------------------------------------- 1 | import fvcore.nn.weight_init as weight_init 2 | import torch 3 | from detectron2.layers import CNNBlockBase, Conv2d 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class LayerNorm(nn.Module): 9 | """ 10 | A LayerNorm variant, popularized by Transformers, that performs point-wise mean and 11 | variance normalization over the channel dimension for inputs that have shape 12 | (batch_size, channels, height, width). 13 | https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 14 | """ 15 | 16 | def __init__(self, normalized_shape, eps=1e-6): 17 | super().__init__() 18 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 19 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 20 | self.eps = eps 21 | self.normalized_shape = (normalized_shape,) 22 | 23 | def forward(self, x: torch.Tensor): 24 | u = x.mean(1, keepdim=True) 25 | s = (x - u).pow(2).mean(1, keepdim=True) 26 | x = (x - u) / torch.sqrt(s + self.eps) 27 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 28 | return x 29 | 30 | 31 | class MLP(nn.Module): 32 | """Very simple multi-layer perceptron (also called FFN)""" 33 | 34 | def __init__( 35 | self, input_dim, hidden_dim, output_dim, num_layers, affine_func=nn.Linear 36 | ): 37 | super().__init__() 38 | self.num_layers = num_layers 39 | h = [hidden_dim] * (num_layers - 1) 40 | self.layers = nn.ModuleList( 41 | affine_func(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 42 | ) 43 | 44 | def forward(self, x: torch.Tensor): 45 | for i, layer in enumerate(self.layers): 46 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 47 | return x 48 | 49 | 50 | class AddFusion(CNNBlockBase): 51 | def __init__(self, in_channels, out_channels): 52 | super().__init__(in_channels, out_channels, 1) 53 | self.input_proj = nn.Sequential( 54 | LayerNorm(in_channels), 55 | Conv2d( 56 | in_channels, 57 | out_channels, 58 | kernel_size=1, 59 | ), 60 | ) 61 | weight_init.c2_xavier_fill(self.input_proj[-1]) 62 | 63 | def forward(self, x: torch.Tensor, y: torch.Tensor, spatial_shape: tuple): 64 | # x: [N,L,C] y: [N,C,H,W] 65 | y = ( 66 | F.interpolate( 67 | self.input_proj(y.contiguous()), 68 | size=spatial_shape, 69 | mode="bilinear", 70 | align_corners=False, 71 | ) 72 | .permute(0, 2, 3, 1) 73 | .reshape(x.shape) 74 | ) 75 | x = x + y 76 | return x 77 | 78 | 79 | def build_fusion_layer(fusion_type: str, in_channels: int, out_channels: int): 80 | if fusion_type == "add": 81 | return AddFusion(in_channels, out_channels) 82 | else: 83 | raise ValueError("Unknown fusion type: {}".format(fusion_type)) 84 | -------------------------------------------------------------------------------- /san/model/matcher.py: -------------------------------------------------------------------------------- 1 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py 2 | # Copied from maskformer2 3 | """ 4 | Modules to compute the matching cost and solve the corresponding LSAP. 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from scipy.optimize import linear_sum_assignment 9 | from torch import nn 10 | from torch.cuda.amp import autocast 11 | 12 | from detectron2.projects.point_rend.point_features import point_sample 13 | 14 | 15 | def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): 16 | """ 17 | Compute the DICE loss, similar to generalized IOU for masks 18 | Args: 19 | inputs: A float tensor of arbitrary shape. 20 | The predictions for each example. 21 | targets: A float tensor with the same shape as inputs. Stores the binary 22 | classification label for each element in inputs 23 | (0 for the negative class and 1 for the positive class). 24 | """ 25 | inputs = inputs.sigmoid() 26 | inputs = inputs.flatten(1) 27 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) 28 | denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] 29 | loss = 1 - (numerator + 1) / (denominator + 1) 30 | return loss 31 | 32 | 33 | batch_dice_loss_jit = torch.jit.script(batch_dice_loss) # type: torch.jit.ScriptModule 34 | 35 | 36 | def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): 37 | """ 38 | Args: 39 | inputs: A float tensor of arbitrary shape. 40 | The predictions for each example. 41 | targets: A float tensor with the same shape as inputs. Stores the binary 42 | classification label for each element in inputs 43 | (0 for the negative class and 1 for the positive class). 44 | Returns: 45 | Loss tensor 46 | """ 47 | hw = inputs.shape[1] 48 | 49 | pos = F.binary_cross_entropy_with_logits( 50 | inputs, torch.ones_like(inputs), reduction="none" 51 | ) 52 | neg = F.binary_cross_entropy_with_logits( 53 | inputs, torch.zeros_like(inputs), reduction="none" 54 | ) 55 | 56 | loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( 57 | "nc,mc->nm", neg, (1 - targets) 58 | ) 59 | 60 | return loss / hw 61 | 62 | 63 | batch_sigmoid_ce_loss_jit = torch.jit.script( 64 | batch_sigmoid_ce_loss 65 | ) # type: torch.jit.ScriptModule 66 | 67 | 68 | class HungarianMatcher(nn.Module): 69 | """This class computes an assignment between the targets and the predictions of the network 70 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 71 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 72 | while the others are un-matched (and thus treated as non-objects). 73 | """ 74 | 75 | def __init__( 76 | self, 77 | cost_class: float = 1, 78 | cost_mask: float = 1, 79 | cost_dice: float = 1, 80 | num_points: int = 0, 81 | ): 82 | """Creates the matcher 83 | Params: 84 | cost_class: This is the relative weight of the classification error in the matching cost 85 | cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost 86 | cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost 87 | """ 88 | super().__init__() 89 | self.cost_class = cost_class 90 | self.cost_mask = cost_mask 91 | self.cost_dice = cost_dice 92 | 93 | assert ( 94 | cost_class != 0 or cost_mask != 0 or cost_dice != 0 95 | ), "all costs cant be 0" 96 | 97 | self.num_points = num_points 98 | 99 | @torch.no_grad() 100 | def memory_efficient_forward(self, outputs, targets): 101 | """More memory-friendly matching""" 102 | bs, num_queries = outputs["pred_logits"].shape[:2] 103 | 104 | indices = [] 105 | 106 | # Iterate through batch size 107 | for b in range(bs): 108 | out_prob = outputs["pred_logits"][b].softmax( 109 | -1 110 | ) # [num_queries, num_classes] 111 | tgt_ids = targets[b]["labels"] 112 | 113 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 114 | # but approximate it in 1 - proba[target class]. 115 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 116 | cost_class = -out_prob[:, tgt_ids] 117 | 118 | out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] 119 | # gt masks are already padded when preparing target 120 | tgt_mask = targets[b]["masks"].to(out_mask) 121 | 122 | out_mask = out_mask[:, None] 123 | tgt_mask = tgt_mask[:, None] 124 | # all masks share the same set of points for efficient matching! 125 | point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device) 126 | # get gt labels 127 | tgt_mask = point_sample( 128 | tgt_mask, 129 | point_coords.repeat(tgt_mask.shape[0], 1, 1), 130 | align_corners=False, 131 | ).squeeze(1) 132 | 133 | out_mask = point_sample( 134 | out_mask, 135 | point_coords.repeat(out_mask.shape[0], 1, 1), 136 | align_corners=False, 137 | ).squeeze(1) 138 | 139 | with autocast(enabled=False): 140 | out_mask = out_mask.float() 141 | tgt_mask = tgt_mask.float() 142 | # Compute the focal loss between masks 143 | cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) 144 | 145 | # Compute the dice loss betwen masks 146 | cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) 147 | 148 | # Final cost matrix 149 | C = ( 150 | self.cost_mask * cost_mask 151 | + self.cost_class * cost_class 152 | + self.cost_dice * cost_dice 153 | ) 154 | C = C.reshape(num_queries, -1).cpu() 155 | 156 | indices.append(linear_sum_assignment(C)) 157 | 158 | return [ 159 | ( 160 | torch.as_tensor(i, dtype=torch.int64), 161 | torch.as_tensor(j, dtype=torch.int64), 162 | ) 163 | for i, j in indices 164 | ] 165 | 166 | @torch.no_grad() 167 | def forward(self, outputs, targets): 168 | """Performs the matching 169 | Params: 170 | outputs: This is a dict that contains at least these entries: 171 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 172 | "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks 173 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 174 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 175 | objects in the target) containing the class labels 176 | "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks 177 | Returns: 178 | A list of size batch_size, containing tuples of (index_i, index_j) where: 179 | - index_i is the indices of the selected predictions (in order) 180 | - index_j is the indices of the corresponding selected targets (in order) 181 | For each batch element, it holds: 182 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 183 | """ 184 | return self.memory_efficient_forward(outputs, targets) 185 | 186 | def __repr__(self, _repr_indent=4): 187 | head = "Matcher " + self.__class__.__name__ 188 | body = [ 189 | "cost_class: {}".format(self.cost_class), 190 | "cost_mask: {}".format(self.cost_mask), 191 | "cost_dice: {}".format(self.cost_dice), 192 | ] 193 | lines = [head] + [" " * _repr_indent + line for line in body] 194 | return "\n".join(lines) 195 | -------------------------------------------------------------------------------- /san/model/san.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import open_clip 4 | import torch 5 | from detectron2.config import configurable 6 | from detectron2.modeling import META_ARCH_REGISTRY 7 | from detectron2.modeling.postprocessing import sem_seg_postprocess 8 | from detectron2.structures import ImageList 9 | from detectron2.utils.memory import retry_if_cuda_oom 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .clip_utils import ( 14 | FeatureExtractor, 15 | LearnableBgOvClassifier, 16 | PredefinedOvClassifier, 17 | RecWithAttnbiasHead, 18 | get_predefined_templates, 19 | ) 20 | from .criterion import SetCriterion 21 | from .matcher import HungarianMatcher 22 | from .side_adapter import build_side_adapter_network 23 | 24 | 25 | @META_ARCH_REGISTRY.register() 26 | class SAN(nn.Module): 27 | @configurable 28 | def __init__( 29 | self, 30 | *, 31 | clip_visual_extractor: nn.Module, 32 | clip_rec_head: nn.Module, 33 | side_adapter_network: nn.Module, 34 | ov_classifier: PredefinedOvClassifier, 35 | criterion: SetCriterion, 36 | size_divisibility: int, 37 | asymetric_input: bool = True, 38 | clip_resolution: float = 0.5, 39 | pixel_mean: List[float] = [0.48145466, 0.4578275, 0.40821073], 40 | pixel_std: List[float] = [0.26862954, 0.26130258, 0.27577711], 41 | sem_seg_postprocess_before_inference: bool = False, 42 | ): 43 | super().__init__() 44 | self.asymetric_input = asymetric_input 45 | self.clip_resolution = clip_resolution 46 | self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference 47 | self.size_divisibility = size_divisibility 48 | self.criterion = criterion 49 | 50 | self.side_adapter_network = side_adapter_network 51 | self.clip_visual_extractor = clip_visual_extractor 52 | self.clip_rec_head = clip_rec_head 53 | self.ov_classifier = ov_classifier 54 | self.register_buffer( 55 | "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False 56 | ) 57 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 58 | 59 | @classmethod 60 | def from_config(cls, cfg): 61 | ## copied from maskformer2 62 | # Loss parameters 63 | no_object_weight = cfg.MODEL.SAN.NO_OBJECT_WEIGHT 64 | # loss weights 65 | class_weight = cfg.MODEL.SAN.CLASS_WEIGHT 66 | dice_weight = cfg.MODEL.SAN.DICE_WEIGHT 67 | mask_weight = cfg.MODEL.SAN.MASK_WEIGHT 68 | 69 | # building criterion 70 | matcher = HungarianMatcher( 71 | cost_class=class_weight, 72 | cost_mask=mask_weight, 73 | cost_dice=dice_weight, 74 | num_points=cfg.MODEL.SAN.TRAIN_NUM_POINTS, 75 | ) 76 | 77 | weight_dict = { 78 | "loss_ce": class_weight, 79 | "loss_mask": mask_weight, 80 | "loss_dice": dice_weight, 81 | } 82 | aux_weight_dict = {} 83 | for i in range(len(cfg.MODEL.SIDE_ADAPTER.DEEP_SUPERVISION_IDXS) - 1): 84 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 85 | weight_dict.update(aux_weight_dict) 86 | losses = ["labels", "masks"] 87 | 88 | criterion = SetCriterion( 89 | num_classes=cfg.MODEL.SAN.NUM_CLASSES, 90 | matcher=matcher, 91 | weight_dict=weight_dict, 92 | eos_coef=no_object_weight, 93 | losses=losses, 94 | num_points=cfg.MODEL.SAN.TRAIN_NUM_POINTS, 95 | oversample_ratio=cfg.MODEL.SAN.OVERSAMPLE_RATIO, 96 | importance_sample_ratio=cfg.MODEL.SAN.IMPORTANCE_SAMPLE_RATIO, 97 | ) 98 | ## end of copy 99 | 100 | model, _, preprocess = open_clip.create_model_and_transforms( 101 | cfg.MODEL.SAN.CLIP_MODEL_NAME, 102 | pretrained=cfg.MODEL.SAN.CLIP_PRETRAINED_NAME, 103 | ) 104 | ov_classifier = LearnableBgOvClassifier( 105 | model, templates=get_predefined_templates(cfg.MODEL.SAN.CLIP_TEMPLATE_SET) 106 | ) 107 | 108 | clip_visual_extractor = FeatureExtractor( 109 | model.visual, 110 | last_layer_idx=cfg.MODEL.SAN.FEATURE_LAST_LAYER_IDX, 111 | frozen_exclude=cfg.MODEL.SAN.CLIP_FROZEN_EXCLUDE, 112 | ) 113 | clip_rec_head = RecWithAttnbiasHead( 114 | model.visual, 115 | first_layer_idx=cfg.MODEL.SAN.FEATURE_LAST_LAYER_IDX, 116 | frozen_exclude=cfg.MODEL.SAN.CLIP_DEEPER_FROZEN_EXCLUDE, 117 | cross_attn=cfg.MODEL.SAN.REC_CROSS_ATTN, 118 | sos_token_format=cfg.MODEL.SAN.SOS_TOKEN_FORMAT, 119 | sos_token_num=cfg.MODEL.SIDE_ADAPTER.NUM_QUERIES, 120 | downsample_method=cfg.MODEL.SAN.REC_DOWNSAMPLE_METHOD, 121 | ) 122 | 123 | pixel_mean, pixel_std = ( 124 | preprocess.transforms[-1].mean, 125 | preprocess.transforms[-1].std, 126 | ) 127 | pixel_mean = [255.0 * x for x in pixel_mean] 128 | pixel_std = [255.0 * x for x in pixel_std] 129 | 130 | return { 131 | "clip_visual_extractor": clip_visual_extractor, 132 | "clip_rec_head": clip_rec_head, 133 | "side_adapter_network": build_side_adapter_network( 134 | cfg, clip_visual_extractor.output_shapes 135 | ), 136 | "ov_classifier": ov_classifier, 137 | "criterion": criterion, 138 | "size_divisibility": cfg.MODEL.SAN.SIZE_DIVISIBILITY, 139 | "asymetric_input": cfg.MODEL.SAN.ASYMETRIC_INPUT, 140 | "clip_resolution": cfg.MODEL.SAN.CLIP_RESOLUTION, 141 | "sem_seg_postprocess_before_inference": cfg.MODEL.SAN.SEM_SEG_POSTPROCESS_BEFORE_INFERENCE, 142 | "pixel_mean": pixel_mean, 143 | "pixel_std": pixel_std, 144 | } 145 | 146 | def forward(self, batched_inputs): 147 | # get classifier weight for each dataset 148 | # !! Could be computed once and saved. It will run only once per dataset. 149 | if "vocabulary" in batched_inputs[0]: 150 | ov_classifier_weight = ( 151 | self.ov_classifier.logit_scale.exp() 152 | * self.ov_classifier.get_classifier_by_vocabulary( 153 | batched_inputs[0]["vocabulary"] 154 | ) 155 | ) 156 | else: 157 | dataset_names = [x["meta"]["dataset_name"] for x in batched_inputs] 158 | assert ( 159 | len(list(set(dataset_names))) == 1 160 | ), "All images in a batch must be from the same dataset." 161 | ov_classifier_weight = ( 162 | self.ov_classifier.logit_scale.exp() 163 | * self.ov_classifier.get_classifier_by_dataset_name(dataset_names[0]) 164 | ) # C+1,ndim 165 | images = [x["image"].to(self.device) for x in batched_inputs] 166 | images = [(x - self.pixel_mean) / self.pixel_std for x in images] 167 | images = ImageList.from_tensors(images, self.size_divisibility) 168 | clip_input = images.tensor 169 | if self.asymetric_input: 170 | clip_input = F.interpolate( 171 | clip_input, scale_factor=self.clip_resolution, mode="bilinear" 172 | ) 173 | clip_image_features = self.clip_visual_extractor(clip_input) 174 | mask_preds, attn_biases = self.side_adapter_network( 175 | images.tensor, clip_image_features 176 | ) 177 | # !! Could be optimized to run in parallel. 178 | mask_embs = [ 179 | self.clip_rec_head(clip_image_features, attn_bias, normalize=True) 180 | for attn_bias in attn_biases 181 | ] # [B,N,C] 182 | mask_logits = [ 183 | torch.einsum("bqc,nc->bqn", mask_emb, ov_classifier_weight) 184 | for mask_emb in mask_embs 185 | ] 186 | if self.training: 187 | if "instances" in batched_inputs[0]: 188 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 189 | targets = self.prepare_targets(gt_instances, images) 190 | else: 191 | targets = None 192 | outputs = { 193 | "pred_logits": mask_logits[-1], 194 | "pred_masks": mask_preds[-1], 195 | "aux_outputs": [ 196 | { 197 | "pred_logits": aux_pred_logits, 198 | "pred_masks": aux_pred_masks, 199 | } 200 | for aux_pred_logits, aux_pred_masks in zip( 201 | mask_logits[:-1], mask_preds[:-1] 202 | ) 203 | ], 204 | } 205 | # bipartite matching-based loss 206 | losses = self.criterion(outputs, targets) 207 | 208 | for k in list(losses.keys()): 209 | if k in self.criterion.weight_dict: 210 | losses[k] *= self.criterion.weight_dict[k] 211 | else: 212 | # remove this loss if not specified in `weight_dict` 213 | losses.pop(k) 214 | return losses 215 | else: 216 | mask_preds = mask_preds[-1] 217 | mask_logits = mask_logits[-1] 218 | # torch.cuda.empty_cache() 219 | # Inference 220 | mask_preds = F.interpolate( 221 | mask_preds, 222 | size=(images.tensor.shape[-2], images.tensor.shape[-1]), 223 | mode="bilinear", 224 | align_corners=False, 225 | ) 226 | processed_results = [] 227 | for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( 228 | mask_logits, mask_preds, batched_inputs, images.image_sizes 229 | ): 230 | height = input_per_image.get("height", image_size[0]) 231 | width = input_per_image.get("width", image_size[1]) 232 | processed_results.append({}) 233 | 234 | if self.sem_seg_postprocess_before_inference: 235 | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( 236 | mask_pred_result, image_size, height, width 237 | ) 238 | mask_cls_result = mask_cls_result.to(mask_pred_result) 239 | r = retry_if_cuda_oom(self.semantic_inference)( 240 | mask_cls_result, mask_pred_result 241 | ) 242 | if not self.sem_seg_postprocess_before_inference: 243 | r = retry_if_cuda_oom(sem_seg_postprocess)( 244 | r, image_size, height, width 245 | ) 246 | processed_results[-1]["sem_seg"] = r 247 | return processed_results 248 | 249 | def prepare_targets(self, targets, images): 250 | h_pad, w_pad = images.tensor.shape[-2:] 251 | new_targets = [] 252 | for targets_per_image in targets: 253 | # pad gt 254 | gt_masks = targets_per_image.gt_masks 255 | padded_masks = torch.zeros( 256 | (gt_masks.shape[0], h_pad, w_pad), 257 | dtype=gt_masks.dtype, 258 | device=gt_masks.device, 259 | ) 260 | padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks 261 | new_targets.append( 262 | { 263 | "labels": targets_per_image.gt_classes, 264 | "masks": padded_masks, 265 | } 266 | ) 267 | return new_targets 268 | 269 | def semantic_inference(self, mask_cls, mask_pred): 270 | mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] 271 | mask_pred = mask_pred.sigmoid() 272 | semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) 273 | return semseg 274 | 275 | @property 276 | def device(self): 277 | return self.pixel_mean.device 278 | -------------------------------------------------------------------------------- /san/model/side_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from . import timm_wrapper 2 | from .side_adapter import build_side_adapter_network 3 | -------------------------------------------------------------------------------- /san/model/side_adapter/side_adapter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | from typing import Dict, List, Tuple 4 | 5 | import torch 6 | from detectron2.config import configurable 7 | from detectron2.layers import ShapeSpec 8 | from detectron2.utils.logger import log_first_n 9 | from detectron2.utils.registry import Registry 10 | from timm import create_model 11 | from timm.models.vision_transformer import VisionTransformer 12 | from torch import nn 13 | from torch.nn import functional as F 14 | 15 | from ..layers import MLP, build_fusion_layer 16 | from .timm_wrapper import PatchEmbed 17 | 18 | SIDE_ADAPTER_REGISTRY = Registry("SIDE_ADAPTER") 19 | SIDE_ADAPTER_REGISTRY.__doc__ = """ 20 | Registry for side adapter. 21 | """ 22 | 23 | 24 | def build_side_adapter_network(cfg, input_shape): 25 | name = cfg.MODEL.SIDE_ADAPTER.NAME 26 | return SIDE_ADAPTER_REGISTRY.get(name)(cfg, input_shape) 27 | 28 | 29 | class MLPMaskDecoder(nn.Module): 30 | def __init__( 31 | self, 32 | *, 33 | in_channels: int, 34 | total_heads: int = 1, 35 | total_layers: int = 1, 36 | embed_channels: int = 256, 37 | mlp_channels: int = 256, 38 | mlp_num_layers: int = 3, 39 | rescale_attn_bias: bool = False, 40 | ): 41 | super().__init__() 42 | self.total_heads = total_heads 43 | self.total_layers = total_layers 44 | 45 | dense_affine_func = partial(nn.Conv2d, kernel_size=1) 46 | # Query Branch 47 | self.query_mlp = MLP(in_channels, mlp_channels, embed_channels, mlp_num_layers) 48 | # Pixel Branch 49 | self.pix_mlp = MLP( 50 | in_channels, 51 | mlp_channels, 52 | embed_channels, 53 | mlp_num_layers, 54 | affine_func=dense_affine_func, 55 | ) 56 | # Attention Bias Branch 57 | self.attn_mlp = MLP( 58 | in_channels, 59 | mlp_channels, 60 | embed_channels * self.total_heads * self.total_layers, 61 | mlp_num_layers, 62 | affine_func=dense_affine_func, 63 | ) 64 | if rescale_attn_bias: 65 | self.bias_scaling = nn.Linear(1, 1) 66 | else: 67 | self.bias_scaling = nn.Identity() 68 | 69 | def forward( 70 | self, query: torch.Tensor, x: torch.Tensor 71 | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: 72 | # query: [B,N,C] 73 | # x: [B,C,H,W] 74 | query = self.query_mlp(query) 75 | pix = self.pix_mlp(x) 76 | b, c, h, w = pix.shape 77 | # preidict mask 78 | mask_preds = torch.einsum("bqc,bchw->bqhw", query, pix) 79 | # generate attn bias 80 | attn = self.attn_mlp(x) 81 | attn = attn.reshape(b, self.total_layers, self.total_heads, c, h, w) 82 | attn_bias = torch.einsum("bqc,blnchw->blnqhw", query, attn) 83 | attn_bias = self.bias_scaling(attn_bias[..., None]).squeeze(-1) 84 | attn_bias = attn_bias.chunk(self.total_layers, dim=1) 85 | attn_bias = [attn.squeeze(1) for attn in attn_bias] 86 | return mask_preds, attn_bias 87 | 88 | 89 | @SIDE_ADAPTER_REGISTRY.register() 90 | class RegionwiseSideAdapterNetwork(nn.Module): 91 | @configurable 92 | def __init__( 93 | self, 94 | vit_model: VisionTransformer, 95 | fusion_layers: nn.ModuleList, 96 | mask_decoder: nn.Module, 97 | num_queries: int, 98 | fusion_map: Dict[int, int], 99 | deep_supervision_idxs: List[int], 100 | ): 101 | super().__init__() 102 | # remove cls token 103 | if vit_model.cls_token is not None: 104 | vit_model.pos_embed = nn.Parameter(vit_model.pos_embed[:, 1:, ...]) 105 | del vit_model.cls_token 106 | vit_model.cls_token = None 107 | # delete out norm 108 | del vit_model.norm 109 | vit_model.norm = nn.Identity() 110 | self.vit_model = vit_model 111 | 112 | self.num_queries = num_queries 113 | self.num_features = vit_model.num_features 114 | # add query token 115 | self.query_embed = nn.Parameter(torch.zeros(1, num_queries, self.num_features)) 116 | self.query_pos_embed = nn.Parameter( 117 | torch.zeros(1, num_queries, self.num_features) 118 | ) 119 | nn.init.normal_(self.query_embed, std=0.02) 120 | nn.init.normal_(self.query_pos_embed, std=0.02) 121 | self.fusion_layers = fusion_layers 122 | self.fusion_map = fusion_map 123 | self.mask_decoder = mask_decoder 124 | # for training 125 | self.deep_supervision_idxs = deep_supervision_idxs 126 | 127 | @classmethod 128 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 129 | vit = create_model( 130 | cfg.MODEL.SIDE_ADAPTER.VIT_NAME, 131 | cfg.MODEL.SIDE_ADAPTER.PRETRAINED, 132 | img_size=cfg.MODEL.SIDE_ADAPTER.IMAGE_SIZE, 133 | drop_path_rate=cfg.MODEL.SIDE_ADAPTER.DROP_PATH_RATE, 134 | fc_norm=False, 135 | num_classes=0, 136 | embed_layer=PatchEmbed, 137 | ) 138 | # ["0->0","3->1","6->2","9->3"] 139 | fusion_map: List[str] = cfg.MODEL.SIDE_ADAPTER.FUSION_MAP 140 | 141 | x2side_map = {int(j): int(i) for i, j in [x.split("->") for x in fusion_map]} 142 | # build fusion layers 143 | fusion_type: str = cfg.MODEL.SIDE_ADAPTER.FUSION_TYPE 144 | fusion_layers = nn.ModuleDict( 145 | { 146 | f"layer_{tgt_idx}": build_fusion_layer( 147 | fusion_type, input_shape[src_idx].channels, vit.num_features 148 | ) 149 | for tgt_idx, src_idx in x2side_map.items() 150 | } 151 | ) 152 | # build mask decoder 153 | return { 154 | "vit_model": vit, 155 | "num_queries": cfg.MODEL.SIDE_ADAPTER.NUM_QUERIES, 156 | "fusion_layers": fusion_layers, 157 | "fusion_map": x2side_map, 158 | "mask_decoder": MLPMaskDecoder( 159 | in_channels=vit.num_features, 160 | total_heads=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_HEADS, 161 | total_layers=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.NUM_LAYERS, 162 | embed_channels=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.EMBED_CHANNELS, 163 | mlp_channels=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_CHANNELS, 164 | mlp_num_layers=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.MLP_NUM_LAYERS, 165 | rescale_attn_bias=cfg.MODEL.SIDE_ADAPTER.ATTN_BIAS.RESCALE_ATTN_BIAS, 166 | ), 167 | "deep_supervision_idxs": cfg.MODEL.SIDE_ADAPTER.DEEP_SUPERVISION_IDXS, 168 | } 169 | 170 | def forward( 171 | self, image: torch.Tensor, clip_features: List[torch.Tensor] 172 | ) -> Dict[str, List[torch.Tensor]]: 173 | features = self.forward_features(image, clip_features) 174 | return self.decode_masks(features) 175 | 176 | def decode_masks( 177 | self, features: List[Dict[str, torch.Tensor]] 178 | ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: 179 | if not self.training: 180 | features = [features[-1]] 181 | mask_preds = [] 182 | attn_biases = [] 183 | for feature in features: 184 | mask_pred, attn_bias = self.mask_decoder(**feature) 185 | mask_preds.append(mask_pred) 186 | attn_biases.append(attn_bias) 187 | return mask_preds, attn_biases 188 | 189 | def forward_features( 190 | self, image: torch.Tensor, clip_features: List[torch.Tensor] 191 | ) -> List[Dict[str, torch.Tensor]]: 192 | x, (h, w) = self.vit_model.patch_embed(image) 193 | L = x.shape[1] # token length 194 | pos_embed = self.vit_model.pos_embed 195 | ori_h, ori_w = self.vit_model.patch_embed.grid_size 196 | if pos_embed.shape[1] != L: 197 | pos_embed = ( 198 | F.interpolate( 199 | pos_embed.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), 200 | size=[h, w], 201 | mode="bicubic", 202 | align_corners=False, 203 | ) 204 | .flatten(2) 205 | .permute(0, 2, 1) 206 | ) 207 | pos_embed = torch.cat( 208 | [self.query_pos_embed.expand(pos_embed.shape[0], -1, -1), pos_embed], dim=1 209 | ) 210 | x = torch.cat( 211 | [self.query_embed.expand(x.shape[0], -1, -1), x], 212 | dim=1, 213 | ) # B, Q+L, C 214 | x = x + pos_embed 215 | x = self.vit_model.norm_pre(x) 216 | x = self.fuse(0, x, clip_features, (h, w)) 217 | outs = [] 218 | for i, blk in enumerate(self.vit_model.blocks, start=1): 219 | x = blk(x) 220 | x = self.fuse(i, x, clip_features, (h, w)) 221 | if i in self.deep_supervision_idxs: 222 | outs.append( 223 | { 224 | "query": x[:, :-L, ...], 225 | "x": x[:, -L:, ...] 226 | .permute(0, 2, 1) 227 | .reshape(x.shape[0], x.shape[-1], h, w), 228 | } 229 | ) 230 | 231 | if i < len(self.vit_model.blocks): 232 | x = x + pos_embed 233 | 234 | return outs 235 | 236 | def fuse( 237 | self, 238 | block_idx: int, 239 | x: torch.Tensor, 240 | clip_features: List[torch.Tensor], 241 | spatial_shape: Tuple[int, int], 242 | ) -> torch.Tensor: 243 | if block_idx in self.fusion_map: 244 | src_idx = self.fusion_map[block_idx] 245 | L = spatial_shape[0] * spatial_shape[1] 246 | x = torch.cat( 247 | [ 248 | x[:, :-L, ...], 249 | self.fusion_layers[f"layer_{block_idx}"]( 250 | x[:, -L:, ...], clip_features[src_idx], spatial_shape 251 | ), 252 | ], 253 | dim=1, 254 | ) 255 | log_first_n( 256 | logging.INFO, 257 | f"fuse clip {src_idx} to {block_idx}", 258 | len(self.fusion_map), 259 | ) 260 | return x 261 | -------------------------------------------------------------------------------- /san/model/side_adapter/timm_wrapper.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from torch import nn 3 | from timm.models.vision_transformer import _create_vision_transformer 4 | from timm.models import register_model 5 | from timm.models.layers import to_2tuple 6 | 7 | 8 | class PatchEmbed(nn.Module): 9 | """2D Image to Patch Embedding. Modify the original implementation to allow return the 2D patch size.""" 10 | 11 | def __init__( 12 | self, 13 | img_size=224, 14 | patch_size=16, 15 | in_chans=3, 16 | embed_dim=768, 17 | norm_layer=None, 18 | flatten=True, 19 | bias=True, 20 | **kwargs 21 | ): 22 | super().__init__() 23 | if len(kwargs)>0: 24 | warnings.warn(f"Unused kwargs are provided:{kwargs}.") 25 | img_size = to_2tuple(img_size) 26 | patch_size = to_2tuple(patch_size) 27 | self.img_size = img_size 28 | self.patch_size = patch_size 29 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 30 | self.num_patches = self.grid_size[0] * self.grid_size[1] 31 | self.flatten = flatten 32 | 33 | self.proj = nn.Conv2d( 34 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias 35 | ) 36 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 37 | 38 | def forward(self, x): 39 | x = self.proj(x) 40 | h, w = x.shape[-2:] 41 | if self.flatten: 42 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 43 | x = self.norm(x) 44 | return x, (h, w) 45 | 46 | 47 | @register_model 48 | def vit_w144n6d8_patch16(pretrained=False, **kwargs): 49 | assert not pretrained 50 | model_kwargs = dict(patch_size=16, embed_dim=144, depth=8, num_heads=6, **kwargs) 51 | model = _create_vision_transformer( 52 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs 53 | ) 54 | return model 55 | 56 | 57 | @register_model 58 | def vit_w192n6d8_patch16(pretrained=False, **kwargs): 59 | assert not pretrained 60 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=8, num_heads=6, **kwargs) 61 | model = _create_vision_transformer( 62 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs 63 | ) 64 | return model 65 | 66 | 67 | @register_model 68 | def vit_w240n6d8_patch16(pretrained=False, **kwargs): 69 | assert not pretrained 70 | model_kwargs = dict(patch_size=16, embed_dim=240, depth=8, num_heads=6, **kwargs) 71 | model = _create_vision_transformer( 72 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs 73 | ) 74 | return model 75 | 76 | 77 | @register_model 78 | def vit_w288n6d8_patch16(pretrained=False, **kwargs): 79 | assert not pretrained 80 | model_kwargs = dict(patch_size=16, embed_dim=288, depth=8, num_heads=6, **kwargs) 81 | model = _create_vision_transformer( 82 | "vit_tiny_patch16_224_in21k", pretrained=pretrained, **model_kwargs 83 | ) 84 | return model 85 | -------------------------------------------------------------------------------- /san/test_time_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copied from mask2former repo. 2 | import copy 3 | import logging 4 | from itertools import count 5 | 6 | import numpy as np 7 | import torch 8 | from fvcore.transforms import HFlipTransform 9 | from torch import nn 10 | from torch.nn.parallel import DistributedDataParallel 11 | 12 | from detectron2.data.detection_utils import read_image 13 | from detectron2.modeling import DatasetMapperTTA 14 | 15 | 16 | __all__ = [ 17 | "SemanticSegmentorWithTTA", 18 | ] 19 | 20 | 21 | class SemanticSegmentorWithTTA(nn.Module): 22 | """ 23 | A SemanticSegmentor with test-time augmentation enabled. 24 | Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. 25 | """ 26 | 27 | def __init__(self, cfg, model, tta_mapper=None, batch_size=1): 28 | """ 29 | Args: 30 | cfg (CfgNode): 31 | model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. 32 | tta_mapper (callable): takes a dataset dict and returns a list of 33 | augmented versions of the dataset dict. Defaults to 34 | `DatasetMapperTTA(cfg)`. 35 | batch_size (int): batch the augmented images into this batch size for inference. 36 | """ 37 | super().__init__() 38 | if isinstance(model, DistributedDataParallel): 39 | model = model.module 40 | self.cfg = cfg.clone() 41 | 42 | self.model = model 43 | 44 | if tta_mapper is None: 45 | tta_mapper = DatasetMapperTTA(cfg) 46 | self.tta_mapper = tta_mapper 47 | self.batch_size = batch_size 48 | 49 | def __call__(self, batched_inputs): 50 | """ 51 | Same input/output format as :meth:`SemanticSegmentor.forward` 52 | """ 53 | 54 | def _maybe_read_image(dataset_dict): 55 | ret = copy.copy(dataset_dict) 56 | if "image" not in ret: 57 | image = read_image(ret.pop("file_name"), self.model.input_format) 58 | image = torch.from_numpy( 59 | np.ascontiguousarray(image.transpose(2, 0, 1)) 60 | ) # CHW 61 | ret["image"] = image 62 | if "height" not in ret and "width" not in ret: 63 | ret["height"] = image.shape[1] 64 | ret["width"] = image.shape[2] 65 | return ret 66 | 67 | processed_results = [] 68 | for x in batched_inputs: 69 | result = self._inference_one_image(_maybe_read_image(x)) 70 | processed_results.append(result) 71 | return processed_results 72 | 73 | def _inference_one_image(self, input): 74 | """ 75 | Args: 76 | input (dict): one dataset dict with "image" field being a CHW tensor 77 | Returns: 78 | dict: one output dict 79 | """ 80 | orig_shape = (input["height"], input["width"]) 81 | augmented_inputs, tfms = self._get_augmented_inputs(input) 82 | 83 | final_predictions = None 84 | count_predictions = 0 85 | for input, tfm in zip(augmented_inputs, tfms): 86 | count_predictions += 1 87 | with torch.no_grad(): 88 | if final_predictions is None: 89 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 90 | final_predictions = ( 91 | self.model([input])[0].pop("sem_seg").flip(dims=[2]) 92 | ) 93 | else: 94 | final_predictions = self.model([input])[0].pop("sem_seg") 95 | else: 96 | if any(isinstance(t, HFlipTransform) for t in tfm.transforms): 97 | final_predictions += ( 98 | self.model([input])[0].pop("sem_seg").flip(dims=[2]) 99 | ) 100 | else: 101 | final_predictions += self.model([input])[0].pop("sem_seg") 102 | 103 | final_predictions = final_predictions / count_predictions 104 | return {"sem_seg": final_predictions} 105 | 106 | def _get_augmented_inputs(self, input): 107 | augmented_inputs = self.tta_mapper(input) 108 | tfms = [x.pop("transforms") for x in augmented_inputs] 109 | return augmented_inputs, tfms 110 | -------------------------------------------------------------------------------- /san/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .events import WandbWriter, setup_wandb 2 | from . import file_io 3 | -------------------------------------------------------------------------------- /san/utils/events.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | from detectron2.utils import comm 4 | from detectron2.utils.events import EventWriter, get_event_storage 5 | 6 | 7 | def setup_wandb(cfg, args): 8 | if comm.is_main_process(): 9 | init_args = { 10 | k.lower(): v 11 | for k, v in cfg.WANDB.items() 12 | if isinstance(k, str) and k not in ["config", "name"] 13 | } 14 | # only include most related part to avoid too big table 15 | # TODO: add configurable params to select which part of `cfg` should be saved in config 16 | if "config_exclude_keys" in init_args: 17 | init_args["config"] = cfg 18 | init_args["config"]["cfg_file"] = args.config_file 19 | else: 20 | init_args["config"] = { 21 | "model": cfg.MODEL, 22 | "solver": cfg.SOLVER, 23 | "cfg_file": args.config_file, 24 | } 25 | if ("name" not in init_args) or (init_args["name"] is None): 26 | init_args["name"] = os.path.basename(args.config_file) 27 | wandb.init(**init_args) 28 | 29 | 30 | class BaseRule(object): 31 | def __call__(self, target): 32 | return target 33 | 34 | 35 | class IsIn(BaseRule): 36 | def __init__(self, keyword: str): 37 | self.keyword = keyword 38 | 39 | def __call__(self, target): 40 | return self.keyword in target 41 | 42 | 43 | class Prefix(BaseRule): 44 | def __init__(self, keyword: str): 45 | self.keyword = keyword 46 | 47 | def __call__(self, target): 48 | return "/".join([self.keyword, target]) 49 | 50 | 51 | class WandbWriter(EventWriter): 52 | """ 53 | Write all scalars to a tensorboard file. 54 | """ 55 | 56 | def __init__(self): 57 | """ 58 | Args: 59 | log_dir (str): the directory to save the output events 60 | kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` 61 | """ 62 | self._last_write = -1 63 | self._group_rules = [ 64 | (IsIn("/"), BaseRule()), 65 | (IsIn("loss"), Prefix("train")), 66 | ] 67 | 68 | def write(self): 69 | storage = get_event_storage() 70 | 71 | def _group_name(scalar_name): 72 | for rule, op in self._group_rules: 73 | if rule(scalar_name): 74 | return op(scalar_name) 75 | return scalar_name 76 | 77 | stats = { 78 | _group_name(name): scalars[0] 79 | for name, scalars in storage.latest().items() 80 | if scalars[1] > self._last_write 81 | } 82 | if len(stats) > 0: 83 | self._last_write = max([v[1] for k, v in storage.latest().items()]) 84 | 85 | # storage.put_{image,histogram} is only meant to be used by 86 | # tensorboard writer. So we access its internal fields directly from here. 87 | if len(storage._vis_data) >= 1: 88 | stats["image"] = [ 89 | wandb.Image(img, caption=img_name) 90 | for img_name, img, step_num in storage._vis_data 91 | ] 92 | # Storage stores all image data and rely on this writer to clear them. 93 | # As a result it assumes only one writer will use its image data. 94 | # An alternative design is to let storage store limited recent 95 | # data (e.g. only the most recent image) that all writers can access. 96 | # In that case a writer may not see all image data if its period is long. 97 | storage.clear_images() 98 | 99 | if len(storage._histograms) >= 1: 100 | 101 | def create_bar(tag, bucket_limits, bucket_counts, **kwargs): 102 | data = [ 103 | [label, val] for (label, val) in zip(bucket_limits, bucket_counts) 104 | ] 105 | table = wandb.Table(data=data, columns=["label", "value"]) 106 | return wandb.plot.bar(table, "label", "value", title=tag) 107 | 108 | stats["hist"] = [create_bar(**params) for params in storage._histograms] 109 | 110 | storage.clear_histograms() 111 | 112 | if len(stats) == 0: 113 | return 114 | wandb.log(stats, step=storage.iter) 115 | 116 | def close(self): 117 | wandb.finish() 118 | -------------------------------------------------------------------------------- /san/utils/file_io.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import warnings 5 | from iopath.common.file_io import PathHandler 6 | from detectron2.utils.file_io import PathManager 7 | from zipfile import ZipFile 8 | from detectron2.utils.logger import log_first_n 9 | from io import BytesIO 10 | import multiprocessing 11 | 12 | __zip_file_pool__ = { 13 | # (path, mode): ZipFile 14 | } 15 | 16 | 17 | def find_zip_parent(path: str, mode: str = "r", is_dir=False): 18 | """Find the best match zipfile from the end""" 19 | if mode[-1] == "b": 20 | mode = mode[:-1] 21 | if is_dir: 22 | par_path = path 23 | else: 24 | par_path = os.path.dirname(path) 25 | visited = [par_path] 26 | # count = 0 27 | while par_path: 28 | # count += 1 29 | if ((par_path, mode) in __zip_file_pool__) and ( 30 | __zip_file_pool__[(par_path, mode)].fp is not None 31 | ): 32 | # zip file is still open 33 | zip_file = __zip_file_pool__[(par_path, mode)] 34 | for path in visited[:-1]: 35 | __zip_file_pool__[(path, mode)] = zip_file 36 | return zip_file 37 | elif os.path.isfile(par_path + ".zip"): 38 | log_first_n(logging.INFO, "Open zip file {}.".format(par_path), n=1) 39 | zip_file = ZipFile(par_path + ".zip", mode=mode) 40 | for path in visited: 41 | __zip_file_pool__[(path, mode)] = zip_file 42 | # return par_path, zip_file, count 43 | return zip_file 44 | 45 | par_path = os.path.sep.join(par_path.split(os.path.sep)[:-1]) 46 | visited.append(par_path) 47 | # return None, None, count 48 | return None 49 | 50 | 51 | class ZipFileHandler(PathHandler): 52 | """ 53 | Load data from zipfile and return a file-like object 54 | """ 55 | 56 | PREFIX = "zip://" 57 | 58 | def _get_supported_prefixes(self): 59 | return [self.PREFIX] 60 | 61 | def _get_local_path(self, path, **kwargs): 62 | name = path[len(self.PREFIX) :] 63 | 64 | return name 65 | 66 | def _open(self, path: str, mode: str = "r", buffering=-1, **kwargs): 67 | """Open a file and return a file object. 68 | Args: 69 | path (str): _description_ 70 | mode (str, optional): file open mode. Defaults to "r". 71 | 72 | Returns: 73 | ByteIO: file-like object 74 | """ 75 | 76 | path = self._get_local_path(path) 77 | zip_file: ZipFile = find_zip_parent(path, mode) 78 | if zip_file is None: 79 | warnings.warn( 80 | "No zipfile contains {}, falling back to naive PathHandler".format( 81 | path 82 | ), 83 | ) 84 | return PathManager.open(path, mode, buffering, **kwargs) 85 | assert mode in [ 86 | # "r", 87 | "rb", 88 | ], "Writing to ZipFile object is not thread safe. Only read mode is supported for now." # Need to deal with write mode carefully, maybe we will change it later. 89 | filename = os.path.join( 90 | zip_file.filelist[0].filename, 91 | path[len(os.path.splitext(zip_file.filename)[0]) + 1 :], 92 | ) 93 | log_first_n( 94 | logging.INFO, 95 | "[Example] load file {} from zip file {}.".format( 96 | filename, zip_file.filename 97 | ), 98 | n=1, 99 | ) 100 | assert ( 101 | buffering == -1 102 | ), f"{self.__class__.__name__} does not support the `buffering` argument" 103 | # Use zipfile.Path in Python 3.10 104 | 105 | if mode[-1] == "b": 106 | mode = mode[:-1] 107 | 108 | return BytesIO( 109 | zip_file.read(filename) 110 | ) # If any errors occur, check whether a ZipFile object is called by multiple threads/processes at the same time. 111 | 112 | def _ls(self, path: str, **kwargs): 113 | """ 114 | List the contents of the directory at the provided URI. 115 | 116 | Args: 117 | path (str): A URI supported by this PathHandler 118 | 119 | Returns: 120 | List[str]: list of contents in given path 121 | """ 122 | path = self._get_local_path(path) 123 | zip_file: ZipFile = find_zip_parent(path, is_dir=True) 124 | assert zip_file is not None, "No zipfile contains {}".format(path) 125 | file_names = zip_file.namelist() 126 | in_archive_path = os.path.join( 127 | file_names[0], path[len(os.path.splitext(zip_file.filename)[0]) + 1 :] 128 | ).rstrip(os.path.sep) 129 | file_names = [ 130 | f[len(in_archive_path) + 1 :] 131 | for f in file_names 132 | if f.startswith(in_archive_path) 133 | ] 134 | # must be closed to avoid thread safety issues 135 | zip_file.close() 136 | return file_names 137 | 138 | 139 | PathManager.register_handler(ZipFileHandler()) 140 | -------------------------------------------------------------------------------- /san/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py 3 | """ 4 | Misc functions, including distributed helpers. 5 | Mostly copy-paste from torchvision references. 6 | """ 7 | from typing import List, Optional 8 | from functools import reduce 9 | import torch 10 | import torch.distributed as dist 11 | import torchvision 12 | from torch import Tensor 13 | 14 | 15 | def _max_by_axis(the_list): 16 | # type: (List[List[int]]) -> List[int] 17 | maxes = the_list[0] 18 | for sublist in the_list[1:]: 19 | for index, item in enumerate(sublist): 20 | maxes[index] = max(maxes[index], item) 21 | return maxes 22 | 23 | 24 | class NestedTensor(object): 25 | def __init__(self, tensors, mask: Optional[Tensor]): 26 | self.tensors = tensors 27 | self.mask = mask 28 | 29 | def to(self, device): 30 | # type: (Device) -> NestedTensor # noqa 31 | cast_tensor = self.tensors.to(device) 32 | mask = self.mask 33 | if mask is not None: 34 | assert mask is not None 35 | cast_mask = mask.to(device) 36 | else: 37 | cast_mask = None 38 | return NestedTensor(cast_tensor, cast_mask) 39 | 40 | def decompose(self): 41 | return self.tensors, self.mask 42 | 43 | def __repr__(self): 44 | return str(self.tensors) 45 | 46 | 47 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 48 | # TODO make this more general 49 | if tensor_list[0].ndim == 3: 50 | if torchvision._is_tracing(): 51 | # nested_tensor_from_tensor_list() does not export well to ONNX 52 | # call _onnx_nested_tensor_from_tensor_list() instead 53 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 54 | 55 | # TODO make it support different-sized images 56 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 57 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 58 | batch_shape = [len(tensor_list)] + max_size 59 | b, c, h, w = batch_shape 60 | dtype = tensor_list[0].dtype 61 | device = tensor_list[0].device 62 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 63 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 64 | for img, pad_img, m in zip(tensor_list, tensor, mask): 65 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 66 | m[: img.shape[1], : img.shape[2]] = False 67 | else: 68 | raise ValueError("not supported") 69 | return NestedTensor(tensor, mask) 70 | 71 | 72 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 73 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 74 | @torch.jit.unused 75 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 76 | max_size = [] 77 | for i in range(tensor_list[0].dim()): 78 | max_size_i = torch.max( 79 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 80 | ).to(torch.int64) 81 | max_size.append(max_size_i) 82 | max_size = tuple(max_size) 83 | 84 | # work around for 85 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 86 | # m[: img.shape[1], :img.shape[2]] = False 87 | # which is not yet supported in onnx 88 | padded_imgs = [] 89 | padded_masks = [] 90 | for img in tensor_list: 91 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 92 | padded_img = torch.nn.functional.pad( 93 | img, (0, padding[2], 0, padding[1], 0, padding[0]) 94 | ) 95 | padded_imgs.append(padded_img) 96 | 97 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 98 | padded_mask = torch.nn.functional.pad( 99 | m, (0, padding[2], 0, padding[1]), "constant", 1 100 | ) 101 | padded_masks.append(padded_mask.to(torch.bool)) 102 | 103 | tensor = torch.stack(padded_imgs) 104 | mask = torch.stack(padded_masks) 105 | 106 | return NestedTensor(tensor, mask=mask) 107 | 108 | 109 | def is_dist_avail_and_initialized(): 110 | if not dist.is_available(): 111 | return False 112 | if not dist.is_initialized(): 113 | return False 114 | return True 115 | 116 | 117 | def get_module_by_name(module, access_string): 118 | names = access_string.split(sep=".") 119 | return reduce(getattr, names, module) 120 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | try: 2 | # ignore ShapelyDeprecationWarning from fvcore 3 | import warnings 4 | 5 | from shapely.errors import ShapelyDeprecationWarning 6 | 7 | warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) 8 | except: 9 | pass 10 | import copy 11 | import itertools 12 | import logging 13 | import os 14 | from collections import OrderedDict, defaultdict 15 | from typing import Any, Dict, List, Set 16 | 17 | import detectron2.utils.comm as comm 18 | import torch 19 | 20 | warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.parallel") 21 | from detectron2.checkpoint import DetectionCheckpointer 22 | from detectron2.config import get_cfg 23 | from detectron2.data import MetadataCatalog 24 | from detectron2.engine import ( 25 | DefaultTrainer, 26 | default_argument_parser, 27 | default_setup, 28 | launch, 29 | ) 30 | from detectron2.evaluation import ( 31 | CityscapesSemSegEvaluator, 32 | DatasetEvaluators, 33 | SemSegEvaluator, 34 | verify_results, 35 | ) 36 | from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler 37 | from detectron2.solver.build import maybe_add_gradient_clipping 38 | from detectron2.utils.logger import setup_logger 39 | from tabulate import tabulate 40 | 41 | from san import ( 42 | MaskFormerSemanticDatasetMapper, 43 | SemanticSegmentorWithTTA, 44 | add_san_config, 45 | ) 46 | from san.data import build_detection_test_loader, build_detection_train_loader 47 | from san.utils import WandbWriter, setup_wandb 48 | 49 | 50 | class Trainer(DefaultTrainer): 51 | def build_writers(self): 52 | writers = super().build_writers() 53 | # use wandb writer instead. 54 | writers[-1] = WandbWriter() 55 | return writers 56 | 57 | @classmethod 58 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 59 | if output_folder is None: 60 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 61 | evaluator_list = [] 62 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 63 | # semantic segmentation 64 | if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]: 65 | evaluator_list.append( 66 | SemSegEvaluator( 67 | dataset_name, 68 | distributed=True, 69 | output_dir=output_folder, 70 | ) 71 | ) 72 | 73 | if evaluator_type == "cityscapes_sem_seg": 74 | assert ( 75 | torch.cuda.device_count() > comm.get_rank() 76 | ), "CityscapesEvaluator currently do not work with multiple machines." 77 | return CityscapesSemSegEvaluator(dataset_name) 78 | 79 | if len(evaluator_list) == 0: 80 | raise NotImplementedError( 81 | "no Evaluator for the dataset {} with the type {}".format( 82 | dataset_name, evaluator_type 83 | ) 84 | ) 85 | elif len(evaluator_list) == 1: 86 | return evaluator_list[0] 87 | return DatasetEvaluators(evaluator_list) 88 | 89 | @classmethod 90 | def build_train_loader(cls, cfg): 91 | # resue maskformer dataset mapper 92 | if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic": 93 | mapper = MaskFormerSemanticDatasetMapper(cfg, True) 94 | return build_detection_train_loader(cfg, mapper=mapper) 95 | else: 96 | mapper = None 97 | return build_detection_train_loader(cfg, mapper=mapper) 98 | 99 | @classmethod 100 | def build_test_loader(cls, cfg, dataset_name): 101 | # Add dataset meta info. 102 | return build_detection_test_loader(cfg, dataset_name) 103 | 104 | @classmethod 105 | def build_lr_scheduler(cls, cfg, optimizer): 106 | # use poly scheduler 107 | return build_lr_scheduler(cfg, optimizer) 108 | 109 | @classmethod 110 | def build_optimizer(cls, cfg, model): 111 | weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM 112 | weight_decay_embed_group = cfg.SOLVER.WEIGHT_DECAY_EMBED_GROUP 113 | weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED 114 | 115 | defaults = {} 116 | defaults["lr"] = cfg.SOLVER.BASE_LR 117 | defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY 118 | 119 | norm_module_types = ( 120 | torch.nn.BatchNorm1d, 121 | torch.nn.BatchNorm2d, 122 | torch.nn.BatchNorm3d, 123 | torch.nn.SyncBatchNorm, 124 | # NaiveSyncBatchNorm inherits from BatchNorm2d 125 | torch.nn.GroupNorm, 126 | torch.nn.InstanceNorm1d, 127 | torch.nn.InstanceNorm2d, 128 | torch.nn.InstanceNorm3d, 129 | torch.nn.LayerNorm, 130 | torch.nn.LocalResponseNorm, 131 | ) 132 | 133 | params: List[Dict[str, Any]] = [] 134 | memo: Set[torch.nn.parameter.Parameter] = set() 135 | for module_name, module in model.named_modules(): 136 | for module_param_name, value in module.named_parameters(recurse=False): 137 | if not value.requires_grad: 138 | continue 139 | # Avoid duplicating parameters 140 | if value in memo: 141 | continue 142 | memo.add(value) 143 | 144 | hyperparams = copy.copy(defaults) 145 | hyperparams["param_name"] = ".".join([module_name, module_param_name]) 146 | if "side_adapter_network" in module_name: 147 | hyperparams["lr"] = ( 148 | hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER 149 | ) 150 | # scale clip lr 151 | if "clip" in module_name: 152 | hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.CLIP_MULTIPLIER 153 | if any([x in module_param_name for x in weight_decay_embed_group]): 154 | hyperparams["weight_decay"] = weight_decay_embed 155 | if isinstance(module, norm_module_types): 156 | hyperparams["weight_decay"] = weight_decay_norm 157 | if isinstance(module, torch.nn.Embedding): 158 | hyperparams["weight_decay"] = weight_decay_embed 159 | params.append({"params": [value], **hyperparams}) 160 | 161 | def maybe_add_full_model_gradient_clipping(optim): 162 | # detectron2 doesn't have full model gradient clipping now 163 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 164 | enable = ( 165 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 166 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 167 | and clip_norm_val > 0.0 168 | ) 169 | 170 | class FullModelGradientClippingOptimizer(optim): 171 | def step(self, closure=None): 172 | all_params = itertools.chain( 173 | *[x["params"] for x in self.param_groups] 174 | ) 175 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 176 | super().step(closure=closure) 177 | 178 | return FullModelGradientClippingOptimizer if enable else optim 179 | 180 | optimizer_type = cfg.SOLVER.OPTIMIZER 181 | if optimizer_type == "SGD": 182 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 183 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM 184 | ) 185 | elif optimizer_type == "ADAMW": 186 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 187 | params, cfg.SOLVER.BASE_LR 188 | ) 189 | else: 190 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 191 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 192 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 193 | # display the lr and wd of each param group in a table 194 | optim_info = defaultdict(list) 195 | total_params_size = 0 196 | for group in optimizer.param_groups: 197 | optim_info["Param Name"].append(group["param_name"]) 198 | optim_info["Param Shape"].append( 199 | "X".join([str(x) for x in list(group["params"][0].shape)]) 200 | ) 201 | total_params_size += group["params"][0].numel() 202 | optim_info["Lr"].append(group["lr"]) 203 | optim_info["Wd"].append(group["weight_decay"]) 204 | # Counting the number of parameters 205 | optim_info["Param Name"].append("Total") 206 | optim_info["Param Shape"].append("{:.2f}M".format(total_params_size / 1e6)) 207 | optim_info["Lr"].append("-") 208 | optim_info["Wd"].append("-") 209 | table = tabulate( 210 | list(zip(*optim_info.values())), 211 | headers=optim_info.keys(), 212 | tablefmt="grid", 213 | floatfmt=".2e", 214 | stralign="center", 215 | numalign="center", 216 | ) 217 | logger = logging.getLogger("san") 218 | logger.info("Optimizer Info:\n{}\n".format(table)) 219 | return optimizer 220 | 221 | @classmethod 222 | def test_with_TTA(cls, cfg, model): 223 | logger = logging.getLogger("detectron2.trainer") 224 | # In the end of training, run an evaluation with TTA. 225 | logger.info("Running inference with test-time augmentation ...") 226 | model = SemanticSegmentorWithTTA(cfg, model) 227 | evaluators = [ 228 | cls.build_evaluator( 229 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 230 | ) 231 | for name in cfg.DATASETS.TEST 232 | ] 233 | res = cls.test(cfg, model, evaluators) 234 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 235 | return res 236 | 237 | 238 | def setup(args): 239 | """ 240 | Create configs and perform basic setups. 241 | """ 242 | cfg = get_cfg() 243 | # for poly lr schedule 244 | add_deeplab_config(cfg) 245 | add_san_config(cfg) 246 | cfg.merge_from_file(args.config_file) 247 | cfg.merge_from_list(args.opts) 248 | cfg.freeze() 249 | default_setup(cfg, args) 250 | if not args.eval_only: 251 | setup_wandb(cfg, args) 252 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="san") 253 | return cfg 254 | 255 | 256 | def main(args): 257 | cfg = setup(args) 258 | 259 | if args.eval_only: 260 | model = Trainer.build_model(cfg) 261 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 262 | cfg.MODEL.WEIGHTS, resume=args.resume 263 | ) 264 | res = Trainer.test(cfg, model) 265 | if cfg.TEST.AUG.ENABLED: 266 | res.update(Trainer.test_with_TTA(cfg, model)) 267 | if comm.is_main_process(): 268 | verify_results(cfg, res) 269 | return res 270 | 271 | trainer = Trainer(cfg) 272 | 273 | trainer.resume_or_load(resume=args.resume) 274 | return trainer.train() 275 | 276 | 277 | if __name__ == "__main__": 278 | args = default_argument_parser().parse_args() 279 | print("Command Line Args:", args) 280 | launch( 281 | main, 282 | args.num_gpus, 283 | num_machines=args.num_machines, 284 | machine_rank=args.machine_rank, 285 | dist_url=args.dist_url, 286 | args=(args,), 287 | ) 288 | -------------------------------------------------------------------------------- /visualize_json_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import argparse 5 | import json 6 | import numpy as np 7 | import os 8 | from collections import defaultdict 9 | import cv2 10 | import tqdm 11 | from fvcore.common.file_io import PathManager 12 | from PIL import Image 13 | from detectron2.data import DatasetCatalog, MetadataCatalog 14 | from detectron2.structures import Boxes, BoxMode, Instances 15 | from detectron2.utils.logger import setup_logger 16 | from detectron2.utils.visualizer import Visualizer, GenericMask 17 | import sys 18 | 19 | import san 20 | 21 | 22 | def create_instances(predictions, image_size, ignore_label=255): 23 | ret = Instances(image_size) 24 | 25 | labels = np.asarray( 26 | [dataset_id_map(predictions[i]["category_id"]) for i in range(len(predictions))] 27 | ) 28 | ret.pred_classes = labels 29 | ret.pred_masks = [ 30 | GenericMask(predictions[i]["segmentation"], *image_size) 31 | for i in range(len(predictions)) 32 | ] 33 | # convert instance to sem_seg map 34 | sem_seg = np.ones(image_size[:2], dtype=np.uint16) * ignore_label 35 | for mask, label in zip(ret.pred_masks, ret.pred_classes): 36 | sem_seg[mask.mask == 1] = label 37 | return sem_seg 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser( 42 | description="A script that visualizes the json predictions from COCO or LVIS dataset." 43 | ) 44 | parser.add_argument( 45 | "--input", required=True, help="JSON file produced by the model" 46 | ) 47 | parser.add_argument("--output", required=True, help="output directory") 48 | parser.add_argument( 49 | "--dataset", help="name of the dataset", default="coco_2017_val" 50 | ) 51 | parser.add_argument( 52 | "--conf-threshold", default=0.5, type=float, help="confidence threshold" 53 | ) 54 | args = parser.parse_args() 55 | 56 | logger = setup_logger() 57 | 58 | with PathManager.open(args.input, "r") as f: 59 | predictions = json.load(f) 60 | 61 | pred_by_image = defaultdict(list) 62 | for p in predictions: 63 | 64 | pred_by_image[p["file_name"]].append(p) 65 | 66 | dicts = list(DatasetCatalog.get(args.dataset)) 67 | metadata = MetadataCatalog.get(args.dataset) 68 | if hasattr(metadata, "thing_dataset_id_to_contiguous_id"): 69 | 70 | def dataset_id_map(ds_id): 71 | return metadata.thing_dataset_id_to_contiguous_id[ds_id] 72 | 73 | elif "lvis" in args.dataset: 74 | # LVIS results are in the same format as COCO results, but have a different 75 | # mapping from dataset category id to contiguous category id in [0, #categories - 1] 76 | def dataset_id_map(ds_id): 77 | return ds_id - 1 78 | 79 | elif "sem_seg" in args.dataset: 80 | 81 | def dataset_id_map(ds_id): 82 | return ds_id 83 | 84 | else: 85 | raise ValueError("Unsupported dataset: {}".format(args.dataset)) 86 | 87 | os.makedirs(args.output, exist_ok=True) 88 | 89 | for dic in tqdm.tqdm(dicts): 90 | img = cv2.imread(dic["file_name"], cv2.IMREAD_COLOR)[:, :, ::-1] 91 | basename = os.path.basename(dic["file_name"]) 92 | if dic["file_name"] in pred_by_image: 93 | pred = create_instances( 94 | pred_by_image[dic["file_name"]], 95 | img.shape[:2], 96 | ignore_label=metadata.ignore_label, 97 | ) 98 | 99 | vis = Visualizer(img, metadata) 100 | vis_pred = vis.draw_sem_seg(pred).get_image() 101 | # import pdb 102 | # pdb.set_trace() 103 | vis = Visualizer(img, metadata) 104 | with PathManager.open(dic["sem_seg_file_name"], "rb") as f: 105 | sem_seg = Image.open(f) 106 | sem_seg = np.asarray(sem_seg, dtype="uint16") 107 | vis_gt = vis.draw_sem_seg(sem_seg).get_image() 108 | # reisze pred and gt to the same height 109 | ratio = vis_gt.shape[0] / 512 110 | tgt_w = int(vis_pred.shape[1] / ratio) 111 | vis_pred = cv2.resize(vis_pred, (tgt_w,512)) 112 | vis_gt = cv2.resize(vis_gt, (tgt_w,512)) 113 | img = cv2.resize(img, (tgt_w,512)) 114 | # build grid view 115 | blank_int = 255 * np.ones((vis_gt.shape[0], 10, 3), dtype=np.uint8) 116 | concat = np.concatenate( 117 | (img, blank_int, vis_pred, blank_int, vis_gt), axis=1 118 | ) 119 | cv2.imwrite(os.path.join(args.output, basename), concat[:, :, ::-1]) --------------------------------------------------------------------------------