├── .gitattributes
├── .gitignore
├── README.md
├── config
└── eval.yaml
├── demo
└── src
│ └── icon
│ ├── arXiv-Paper.svg
│ ├── bilibili-demo.svg
│ ├── colab-badge.svg
│ ├── license-MIT.svg
│ ├── publication-Paper.svg
│ └── youtube-demo.svg
├── environment.yml
├── eval.py
├── icm
├── criterion
│ ├── __init__.py
│ ├── loss_function.py
│ └── matting_criterion_eval.py
├── data
│ ├── __init__.py
│ ├── data_generator.py
│ ├── data_module.py
│ └── image_file.py
├── lr_scheduler.py
├── models
│ ├── __init__.py
│ ├── decoder
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── bottleneck_block.py
│ │ ├── detail_capture.py
│ │ ├── in_context_correspondence.py
│ │ └── in_context_decoder.py
│ ├── feature_extractor
│ │ ├── attention_controllers.py
│ │ └── dift_sd.py
│ └── in_context_matting.py
└── util.py
└── train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | icm/data/image_file.py merge=ours
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .*
2 | !/.gitignore
3 | !/.gitattributes
4 | reference_code/
5 | __pycache__/
6 | logs/
7 | ckpt/
8 | pretrained_models/
9 | datasets/
10 | lightning_logs/
11 | old_logs/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
In-Context Matting [CVPR 2024, Highlight]
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | This is the official repository of the paper In-Context Matting.
13 |
14 | Details of the model architecture and experimental results can be found in our homepage.
15 |
16 | ## TODO:
17 | - [x] Release code
18 | - [x] Release pre-trained models and instructions for inference
19 | - [x] Release ICM-57 dataset
20 | - [ ] Release training dataset and instructions for training
21 |
22 | ## Requirements
23 | We follow the environment setup of [Stable Diffusion Version 2](https://github.com/Stability-AI/StableDiffusion#requirements).
24 |
25 | ## Usage
26 |
27 | To evaluate the performance on the ICM-57 dataset using the `eval.py` script, follow these instructions:
28 |
29 | 1. **Download the Pretrained Model:**
30 | - Download the pretrained model from [this link](https://pan.baidu.com/s/1HPbRRE5ZtPRpOSocm9qOmA?pwd=BA1c).
31 |
32 | 2. **Prepare the dataset:**
33 | Ensure that your ICM-57 is ready.
34 |
35 | 3. **Run the Evaluation:**
36 | Use the following command to run the evaluation script. Replace the placeholders with the actual paths if they differ.
37 |
38 | ```bash
39 | python eval.py --checkpoint PATH_TO_MODEL --save_path results/ --config config/eval.yaml
40 | ```
41 |
42 | ### Dataset
43 | **ICM-57**
44 | - Download link: [ICM-57 Dataset](https://pan.baidu.com/s/1ZJU_XHEVhIaVzGFPK_XCRg?pwd=BA1c)
45 | - **Installation Guide**:
46 | 1. After downloading, unzip the dataset into the `datasets/` directory of the project.
47 | 2. Ensure the structure of the dataset folder is as follows:
48 | ```
49 | datasets/ICM57/
50 | ├── image
51 | └── alpha
52 | ```
53 |
54 | ### Acknowledgments
55 |
56 | We would like to express our gratitude to the developers and contributors of the [DIFT](https://github.com/Tsingularity/dift) and [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/) projects. Their shared resources and insights have significantly aided the development of our work.
57 |
58 | ## Statement
59 |
60 |
64 |
65 | This project is under the MIT license. For technical questions, please contact He Guo at [hguo01@hust.edu.cn](mailto:hguo01@hust.edu.cn). For commerial use, please contact Hao Lu at [hlu@hust.edu.cn](mailto:hlu@hust.edu.cn)
66 |
--------------------------------------------------------------------------------
/config/eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: icm.models.in_context_matting.InContextMatting
3 | params:
4 | learning_rate: 0.0004
5 | cfg_loss_function:
6 | target: icm.criterion.loss_function.LossFunction2
7 | params:
8 | losses_seg:
9 | - known_smooth_l1_loss
10 | losses_matting:
11 | - unknown_l1_loss
12 | - known_l1_loss
13 | - loss_pha_laplacian
14 | - loss_gradient_penalty
15 | cfg_scheduler:
16 | target: icm.lr_scheduler.LambdaLinearScheduler
17 | params:
18 | warm_up_steps:
19 | - 250
20 | cycle_lengths:
21 | - 10000000000000
22 | f_start:
23 | - 1.0e-06
24 | f_max:
25 | - 1.0
26 | f_min:
27 | - 1.0
28 | cfg_feature_extractor:
29 | target: icm.models.feature_extractor.dift_sd.FeatureExtractor
30 | params:
31 | sd_id: stabilityai/stable-diffusion-2-1
32 | load_local: true
33 | if_softmax: 1
34 | feature_index_cor: 1
35 | feature_index_matting:
36 | - 0
37 | - 1
38 | attention_res:
39 | - 24
40 | - 48
41 | set_diag_to_one: false
42 | time_steps:
43 | - 200
44 | extract_feature_inputted_to_layer: false
45 | ensemble_size: 4
46 | cfg_in_context_decoder:
47 | target: icm.models.decoder.in_context_decoder.InContextDecoder
48 | params:
49 | freeze_in_context_fusion: false
50 | cfg_detail_decoder:
51 | target: icm.models.decoder.detail_capture.DetailCapture
52 | params:
53 | use_sigmoid: true
54 | ckpt: ''
55 | in_chans: 320
56 | img_chans: 3
57 | convstream_out:
58 | - 48
59 | - 96
60 | - 192
61 | fusion_out:
62 | - 256
63 | - 128
64 | - 64
65 | - 32
66 | cfg_in_context_fusion:
67 | target: icm.models.decoder.in_context_correspondence.SemiTrainingAttentionBlocks
68 | params:
69 | res_ratio: null
70 | pool_type: min
71 | upsample_mode: bicubic
72 | bottle_neck_dim: null
73 | use_norm: 1280
74 | in_ft_dim:
75 | - 1280
76 | - 1280
77 | in_attn_dim:
78 | - 576
79 | - 2304
80 | attn_out_dim: 256
81 | ft_out_dim:
82 | - 320
83 | - 320
84 | training_cross_attn: false
85 | data:
86 | target: icm.data.data_module.DataModuleFromConfig
87 | params:
88 | batch_size: 2
89 | batch_size_val: 1
90 | num_workers: 8
91 | shuffle_train: false
92 | validation:
93 | target: icm.data.data_generator.InContextDataset
94 | params:
95 | crop_size: 768
96 | phase: val
97 | norm_type: sd
98 | data:
99 | target: icm.data.image_file.ContextData
100 | params:
101 | ratio: 0
102 | dataset_name:
103 | - ICM57
104 | trainer:
105 | accelerator: ddp
106 | gpus: 1
107 | max_epochs: 1000
108 | auto_select_gpus: false
109 | num_sanity_val_steps: 0
110 | cfg_logger:
111 | target: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
112 | params:
113 | save_dir: logs
114 | default_hp_metric: false
115 | plugins:
116 | target: pytorch_lightning.plugins.DDPPlugin
117 | params:
118 | find_unused_parameters: false
119 |
--------------------------------------------------------------------------------
/demo/src/icon/arXiv-Paper.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/demo/src/icon/bilibili-demo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/demo/src/icon/colab-badge.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/demo/src/icon/license-MIT.svg:
--------------------------------------------------------------------------------
1 | license: MITlicenselicenseMITMIT
--------------------------------------------------------------------------------
/demo/src/icon/publication-Paper.svg:
--------------------------------------------------------------------------------
1 | publication: PaperpublicationpublicationPaperPaper
--------------------------------------------------------------------------------
/demo/src/icon/youtube-demo.svg:
--------------------------------------------------------------------------------
1 | youtube: demoyoutubeyoutubedemodemo
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: icm
2 | channels:
3 | - xformers
4 | - pytorch
5 | - nvidia
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - blas=1.0=mkl
11 | - brotlipy=0.7.0=py310h7f8727e_1002
12 | - bzip2=1.0.8=h7b6447c_0
13 | - ca-certificates=2023.01.10=h06a4308_0
14 | - certifi=2023.5.7=py310h06a4308_0
15 | - cffi=1.15.1=py310h5eee18b_3
16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
17 | - cryptography=39.0.1=py310h9ce1e76_0
18 | - cuda-cudart=11.7.99=0
19 | - cuda-cupti=11.7.101=0
20 | - cuda-libraries=11.7.1=0
21 | - cuda-nvrtc=11.7.99=0
22 | - cuda-nvtx=11.7.91=0
23 | - cuda-runtime=11.7.1=0
24 | - ffmpeg=4.3=hf484d3e_0
25 | - freetype=2.12.1=h4a9f257_0
26 | - giflib=5.2.1=h5eee18b_3
27 | - gmp=6.2.1=h295c915_3
28 | - gnutls=3.6.15=he1e5248_0
29 | - idna=3.4=py310h06a4308_0
30 | - intel-openmp=2023.1.0=hdb19cb5_46305
31 | - jpeg=9e=h5eee18b_1
32 | - lame=3.100=h7b6447c_0
33 | - lcms2=2.12=h3be6417_0
34 | - ld_impl_linux-64=2.38=h1181459_1
35 | - lerc=3.0=h295c915_0
36 | - libcublas=11.10.3.66=0
37 | - libcufft=10.7.2.124=h4fbf590_0
38 | - libcufile=1.6.1.9=0
39 | - libcurand=10.3.2.106=0
40 | - libcusolver=11.4.0.1=0
41 | - libcusparse=11.7.4.91=0
42 | - libdeflate=1.17=h5eee18b_0
43 | - libffi=3.4.4=h6a678d5_0
44 | - libgcc-ng=11.2.0=h1234567_1
45 | - libgomp=11.2.0=h1234567_1
46 | - libiconv=1.16=h7f8727e_2
47 | - libidn2=2.3.4=h5eee18b_0
48 | - libnpp=11.7.4.75=0
49 | - libnvjpeg=11.8.0.2=0
50 | - libpng=1.6.39=h5eee18b_0
51 | - libstdcxx-ng=11.2.0=h1234567_1
52 | - libtasn1=4.19.0=h5eee18b_0
53 | - libtiff=4.5.0=h6a678d5_2
54 | - libunistring=0.9.10=h27cfd23_0
55 | - libuuid=1.41.5=h5eee18b_0
56 | - libwebp=1.2.4=h11a3e52_1
57 | - libwebp-base=1.2.4=h5eee18b_1
58 | - lz4-c=1.9.4=h6a678d5_0
59 | - mkl=2023.1.0=h6d00ec8_46342
60 | - mkl-service=2.4.0=py310h5eee18b_1
61 | - mkl_fft=1.3.6=py310h1128e8f_1
62 | - mkl_random=1.2.2=py310h1128e8f_1
63 | - ncurses=6.4=h6a678d5_0
64 | - nettle=3.7.3=hbbd107a_1
65 | - numpy=1.24.3=py310h5f9d8c6_1
66 | - numpy-base=1.24.3=py310hb5e798b_1
67 | - openh264=2.1.1=h4ff587b_0
68 | - openssl=1.1.1t=h7f8727e_0
69 | - pillow=9.4.0=py310h6a678d5_0
70 | - pip=23.0.1=py310h06a4308_0
71 | - pycparser=2.21=pyhd3eb1b0_0
72 | - pyopenssl=23.0.0=py310h06a4308_0
73 | - pysocks=1.7.1=py310h06a4308_0
74 | - python=3.10.9=h7a1cb2a_2
75 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0
76 | - pytorch-cuda=11.7=h778d358_5
77 | - pytorch-mutex=1.0=cuda
78 | - readline=8.2=h5eee18b_0
79 | - requests=2.29.0=py310h06a4308_0
80 | - setuptools=67.8.0=py310h06a4308_0
81 | - sqlite=3.41.2=h5eee18b_0
82 | - tbb=2021.8.0=hdb19cb5_0
83 | - tk=8.6.12=h1ccaba5_0
84 | - torchaudio=0.13.1=py310_cu117
85 | - torchvision=0.14.1=py310_cu117
86 | - typing_extensions=4.5.0=py310h06a4308_0
87 | - tzdata=2023c=h04d1e81_0
88 | - urllib3=1.26.15=py310h06a4308_0
89 | - wheel=0.38.4=py310h06a4308_0
90 | - xformers=0.0.20=py310_cu11.7.1_pyt1.13.1
91 | - xz=5.4.2=h5eee18b_0
92 | - zlib=1.2.13=h5eee18b_0
93 | - zstd=1.5.5=hc292b87_0
94 | - pip:
95 | - accelerate==0.19.0
96 | - aiohttp==3.8.4
97 | - aiosignal==1.3.1
98 | - anyio==3.7.0
99 | - argon2-cffi==21.3.0
100 | - argon2-cffi-bindings==21.2.0
101 | - arrow==1.2.3
102 | - asttokens==2.2.1
103 | - async-lru==2.0.2
104 | - async-timeout==4.0.2
105 | - attrs==23.1.0
106 | - babel==2.12.1
107 | - backcall==0.2.0
108 | - beautifulsoup4==4.12.2
109 | - bleach==6.0.0
110 | - brotli==1.0.9
111 | - cmake==3.26.3
112 | - comm==0.1.3
113 | - contourpy==1.0.7
114 | - cycler==0.11.0
115 | - debugpy==1.6.7
116 | - decorator==5.1.1
117 | - defusedxml==0.7.1
118 | - diffusers==0.18.1
119 | - exceptiongroup==1.1.1
120 | - executing==1.2.0
121 | - fastjsonschema==2.17.1
122 | - filelock==3.12.0
123 | - fonttools==4.39.4
124 | - fqdn==1.5.1
125 | - frozenlist==1.3.3
126 | - fsspec==2023.5.0
127 | - gevent==22.10.2
128 | - geventhttpclient==2.0.2
129 | - greenlet==2.0.2
130 | - grpcio==1.54.2
131 | - huggingface-hub==0.14.1
132 | - importlib-metadata==6.6.0
133 | - ipykernel==6.23.1
134 | - ipympl==0.9.3
135 | - ipython==8.13.2
136 | - ipython-genutils==0.2.0
137 | - ipywidgets==8.0.6
138 | - isoduration==20.11.0
139 | - jedi==0.18.2
140 | - jinja2==3.1.2
141 | - json5==0.9.14
142 | - jsonpointer==2.3
143 | - jsonschema==4.17.3
144 | - jupyter-client==8.2.0
145 | - jupyter-core==5.3.0
146 | - jupyter-events==0.6.3
147 | - jupyter-lsp==2.2.0
148 | - jupyter-server
149 | - jupyter-server-terminals==0.4.4
150 | - jupyterlab==4.0.0
151 | - jupyterlab-pygments==0.2.2
152 | - jupyterlab-server==2.22.1
153 | - jupyterlab-widgets==3.0.7
154 | - kiwisolver==1.4.4
155 | - lit==16.0.5
156 | - markupsafe==2.1.2
157 | - matplotlib==3.7.1
158 | - matplotlib-inline==0.1.6
159 | - mistune==2.0.5
160 | - multidict==6.0.4
161 | - mypy-extensions==1.0.0
162 | - nbclient==0.8.0
163 | - nbconvert==7.4.0
164 | - nbformat==5.8.0
165 | - nest-asyncio==1.5.6
166 | - notebook-shim==0.2.3
167 | - overrides==7.3.1
168 | - packaging==23.1
169 | - pandocfilters==1.5.0
170 | - parso==0.8.3
171 | - pexpect==4.8.0
172 | - pickleshare==0.7.5
173 | - platformdirs==3.5.1
174 | - prometheus-client==0.17.0
175 | - prompt-toolkit==3.0.38
176 | - protobuf==3.20.3
177 | - psutil==5.9.5
178 | - ptyprocess==0.7.0
179 | - pure-eval==0.2.2
180 | - pygments==2.15.1
181 | - pyparsing==3.0.9
182 | - pyrsistent==0.19.3
183 | - python-dateutil==2.8.2
184 | - python-json-logger==2.0.7
185 | - python-rapidjson==1.10
186 | - pyyaml==6.0
187 | - regex==2023.5.5
188 | - rfc3339-validator==0.1.4
189 | - rfc3986-validator==0.1.1
190 | - send2trash==1.8.2
191 | - sh==1.14.3
192 | - six==1.16.0
193 | - sniffio==1.3.0
194 | - soupsieve==2.4.1
195 | - stack-data==0.6.2
196 | - terminado==0.17.1
197 | - tinycss2==1.2.1
198 | - tokenizers==0.13.3
199 | - tomli==2.0.1
200 | - tornado==6.3.2
201 | - tqdm==4.65.0
202 | - traitlets==5.9.0
203 | - transformers==4.29.2
204 | - triton==2.0.0.post1
205 | - tritonclient==2.33.0
206 | - typing-inspect==0.6.0
207 | - uri-template==1.2.0
208 | - wcwidth==0.2.6
209 | - webcolors==1.13
210 | - webencodings==0.5.1
211 | - websocket-client==1.5.2
212 | - widgetsnbextension==4.0.7
213 | - wrapt==1.15.0
214 | - yarl==1.9.2
215 | - zipp==3.15.0
216 | - zope-event==4.6
217 | - zope-interface==6.0
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import argparse
3 | from omegaconf import OmegaConf
4 | from icm.util import instantiate_from_config
5 | import torch
6 | from pytorch_lightning import Trainer, seed_everything
7 | import os
8 | from tqdm import tqdm
9 |
10 | def load_model_from_config(config, ckpt, verbose=False):
11 | print(f"Loading model from {ckpt}")
12 | pl_sd = torch.load(ckpt, map_location="cpu")
13 | if "global_step" in pl_sd:
14 | print(f"Global Step: {pl_sd['global_step']}")
15 | sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
16 | model = instantiate_from_config(config)
17 | m, u = model.load_state_dict(sd, strict=False)
18 | if len(m) > 0 and verbose:
19 | print("missing keys:")
20 | print(m)
21 | if len(u) > 0 and verbose:
22 | print("unexpected keys:")
23 | print(u)
24 |
25 | # model.eval()
26 | return model
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser()
31 |
32 | parser.add_argument(
33 | "--checkpoint",
34 | type=str,
35 | default="",
36 | )
37 | parser.add_argument(
38 | "--save_path",
39 | type=str,
40 | default="",
41 | )
42 | parser.add_argument(
43 | "--config",
44 | type=str,
45 | default="",
46 | )
47 | parser.add_argument(
48 | "--seed",
49 | type=int,
50 | default=42,
51 | )
52 |
53 | args = parser.parse_args()
54 | return args
55 |
56 |
57 | if __name__ == '__main__':
58 | args = parse_args()
59 | # if args.checkpoint:
60 | # path = args.checkpoint.split('checkpoints')[0]
61 | # # get the folder of last version folder
62 | # all_folder = os.listdir(path)
63 | # all_folder = [os.path.join(path, folder)
64 | # for folder in all_folder if 'version' in folder]
65 | # all_folder.sort()
66 | # last_version_folder = all_folder[-1]
67 | # # get the hparams.yaml path
68 | # hparams_path = os.path.join(last_version_folder, 'hparams.yaml')
69 | # cfg = OmegaConf.load(hparams_path)
70 | # else:
71 | # raise ValueError('Please input the checkpoint path')
72 |
73 | # set seed
74 | seed_everything(args.seed)
75 |
76 | cfg = OmegaConf.load(args.config)
77 |
78 | """=== Init data ==="""
79 |
80 | cfg_data = cfg.get('data')
81 |
82 | data = instantiate_from_config(cfg_data)
83 | data.setup()
84 |
85 | """=== Init model ==="""
86 | cfg_model = cfg.get('model')
87 |
88 | # model = instantiate_from_config(cfg_model)
89 | model = load_model_from_config(cfg_model, args.checkpoint, verbose=True)
90 |
91 | """=== Start validation ==="""
92 | model.on_train_start()
93 | model.eval()
94 | model.cuda()
95 | # model.train()
96 | # loss_list = []
97 | # for batch in tqdm(data._val_dataloader()):
98 | # # move tensor in batch to cuda
99 | # for key in batch:
100 | # if isinstance(batch[key], torch.Tensor):
101 | # batch[key] = batch[key].cuda()
102 | # output, loss = model.test_step(batch, None)
103 | # loss_list.append(loss.item())
104 | # print('Validation loss: ', sum(loss_list)/len(loss_list))
105 | # print('Validation loss: ', sum(loss_list)/len(loss_list))
106 | # print('Finish validation')
107 |
108 |
109 | # init trainer for validation
110 | cfg_trainer = cfg.get('trainer')
111 | # set gpu = 1
112 | cfg_trainer.gpus = 1
113 |
114 |
115 | # omegaconf to dict
116 | cfg_trainer = OmegaConf.to_container(cfg_trainer)
117 | cfg_trainer.pop('cfg_callbacks') if 'cfg_callbacks' in cfg_trainer else None
118 | # init logger
119 | cfg_logger = cfg_trainer.pop('cfg_logger') if 'cfg_logger' in cfg_trainer else None
120 | cfg_logger['params']['save_dir'] = 'logs/'
121 | cfg_logger['params']['name'] = 'eval'
122 | cfg_trainer['logger'] = instantiate_from_config(cfg_logger)
123 |
124 | # plugin
125 | cfg_plugin = cfg_trainer.pop('plugins') if 'plugins' in cfg_trainer else None
126 |
127 | # init trainer
128 | trainer_opt = argparse.Namespace(**cfg_trainer)
129 | trainer = Trainer.from_argparse_args(trainer_opt)
130 | # init logger
131 | model.val_save_path = args.save_path
132 | trainer.validate(model, data.val_dataloader())
133 |
--------------------------------------------------------------------------------
/icm/criterion/__init__.py:
--------------------------------------------------------------------------------
1 | from .loss_function import LossFunction
2 |
--------------------------------------------------------------------------------
/icm/criterion/loss_function.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torchvision.ops import focal_loss
6 |
7 | class LossFunction(nn.Module):
8 | '''
9 | Loss function set
10 | losses=['unknown_l1_loss', 'known_l1_loss',
11 | 'loss_pha_laplacian', 'loss_gradient_penalty',
12 | 'smooth_l1_loss', 'cross_entropy_loss', 'focal_loss']
13 | '''
14 | def __init__(self,
15 | *,
16 | losses,
17 | ):
18 | super(LossFunction, self).__init__()
19 | self.losses = losses
20 |
21 | def loss_gradient_penalty(self, sample_map ,preds, targets):
22 | preds = preds['phas']
23 | targets = targets['phas']
24 | h,w = sample_map.shape[2:]
25 | if torch.sum(sample_map) == 0:
26 | scale = 0
27 | else:
28 | #sample_map for unknown area
29 | scale = sample_map.shape[0]*262144/torch.sum(sample_map)
30 |
31 | #gradient in x
32 | sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type())
33 | delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1)
34 | delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1)
35 |
36 | #gradient in y
37 | sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type())
38 | delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1)
39 | delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1)
40 |
41 | #loss
42 | loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \
43 | F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \
44 | 0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \
45 | 0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale)
46 |
47 | return dict(loss_gradient_penalty=loss)
48 |
49 | def loss_pha_laplacian(self, preds, targets):
50 | assert 'phas' in preds and 'phas' in targets
51 | loss = laplacian_loss(preds['phas'], targets['phas'])
52 |
53 | return dict(loss_pha_laplacian=loss)
54 |
55 | def unknown_l1_loss(self, sample_map, preds, targets):
56 | h,w = sample_map.shape[2:]
57 | if torch.sum(sample_map) == 0:
58 | scale = 0
59 | else:
60 | #sample_map for unknown area
61 | scale = sample_map.shape[0]*262144/torch.sum(sample_map)
62 |
63 | # scale = 1
64 |
65 | loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale
66 | return dict(unknown_l1_loss=loss)
67 |
68 | def known_l1_loss(self, sample_map, preds, targets):
69 | new_sample_map = torch.zeros_like(sample_map)
70 | new_sample_map[sample_map==0] = 1
71 | h,w = sample_map.shape[2:]
72 | if torch.sum(new_sample_map) == 0:
73 | scale = 0
74 | else:
75 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map)
76 | # scale = 1
77 |
78 | loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale
79 | return dict(known_l1_loss=loss)
80 |
81 | def smooth_l1_loss(self, preds, targets):
82 | assert 'phas' in preds and 'phas' in targets
83 | loss = F.smooth_l1_loss(preds['phas'], targets['phas'])
84 |
85 | return dict(smooth_l1_loss=loss)
86 |
87 | def known_smooth_l1_loss(self, sample_map, preds, targets):
88 | new_sample_map = torch.zeros_like(sample_map)
89 | new_sample_map[sample_map==0] = 1
90 | h,w = sample_map.shape[2:]
91 | if torch.sum(new_sample_map) == 0:
92 | scale = 0
93 | else:
94 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map)
95 | # scale = 1
96 |
97 | loss = F.smooth_l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale
98 | return dict(known_l1_loss=loss)
99 |
100 | def cross_entropy_loss(self, preds, targets):
101 | assert 'phas' in preds and 'phas' in targets
102 | loss = F.binary_cross_entropy_with_logits(preds['phas'], targets['phas'])
103 |
104 | return dict(cross_entropy_loss=loss)
105 |
106 | def focal_loss(self, preds, targets):
107 | assert 'phas' in preds and 'phas' in targets
108 | loss = focal_loss.sigmoid_focal_loss(preds['phas'], targets['phas'], reduction='mean')
109 |
110 | return dict(focal_loss=loss)
111 | def forward(self, sample_map, preds, targets):
112 |
113 | preds = {'phas': preds}
114 | targets = {'phas': targets}
115 | losses = dict()
116 | for k in self.losses:
117 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty' or k=='known_smooth_l1_loss':
118 | losses.update(getattr(self, k)(sample_map, preds, targets))
119 | else:
120 | losses.update(getattr(self, k)(preds, targets))
121 | return losses
122 |
123 | class LossFunction2(nn.Module):
124 | '''
125 | Loss function set
126 | losses=['unknown_l1_loss', 'known_l1_loss',
127 | 'loss_pha_laplacian', 'loss_gradient_penalty',
128 | 'smooth_l1_loss', 'cross_entropy_loss', 'focal_loss']
129 | '''
130 | def __init__(self,
131 | *,
132 | losses_seg = ['known_smooth_l1_loss'],
133 | losses_matting = ['unknown_l1_loss', 'known_l1_loss','loss_pha_laplacian', 'loss_gradient_penalty',],
134 | ):
135 | super(LossFunction2, self).__init__()
136 | self.losses_seg = losses_seg
137 | self.losses_matting = losses_matting
138 |
139 | def loss_gradient_penalty(self, sample_map ,preds, targets):
140 | preds = preds['phas']
141 | targets = targets['phas']
142 | h,w = sample_map.shape[2:]
143 | if torch.sum(sample_map) == 0:
144 | scale = 0
145 | else:
146 | #sample_map for unknown area
147 | scale = sample_map.shape[0]*262144/torch.sum(sample_map)
148 |
149 | #gradient in x
150 | sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type())
151 | delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1)
152 | delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1)
153 |
154 | #gradient in y
155 | sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type())
156 | delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1)
157 | delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1)
158 |
159 | #loss
160 | loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \
161 | F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \
162 | 0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \
163 | 0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale)
164 |
165 | return dict(loss_gradient_penalty=loss)
166 |
167 | def loss_pha_laplacian(self, preds, targets):
168 | assert 'phas' in preds and 'phas' in targets
169 | loss = laplacian_loss(preds['phas'], targets['phas'])
170 |
171 | return dict(loss_pha_laplacian=loss)
172 |
173 | def unknown_l1_loss(self, sample_map, preds, targets):
174 | h,w = sample_map.shape[2:]
175 | if torch.sum(sample_map) == 0:
176 | scale = 0
177 | else:
178 | #sample_map for unknown area
179 | scale = sample_map.shape[0]*262144/torch.sum(sample_map)
180 |
181 | # scale = 1
182 |
183 | loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale
184 | return dict(unknown_l1_loss=loss)
185 |
186 | def known_l1_loss(self, sample_map, preds, targets):
187 | new_sample_map = torch.zeros_like(sample_map)
188 | new_sample_map[sample_map==0] = 1
189 | h,w = sample_map.shape[2:]
190 | if torch.sum(new_sample_map) == 0:
191 | scale = 0
192 | else:
193 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map)
194 | # scale = 1
195 |
196 | loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale
197 | return dict(known_l1_loss=loss)
198 |
199 | def smooth_l1_loss(self, preds, targets):
200 | assert 'phas' in preds and 'phas' in targets
201 | loss = F.smooth_l1_loss(preds['phas'], targets['phas'])
202 |
203 | return dict(smooth_l1_loss=loss)
204 |
205 | def known_smooth_l1_loss(self, sample_map, preds, targets):
206 | new_sample_map = torch.zeros_like(sample_map)
207 | new_sample_map[sample_map==0] = 1
208 | h,w = sample_map.shape[2:]
209 | if torch.sum(new_sample_map) == 0:
210 | scale = 0
211 | else:
212 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map)
213 | # scale = 1
214 |
215 | loss = F.smooth_l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale
216 | return dict(known_l1_loss=loss)
217 |
218 | def cross_entropy_loss(self, preds, targets):
219 | assert 'phas' in preds and 'phas' in targets
220 | loss = F.binary_cross_entropy_with_logits(preds['phas'], targets['phas'])
221 |
222 | return dict(cross_entropy_loss=loss)
223 |
224 | def focal_loss(self, preds, targets):
225 | assert 'phas' in preds and 'phas' in targets
226 | loss = focal_loss.sigmoid_focal_loss(preds['phas'], targets['phas'], reduction='mean')
227 |
228 | return dict(focal_loss=loss)
229 | def forward_single_sample(self, sample_map, preds, targets):
230 | # check if targets only have element 0 and 1
231 | if torch.all(targets == 0) or torch.all(targets == 1):
232 |
233 | preds = {'phas': preds}
234 | targets = {'phas': targets}
235 | losses = dict()
236 | for k in self.losses_seg:
237 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty' or k=='known_smooth_l1_loss':
238 | losses.update(getattr(self, k)(sample_map, preds, targets))
239 | else:
240 | losses.update(getattr(self, k)(preds, targets))
241 | return losses
242 | else:
243 | preds = {'phas': preds}
244 | targets = {'phas': targets}
245 | losses = dict()
246 | for k in self.losses_matting:
247 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty' or k=='known_smooth_l1_loss':
248 | losses.update(getattr(self, k)(sample_map, preds, targets))
249 | else:
250 | losses.update(getattr(self, k)(preds, targets))
251 | return losses
252 |
253 | def forward(self, sample_map, preds, targets):
254 | losses = dict()
255 | for i in range(preds.shape[0]):
256 | losses_ = self.forward_single_sample(sample_map[i].unsqueeze(0), preds[i].unsqueeze(0), targets[i].unsqueeze(0))
257 | for k in losses_:
258 | if k in losses:
259 | losses[k] += losses_[k]
260 | else:
261 | losses[k] = losses_[k]
262 | return losses
263 | #-----------------Laplacian Loss-------------------------#
264 | def laplacian_loss(pred, true, max_levels=5):
265 | kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
266 | pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
267 | true_pyramid = laplacian_pyramid(true, kernel, max_levels)
268 | loss = 0
269 | for level in range(max_levels):
270 | loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
271 | return loss / max_levels
272 |
273 | def laplacian_pyramid(img, kernel, max_levels):
274 | current = img
275 | pyramid = []
276 | for _ in range(max_levels):
277 | current = crop_to_even_size(current)
278 | down = downsample(current, kernel)
279 | up = upsample(down, kernel)
280 | diff = current - up
281 | pyramid.append(diff)
282 | current = down
283 | return pyramid
284 |
285 | def gauss_kernel(device='cpu', dtype=torch.float32):
286 | kernel = torch.tensor([[1, 4, 6, 4, 1],
287 | [4, 16, 24, 16, 4],
288 | [6, 24, 36, 24, 6],
289 | [4, 16, 24, 16, 4],
290 | [1, 4, 6, 4, 1]], device=device, dtype=dtype)
291 | kernel /= 256
292 | kernel = kernel[None, None, :, :]
293 | return kernel
294 |
295 | def gauss_convolution(img, kernel):
296 | B, C, H, W = img.shape
297 | img = img.reshape(B * C, 1, H, W)
298 | img = F.pad(img, (2, 2, 2, 2), mode='reflect')
299 | img = F.conv2d(img, kernel)
300 | img = img.reshape(B, C, H, W)
301 | return img
302 |
303 | def downsample(img, kernel):
304 | img = gauss_convolution(img, kernel)
305 | img = img[:, :, ::2, ::2]
306 | return img
307 |
308 | def upsample(img, kernel):
309 | B, C, H, W = img.shape
310 | out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
311 | out[:, :, ::2, ::2] = img * 4
312 | out = gauss_convolution(out, kernel)
313 | return out
314 |
315 | def crop_to_even_size(img):
316 | H, W = img.shape[2:]
317 | H = H - H % 2
318 | W = W - W % 2
319 | return img[:, :, :H, :W]
--------------------------------------------------------------------------------
/icm/criterion/matting_criterion_eval.py:
--------------------------------------------------------------------------------
1 | import scipy.ndimage
2 | import numpy as np
3 | from skimage.measure import label
4 | import scipy.ndimage.morphology
5 | import torch
6 |
7 | def compute_mse_loss(pred, target, trimap):
8 | error_map = (pred - target) / 255.0
9 | loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
10 |
11 | return loss
12 |
13 |
14 | def compute_sad_loss(pred, target, trimap):
15 | error_map = np.abs((pred - target) / 255.0)
16 | loss = np.sum(error_map * (trimap == 128))
17 |
18 | return loss / 1000, np.sum(trimap == 128) / 1000
19 |
20 | def gauss(x, sigma):
21 | y = np.exp(-x ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))
22 | return y
23 |
24 |
25 | def dgauss(x, sigma):
26 | y = -x * gauss(x, sigma) / (sigma ** 2)
27 | return y
28 |
29 |
30 | def gaussgradient(im, sigma):
31 | epsilon = 1e-2
32 | halfsize = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))).astype(int)
33 | size = 2 * halfsize + 1
34 | hx = np.zeros((size, size))
35 | for i in range(0, size):
36 | for j in range(0, size):
37 | u = [i - halfsize, j - halfsize]
38 | hx[i, j] = gauss(u[0], sigma) * dgauss(u[1], sigma)
39 |
40 | hx = hx / np.sqrt(np.sum(np.abs(hx) * np.abs(hx)))
41 | hy = hx.transpose()
42 |
43 | gx = scipy.ndimage.convolve(im, hx, mode='nearest')
44 | gy = scipy.ndimage.convolve(im, hy, mode='nearest')
45 |
46 | return gx, gy
47 |
48 | def compute_gradient_loss(pred, target, trimap):
49 |
50 | pred = pred / 255.0
51 | target = target / 255.0
52 |
53 | pred_x, pred_y = gaussgradient(pred, 1.4)
54 | target_x, target_y = gaussgradient(target, 1.4)
55 |
56 | pred_amp = np.sqrt(pred_x ** 2 + pred_y ** 2)
57 | target_amp = np.sqrt(target_x ** 2 + target_y ** 2)
58 |
59 | error_map = (pred_amp - target_amp) ** 2
60 | loss = np.sum(error_map[trimap == 128])
61 |
62 | return loss / 1000.
63 |
64 |
65 | def compute_connectivity_error(pred, target, trimap, step):
66 | pred = pred / 255.0
67 | target = target / 255.0
68 | h, w = pred.shape
69 |
70 | thresh_steps = list(np.arange(0, 1 + step, step))
71 | l_map = np.ones_like(pred, dtype=float) * -1
72 | for i in range(1, len(thresh_steps)):
73 | pred_alpha_thresh = (pred >= thresh_steps[i]).astype(int)
74 | target_alpha_thresh = (target >= thresh_steps[i]).astype(int)
75 |
76 | omega = getLargestCC(pred_alpha_thresh * target_alpha_thresh).astype(int)
77 | flag = ((l_map == -1) & (omega == 0)).astype(int)
78 | l_map[flag == 1] = thresh_steps[i - 1]
79 |
80 | l_map[l_map == -1] = 1
81 |
82 | pred_d = pred - l_map
83 | target_d = target - l_map
84 | pred_phi = 1 - pred_d * (pred_d >= 0.15).astype(int)
85 | target_phi = 1 - target_d * (target_d >= 0.15).astype(int)
86 | loss = np.sum(np.abs(pred_phi - target_phi)[trimap == 128])
87 |
88 | return loss / 1000.
89 |
90 | def getLargestCC(segmentation):
91 | labels = label(segmentation, connectivity=1)
92 | largestCC = labels == np.argmax(np.bincount(labels.flat))
93 | return largestCC
94 |
95 |
96 | def compute_mse_loss_torch(pred, target, trimap):
97 | error_map = (pred - target) / 255.0
98 | # rewrite the loss with torch
99 | # loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
100 | loss = torch.sum((error_map ** 2) * (trimap == 128).float()) / (torch.sum(trimap == 128).float() + 1e-8)
101 |
102 | return loss
103 |
104 |
105 | def compute_sad_loss_torch(pred, target, trimap):
106 | # rewrite the map with torch
107 | # error_map = np.abs((pred - target) / 255.0)
108 | error_map = torch.abs((pred - target) / 255.0)
109 | # loss = np.sum(error_map * (trimap == 128))
110 | loss = torch.sum(error_map * (trimap == 128).float())
111 |
112 | return loss / 1000
113 |
--------------------------------------------------------------------------------
/icm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiny-smart/in-context-matting/a58bcf4b948babcf5f1c2e1c41e2fa040bc53d1e/icm/data/__init__.py
--------------------------------------------------------------------------------
/icm/data/data_generator.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import math
4 | import numbers
5 | import random
6 | import logging
7 | import numpy as np
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 | from torch.nn import functional as F
12 | from torchvision import transforms
13 |
14 | from icm.util import instantiate_from_config
15 | from icm.data.image_file import get_dir_ext
16 | # one-hot or class, choice: [3, 1]
17 | TRIMAP_CHANNEL = 1
18 |
19 | RANDOM_INTERP = True
20 |
21 | CUTMASK_PROB = 0
22 |
23 | interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR,
24 | cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
25 |
26 |
27 | def maybe_random_interp(cv2_interp):
28 | if RANDOM_INTERP:
29 | return np.random.choice(interp_list)
30 | else:
31 | return cv2_interp
32 |
33 |
34 | class ToTensor(object):
35 | """
36 | Convert ndarrays in sample to Tensors with normalization.
37 | """
38 |
39 | def __init__(self, phase="test", norm_type='imagenet'):
40 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
41 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
42 | self.phase = phase
43 | self.norm_type = norm_type
44 |
45 | def __call__(self, sample):
46 | # convert GBR images to RGB
47 | image, alpha, trimap, mask = sample['image'][:, :, ::-
48 | 1], sample['alpha'], sample['trimap'], sample['mask']
49 |
50 | alpha[alpha < 0] = 0
51 | alpha[alpha > 1] = 1
52 |
53 | # swap color axis because
54 | # numpy image: H x W x C
55 | # torch image: C X H X W
56 | image = image.transpose((2, 0, 1)).astype(np.float32)
57 | alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
58 | trimap[trimap < 85] = 0
59 | trimap[trimap >= 170] = 1
60 | trimap[trimap >= 85] = 0.5
61 |
62 | mask = np.expand_dims(mask.astype(np.float32), axis=0)
63 |
64 | if self.phase == "train":
65 | # convert GBR images to RGB
66 | fg = sample['fg'][:, :, ::-
67 | 1].transpose((2, 0, 1)).astype(np.float32) / 255.
68 | sample['fg'] = torch.from_numpy(fg).sub_(self.mean).div_(self.std)
69 | bg = sample['bg'][:, :, ::-
70 | 1].transpose((2, 0, 1)).astype(np.float32) / 255.
71 | sample['bg'] = torch.from_numpy(bg).sub_(self.mean).div_(self.std)
72 | # del sample['image_name']
73 |
74 | sample['image'], sample['alpha'], sample['trimap'] = \
75 | torch.from_numpy(image), torch.from_numpy(
76 | alpha), torch.from_numpy(trimap)
77 |
78 | if self.norm_type == 'imagenet':
79 | # normalize image
80 | sample['image'] /= 255.
81 |
82 | sample['image'] = sample['image'].sub_(self.mean).div_(self.std)
83 | elif self.norm_type == 'sd':
84 | sample['image'] = sample['image'].to(dtype=torch.float32) / 127.5 - 1.0
85 | else:
86 | raise NotImplementedError(
87 | "norm_type {} is not implemented".format(self.norm_type))
88 |
89 | if TRIMAP_CHANNEL == 3:
90 | sample['trimap'] = F.one_hot(
91 | sample['trimap'], num_classes=3).permute(2, 0, 1).float()
92 | elif TRIMAP_CHANNEL == 1:
93 | sample['trimap'] = sample['trimap'][None, ...].float()
94 | else:
95 | raise NotImplementedError("TRIMAP_CHANNEL can only be 3 or 1")
96 |
97 | sample['mask'] = torch.from_numpy(mask).float()
98 |
99 | return sample
100 |
101 |
102 | class RandomAffine(object):
103 | """
104 | Random affine translation
105 | """
106 |
107 | def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
108 | if isinstance(degrees, numbers.Number):
109 | if degrees < 0:
110 | raise ValueError(
111 | "If degrees is a single number, it must be positive.")
112 | self.degrees = (-degrees, degrees)
113 | else:
114 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
115 | "degrees should be a list or tuple and it must be of length 2."
116 | self.degrees = degrees
117 |
118 | if translate is not None:
119 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
120 | "translate should be a list or tuple and it must be of length 2."
121 | for t in translate:
122 | if not (0.0 <= t <= 1.0):
123 | raise ValueError(
124 | "translation values should be between 0 and 1")
125 | self.translate = translate
126 |
127 | if scale is not None:
128 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
129 | "scale should be a list or tuple and it must be of length 2."
130 | for s in scale:
131 | if s <= 0:
132 | raise ValueError("scale values should be positive")
133 | self.scale = scale
134 |
135 | if shear is not None:
136 | if isinstance(shear, numbers.Number):
137 | if shear < 0:
138 | raise ValueError(
139 | "If shear is a single number, it must be positive.")
140 | self.shear = (-shear, shear)
141 | else:
142 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
143 | "shear should be a list or tuple and it must be of length 2."
144 | self.shear = shear
145 | else:
146 | self.shear = shear
147 |
148 | self.resample = resample
149 | self.fillcolor = fillcolor
150 | self.flip = flip
151 |
152 | @staticmethod
153 | def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
154 | """Get parameters for affine transformation
155 |
156 | Returns:
157 | sequence: params to be passed to the affine transformation
158 | """
159 | angle = random.uniform(degrees[0], degrees[1])
160 | if translate is not None:
161 | max_dx = translate[0] * img_size[0]
162 | max_dy = translate[1] * img_size[1]
163 | translations = (np.round(random.uniform(-max_dx, max_dx)),
164 | np.round(random.uniform(-max_dy, max_dy)))
165 | else:
166 | translations = (0, 0)
167 |
168 | if scale_ranges is not None:
169 | scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
170 | random.uniform(scale_ranges[0], scale_ranges[1]))
171 | else:
172 | scale = (1.0, 1.0)
173 |
174 | if shears is not None:
175 | shear = random.uniform(shears[0], shears[1])
176 | else:
177 | shear = 0.0
178 |
179 | if flip is not None:
180 | flip = (np.random.rand(2) < flip).astype(np.int) * 2 - 1
181 |
182 | return angle, translations, scale, shear, flip
183 |
184 | def __call__(self, sample):
185 | fg, alpha = sample['fg'], sample['alpha']
186 | rows, cols, ch = fg.shape
187 | if np.maximum(rows, cols) < 1024:
188 | params = self.get_params(
189 | (0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
190 | else:
191 | params = self.get_params(
192 | self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
193 |
194 | center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
195 | M = self._get_inverse_affine_matrix(center, *params)
196 | M = np.array(M).reshape((2, 3))
197 |
198 | fg = cv2.warpAffine(fg, M, (cols, rows),
199 | flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
200 | alpha = cv2.warpAffine(alpha, M, (cols, rows),
201 | flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
202 |
203 | sample['fg'], sample['alpha'] = fg, alpha
204 |
205 | return sample
206 |
207 | @ staticmethod
208 | def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
209 | # Helper method to compute inverse matrix for affine transformation
210 |
211 | # As it is explained in PIL.Image.rotate
212 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
213 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
214 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
215 | # RSS is rotation with scale and shear matrix
216 | # It is different from the original function in torchvision
217 | # The order are changed to flip -> scale -> rotation -> shear
218 | # x and y have different scale factors
219 | # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y 0]
220 | # [ sin(a)*scale_x*f cos(a)*scale_y 0]
221 | # [ 0 0 1]
222 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
223 |
224 | angle = math.radians(angle)
225 | shear = math.radians(shear)
226 | scale_x = 1.0 / scale[0] * flip[0]
227 | scale_y = 1.0 / scale[1] * flip[1]
228 |
229 | # Inverted rotation matrix with scale and shear
230 | d = math.cos(angle + shear) * math.cos(angle) + \
231 | math.sin(angle + shear) * math.sin(angle)
232 | matrix = [
233 | math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
234 | -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
235 | ]
236 | matrix = [m / d for m in matrix]
237 |
238 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
239 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + \
240 | matrix[1] * (-center[1] - translate[1])
241 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + \
242 | matrix[4] * (-center[1] - translate[1])
243 |
244 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1
245 | matrix[2] += center[0]
246 | matrix[5] += center[1]
247 |
248 | return matrix
249 |
250 |
251 | class RandomJitter(object):
252 | """
253 | Random change the hue of the image
254 | """
255 |
256 | def __call__(self, sample):
257 | sample_ori = sample.copy()
258 | fg, alpha = sample['fg'], sample['alpha']
259 | # if alpha is all 0 skip
260 | if np.all(alpha == 0):
261 | return sample_ori
262 | # convert to HSV space, convert to float32 image to keep precision during space conversion.
263 | fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
264 | # Hue noise
265 | hue_jitter = np.random.randint(-40, 40)
266 | fg[:, :, 0] = np.remainder(
267 | fg[:, :, 0].astype(np.float32) + hue_jitter, 360)
268 | # Saturation noise
269 | sat_bar = fg[:, :, 1][alpha > 0].mean()
270 | if np.isnan(sat_bar):
271 | return sample_ori
272 | sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
273 | sat = fg[:, :, 1]
274 | sat = np.abs(sat + sat_jitter)
275 | sat[sat > 1] = 2 - sat[sat > 1]
276 | fg[:, :, 1] = sat
277 | # Value noise
278 | val_bar = fg[:, :, 2][alpha > 0].mean()
279 | if np.isnan(val_bar):
280 | return sample_ori
281 | val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
282 | val = fg[:, :, 2]
283 | val = np.abs(val + val_jitter)
284 | val[val > 1] = 2 - val[val > 1]
285 | fg[:, :, 2] = val
286 | # convert back to BGR space
287 | fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR)
288 | sample['fg'] = fg*255
289 |
290 | return sample
291 |
292 |
293 | class RandomHorizontalFlip(object):
294 | """
295 | Random flip image and label horizontally
296 | """
297 |
298 | def __init__(self, prob=0.5):
299 | self.prob = prob
300 |
301 | def __call__(self, sample):
302 | fg, alpha = sample['fg'], sample['alpha']
303 | if np.random.uniform(0, 1) < self.prob:
304 | fg = cv2.flip(fg, 1)
305 | alpha = cv2.flip(alpha, 1)
306 | sample['fg'], sample['alpha'] = fg, alpha
307 |
308 | return sample
309 |
310 |
311 | class RandomCrop(object):
312 | """
313 | Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
314 |
315 | :param output_size (tuple or int): Desired output size. If int, square crop
316 | is made.
317 | """
318 |
319 | def __init__(self, output_size):
320 | assert isinstance(output_size, (int, tuple))
321 | if isinstance(output_size, int):
322 | self.output_size = (output_size, output_size)
323 | else:
324 | assert len(output_size) == 2
325 | self.output_size = output_size
326 | self.margin = output_size[0] // 2
327 | self.logger = logging.getLogger("Logger")
328 |
329 | def __call__(self, sample):
330 | fg, alpha, trimap, mask, name = sample['fg'], sample[
331 | 'alpha'], sample['trimap'], sample['mask'], sample['image_name']
332 | bg = sample['bg']
333 | h, w = trimap.shape
334 | bg = cv2.resize(
335 | bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
336 | if w < self.output_size[0]+1 or h < self.output_size[1]+1:
337 | ratio = 1.1*self.output_size[0] / \
338 | h if h < w else 1.1*self.output_size[1]/w
339 | # self.logger.warning("Size of {} is {}.".format(name, (h, w)))
340 | while h < self.output_size[0]+1 or w < self.output_size[1]+1:
341 | fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)),
342 | interpolation=maybe_random_interp(cv2.INTER_NEAREST))
343 | alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
344 | interpolation=maybe_random_interp(cv2.INTER_NEAREST))
345 | trimap = cv2.resize(
346 | trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
347 | bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)),
348 | interpolation=maybe_random_interp(cv2.INTER_CUBIC))
349 | mask = cv2.resize(
350 | mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
351 | h, w = trimap.shape
352 | small_trimap = cv2.resize(
353 | trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
354 | unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
355 | self.margin//4:(w-self.margin)//4] == 128)))
356 | unknown_num = len(unknown_list)
357 | if len(unknown_list) < 10:
358 | left_top = (np.random.randint(
359 | 0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
360 | else:
361 | idx = np.random.randint(unknown_num)
362 | left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
363 |
364 | fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0],
365 | left_top[1]:left_top[1]+self.output_size[1], :]
366 | alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0],
367 | left_top[1]:left_top[1]+self.output_size[1]]
368 | bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0],
369 | left_top[1]:left_top[1]+self.output_size[1], :]
370 | trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0],
371 | left_top[1]:left_top[1]+self.output_size[1]]
372 | mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0],
373 | left_top[1]:left_top[1]+self.output_size[1]]
374 |
375 | if len(np.where(trimap == 128)[0]) == 0:
376 | self.logger.error("{} does not have enough unknown area for crop. Resized to target size."
377 | "left_top: {}".format(name, left_top))
378 | fg_crop = cv2.resize(
379 | fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
380 | alpha_crop = cv2.resize(
381 | alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
382 | trimap_crop = cv2.resize(
383 | trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
384 | bg_crop = cv2.resize(
385 | bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC))
386 | mask_crop = cv2.resize(
387 | mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
388 |
389 | sample.update({'fg': fg_crop, 'alpha': alpha_crop,
390 | 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop})
391 | return sample
392 |
393 |
394 | class CropResize(object):
395 | # crop the image to square, and resize to target size
396 | def __init__(self, output_size):
397 | assert isinstance(output_size, (int, tuple))
398 | if isinstance(output_size, int):
399 | self.output_size = (output_size, output_size)
400 | else:
401 | assert len(output_size) == 2
402 | self.output_size = output_size
403 |
404 | def __call__(self, sample):
405 | img, alpha, trimap, mask = sample['image'], sample['alpha'], sample['trimap'], sample['mask']
406 | # crop the image to square, and resize to target size
407 |
408 | h, w = img.shape[:2]
409 | if h == w:
410 | img_crop = cv2.resize(
411 | img, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST))
412 | alpha_crop = cv2.resize(
413 | alpha, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST))
414 | trimap_crop = cv2.resize(
415 | trimap, self.output_size, interpolation=cv2.INTER_NEAREST)
416 | mask_crop = cv2.resize(
417 | mask, self.output_size, interpolation=cv2.INTER_NEAREST)
418 | elif h > w:
419 | margin = (h-w)//2
420 | img = img[margin:margin+w, :]
421 | alpha = alpha[margin:margin+w, :]
422 | trimap = trimap[margin:margin+w, :]
423 | mask = mask[margin:margin+w, :]
424 | img_crop = cv2.resize(
425 | img, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST))
426 | alpha_crop = cv2.resize(
427 | alpha, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST))
428 | trimap_crop = cv2.resize(
429 | trimap, self.output_size, interpolation=cv2.INTER_NEAREST)
430 | mask_crop = cv2.resize(
431 | mask, self.output_size, interpolation=cv2.INTER_NEAREST)
432 | else:
433 | margin = (w-h)//2
434 | img = img[:, margin:margin+h]
435 | alpha = alpha[:, margin:margin+h]
436 | trimap = trimap[:, margin:margin+h]
437 | mask = mask[:, margin:margin+h]
438 | img_crop = cv2.resize(
439 | img, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST))
440 | alpha_crop = cv2.resize(
441 | alpha, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST))
442 | trimap_crop = cv2.resize(
443 | trimap, self.output_size, interpolation=cv2.INTER_NEAREST)
444 | mask_crop = cv2.resize(
445 | mask, self.output_size, interpolation=cv2.INTER_NEAREST)
446 | sample.update({'image': img_crop, 'alpha': alpha_crop,
447 | 'trimap': trimap_crop, 'mask': mask_crop})
448 | return sample
449 |
450 |
451 | class OriginScale(object):
452 | def __call__(self, sample):
453 | h, w = sample["alpha_shape"]
454 |
455 | if h % 32 == 0 and w % 32 == 0:
456 | return sample
457 |
458 | target_h = 32 * ((h - 1) // 32 + 1)
459 | target_w = 32 * ((w - 1) // 32 + 1)
460 | pad_h = target_h - h
461 | pad_w = target_w - w
462 |
463 | padded_image = np.pad(
464 | sample['image'], ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
465 | padded_trimap = np.pad(
466 | sample['trimap'], ((0, pad_h), (0, pad_w)), mode="reflect")
467 | padded_mask = np.pad(
468 | sample['mask'], ((0, pad_h), (0, pad_w)), mode="reflect")
469 |
470 | sample['image'] = padded_image
471 | sample['trimap'] = padded_trimap
472 | sample['mask'] = padded_mask
473 |
474 | return sample
475 |
476 |
477 | class GenMask(object):
478 | def __init__(self):
479 | self.erosion_kernels = [None] + [cv2.getStructuringElement(
480 | cv2.MORPH_ELLIPSE, (size, size)) for size in range(1, 30)]
481 |
482 | def __call__(self, sample):
483 | alpha_ori = sample['alpha']
484 | h, w = alpha_ori.shape
485 |
486 | max_kernel_size = 30
487 | alpha = cv2.resize(alpha_ori, (640, 640),
488 | interpolation=maybe_random_interp(cv2.INTER_NEAREST))
489 |
490 | # generate trimap
491 | fg_mask = (alpha + 1e-5).astype(np.int).astype(np.uint8)
492 | bg_mask = (1 - alpha + 1e-5).astype(np.int).astype(np.uint8)
493 | fg_mask = cv2.erode(
494 | fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
495 | bg_mask = cv2.erode(
496 | bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
497 |
498 | fg_width = np.random.randint(1, 30)
499 | bg_width = np.random.randint(1, 30)
500 | fg_mask = (alpha + 1e-5).astype(np.int).astype(np.uint8)
501 | bg_mask = (1 - alpha + 1e-5).astype(np.int).astype(np.uint8)
502 | fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width])
503 | bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width])
504 |
505 | trimap = np.ones_like(alpha) * 128
506 | trimap[fg_mask == 1] = 255
507 | trimap[bg_mask == 1] = 0
508 |
509 | trimap = cv2.resize(trimap, (w, h), interpolation=cv2.INTER_NEAREST)
510 | sample['trimap'] = trimap
511 |
512 | # generate mask
513 | low = 0.01
514 | high = 1.0
515 | thres = random.random() * (high - low) + low
516 | seg_mask = (alpha >= thres).astype(np.int).astype(np.uint8)
517 | random_num = random.randint(0, 3)
518 | if random_num == 0:
519 | seg_mask = cv2.erode(
520 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
521 | elif random_num == 1:
522 | seg_mask = cv2.dilate(
523 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
524 | elif random_num == 2:
525 | seg_mask = cv2.erode(
526 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
527 | seg_mask = cv2.dilate(
528 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
529 | elif random_num == 3:
530 | seg_mask = cv2.dilate(
531 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
532 | seg_mask = cv2.erode(
533 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
534 |
535 | seg_mask = cv2.resize(
536 | seg_mask, (w, h), interpolation=cv2.INTER_NEAREST)
537 | sample['mask'] = seg_mask
538 |
539 | return sample
540 |
541 |
542 | class Composite(object):
543 | def __call__(self, sample):
544 | fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
545 | alpha[alpha < 0] = 0
546 | alpha[alpha > 1] = 1
547 | fg[fg < 0] = 0
548 | fg[fg > 255] = 255
549 | bg[bg < 0] = 0
550 | bg[bg > 255] = 255
551 |
552 | image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None])
553 | sample['image'] = image
554 | return sample
555 |
556 |
557 | class CutMask(object):
558 | def __init__(self, perturb_prob=0):
559 | self.perturb_prob = perturb_prob
560 |
561 | def __call__(self, sample):
562 | if np.random.rand() < self.perturb_prob:
563 | return sample
564 |
565 | mask = sample['mask'] # H x W, trimap 0--255, segmask 0--1, alpha 0--1
566 | h, w = mask.shape
567 | perturb_size_h, perturb_size_w = random.randint(
568 | h // 4, h // 2), random.randint(w // 4, w // 2)
569 | x = random.randint(0, h - perturb_size_h)
570 | y = random.randint(0, w - perturb_size_w)
571 | x1 = random.randint(0, h - perturb_size_h)
572 | y1 = random.randint(0, w - perturb_size_w)
573 |
574 | mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1 +
575 | perturb_size_h, y1:y1+perturb_size_w].copy()
576 |
577 | sample['mask'] = mask
578 | return sample
579 |
580 |
581 | class DataGenerator(Dataset):
582 | def __init__(self, data, crop_size=512, phase="train"):
583 | self.phase = phase
584 | self.crop_size = crop_size
585 | self.alpha = data.alpha
586 |
587 | if self.phase == "train":
588 | self.fg = data.fg
589 | self.bg = data.bg
590 | self.merged = []
591 | self.trimap = []
592 |
593 | else:
594 | self.fg = []
595 | self.bg = []
596 | self.merged = data.merged
597 | self.trimap = data.trimap
598 |
599 | train_trans = [
600 | RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
601 | GenMask(),
602 | CutMask(perturb_prob=CUTMASK_PROB),
603 | RandomCrop((self.crop_size, self.crop_size)),
604 | RandomJitter(),
605 | Composite(),
606 | ToTensor(phase="train")]
607 |
608 | test_trans = [OriginScale(), ToTensor()]
609 |
610 | self.transform = {
611 | 'train':
612 | transforms.Compose(train_trans),
613 | 'val':
614 | transforms.Compose([
615 | OriginScale(),
616 | ToTensor()
617 | ]),
618 | 'test':
619 | transforms.Compose(test_trans)
620 | }[phase]
621 |
622 | self.fg_num = len(self.fg)
623 |
624 | def __getitem__(self, idx):
625 | if self.phase == "train":
626 | fg = cv2.imread(self.fg[idx % self.fg_num])
627 | alpha = cv2.imread(
628 | self.alpha[idx % self.fg_num], 0).astype(np.float32)/255
629 | bg = cv2.imread(self.bg[idx], 1)
630 |
631 | fg, alpha = self._composite_fg(fg, alpha, idx)
632 |
633 | image_name = os.path.split(self.fg[idx % self.fg_num])[-1]
634 | sample = {'fg': fg, 'alpha': alpha,
635 | 'bg': bg, 'image_name': image_name}
636 |
637 | else:
638 | image = cv2.imread(self.merged[idx])
639 | alpha = cv2.imread(self.alpha[idx], 0)/255.
640 | trimap = cv2.imread(self.trimap[idx], 0)
641 | mask = (trimap >= 170).astype(np.float32)
642 | image_name = os.path.split(self.merged[idx])[-1]
643 |
644 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap,
645 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape}
646 |
647 | sample = self.transform(sample)
648 |
649 | return sample
650 |
651 | def _composite_fg(self, fg, alpha, idx):
652 |
653 | if np.random.rand() < 0.5:
654 | idx2 = np.random.randint(self.fg_num) + idx
655 | fg2 = cv2.imread(self.fg[idx2 % self.fg_num])
656 | alpha2 = cv2.imread(
657 | self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255.
658 | h, w = alpha.shape
659 | fg2 = cv2.resize(
660 | fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
661 | alpha2 = cv2.resize(
662 | alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
663 |
664 | alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
665 | if np.any(alpha_tmp < 1):
666 | fg = fg.astype(
667 | np.float32) * alpha[:, :, None] + fg2.astype(np.float32) * (1 - alpha[:, :, None])
668 | # The overlap of two 50% transparency should be 25%
669 | alpha = alpha_tmp
670 | fg = fg.astype(np.uint8)
671 |
672 | if np.random.rand() < 0.25:
673 | fg = cv2.resize(
674 | fg, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
675 | alpha = cv2.resize(
676 | alpha, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
677 |
678 | return fg, alpha
679 |
680 | def __len__(self):
681 | if self.phase == "train":
682 | return len(self.bg)
683 | else:
684 | return len(self.alpha)
685 |
686 |
687 | class MultiDataGeneratorDoubleSet(Dataset):
688 | # divide a dataset into train set and validation set
689 | def __init__(self, data, crop_size=1024, phase="train",norm_type='imagenet'):
690 | self.phase = phase
691 | self.crop_size = crop_size
692 | data = instantiate_from_config(data)
693 |
694 | if self.phase == "train":
695 | self.alpha = data.alpha_train
696 | self.merged = data.merged_train
697 | self.trimap = data.trimap_train
698 |
699 | elif self.phase == "val":
700 | self.alpha = data.alpha_val
701 | self.merged = data.merged_val
702 | self.trimap = data.trimap_val
703 |
704 | train_trans = [
705 | # RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
706 |
707 | # CutMask(perturb_prob=CUTMASK_PROB),
708 | CropResize((self.crop_size, self.crop_size)),
709 | # RandomJitter(),
710 | ToTensor(phase="val",norm_type=norm_type)]
711 |
712 | # val_trans = [ OriginScale(), ToTensor() ]
713 | val_trans = [CropResize(
714 | (self.crop_size, self.crop_size)), ToTensor(phase="val",norm_type=norm_type)]
715 |
716 | self.transform = {
717 | 'train':
718 | transforms.Compose(train_trans),
719 |
720 | 'val':
721 | transforms.Compose(val_trans)
722 | }[phase]
723 |
724 | self.alpha_num = len(self.alpha)
725 |
726 | def __getitem__(self, idx):
727 |
728 | image = cv2.imread(self.merged[idx])
729 | alpha = cv2.imread(self.alpha[idx], 0)/255.
730 | trimap = cv2.imread(self.trimap[idx], 0).astype(np.float32)
731 | mask = (trimap >= 170).astype(np.float32)
732 | image_name = os.path.split(self.merged[idx])[-1]
733 |
734 | dataset_name = self.get_dataset_name(image_name)
735 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap,
736 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name}
737 |
738 | sample = self.transform(sample)
739 |
740 | return sample
741 |
742 | def __len__(self):
743 | return len(self.alpha)
744 |
745 | def get_dataset_name(self, image_name):
746 | image_name = image_name.split('.')[0]
747 | if image_name.startswith('o_'):
748 | return 'AIM'
749 | elif image_name.endswith('_o') or image_name.endswith('_5k'):
750 | return 'PPM'
751 | elif image_name.startswith('m_'):
752 | return 'AM2k'
753 | elif image_name.endswith('_input'):
754 | return 'RWP636'
755 | elif image_name.startswith('p_'):
756 | return 'P3M'
757 |
758 | else:
759 | # raise ValueError('image_name {} not recognized'.format(image_name))
760 | return 'RM1k'
761 |
762 | class ContextDataset(Dataset):
763 | # divide a dataset into train set and validation set
764 | def __init__(self, data, crop_size=1024, phase="train",norm_type='imagenet'):
765 | self.phase = phase
766 | self.crop_size = crop_size
767 | data = instantiate_from_config(data)
768 |
769 | if self.phase == "train":
770 | self.dataset = data.dataset_train
771 | self.image_class_dict = data.image_class_dict_train
772 |
773 | elif self.phase == "val":
774 | self.dataset = data.dataset_val
775 | self.image_class_dict = data.image_class_dict_val
776 |
777 | # dict to list
778 | for key, value in self.image_class_dict.items():
779 | self.image_class_dict[key] = list(value.items())
780 | self.dataset = list(self.dataset.items())
781 |
782 | train_trans = [
783 | # RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
784 |
785 | # CutMask(perturb_prob=CUTMASK_PROB),
786 | CropResize((self.crop_size, self.crop_size)),
787 | # RandomJitter(),
788 | ToTensor(phase="val",norm_type=norm_type)]
789 |
790 | # val_trans = [ OriginScale(), ToTensor() ]
791 | val_trans = [CropResize(
792 | (self.crop_size, self.crop_size)), ToTensor(phase="val",norm_type=norm_type)]
793 |
794 | self.transform = {
795 | 'train':
796 | transforms.Compose(train_trans),
797 |
798 | 'val':
799 | transforms.Compose(val_trans)
800 | }[phase]
801 |
802 | def __getitem__(self, idx):
803 | cv2.setNumThreads(0)
804 |
805 | image_name, image_info = self.dataset[idx]
806 |
807 | # get image sample
808 | dataset_name = image_info['dataset_name']
809 | image_sample = self.get_sample(image_name, dataset_name)
810 |
811 | # get context image
812 | class_name = str(
813 | image_info['class'])+'-'+str(image_info['sub_class'])+'-'+str(image_info['HalfOrFull'])
814 | (context_image_name, context_dataset_name) = self.image_class_dict[class_name][np.random.randint(
815 | len(self.image_class_dict[class_name]))]
816 | context_image_sample = self.get_sample(
817 | context_image_name, context_dataset_name)
818 |
819 | # merge image and context
820 | image_sample['context_image'] = context_image_sample['image']
821 | image_sample['context_guidance'] = context_image_sample['alpha']
822 | image_sample['context_image_name'] = context_image_sample['image_name']
823 |
824 | return image_sample
825 |
826 | def __len__(self):
827 | return len(self.dataset)
828 |
829 | def get_sample(self, image_name, dataset_name):
830 | cv2.setNumThreads(0)
831 | image_dir, label_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext = get_dir_ext(
832 | dataset_name)
833 | image_path = os.path.join(image_dir, image_name + merged_ext) if 'open-images' not in dataset_name else os.path.join(
834 | image_dir, image_name.split('_')[0] + merged_ext)
835 | label_path = os.path.join(label_dir, image_name + alpha_ext)
836 | trimap_path = os.path.join(trimap_dir, image_name + trimap_ext)
837 |
838 | image = cv2.imread(image_path)
839 | alpha = cv2.imread(label_path, 0)/255.
840 | trimap = cv2.imread(trimap_path, 0).astype(np.float32)
841 | mask = (trimap >= 170).astype(np.float32)
842 | image_name = os.path.split(image_path)[-1]
843 |
844 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap,
845 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name}
846 |
847 | sample = self.transform(sample)
848 | return sample
849 | def get_sample_example(self, image_dir, mask_dir, img_list, mask_list, index):
850 | image = cv2.imread(os.path.join(image_dir, img_list[index]))
851 | alpha = cv2.imread(os.path.join(mask_dir, mask_list[index]), 0)/255.
852 |
853 | # resize alpha to image size
854 | alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]))
855 |
856 | # unused
857 | trimap = cv2.imread(os.path.join(mask_dir, mask_list[index]))/255.
858 | mask = (trimap >= 170).astype(np.float32)
859 | image_name = ''
860 | dataset_name = ''
861 |
862 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap,
863 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name}
864 |
865 | sample = self.transform(sample)
866 | return sample['image'], sample['alpha']
867 |
868 | class InContextDataset(Dataset):
869 | # divide a dataset into train set and validation set
870 | def __init__(self, data, crop_size=1024, phase="train",norm_type='imagenet'):
871 | self.phase = phase
872 | self.crop_size = crop_size
873 | data = instantiate_from_config(data)
874 |
875 | if self.phase == "train":
876 | self.dataset = data.dataset_train
877 | self.image_class_dict = data.image_class_dict_train
878 |
879 | elif self.phase == "val":
880 | self.dataset = data.dataset_val
881 | self.image_class_dict = data.image_class_dict_val
882 |
883 | # dict to list
884 | for key, value in self.image_class_dict.items():
885 | self.image_class_dict[key] = list(value.items())
886 | self.dataset = list(self.dataset.items())
887 |
888 | train_trans = [
889 | # RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
890 |
891 | # CutMask(perturb_prob=CUTMASK_PROB),
892 | CropResize((self.crop_size, self.crop_size)),
893 | # RandomJitter(),
894 | ToTensor(phase="val",norm_type=norm_type)]
895 |
896 | # val_trans = [ OriginScale(), ToTensor() ]
897 | val_trans = [CropResize(
898 | (self.crop_size, self.crop_size)), ToTensor(phase="val",norm_type=norm_type)]
899 |
900 | self.transform = {
901 | 'train':
902 | transforms.Compose(train_trans),
903 |
904 | 'val':
905 | transforms.Compose(val_trans)
906 | }[phase]
907 |
908 | def __getitem__(self, idx):
909 | cv2.setNumThreads(0)
910 |
911 | image_name, image_info = self.dataset[idx]
912 |
913 | # get image sample
914 | dataset_name = image_info['dataset_name']
915 | image_sample = self.get_sample(image_name, dataset_name)
916 |
917 | # get context image
918 | class_name = str(
919 | image_info['class'])+'-'+str(image_info['sub_class'])+'-'+str(image_info['HalfOrFull'])
920 |
921 | context_set = self.image_class_dict[class_name]
922 | if len(context_set) > 2:
923 | # delet image_name from context_set (dict)
924 | context_set = [x for x in context_set if x[0] != image_name]
925 |
926 | (reference_image_name, context_dataset_name) = context_set[np.random.randint(
927 | len(context_set))]
928 | reference_image_sample = self.get_sample(
929 | reference_image_name, context_dataset_name)
930 |
931 | # merge image and context
932 | image_sample['reference_image'] = reference_image_sample['source_image']
933 | image_sample['guidance_on_reference_image'] = reference_image_sample['alpha']
934 | image_sample['reference_image_name'] = reference_image_sample['image_name']
935 |
936 | return image_sample
937 |
938 | def __len__(self):
939 | return len(self.dataset)
940 |
941 | def get_sample(self, image_name, dataset_name):
942 | cv2.setNumThreads(0)
943 | image_dir, label_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext = get_dir_ext(
944 | dataset_name)
945 | image_path = os.path.join(image_dir, image_name + merged_ext) if 'open-images' not in dataset_name else os.path.join(
946 | image_dir, image_name.split('_')[0] + merged_ext)
947 | label_path = os.path.join(label_dir, image_name + alpha_ext)
948 | trimap_path = os.path.join(trimap_dir, image_name + trimap_ext)
949 |
950 | image = cv2.imread(image_path)
951 | alpha = cv2.imread(label_path, 0)/255.
952 | trimap = cv2.imread(trimap_path, 0).astype(np.float32)
953 | mask = (trimap >= 170).astype(np.float32)
954 | image_name = os.path.split(image_path)[-1]
955 |
956 | if 'open-images' in dataset_name:
957 | # 量化alpha to 0,1
958 | alpha[alpha < 0.5] = 0
959 | alpha[alpha >= 0.5] = 1
960 |
961 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap,
962 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name}
963 |
964 | sample = self.transform(sample)
965 |
966 | # modify 'image' to 'source_image'
967 | sample['source_image'] = sample['image']
968 | del sample['image']
969 |
970 | return sample
971 |
972 |
--------------------------------------------------------------------------------
/icm/data/data_module.py:
--------------------------------------------------------------------------------
1 |
2 | import pytorch_lightning as pl
3 | from icm.util import instantiate_from_config
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | import torch
7 | def worker_init_fn(worker_id):
8 | # set numpy random seed with torch.randint so each worker has a different seed
9 | np.random.seed(torch.randint(0, 2**32 - 1, size=(1,)).item())
10 |
11 | class DataModuleFromConfig(pl.LightningDataModule):
12 | def __init__(self, train=None, validation=None, test=None, predict=None, num_workers=None,
13 | batch_size=None, shuffle_train=False,batch_size_val=None):
14 | super().__init__()
15 | self.batch_size = batch_size
16 | self.batch_size_val = batch_size_val
17 | self.dataset_configs = dict()
18 | self.num_workers = num_workers if num_workers is not None else batch_size * 2
19 | self.shuffle_train = shuffle_train
20 | # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method
21 | if train is not None:
22 | self.dataset_configs["train"] = train
23 | self.train_dataloader = self._train_dataloader
24 | if validation is not None:
25 | self.dataset_configs["validation"] = validation
26 | self.val_dataloader = self._val_dataloader
27 |
28 | # for debugging
29 | # self.setup()
30 |
31 |
32 | def setup(self, stage=None):
33 | # Instantiate datasets from the dataset configs
34 | self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
35 |
36 | def _train_dataloader(self):
37 | return DataLoader(self.datasets["train"],
38 | batch_size=self.batch_size,
39 | num_workers=self.num_workers,
40 | shuffle=self.shuffle_train,
41 | worker_init_fn=worker_init_fn,)
42 |
43 | def _val_dataloader(self):
44 |
45 | return DataLoader(self.datasets["validation"],
46 | batch_size=self.batch_size if self.batch_size_val is None else self.batch_size_val,
47 | num_workers=self.num_workers,
48 | shuffle=True,
49 | worker_init_fn=worker_init_fn,
50 | )
51 |
52 | def prepare_data(self):
53 | return super().prepare_data()
54 |
--------------------------------------------------------------------------------
/icm/data/image_file.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import glob
4 | import functools
5 | import numpy as np
6 |
7 |
8 | class ImageFile(object):
9 | def __init__(self, phase='train'):
10 | self.phase = phase
11 | self.rng = np.random.RandomState(0)
12 |
13 | def _get_valid_names(self, *dirs, shuffle=True):
14 | # Extract valid names
15 | name_sets = [self._get_name_set(d) for d in dirs]
16 |
17 | # Reduce
18 | def _join_and(a, b):
19 | return a & b
20 |
21 | valid_names = list(functools.reduce(_join_and, name_sets))
22 | if shuffle:
23 | self.rng.shuffle(valid_names)
24 |
25 | return valid_names
26 |
27 | @staticmethod
28 | def _get_name_set(dir_name):
29 | path_list = glob.glob(os.path.join(dir_name, '*'))
30 | name_set = set()
31 | for path in path_list:
32 | name = os.path.basename(path)
33 | name = os.path.splitext(name)[0]
34 | name_set.add(name)
35 | return name_set
36 |
37 | @staticmethod
38 | def _list_abspath(data_dir, ext, data_list):
39 | return [os.path.join(data_dir, name + ext)
40 | for name in data_list]
41 |
42 |
43 | class ImageFileTrain(ImageFile):
44 | def __init__(self,
45 | alpha_dir="train_alpha",
46 | fg_dir="train_fg",
47 | bg_dir="train_bg",
48 | alpha_ext=".jpg",
49 | fg_ext=".jpg",
50 | bg_ext=".jpg"):
51 | super(ImageFileTrain, self).__init__(phase="train")
52 |
53 | self.alpha_dir = alpha_dir
54 | self.fg_dir = fg_dir
55 | self.bg_dir = bg_dir
56 | self.alpha_ext = alpha_ext
57 | self.fg_ext = fg_ext
58 | self.bg_ext = bg_ext
59 |
60 | self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir)
61 | self.valid_bg_list = [os.path.splitext(
62 | name)[0] for name in os.listdir(self.bg_dir)]
63 |
64 | self.alpha = self._list_abspath(
65 | self.alpha_dir, self.alpha_ext, self.valid_fg_list)
66 | self.fg = self._list_abspath(
67 | self.fg_dir, self.fg_ext, self.valid_fg_list)
68 | self.bg = self._list_abspath(
69 | self.bg_dir, self.bg_ext, self.valid_bg_list)
70 |
71 | def __len__(self):
72 | return len(self.alpha)
73 |
74 |
75 | class ImageFileTest(ImageFile):
76 | def __init__(self,
77 | alpha_dir="test_alpha",
78 | merged_dir="test_merged",
79 | trimap_dir="test_trimap",
80 | alpha_ext=".png",
81 | merged_ext=".png",
82 | trimap_ext=".png"):
83 | super(ImageFileTest, self).__init__(phase="test")
84 |
85 | self.alpha_dir = alpha_dir
86 | self.merged_dir = merged_dir
87 | self.trimap_dir = trimap_dir
88 | self.alpha_ext = alpha_ext
89 | self.merged_ext = merged_ext
90 | self.trimap_ext = trimap_ext
91 |
92 | self.valid_image_list = self._get_valid_names(
93 | self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False)
94 |
95 | self.alpha = self._list_abspath(
96 | self.alpha_dir, self.alpha_ext, self.valid_image_list)
97 | self.merged = self._list_abspath(
98 | self.merged_dir, self.merged_ext, self.valid_image_list)
99 | self.trimap = self._list_abspath(
100 | self.trimap_dir, self.trimap_ext, self.valid_image_list)
101 |
102 | def __len__(self):
103 | return len(self.alpha)
104 |
105 |
106 | dataset = {'AIM', 'PPM', 'AM2k_train', 'AM2k_val',
107 | 'RWP636', 'P3M_val_np', 'P3M_train', 'P3M_val_p'}
108 | # o_ _o or _5k m_ m_ _input.jpg p_ p_ p_
109 |
110 |
111 | def get_dir_ext(dataset):
112 |
113 | # assert dataset in ['ICM57', 'ICM']:
114 | image_dir = './datasets/ICM57/image'
115 | label_dir = './datasets/ICM57/alpha'
116 | trimap_dir = './datasets/ICM57/trimap'
117 |
118 | merged_ext = '.jpg'
119 | alpha_ext = '.png'
120 | trimap_ext = '.png'
121 | return image_dir, label_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext
122 |
123 | class MultiImageFile(object):
124 | def __init__(self):
125 |
126 | self.rng = np.random.RandomState(1)
127 |
128 | def _get_valid_names(self, *dirs, shuffle=True):
129 | # Extract valid names
130 | name_sets = [self._get_name_set(d) for d in dirs]
131 |
132 | # Reduce
133 | def _join_and(a, b):
134 | return a & b
135 |
136 | valid_names = list(functools.reduce(_join_and, name_sets))
137 |
138 | # ensure the order is the same for both training and validation
139 | if shuffle:
140 | valid_names.sort()
141 | self.rng.shuffle(valid_names)
142 |
143 | return valid_names
144 |
145 | @staticmethod
146 | def _get_name_set(dir_name):
147 | path_list = glob.glob(os.path.join(dir_name, '*'))
148 | name_set = set()
149 | for path in path_list:
150 | name = os.path.basename(path)
151 | name = os.path.splitext(name)[0]
152 | name_set.add(name)
153 | return name_set
154 |
155 | @staticmethod
156 | def _list_abspath(data_dir, ext, data_list):
157 | return [os.path.join(data_dir, name + ext)
158 | for name in data_list]
159 |
160 |
161 | class MultiImageFileDoubleSet(MultiImageFile):
162 | def __init__(self, ratio=0.9, dataset_name=['AIM', 'PPM', 'AM2k_train', 'AM2k_val', 'RWP636', 'P3M_val_np']):
163 |
164 | super(MultiImageFileDoubleSet, self).__init__()
165 |
166 | self.alpha_train = []
167 | self.merged_train = []
168 | self.trimap_train = []
169 | self.alpha_val = []
170 | self.merged_val = []
171 | self.trimap_val = []
172 |
173 | for dataset_name_ in dataset_name:
174 | merged_dir, alpha_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext = get_dir_ext(
175 | dataset_name_)
176 | valid_image_list = self._get_valid_names(
177 | alpha_dir, merged_dir, trimap_dir)
178 |
179 | alpha = self._list_abspath(alpha_dir, alpha_ext, valid_image_list)
180 | merged = self._list_abspath(
181 | merged_dir, merged_ext, valid_image_list)
182 | trimap = self._list_abspath(
183 | trimap_dir, trimap_ext, valid_image_list)
184 |
185 | alpha_train, alpha_val = self._split(alpha, ratio)
186 | merged_train, merged_val = self._split(merged, ratio)
187 | trimap_train, trimap_val = self._split(trimap, ratio)
188 |
189 | self.alpha_train.extend(alpha_train)
190 | self.merged_train.extend(merged_train)
191 | self.trimap_train.extend(trimap_train)
192 | self.alpha_val.extend(alpha_val)
193 | self.merged_val.extend(merged_val)
194 | self.trimap_val.extend(trimap_val)
195 |
196 | def _split(self, data_list, ratio):
197 | num = len(data_list)
198 | split = int(num * ratio)
199 | return data_list[:split], data_list[split:]
200 |
201 |
202 | class ContextData():
203 | '''
204 | dataset_name: corresponding to the dataset file in /datasets
205 |
206 | return:
207 | dataset: dict2list, key: image_name,
208 | value: {"dataset_name": "AIM", "class": "animal",
209 | "sub_class": null, "HalfOrFull": "half", "TransparentOrOpaque": "SO"}
210 |
211 | image_class_dict: dict, key: class_name, (class-sub_class)
212 | value: dict, key: image_name, value: dataset_name
213 | '''
214 |
215 | def __init__(self, ratio=0.9, dataset_name=['PPM', 'AM2k', 'RWP636', 'P3M_val_np']):
216 | dataset = {}
217 | for dataset_name_ in dataset_name:
218 | json_dir = os.path.join('datasets', dataset_name_+'.json')
219 | # read json file and append to dataset
220 | with open(json_dir) as f:
221 | new_data = json.load(f)
222 | # filter out the items if each element in "instance_area_ratio" list >0.1
223 | # check if "instance_area_ratio" exists
224 | if 'instance_area_ratio' in new_data[list(new_data.keys())[0]].keys():
225 | new_data_ = {}
226 | for k, v in new_data.items():
227 | if min(v['instance_area_ratio']) < 0.3:
228 | x = min(v['instance_area_ratio'])
229 | r = np.random.rand()
230 | if 100*(0.6-x)*x/9 > r**0.2:
231 | new_data_[k] = v
232 | else:
233 | new_data_[k] = v
234 | new_data = new_data_
235 | dataset.update(new_data)
236 |
237 | # shuffle dataset with seed
238 | self.rng = np.random.RandomState(1)
239 | dataset_list = list(dataset.items())
240 | dataset_list.sort()
241 | self.rng.shuffle(dataset_list)
242 |
243 | # split dataset into train and val
244 | dataset_train, dataset_val = self._split(dataset_list, ratio)
245 |
246 | # get image_class_dict
247 | image_class_dict_train = self.get_image_class_dict(dataset_train)
248 | image_class_dict_val = self.get_image_class_dict(dataset_val)
249 |
250 | self.image_class_dict_train = image_class_dict_train
251 | self.image_class_dict_val = image_class_dict_val
252 | self.dataset_train = dataset_train
253 | self.dataset_val = dataset_val
254 |
255 | def _split(self, data_list, ratio):
256 | num = len(data_list)
257 | split = int(num * ratio)
258 |
259 | dataset_train = dict(data_list[:split])
260 | dataset_val = dict(data_list[split:])
261 | return dataset_train, dataset_val
262 |
263 | def get_image_class_dict(self, dataset):
264 | image_class_dict = {}
265 | for image_name, image_info in dataset.items():
266 | class_name = str(
267 | image_info['class'])+'-'+str(image_info['sub_class'])+'-'+str(image_info['HalfOrFull'])
268 | if class_name not in image_class_dict.keys():
269 | image_class_dict[class_name] = {}
270 | image_class_dict[class_name][image_name] = image_info['dataset_name']
271 | else:
272 | image_class_dict[class_name][image_name] = image_info['dataset_name']
273 | return image_class_dict
274 |
275 |
276 | if __name__ == "__main__":
277 |
278 | test = MultiImageFileDoubleSet()
279 | print(0)
280 |
--------------------------------------------------------------------------------
/icm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/icm/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiny-smart/in-context-matting/a58bcf4b948babcf5f1c2e1c41e2fa040bc53d1e/icm/models/__init__.py
--------------------------------------------------------------------------------
/icm/models/decoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiny-smart/in-context-matting/a58bcf4b948babcf5f1c2e1c41e2fa040bc53d1e/icm/models/decoder/__init__.py
--------------------------------------------------------------------------------
/icm/models/decoder/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | import os
14 | import numpy as np
15 | from PIL import Image
16 |
17 |
18 | class Attention(nn.Module):
19 | """
20 | An attention layer that allows for downscaling the size of the embedding
21 | after projection to queries, keys, and values.
22 | """
23 |
24 | def __init__(
25 | self,
26 | embedding_dim: int,
27 | num_heads: int,
28 | downsample_rate: int = 1,
29 | ) -> None:
30 | super().__init__()
31 | self.embedding_dim = embedding_dim
32 | self.internal_dim = embedding_dim // downsample_rate
33 | self.num_heads = num_heads
34 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
35 |
36 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
37 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
38 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
39 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
40 |
41 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
42 | b, n, c = x.shape
43 | x = x.reshape(b, n, num_heads, c // num_heads)
44 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
45 |
46 | def _recombine_heads(self, x: Tensor) -> Tensor:
47 | b, n_heads, n_tokens, c_per_head = x.shape
48 | x = x.transpose(1, 2)
49 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
50 |
51 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
52 | # Input projections
53 | q = self.q_proj(q)
54 | k = self.k_proj(k)
55 | v = self.v_proj(v)
56 |
57 | # Separate into heads
58 | q = self._separate_heads(q, self.num_heads)
59 | k = self._separate_heads(k, self.num_heads)
60 | v = self._separate_heads(v, self.num_heads)
61 |
62 | # Attention
63 | _, _, _, c_per_head = q.shape
64 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
65 | attn = attn / math.sqrt(c_per_head)
66 | attn = torch.softmax(attn, dim=-1) # [b,head,token,token]
67 |
68 | # save_attn_map(attn)
69 |
70 | # Get output
71 | out = attn @ v
72 | out = self._recombine_heads(out)
73 | out = self.out_proj(out)
74 |
75 | return out
76 |
77 | class MLPBlock(nn.Module):
78 | def __init__(
79 | self,
80 | embedding_dim: int,
81 | mlp_dim: int,
82 | act: Type[nn.Module] = nn.GELU,
83 | ) -> None:
84 | super().__init__()
85 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
86 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
87 | self.act = act()
88 |
89 | def forward(self, x: torch.Tensor) -> torch.Tensor:
90 | return self.lin2(self.act(self.lin1(x)))
91 |
92 |
93 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
94 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
95 | class LayerNorm2d(nn.Module):
96 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
97 | super().__init__()
98 | self.weight = nn.Parameter(torch.ones(num_channels))
99 | self.bias = nn.Parameter(torch.zeros(num_channels))
100 | self.eps = eps
101 |
102 | def forward(self, x: torch.Tensor) -> torch.Tensor:
103 | u = x.mean(1, keepdim=True)
104 | s = (x - u).pow(2).mean(1, keepdim=True)
105 | x = (x - u) / torch.sqrt(s + self.eps)
106 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
107 | return x
108 |
--------------------------------------------------------------------------------
/icm/models/decoder/bottleneck_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | def c2_msra_fill(module: nn.Module) -> None:
6 | """
7 | Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
8 | Also initializes `module.bias` to 0.
9 |
10 | Args:
11 | module (torch.nn.Module): module to initialize.
12 | """
13 | # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`.
14 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
15 | if module.bias is not None:
16 | # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
17 | # torch.Tensor]`.
18 | nn.init.constant_(module.bias, 0)
19 |
20 | class Conv2d(torch.nn.Conv2d):
21 | """
22 | A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
23 | """
24 |
25 | def __init__(self, *args, **kwargs):
26 | """
27 | Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
28 |
29 | Args:
30 | norm (nn.Module, optional): a normalization layer
31 | activation (callable(Tensor) -> Tensor): a callable activation function
32 |
33 | It assumes that norm layer is used before activation.
34 | """
35 | norm = kwargs.pop("norm", None)
36 | activation = kwargs.pop("activation", None)
37 | super().__init__(*args, **kwargs)
38 |
39 | self.norm = norm
40 | self.activation = activation
41 |
42 | def forward(self, x):
43 | # torchscript does not support SyncBatchNorm yet
44 | # https://github.com/pytorch/pytorch/issues/40507
45 | # and we skip these codes in torchscript since:
46 | # 1. currently we only support torchscript in evaluation mode
47 | # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
48 | # later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
49 | if not torch.jit.is_scripting():
50 | if x.numel() == 0 and self.training:
51 | # https://github.com/pytorch/pytorch/issues/12013
52 | assert not isinstance(
53 | self.norm, torch.nn.SyncBatchNorm
54 | ), "SyncBatchNorm does not support empty inputs!"
55 |
56 | x = F.conv2d(
57 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
58 | )
59 | if self.norm is not None:
60 | x = self.norm(x)
61 | if self.activation is not None:
62 | x = self.activation(x)
63 | return x
64 |
65 | def get_norm(norm, out_channels):
66 | """
67 | Args:
68 | norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
69 | or a callable that takes a channel number and returns
70 | the normalization layer as a nn.Module.
71 |
72 | Returns:
73 | nn.Module or None: the normalization layer
74 | """
75 | if norm is None:
76 | return None
77 | if isinstance(norm, str):
78 | if len(norm) == 0:
79 | return None
80 | norm = {
81 | "BN": nn.BatchNorm2d,
82 | "GN": lambda channels: nn.GroupNorm(32, channels),
83 | }[norm]
84 | return norm(out_channels)
85 |
86 | class CNNBlockBase(nn.Module):
87 | """
88 | A CNN block is assumed to have input channels, output channels and a stride.
89 | The input and output of `forward()` method must be NCHW tensors.
90 | The method can perform arbitrary computation but must match the given
91 | channels and stride specification.
92 |
93 | Attribute:
94 | in_channels (int):
95 | out_channels (int):
96 | stride (int):
97 | """
98 |
99 | def __init__(self, in_channels, out_channels, stride):
100 | """
101 | The `__init__` method of any subclass should also contain these arguments.
102 |
103 | Args:
104 | in_channels (int):
105 | out_channels (int):
106 | stride (int):
107 | """
108 | super().__init__()
109 | self.in_channels = in_channels
110 | self.out_channels = out_channels
111 | self.stride = stride
112 |
113 | def freeze(self):
114 | """
115 | Make this block not trainable.
116 | This method sets all parameters to `requires_grad=False`,
117 | and convert all BatchNorm layers to FrozenBatchNorm
118 |
119 | Returns:
120 | the block itself
121 | """
122 | for p in self.parameters():
123 | p.requires_grad = False
124 | return self
125 |
126 | class BottleneckBlock(CNNBlockBase):
127 | """
128 | The standard bottleneck residual block used by ResNet-50, 101 and 152
129 | defined in :paper:`ResNet`. It contains 3 conv layers with kernels
130 | 1x1, 3x3, 1x1, and a projection shortcut if needed.
131 | """
132 |
133 | def __init__(
134 | self,
135 | in_channels,
136 | out_channels,
137 | *,
138 | bottleneck_channels,
139 | stride=1,
140 | num_groups=1,
141 | norm="GN",
142 | stride_in_1x1=False,
143 | dilation=1,
144 | ):
145 | """
146 | Args:
147 | bottleneck_channels (int): number of output channels for the 3x3
148 | "bottleneck" conv layers.
149 | num_groups (int): number of groups for the 3x3 conv layer.
150 | norm (str or callable): normalization for all conv layers.
151 | See :func:`layers.get_norm` for supported format.
152 | stride_in_1x1 (bool): when stride>1, whether to put stride in the
153 | first 1x1 convolution or the bottleneck 3x3 convolution.
154 | dilation (int): the dilation rate of the 3x3 conv layer.
155 | """
156 | super().__init__(in_channels, out_channels, stride)
157 |
158 | if in_channels != out_channels:
159 | self.shortcut = Conv2d(
160 | in_channels,
161 | out_channels,
162 | kernel_size=1,
163 | stride=stride,
164 | bias=False,
165 | norm=get_norm(norm, out_channels),
166 | )
167 | else:
168 | self.shortcut = None
169 |
170 | # The original MSRA ResNet models have stride in the first 1x1 conv
171 | # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
172 | # stride in the 3x3 conv
173 | stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
174 |
175 | self.conv1 = Conv2d(
176 | in_channels,
177 | bottleneck_channels,
178 | kernel_size=1,
179 | stride=stride_1x1,
180 | bias=False,
181 | norm=get_norm(norm, bottleneck_channels),
182 | )
183 |
184 | self.conv2 = Conv2d(
185 | bottleneck_channels,
186 | bottleneck_channels,
187 | kernel_size=3,
188 | stride=stride_3x3,
189 | padding=1 * dilation,
190 | bias=False,
191 | groups=num_groups,
192 | dilation=dilation,
193 | norm=get_norm(norm, bottleneck_channels),
194 | )
195 |
196 | self.conv3 = Conv2d(
197 | bottleneck_channels,
198 | out_channels,
199 | kernel_size=1,
200 | bias=False,
201 | norm=get_norm(norm, out_channels),
202 | )
203 |
204 | for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
205 | if layer is not None: # shortcut can be None
206 | c2_msra_fill(layer)
207 |
208 | # Zero-initialize the last normalization in each residual branch,
209 | # so that at the beginning, the residual branch starts with zeros,
210 | # and each residual block behaves like an identity.
211 | # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
212 | # "For BN layers, the learnable scaling coefficient γ is initialized
213 | # to be 1, except for each residual block's last BN
214 | # where γ is initialized to be 0."
215 |
216 | # nn.init.constant_(self.conv3.norm.weight, 0)
217 | # TODO this somehow hurts performance when training GN models from scratch.
218 | # Add it as an option when we need to use this code to train a backbone.
219 |
220 | def forward(self, x):
221 | out = self.conv1(x)
222 | out = F.relu_(out)
223 |
224 | out = self.conv2(out)
225 | out = F.relu_(out)
226 |
227 | out = self.conv3(out)
228 |
229 | if self.shortcut is not None:
230 | shortcut = self.shortcut(x)
231 | else:
232 | shortcut = x
233 |
234 | out += shortcut
235 | out = F.relu_(out)
236 | return out
--------------------------------------------------------------------------------
/icm/models/decoder/detail_capture.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class Basic_Conv3x3(nn.Module):
6 | """
7 | Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
8 | """
9 | def __init__(
10 | self,
11 | in_chans,
12 | out_chans,
13 | stride=1,
14 | padding=1,
15 | ):
16 | super().__init__()
17 | self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
18 | self.bn = nn.BatchNorm2d(out_chans)
19 | self.relu = nn.ReLU(True)
20 |
21 | def forward(self, x):
22 | x = self.conv(x)
23 | x = self.bn(x)
24 | x = self.relu(x)
25 |
26 | return x
27 |
28 |
29 | class Basic_Conv3x3_attn(nn.Module):
30 | """
31 | Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
32 | """
33 | def __init__(
34 | self,
35 | in_chans,
36 | out_chans,
37 | res = False,
38 | stride=1,
39 | padding=1,
40 | ):
41 | super().__init__()
42 | self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
43 | self.ln = nn.LayerNorm(in_chans, elementwise_affine=True)
44 | self.relu = nn.ReLU(True)
45 |
46 | def forward(self, x):
47 | x = self.ln(x)
48 | x = x.permute(0, 3, 1, 2)
49 | x = self.relu(x)
50 | x = self.conv(x)
51 |
52 | return x
53 |
54 | # class Basic_Conv3x3_attn(nn.Module):
55 | # """
56 | # Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
57 | # """
58 | # def __init__(
59 | # self,
60 | # in_chans,
61 | # out_chans,
62 | # res = False,
63 | # stride=1,
64 | # padding=1,
65 | # ):
66 | # super().__init__()
67 | # self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
68 | # self.ln = nn.LayerNorm([in_chans, res, res], elementwise_affine=True)
69 | # self.relu = nn.ReLU(True)
70 |
71 | # def forward(self, x):
72 | # x = x.permute(0, 3, 1, 2)
73 | # x = self.ln(x)
74 | # x = self.relu(x)
75 | # x = self.conv(x)
76 |
77 | # return x
78 |
79 | class ConvStream(nn.Module):
80 | """
81 | Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
82 | """
83 | def __init__(
84 | self,
85 | in_chans = 4,
86 | out_chans = [48, 96, 192],
87 | ):
88 | super().__init__()
89 | self.convs = nn.ModuleList()
90 |
91 | self.conv_chans = out_chans
92 | self.conv_chans.insert(0, in_chans)
93 |
94 | for i in range(len(self.conv_chans)-1):
95 | in_chan_ = self.conv_chans[i]
96 | out_chan_ = self.conv_chans[i+1]
97 | self.convs.append(
98 | Basic_Conv3x3(in_chan_, out_chan_, stride=2)
99 | )
100 |
101 | def forward(self, x):
102 | out_dict = {'D0': x}
103 | for i in range(len(self.convs)):
104 | x = self.convs[i](x)
105 | check = self.convs[i]
106 | name_ = 'D'+str(i+1)
107 | out_dict[name_] = x
108 |
109 | return out_dict
110 |
111 | class Fusion_Block(nn.Module):
112 | """
113 | Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
114 | """
115 | def __init__(
116 | self,
117 | in_chans,
118 | out_chans,
119 | ):
120 | super().__init__()
121 | self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
122 |
123 | def forward(self, x, D):
124 | F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
125 | out = torch.cat([D, F_up], dim=1)
126 | out = self.conv(out)
127 |
128 | return out
129 |
130 | class Matting_Head(nn.Module):
131 | """
132 | Simple Matting Head, containing only conv3x3 and conv1x1 layers.
133 | """
134 | def __init__(
135 | self,
136 | in_chans = 32,
137 | mid_chans = 16,
138 | ):
139 | super().__init__()
140 | self.matting_convs = nn.Sequential(
141 | nn.Conv2d(in_chans, mid_chans, 3, 1, 1),
142 | nn.BatchNorm2d(mid_chans),
143 | nn.ReLU(True),
144 | nn.Conv2d(mid_chans, 1, 1, 1, 0)
145 | )
146 |
147 | def forward(self, x):
148 | x = self.matting_convs(x)
149 |
150 | return x
151 |
152 | # TODO: implement groupnorm and ws. In ODISE, bs=2, they work well; when bs = 1, mse loss will be nan, why?
153 | class DetailCapture(nn.Module):
154 | """
155 | Simple and Lightweight Detail Capture Module for ViT Matting.
156 | """
157 | def __init__(
158 | self,
159 | in_chans = 384,
160 | img_chans=4,
161 | convstream_out = [48, 96, 192],
162 | fusion_out = [256, 128, 64, 32],
163 | ckpt=None,
164 | use_sigmoid = True,
165 | ):
166 | super().__init__()
167 | assert len(fusion_out) == len(convstream_out) + 1
168 |
169 | self.convstream = ConvStream(in_chans = img_chans)
170 | self.conv_chans = self.convstream.conv_chans
171 |
172 | self.fusion_blks = nn.ModuleList()
173 | self.fus_channs = fusion_out.copy()
174 | self.fus_channs.insert(0, in_chans)
175 | for i in range(len(self.fus_channs)-1):
176 | self.fusion_blks.append(
177 | Fusion_Block(
178 | in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
179 | out_chans = self.fus_channs[i+1],
180 | )
181 | )
182 |
183 | self.matting_head = Matting_Head(
184 | in_chans = fusion_out[-1],
185 | )
186 |
187 | if ckpt != None and ckpt != '':
188 | self.load_state_dict(ckpt['state_dict'], strict=False)
189 | print('load detail capture ckpt from', ckpt['path'])
190 |
191 | self.use_sigmoid = use_sigmoid
192 | self.img_chans = img_chans
193 | def forward(self, features, images):
194 |
195 | if isinstance(features, dict):
196 |
197 | trimap = features['trimap']
198 | features = features['feature']
199 | if self.img_chans ==4:
200 | images = torch.cat([images, trimap], dim=1)
201 |
202 | detail_features = self.convstream(images)
203 | # D0 2 4 512 512
204 | # D1 2 48 256 256
205 | # D2 2 96 128 128
206 | # D3 2 192 64 64
207 | for i in range(len(self.fusion_blks)):
208 | d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
209 | features = self.fusion_blks[i](features, detail_features[d_name_])
210 |
211 | if self.use_sigmoid:
212 | phas = torch.sigmoid(self.matting_head(features))
213 | else:
214 | phas = self.matting_head(features)
215 | return phas
216 |
217 | def get_trainable_params(self):
218 | return list(self.parameters())
219 |
220 | class MaskDecoder(nn.Module):
221 | '''
222 | use trans-conv to decode mask
223 | '''
224 | def __init__(
225 | self,
226 | in_chans = 384,
227 | ):
228 | super().__init__()
229 | self.output_upscaling = nn.Sequential(
230 | nn.ConvTranspose2d(in_chans, in_chans // 4, kernel_size=2, stride=2),
231 | # LayerNorm2d(in_chans // 4),
232 | nn.BatchNorm2d(in_chans // 4),
233 | nn.ReLU(),
234 | nn.ConvTranspose2d(in_chans // 4, in_chans // 8, kernel_size=2, stride=2),
235 | nn.BatchNorm2d(in_chans // 8),
236 | nn.ReLU(),
237 | )
238 |
239 | self.matting_head = Matting_Head(
240 | in_chans = in_chans // 8,
241 | )
242 |
243 | def forward(self, x, images):
244 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
245 | x = self.output_upscaling(x)
246 | x = self.matting_head(x)
247 | x = torch.sigmoid(x)
248 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
249 | return x
--------------------------------------------------------------------------------
/icm/models/decoder/in_context_correspondence.py:
--------------------------------------------------------------------------------
1 |
2 | from einops import rearrange
3 | from torch import einsum
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 |
8 | from icm.models.decoder.bottleneck_block import BottleneckBlock
9 |
10 | from icm.models.decoder.detail_capture import Basic_Conv3x3, Basic_Conv3x3_attn, Fusion_Block
11 | import math
12 | from icm.models.decoder.attention import Attention, MLPBlock
13 |
14 |
15 | class OneWayAttentionBlock(nn.Module):
16 | def __init__(
17 | self,
18 | dim,
19 | n_heads,
20 | d_head,
21 | mlp_dim_rate,
22 | ):
23 | super().__init__()
24 |
25 | self.attn = Attention(dim, n_heads, downsample_rate=dim//d_head)
26 |
27 | self.norm1 = nn.LayerNorm(dim)
28 |
29 | self.mlp = MLPBlock(dim, int(dim*mlp_dim_rate))
30 |
31 | self.norm2 = nn.LayerNorm(dim)
32 |
33 | def forward(self, q, context_all):
34 | output = []
35 | for i in range(len(q)):
36 | x = q[i].unsqueeze(0)
37 | context = context_all[i].unsqueeze(0)
38 | x = self.norm1(x)
39 | context = self.norm1(context)
40 | x = self.attn(q=x, k=context, v=context) + x
41 | x = self.norm2(x)
42 | x = self.mlp(x) + x
43 | output.append(x.squeeze(0))
44 |
45 | return output
46 |
47 |
48 | def compute_correspondence_matrix(source_feature, ref_feature):
49 | """
50 | Compute correspondence matrix between source and reference features.
51 | Args:
52 | source_feature: [B, C, H, W]
53 | ref_feature: [B, C, H, W]
54 | Returns:
55 | correspondence_matrix: [B, H*W, H*W]
56 | """
57 | # [B, C, H, W] -> [B, H, W, C]
58 | source_feature = source_feature.permute(0, 2, 3, 1)
59 | ref_feature = ref_feature.permute(0, 2, 3, 1)
60 |
61 | # [B, H, W, C] -> [B, H*W, C]
62 | source_feature = torch.reshape(
63 | source_feature, [source_feature.shape[0], -1, source_feature.shape[-1]])
64 | ref_feature = torch.reshape(
65 | ref_feature, [ref_feature.shape[0], -1, ref_feature.shape[-1]])
66 |
67 | # norm
68 | source_feature = F.normalize(source_feature, p=2, dim=-1)
69 | ref_feature = F.normalize(ref_feature, p=2, dim=-1)
70 |
71 | # cosine similarity
72 | cos_sim = torch.matmul(
73 | source_feature, ref_feature.transpose(1, 2)) # [B, H*W, H*W]
74 |
75 | return cos_sim
76 |
77 |
78 | def maskpooling(mask, res):
79 | '''
80 | Mask pooling to reduce the resolution of mask
81 | Input:
82 | mask: [B, 1, H, W]
83 | res: resolution
84 | Output: [B, 1, res, res]
85 | '''
86 | # mask[mask < 1] = 0
87 | mask[mask > 0] = 1
88 | mask = -1 * mask
89 | kernel_size = mask.shape[2] // res
90 | mask = F.max_pool2d(mask, kernel_size=kernel_size,
91 | stride=kernel_size, padding=0)
92 | mask = -1*mask
93 | return mask
94 |
95 |
96 | def dilate(bin_img, ksize=5):
97 | pad = (ksize - 1) // 2
98 | bin_img = F.pad(bin_img, pad=[pad, pad, pad, pad], mode='reflect')
99 | out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
100 | return out
101 |
102 |
103 | def erode(bin_img, ksize=5):
104 | out = 1 - dilate(1 - bin_img, ksize)
105 | return out
106 |
107 |
108 | def generate_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10):
109 | eroded = erode(mask, erode_kernel_size)
110 | dilated = dilate(mask, dilate_kernel_size)
111 | trimap = torch.zeros_like(mask)
112 | trimap[dilated == 1] = 0.5
113 | trimap[eroded == 1] = 1
114 | return trimap
115 |
116 |
117 | def calculate_attention_score_(mask, attention_map, score_type):
118 | '''
119 | Calculate the attention score of the attention map
120 | mask: [B, H, W] value:0 or 1
121 | attention_map: [B, H*W, H, W]
122 | '''
123 | B, H, W = mask.shape
124 | mask_pos = mask.repeat(1, attention_map.shape[1], 1, 1) # [B, H*W, H, W]
125 | score_pos = torch.sum(attention_map * mask_pos, dim=(2, 3)) # [B, H*W]
126 | score_pos = score_pos / torch.sum(mask_pos, dim=(2, 3)) # [B, H*W]
127 |
128 | mask_neg = 1 - mask_pos
129 | score_neg = torch.sum(attention_map * mask_neg, dim=(2, 3))
130 | score_neg = score_neg / torch.sum(mask_neg, dim=(2, 3))
131 |
132 | assert score_type in ['classification', 'softmax', 'ratio']
133 |
134 | if score_type == 'classification':
135 | score = torch.zeros_like(score_pos)
136 | score[score_pos > score_neg] = 1
137 |
138 | return score.reshape(B, H, W)
139 |
140 |
141 | def refine_mask_by_attention(mask, attention_maps, iterations=10, score_type='classification'):
142 | '''
143 | mask: [B, H, W]
144 | attention_maps: [B, H*W, H, W]
145 | '''
146 | assert mask.shape[1] == attention_maps.shape[2]
147 | for i in range(iterations):
148 | score = calculate_attention_score_(
149 | mask, attention_maps, score_type=score_type) # [B, H, W]
150 |
151 | if score.equal(mask):
152 | # print("iteration: ", i, "score equal to mask")
153 | break
154 | else:
155 | mask = score
156 |
157 | assert i != iterations - 1
158 | return mask
159 |
160 |
161 | class InContextCorrespondence(nn.Module):
162 | '''
163 | one implementation of in_context_fusion
164 |
165 | forward(feature_of_reference_image, feature_of_source_image, guidance_on_reference_image)
166 | '''
167 |
168 | def __init__(self,
169 | use_bottle_neck=False,
170 | in_dim=1280,
171 | bottle_neck_dim=512,
172 | refine_with_attention=False,
173 | ):
174 | super().__init__()
175 | self.use_bottle_neck = use_bottle_neck
176 | self.refine_with_attention = refine_with_attention
177 |
178 | def forward(self, feature_of_reference_image, ft_attn_of_source_image, guidance_on_reference_image):
179 | '''
180 | feature_of_reference_image: [B, C, H, W]
181 | ft_attn_of_source_image: {"ft": [B, C, H, W], "attn": [B, H_1, W_1, H_1*W_1]}
182 | guidance_on_reference_image: [B, 1, H_2, W_2]
183 | '''
184 |
185 | # get source_image h,w
186 | h, w = guidance_on_reference_image.shape[-2:]
187 | h_attn, w_attn = ft_attn_of_source_image['attn'].shape[-3:-1]
188 |
189 | feature_of_source_image = ft_attn_of_source_image['ft']
190 | attention_map_of_source_image = ft_attn_of_source_image['attn']
191 |
192 | cos_sim = compute_correspondence_matrix(
193 | feature_of_source_image, feature_of_reference_image)
194 |
195 | # 获得cos_sim每一行的最大值的索引
196 | index = torch.argmax(cos_sim, dim=-1) # 1*1024
197 |
198 | mask_ref = maskpooling(guidance_on_reference_image,
199 | h_attn)
200 |
201 | mask_ref = mask_ref.reshape(mask_ref.shape[0], -1) # 1*1024
202 |
203 | new_index = torch.gather(mask_ref, 1, index) # 1*1024
204 | res = int(new_index.shape[-1]**0.5)
205 | new_index = new_index.reshape(
206 | new_index.shape[0], res, res).unsqueeze(1)
207 |
208 | # resize mask_result to 512*512
209 | mask_result = new_index
210 |
211 | if self.refine_with_attention:
212 | mask_result = refine_mask_by_attention(
213 | mask_result, attention_map_of_source_image, iterations=10, score_type='classification')
214 |
215 | mask_result = F.interpolate(
216 | mask_result.float(), size=(h, w), mode='bilinear')
217 |
218 | # get trimap
219 |
220 | pesudo_trimap = generate_trimap(
221 | mask_result, erode_kernel_size=self.kernel_size, dilate_kernel_size=self.kernel_size)
222 |
223 | output = {}
224 | output['trimap'] = pesudo_trimap
225 | output['feature'] = feature_of_source_image
226 | output['mask'] = mask_result
227 |
228 | return output
229 |
230 |
231 | class TrainingFreeAttention(nn.Module):
232 | def __init__(self, res_ratio=4, pool_type='average', temp_softmax=1, use_scale=False, upsample_mode='bilinear', use_norm=False) -> None:
233 | super().__init__()
234 | self.res_ratio = res_ratio
235 | self.pool_type = pool_type
236 | self.temp_softmax = temp_softmax
237 | self.use_scale = use_scale
238 | self.upsample_mode = upsample_mode
239 | if use_norm:
240 | self.norm = nn.LayerNorm(use_norm, elementwise_affine=True)
241 | else:
242 | self.idt = nn.Identity()
243 |
244 | def forward(self, features, features_ref, roi_mask,):
245 | # roi_mask: [B, 1, H, W]
246 | # features: [B, C, h, w]
247 | # features_ref: [B, C, h, w]
248 | B, _, H, W = roi_mask.shape
249 | if self.res_ratio == None:
250 | H_attn, W_attn = features.shape[2], features.shape[3]
251 | else:
252 | H_attn = H//self.res_ratio
253 | W_attn = W//self.res_ratio
254 | features, features_ref = self.resize_input_to_res(
255 | features, features_ref, (H, W)) # [H//res_ratio, W//res_ratio]
256 |
257 | # List, len = B, each element: [C_q, dim], dim = H//res_ratio * W//res_ratio
258 | features_ref = self.get_roi_features(features_ref, roi_mask)
259 |
260 | features = features.reshape(
261 | B, -1, features.shape[2] * features.shape[3]).permute(0, 2, 1) # [B, C, dim]
262 | # List, len = B, each element: [C_q, C]
263 | attn_output = self.compute_attention(features, features_ref)
264 |
265 | # List, len = B, each element: [C_q, H_attn, W_attn]
266 | attn_output = self.reshape_attn_output(attn_output, (H_attn, W_attn))
267 |
268 | return attn_output
269 |
270 | def resize_input_to_res(self, features, features_ref, size):
271 | # features: [B, C, h, w]
272 | # features_ref: [B, C, h, w]
273 | H, W = size
274 | target_H, target_W = H//self.res_ratio, W//self.res_ratio
275 | features = F.interpolate(features, size=(
276 | target_H, target_W), mode=self.upsample_mode)
277 | features_ref = F.interpolate(features_ref, size=(
278 | target_H, target_W), mode=self.upsample_mode)
279 | return features, features_ref
280 |
281 | def get_roi_features(self, feature, mask):
282 | '''
283 | get feature tokens by maskpool
284 | feature: [B, C, h, w]
285 | mask: [B, 1, H, W] [0,1]
286 | return: List, len = B, each element: [token_num, C]
287 | '''
288 |
289 | # assert mask only has elements 0 and 1
290 | assert torch.all(torch.logical_or(mask == 0, mask == 1))
291 | # assert mask.max() == 1 and mask.min() == 0
292 |
293 | B, _, H, W = mask.shape
294 | h, w = feature.shape[2:]
295 |
296 | output = []
297 | for i in range(B):
298 | mask_ = mask[i]
299 | feature_ = feature[i]
300 | feature_ = self.maskpool(feature_, mask_)
301 | output.append(feature_)
302 | return output
303 |
304 | def maskpool(self, feature, mask):
305 | '''
306 | get feature tokens by maskpool
307 | feature: [C, h, w]
308 | mask: [1, H, W] [0,1]
309 | return: [token_num, C]
310 | '''
311 | kernel_size = mask.shape[1] // feature.shape[1] if self.res_ratio == None else self.res_ratio
312 | if self.pool_type == 'max':
313 | mask = F.max_pool2d(mask, kernel_size=kernel_size,
314 | stride=kernel_size, padding=0)
315 | elif self.pool_type == 'average':
316 | mask = F.avg_pool2d(mask, kernel_size=kernel_size,
317 | stride=kernel_size, padding=0)
318 | elif self.pool_type == 'min':
319 | mask = -1*mask
320 | mask = F.max_pool2d(mask, kernel_size=kernel_size,
321 | stride=kernel_size, padding=0)
322 | mask = -1*mask
323 | else:
324 | raise NotImplementedError
325 |
326 | # element-wise multiplication mask and feature
327 | feature = feature * mask
328 |
329 | index = (mask > 0).reshape(1, -1).squeeze()
330 | feature = feature.reshape(feature.shape[0], -1).permute(1, 0)
331 |
332 | feature = feature[index]
333 | return feature
334 |
335 | def compute_attention(self, features, features_ref):
336 | '''
337 | features: [B, C, dim]
338 | features_ref: List, len = B, each element: [C_q, dim]
339 | return: List, len = B, each element: [C_q, C]
340 | '''
341 | output = []
342 | for i in range(len(features_ref)):
343 | feature_ref = features_ref[i]
344 | feature = features[i]
345 | feature = self.compute_attention_single(feature, feature_ref)
346 | output.append(feature)
347 | return output
348 |
349 | def compute_attention_single(self, feature, feature_ref):
350 | '''
351 | compute attention with softmax
352 | feature: [C, dim]
353 | feature_ref: [C_q, dim]
354 | return: [C_q, C]
355 | '''
356 | scale = feature.shape[-1]**-0.5 if self.use_scale else 1.0
357 | feature = self.norm(feature) if hasattr(self, 'norm') else feature
358 | feature_ref = self.norm(feature_ref) if hasattr(
359 | self, 'norm') else feature_ref
360 | sim = einsum('i d, j d -> i j', feature_ref, feature)*scale
361 | sim = sim/self.temp_softmax
362 | sim = sim.softmax(dim=-1)
363 | return sim
364 |
365 | def reshape_attn_output(self, attn_output, attn_size):
366 | '''
367 | attn_output: List, len = B, each element: [C_q, C]
368 | return: List, len = B, each element: [C_q, H_attn, W_attn]
369 | '''
370 | # attn_output[0].shape[1] sqrt to get H_attn, W_attn
371 | H_attn, W_attn = attn_size
372 |
373 | output = []
374 | for i in range(len(attn_output)):
375 | attn_output_ = attn_output[i]
376 | attn_output_ = attn_output_.reshape(
377 | attn_output_.shape[0], H_attn, W_attn)
378 | output.append(attn_output_)
379 | return output
380 |
381 |
382 | class TrainingCrossAttention(nn.Module):
383 | def __init__(self, res_ratio=4, pool_type='average', temp_softmax=1, use_scale=False, upsample_mode='bilinear', use_norm=False, dim=1280,
384 | n_heads=4,
385 | d_head=320,
386 | mlp_dim_rate=0.5,) -> None:
387 | super().__init__()
388 | self.res_ratio = res_ratio
389 | self.pool_type = pool_type
390 | self.temp_softmax = temp_softmax
391 | self.use_scale = use_scale
392 | self.upsample_mode = upsample_mode
393 | if use_norm:
394 | self.norm = nn.LayerNorm(use_norm, elementwise_affine=True)
395 | else:
396 | self.idt = nn.Identity()
397 |
398 | self.attn_module = OneWayAttentionBlock(
399 | dim,
400 | n_heads,
401 | d_head,
402 | mlp_dim_rate,
403 | )
404 |
405 | def forward(self, features, features_ref, roi_mask,):
406 | # roi_mask: [B, 1, H, W]
407 | # features: [B, C, h, w]
408 | # features_ref: [B, C, h, w]
409 | B, _, H, W = roi_mask.shape
410 | if self.res_ratio == None:
411 | H_attn, W_attn = features.shape[2], features.shape[3]
412 | else:
413 | H_attn = H//self.res_ratio
414 | W_attn = W//self.res_ratio
415 | features, features_ref = self.resize_input_to_res(
416 | features, features_ref, (H, W)) # [H//res_ratio, W//res_ratio]
417 |
418 | # List, len = B, each element: [C_q, dim], dim = H//res_ratio * W//res_ratio
419 | features_ref = self.get_roi_features(features_ref, roi_mask)
420 |
421 | features = features.reshape(
422 | B, -1, features.shape[2] * features.shape[3]).permute(0, 2, 1) # [B, C, dim]
423 | # List, len = B, each element: [C_q, C]
424 |
425 | features_ref = self.attn_module(features_ref, features)
426 |
427 | attn_output = self.compute_attention(features, features_ref)
428 |
429 | # List, len = B, each element: [C_q, H_attn, W_attn]
430 | attn_output = self.reshape_attn_output(attn_output, (H_attn, W_attn))
431 |
432 | return attn_output
433 |
434 | def resize_input_to_res(self, features, features_ref, size):
435 | # features: [B, C, h, w]
436 | # features_ref: [B, C, h, w]
437 | H, W = size
438 | target_H, target_W = H//self.res_ratio, W//self.res_ratio
439 | features = F.interpolate(features, size=(
440 | target_H, target_W), mode=self.upsample_mode)
441 | features_ref = F.interpolate(features_ref, size=(
442 | target_H, target_W), mode=self.upsample_mode)
443 | return features, features_ref
444 |
445 | def get_roi_features(self, feature, mask):
446 | '''
447 | get feature tokens by maskpool
448 | feature: [B, C, h, w]
449 | mask: [B, 1, H, W] [0,1]
450 | return: List, len = B, each element: [token_num, C]
451 | '''
452 |
453 | # assert mask only has elements 0 and 1
454 | assert torch.all(torch.logical_or(mask == 0, mask == 1))
455 | # assert mask.max() == 1 and mask.min() == 0
456 |
457 | B, _, H, W = mask.shape
458 | h, w = feature.shape[2:]
459 |
460 | output = []
461 | for i in range(B):
462 | mask_ = mask[i]
463 | feature_ = feature[i]
464 | feature_ = self.maskpool(feature_, mask_)
465 | output.append(feature_)
466 | return output
467 |
468 | def maskpool(self, feature, mask):
469 | '''
470 | get feature tokens by maskpool
471 | feature: [C, h, w]
472 | mask: [1, H, W] [0,1]
473 | return: [token_num, C]
474 | '''
475 | kernel_size = mask.shape[1] // feature.shape[1] if self.res_ratio == None else self.res_ratio
476 | if self.pool_type == 'max':
477 | mask = F.max_pool2d(mask, kernel_size=kernel_size,
478 | stride=kernel_size, padding=0)
479 | elif self.pool_type == 'average':
480 | mask = F.avg_pool2d(mask, kernel_size=kernel_size,
481 | stride=kernel_size, padding=0)
482 | elif self.pool_type == 'min':
483 | mask = -1*mask
484 | mask = F.max_pool2d(mask, kernel_size=kernel_size,
485 | stride=kernel_size, padding=0)
486 | mask = -1*mask
487 | else:
488 | raise NotImplementedError
489 |
490 | # element-wise multiplication mask and feature
491 | feature = feature * mask
492 |
493 | index = (mask > 0).reshape(1, -1).squeeze()
494 | feature = feature.reshape(feature.shape[0], -1).permute(1, 0)
495 |
496 | feature = feature[index]
497 | return feature
498 |
499 | def compute_attention(self, features, features_ref):
500 | '''
501 | features: [B, C, dim]
502 | features_ref: List, len = B, each element: [C_q, dim]
503 | return: List, len = B, each element: [C_q, C]
504 | '''
505 | output = []
506 | for i in range(len(features_ref)):
507 | feature_ref = features_ref[i]
508 | feature = features[i]
509 | feature = self.compute_attention_single(feature, feature_ref)
510 | output.append(feature)
511 | return output
512 |
513 | def compute_attention_single(self, feature, feature_ref):
514 | '''
515 | compute attention with softmax
516 | feature: [C, dim]
517 | feature_ref: [C_q, dim]
518 | return: [C_q, C]
519 | '''
520 | scale = feature.shape[-1]**-0.5 if self.use_scale else 1.0
521 | feature = self.norm(feature) if hasattr(self, 'norm') else feature
522 | feature_ref = self.norm(feature_ref) if hasattr(
523 | self, 'norm') else feature_ref
524 | sim = einsum('i d, j d -> i j', feature_ref, feature)*scale
525 | sim = sim/self.temp_softmax
526 | sim = sim.softmax(dim=-1)
527 | return sim
528 |
529 | def reshape_attn_output(self, attn_output, attn_size):
530 | '''
531 | attn_output: List, len = B, each element: [C_q, C]
532 | return: List, len = B, each element: [C_q, H_attn, W_attn]
533 | '''
534 | # attn_output[0].shape[1] sqrt to get H_attn, W_attn
535 | H_attn, W_attn = attn_size
536 |
537 | output = []
538 | for i in range(len(attn_output)):
539 | attn_output_ = attn_output[i]
540 | attn_output_ = attn_output_.reshape(
541 | attn_output_.shape[0], H_attn, W_attn)
542 | output.append(attn_output_)
543 | return output
544 |
545 |
546 | class TrainingFreeAttentionBlocks(nn.Module):
547 | '''
548 | one implementation of in_context_fusion
549 |
550 | forward(feature_of_reference_image, feature_of_source_image, guidance_on_reference_image)
551 | '''
552 |
553 | def __init__(self,
554 | res_ratio=8,
555 | pool_type='min',
556 | temp_softmax=1000,
557 | use_scale=False,
558 | upsample_mode='bicubic',
559 | bottle_neck_dim=None,
560 | use_norm=False,
561 |
562 | ):
563 | super().__init__()
564 |
565 | self.attn_module = TrainingFreeAttention(res_ratio=res_ratio,
566 | pool_type=pool_type,
567 | temp_softmax=temp_softmax,
568 | use_scale=use_scale,
569 | upsample_mode=upsample_mode,
570 | use_norm=use_norm,)
571 |
572 | def forward(self, feature_of_reference_image, ft_attn_of_source_image, guidance_on_reference_image):
573 | '''
574 | feature_of_reference_image: [B, C, H, W]
575 | ft_attn_of_source_image: {"ft_cor": [B, C, H, W], "attn": {'24':[B, H_1, W_1, H_1*W_1],} "ft_matting": [B, C, H, W]}
576 | guidance_on_reference_image: [B, 1, H_2, W_2]
577 | '''
578 | # assert feature_of_reference_image.shape[0] == 1
579 | # get source_image h,w
580 | h, w = guidance_on_reference_image.shape[-2:]
581 |
582 | features_cor = ft_attn_of_source_image['ft_cor']
583 | features_matting = ft_attn_of_source_image['ft_matting']
584 | features_ref = feature_of_reference_image
585 |
586 | guidance_on_reference_image[guidance_on_reference_image > 0.5] = 1
587 | guidance_on_reference_image[guidance_on_reference_image <= 0.5] = 0
588 | attn_output = self.attn_module(
589 | features_cor, features_ref, guidance_on_reference_image)
590 |
591 | attn_output = [attn_output_.sum(dim=0).unsqueeze(
592 | 0).unsqueeze(0) for attn_output_ in attn_output]
593 | attn_output = torch.cat(attn_output, dim=0)
594 |
595 | self_attn_output = self.training_free_self_attention(
596 | attn_output, ft_attn_of_source_image['attn'])
597 |
598 | # resize
599 | self_attn_output = F.interpolate(
600 | self_attn_output, size=(h, w), mode='bilinear')
601 |
602 | output = {}
603 | output['trimap'] = self_attn_output
604 | output['feature'] = features_matting
605 | output['mask'] = attn_output
606 |
607 | return output
608 |
609 | def training_free_self_attention(self, x, self_attn_maps):
610 | '''
611 | Compute self-attention using the attention maps.
612 |
613 | Parameters:
614 | x (torch.Tensor): The input tensor. Shape: [B, 1, H, W]
615 | self_attn_maps (torch.Tensor): The attention maps. Shape: {'24': [B, H1, W1, H1*W1]}
616 |
617 | Returns:
618 | torch.Tensor: The result of the self-attention computation.
619 | '''
620 |
621 | # Original dimensions of x
622 | # Assuming x's shape is [B, 1, H, W] based on your comment
623 | B, _, H, W = x.shape
624 |
625 | # Dimensions of the attention maps
626 | assert len(self_attn_maps) == 1
627 | # get only one value in dict
628 | self_attn_maps = list(self_attn_maps.values())[0]
629 | _, H1, W1, _ = self_attn_maps.shape
630 |
631 | # Resize x to match the spatial dimensions of the attention maps
632 | # You might need align_corners depending on your version of PyTorch
633 | x = F.interpolate(x, size=(H1, W1), mode='bilinear',
634 | align_corners=True)
635 |
636 | # Reshape the attention maps and x for matrix multiplication
637 | # Reshaping from [B, H1, W1, H1*W1] to [B, H1*W1, H1*W1]
638 | self_attn_maps = self_attn_maps.view(B, H1 * W1, H1 * W1)
639 | # Reshaping from [B, 1, H1, W1] to [B, 1, H1*W1]
640 | x = x.view(B, 1, H1 * W1)
641 |
642 | # Apply the self-attention mechanism
643 | # Matrix multiplication between the attention maps and the input feature map
644 | # This step essentially computes the weighted sum of feature vectors in the input,
645 | # where the weights are defined by the attention maps.
646 | # Multiplying with the transpose to get shape [B, 1, H1*W1]
647 | out = torch.matmul(x, self_attn_maps.transpose(1, 2))
648 |
649 | # Reshape the output tensor to the original spatial dimensions
650 | out = out.view(B, 1, H1, W1) # Reshaping back to spatial dimensions
651 |
652 | # # Resize the output back to the input's original dimensions (if necessary)
653 | # out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=True)
654 |
655 | return out
656 |
657 |
658 | class SemiTrainingAttentionBlocks(nn.Module):
659 | '''
660 | one implementation of in_context_fusion
661 |
662 | forward(feature_of_reference_image, feature_of_source_image, guidance_on_reference_image)
663 | '''
664 |
665 | def __init__(self,
666 | res_ratio=8,
667 | pool_type='min',
668 | upsample_mode='bicubic',
669 | bottle_neck_dim=None,
670 | use_norm=False,
671 | in_ft_dim=[1280, 960],
672 | in_attn_dim=[24**2, 48**2],
673 | attn_out_dim=256,
674 | ft_out_dim=512,
675 | training_cross_attn=False,
676 | ):
677 | super().__init__()
678 | if training_cross_attn:
679 | self.attn_module = TrainingCrossAttention(
680 | res_ratio=res_ratio,
681 | pool_type=pool_type,
682 | temp_softmax=1,
683 | use_scale=True,
684 | upsample_mode=upsample_mode,
685 | use_norm=use_norm,
686 | )
687 | else:
688 | self.attn_module = TrainingFreeAttention(res_ratio=res_ratio,
689 | pool_type=pool_type,
690 | temp_softmax=1,
691 | use_scale=True,
692 | upsample_mode=upsample_mode,
693 | use_norm=use_norm,)
694 |
695 | # init module list for attn, with basic 3*3 conv_attn
696 | self.attn_module_list = nn.ModuleList()
697 | self.ft_attn_module_list = nn.ModuleList()
698 | for i in range(len(in_attn_dim)):
699 | self.attn_module_list.append(Basic_Conv3x3_attn(
700 | in_attn_dim[i], attn_out_dim, int(math.sqrt(in_attn_dim[i]))))
701 | self.ft_attn_module_list.append(Basic_Conv3x3(
702 | ft_out_dim[i] + attn_out_dim, ft_out_dim[i]))
703 | # init module list for ft, with basic 3*3 conv
704 | self.ft_module_list = nn.ModuleList()
705 | for i in range(len(in_ft_dim)):
706 | self.ft_module_list.append(
707 | Basic_Conv3x3(in_ft_dim[i], ft_out_dim[i]))
708 |
709 | ft_out_dim_ = [2*d for d in ft_out_dim]
710 | self.fusion = MultiScaleFeatureFusion(ft_out_dim_, ft_out_dim)
711 |
712 | def forward(self, feature_of_reference_image, ft_attn_of_source_image, guidance_on_reference_image):
713 | '''
714 | feature_of_reference_image: [B, C, H, W]
715 | ft_attn_of_source_image: {"ft_cor": [B, C, H, W], "attn": [B, H_1, W_1, H_1*W_1], "ft_matting": {'24':[B, C, H, W]} }
716 | guidance_on_reference_image: [B, 1, H_2, W_2]
717 | '''
718 | # assert feature_of_reference_image.shape[0] == 1
719 | # get source_image h,w
720 | h, w = guidance_on_reference_image.shape[-2:]
721 |
722 | features_cor = ft_attn_of_source_image['ft_cor']
723 | features_matting = ft_attn_of_source_image['ft_matting']
724 | features_ref = feature_of_reference_image
725 |
726 | guidance_on_reference_image[guidance_on_reference_image > 0.5] = 1
727 | guidance_on_reference_image[guidance_on_reference_image <= 0.5] = 0
728 | attn_output = self.attn_module(
729 | features_cor, features_ref, guidance_on_reference_image)
730 |
731 | attn_output = [attn_output_.sum(dim=0).unsqueeze(
732 | 0).unsqueeze(0) for attn_output_ in attn_output]
733 | attn_output = torch.cat(attn_output, dim=0)
734 |
735 | self_attn_output = self.training_free_self_attention(
736 | attn_output, ft_attn_of_source_image['attn'])
737 |
738 | # concat attn and ft_matting
739 |
740 | attn_ft_matting = {}
741 | for i, key in enumerate(features_matting.keys()):
742 | if key in self_attn_output.keys():
743 | features_matting[key] = self.ft_module_list[i](
744 | features_matting[key])
745 | attn_ft_matting[key] = torch.cat(
746 | [features_matting[key], self_attn_output[key]], dim=1)
747 |
748 | attn_ft_matting[key] = self.ft_attn_module_list[i](
749 | attn_ft_matting[key])
750 |
751 | else:
752 | attn_ft_matting[key] = self.ft_module_list[i](
753 | features_matting[key])
754 |
755 | # forward in multi-scale fusion block
756 | attn_ft_matting = self.fusion(attn_ft_matting)
757 |
758 | att_look = []
759 | # resize and average self_attn_output
760 | for i, key in enumerate(self_attn_output.keys()):
761 | att__ = F.interpolate(
762 | self_attn_output[key].mean(dim=1).unsqueeze(1), size=(h, w), mode='bilinear')
763 | att_look.append(att__)
764 | att_look = torch.cat(att_look, dim=1)
765 | att_look = att_look.mean(dim=1).unsqueeze(1)
766 |
767 | output = {}
768 |
769 | output['trimap'] = att_look
770 | output['feature'] = attn_ft_matting
771 | output['mask'] = attn_output
772 |
773 | return output
774 |
775 | def training_free_self_attention(self, x, self_attn_maps):
776 | '''
777 | Compute weighted attn maps using the attention maps.
778 |
779 | Parameters:
780 | x (torch.Tensor): The input tensor. Shape: [B, 1, H, W]
781 | self_attn_maps (torch.Tensor): The attention maps. Shape: {'24':[B, H1, W1, H1*W1], '48':[B, H2, W2, H2*W2]}
782 |
783 | Returns:
784 | torch.Tensor: The result of the attention computation. {'24':[B, 1, H1*W1, H1, W1], '48':[B, 1, H2*W2, H2, W2]}
785 | '''
786 |
787 | # Original dimensions of x
788 | # Assuming x's shape is [B, 1, H, W] based on your comment
789 | B, _, H, W = x.shape
790 | out = {}
791 | for i, key in enumerate(self_attn_maps.keys()):
792 | # Dimensions of the attention maps
793 | _, H1, W1, _ = self_attn_maps[key].shape
794 |
795 | # Resize x to match the spatial dimensions of the attention maps
796 | # You might need align_corners depending on your version of PyTorch
797 | x_ = F.interpolate(x, size=(H1, W1), mode='bilinear',
798 | align_corners=True)
799 |
800 | # Reshape the attention maps and x for matrix multiplication
801 | # Reshaping from [B, H1, W1, H1*W1] to [B, H1*W1, H1*W1]
802 | self_attn_map_ = self_attn_maps[key].view(
803 | B, H1 * W1, H1 * W1).transpose(1, 2)
804 | # Reshaping from [B, 1, H1, W1] to [B, 1, H1*W1]
805 | x_ = x_.reshape(B, H1 * W1, 1)
806 |
807 | # propagate , element wise multiplication x_ and self_attn_maps
808 | x_ = x_ * self_attn_map_
809 | x_ = x_.reshape(B, H1 * W1, H1, W1)
810 | x_ = x_.permute(0, 2, 3, 1)
811 | x_ = self.attn_module_list[i](x_)
812 | out[key] = x_
813 |
814 | return out
815 |
816 |
817 | class MultiScaleFeatureFusion(nn.Module):
818 | '''
819 | N conv layers or bottleneck blocks to compress the feature dimension
820 |
821 | M conv layers and upsampling to fusion the features
822 |
823 | '''
824 |
825 | def __init__(self,
826 | in_feature_dim=[],
827 | out_feature_dim=[],
828 | use_bottleneck=False) -> None:
829 | super().__init__()
830 | assert len(in_feature_dim) == len(out_feature_dim)
831 | # init module list
832 | self.module_list = nn.ModuleList()
833 | for i in range(len(in_feature_dim)-1):
834 | self.module_list.append(Fusion_Block(
835 | in_feature_dim[i], out_feature_dim[i]))
836 |
837 | def forward(self, features):
838 | # features: {'32': tensor, '16': tensor, '8': tensor}
839 |
840 | key_list = list(features.keys())
841 | ft = features[key_list[0]]
842 | for i in range(len(key_list)-1):
843 | ft = self.module_list[i](ft, features[key_list[i+1]])
844 |
845 | return ft
846 |
--------------------------------------------------------------------------------
/icm/models/decoder/in_context_decoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from icm.util import instantiate_from_config
3 |
4 | class InContextDecoder(nn.Module):
5 | '''
6 | InContextDecoder is the decoder of InContextMatting.
7 |
8 | in-context decoder:
9 |
10 | list get_trainable_params()
11 |
12 | forward(source, reference)
13 | reference = {'feature': feature_of_reference_image,
14 | 'guidance': guidance_on_reference_image}
15 |
16 | source = {'feature': feature_of_source_image, 'image': source_images}
17 |
18 | '''
19 |
20 | def __init__(self,
21 | cfg_detail_decoder,
22 | cfg_in_context_fusion,
23 | freeze_in_context_fusion=False,
24 | ):
25 | super().__init__()
26 |
27 | self.in_context_fusion = instantiate_from_config(
28 | cfg_in_context_fusion)
29 | self.detail_decoder = instantiate_from_config(cfg_detail_decoder)
30 |
31 | self.freeze_in_context_fusion = freeze_in_context_fusion
32 | if freeze_in_context_fusion:
33 | self.__freeze_in_context_fusion()
34 |
35 | def forward(self, source, reference):
36 | feature_of_reference_image = reference['feature']
37 | guidance_on_reference_image = reference['guidance']
38 |
39 | feature_of_source_image = source['feature']
40 | source_images = source['image']
41 |
42 | features = self.in_context_fusion(
43 | feature_of_reference_image, feature_of_source_image, guidance_on_reference_image)
44 |
45 | output = self.detail_decoder(features, source_images)
46 |
47 | return output, features['mask'], features['trimap']
48 |
49 | def get_trainable_params(self):
50 | params = []
51 | params = params + list(self.detail_decoder.parameters())
52 | if not self.freeze_in_context_fusion:
53 | params = params + list(self.in_context_fusion.parameters())
54 | return params
55 |
56 | def __freeze_in_context_fusion(self):
57 | for param in self.in_context_fusion.parameters():
58 | param.requires_grad = False
59 |
--------------------------------------------------------------------------------
/icm/models/feature_extractor/attention_controllers.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Tuple, List, Callable, Dict
2 | import torch
3 | import torch.nn.functional as nnf
4 | import numpy as np
5 | import abc
6 |
7 |
8 | LOW_RESOURCE = False
9 |
10 |
11 | class AttentionControl(abc.ABC):
12 |
13 | def step_callback(self, x_t):
14 | return x_t
15 |
16 | def between_steps(self):
17 | return
18 |
19 | @property
20 | def num_uncond_att_layers(self):
21 | return self.num_att_layers if LOW_RESOURCE else 0
22 |
23 | @abc.abstractmethod
24 | def forward (self, attn, is_cross: bool, place_in_unet: str):
25 | raise NotImplementedError
26 |
27 | def __call__(self, attn, is_cross: bool, place_in_unet: str, ensemble_size=1, token_batch_size=1):
28 | if self.cur_att_layer >= self.num_uncond_att_layers:
29 | if LOW_RESOURCE:
30 | attn = self.forward(attn, is_cross, place_in_unet, ensemble_size, token_batch_size)
31 | else:
32 | h = attn.shape[0]
33 | # attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
34 | attn = self.forward(attn, is_cross, place_in_unet, ensemble_size, token_batch_size)
35 | self.cur_att_layer += 1
36 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
37 | self.cur_att_layer = 0
38 | self.cur_step += 1
39 | self.between_steps()
40 | return attn
41 |
42 | def reset(self):
43 | self.cur_step = 0
44 | self.cur_att_layer = 0
45 |
46 | def __init__(self):
47 | self.cur_step = 0
48 | self.num_att_layers = -1
49 | self.cur_att_layer = 0
50 |
51 | class EmptyControl(AttentionControl):
52 |
53 | def forward (self, attn, is_cross: bool, place_in_unet: str):
54 | return attn
55 |
56 |
57 | class AttentionStore(AttentionControl):
58 |
59 | @staticmethod
60 | def get_empty_store():
61 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
62 | "down_self": [], "mid_self": [], "up_self": []}
63 |
64 | def forward(self, attn, is_cross: bool, place_in_unet: str, ensemble_size=1, token_batch_size=1):
65 | num_head = attn.shape[0]//token_batch_size
66 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
67 | if self.store_res is not None:
68 | if attn.shape[1] in self.store_res and (is_cross is False):
69 | attn = attn.reshape(-1, ensemble_size, *attn.shape[1:])
70 | attn = attn.mean(dim=1)
71 | attn = attn.reshape(-1,num_head , *attn.shape[1:])
72 | attn = attn.mean(dim=1)
73 | self.step_store[key].append(attn)
74 | elif attn.shape[1] <= 48 ** 2 and (is_cross is False): # avoid memory overhead
75 | attn = attn.reshape(-1, ensemble_size, *attn.shape[1:])
76 | attn = attn.mean(dim=1)
77 | attn = attn.reshape(-1,num_head , *attn.shape[1:])
78 | attn = attn.mean(dim=1)
79 | self.step_store[key].append(attn)
80 |
81 | torch.cuda.empty_cache()
82 |
83 | def between_steps(self):
84 | if len(self.attention_store) == 0:
85 | self.attention_store = self.step_store
86 | else:
87 | for key in self.attention_store:
88 | for i in range(len(self.attention_store[key])):
89 | self.attention_store[key][i] += self.step_store[key][i]
90 | self.step_store = self.get_empty_store()
91 |
92 | def get_average_attention(self):
93 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
94 | return average_attention
95 |
96 |
97 | def reset(self):
98 | super(AttentionStore, self).reset()
99 | del self.step_store
100 | torch.cuda.empty_cache()
101 | self.step_store = self.get_empty_store()
102 | self.attention_store = {}
103 |
104 | def __init__(self,store_res = None):
105 | super(AttentionStore, self).__init__()
106 | self.step_store = self.get_empty_store()
107 | self.attention_store = {}
108 |
109 | store_res = [store_res] if isinstance(store_res, int) else list(store_res)
110 | self.store_res = []
111 | for res in store_res:
112 | self.store_res.append(res**2)
113 |
114 |
--------------------------------------------------------------------------------
/icm/models/feature_extractor/dift_sd.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline
2 | import torch
3 | import torch.nn as nn
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from typing import Any, Callable, Dict, List, Optional, Union
7 | from diffusers.models.unet_2d_condition import UNet2DConditionModel
8 | from diffusers import DDIMScheduler
9 | import gc
10 | from PIL import Image
11 |
12 | from icm.models.feature_extractor.attention_controllers import AttentionStore
13 | import xformers
14 |
15 |
16 | def register_attention_control(model, controller, if_softmax=True, ensemble_size=1):
17 | def ca_forward(self, place_in_unet, att_opt_b):
18 |
19 | class MyXFormersAttnProcessor:
20 | r"""
21 | Processor for implementing memory efficient attention using xFormers.
22 |
23 | Args:
24 | attention_op (`Callable`, *optional*, defaults to `None`):
25 | The base
26 | [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
27 | use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
28 | operator.
29 | """
30 |
31 | def __init__(self, attention_op=None):
32 | self.attention_op = attention_op
33 |
34 | def __call__(
35 | self,
36 | attn,
37 | hidden_states: torch.FloatTensor,
38 | encoder_hidden_states=None,
39 | attention_mask=None,
40 | temb=None,
41 | ):
42 | residual = hidden_states
43 |
44 | if attn.spatial_norm is not None:
45 | hidden_states = attn.spatial_norm(hidden_states, temb)
46 |
47 | input_ndim = hidden_states.ndim
48 |
49 | if input_ndim == 4:
50 | batch_size, channel, height, width = hidden_states.shape
51 | hidden_states = hidden_states.view(
52 | batch_size, channel, height * width).transpose(1, 2)
53 |
54 | batch_size, key_tokens, _ = (
55 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
56 | )
57 |
58 | attention_mask = attn.prepare_attention_mask(
59 | attention_mask, key_tokens, batch_size)
60 | if attention_mask is not None:
61 | # expand our mask's singleton query_tokens dimension:
62 | # [batch*heads, 1, key_tokens] ->
63 | # [batch*heads, query_tokens, key_tokens]
64 | # so that it can be added as a bias onto the attention scores that xformers computes:
65 | # [batch*heads, query_tokens, key_tokens]
66 | # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
67 | _, query_tokens, _ = hidden_states.shape
68 | attention_mask = attention_mask.expand(
69 | -1, query_tokens, -1)
70 |
71 | if attn.group_norm is not None:
72 | hidden_states = attn.group_norm(
73 | hidden_states.transpose(1, 2)).transpose(1, 2)
74 |
75 | query = attn.to_q(hidden_states)
76 |
77 | is_cross = False if encoder_hidden_states is None else True
78 |
79 | if encoder_hidden_states is None:
80 | encoder_hidden_states = hidden_states
81 | elif attn.norm_cross:
82 | encoder_hidden_states = attn.norm_encoder_hidden_states(
83 | encoder_hidden_states)
84 |
85 | key = attn.to_k(encoder_hidden_states)
86 | value = attn.to_v(encoder_hidden_states)
87 |
88 | query = attn.head_to_batch_dim(query).contiguous()
89 | key = attn.head_to_batch_dim(key).contiguous()
90 | value = attn.head_to_batch_dim(value).contiguous()
91 |
92 | # controller
93 | if query.shape[1] in controller.store_res:
94 | sim = torch.einsum('b i d, b j d -> b i j',
95 | query, key) * attn.scale
96 |
97 | if if_softmax:
98 | sim = sim / if_softmax
99 | my_attn = sim.softmax(dim=-1).detach()
100 | del sim
101 | else:
102 | my_attn = sim.detach()
103 |
104 | controller(my_attn, is_cross, place_in_unet, ensemble_size, batch_size)
105 |
106 | # end controller
107 |
108 | hidden_states = xformers.ops.memory_efficient_attention(
109 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
110 | )
111 | hidden_states = hidden_states.to(query.dtype)
112 | hidden_states = attn.batch_to_head_dim(hidden_states)
113 |
114 | # linear proj
115 | hidden_states = attn.to_out[0](hidden_states)
116 | # dropout
117 | hidden_states = attn.to_out[1](hidden_states)
118 |
119 | if input_ndim == 4:
120 | hidden_states = hidden_states.transpose(
121 | -1, -2).reshape(batch_size, channel, height, width)
122 |
123 | if attn.residual_connection:
124 | hidden_states = hidden_states + residual
125 |
126 | hidden_states = hidden_states / attn.rescale_output_factor
127 |
128 | return hidden_states
129 |
130 | return MyXFormersAttnProcessor(att_opt_b)
131 |
132 | class DummyController:
133 |
134 | def __call__(self, *args):
135 | return args[0]
136 |
137 | def __init__(self):
138 | self.num_att_layers = 0
139 |
140 | if controller is None:
141 | controller = DummyController()
142 |
143 | def register_recr(net_, count, place_in_unet):
144 | if net_.__class__.__name__ == 'Attention':
145 | net_.processor = ca_forward(
146 | net_, place_in_unet, net_.processor.attention_op)
147 | return count + 1
148 | elif hasattr(net_, 'children'):
149 | for net__ in net_.children():
150 | count = register_recr(net__, count, place_in_unet)
151 | return count
152 |
153 | cross_att_count = 0
154 | # sub_nets = model.unet.named_children()
155 | sub_nets = model.unet.named_children()
156 | # for net in sub_nets:
157 | # if "down" in net[0]:
158 | # cross_att_count += register_recr(net[1], 0, "down")
159 | # elif "up" in net[0]:
160 | # cross_att_count += register_recr(net[1], 0, "up")
161 | # elif "mid" in net[0]:
162 | # cross_att_count += register_recr(net[1], 0, "mid")
163 | for net in sub_nets:
164 | if "down_blocks" in net[0]:
165 | cross_att_count += register_recr(net[1], 0, "down")
166 | elif "up_blocks" in net[0]:
167 | cross_att_count += register_recr(net[1], 0, "up")
168 | elif "mid_block" in net[0]:
169 | cross_att_count += register_recr(net[1], 0, "mid")
170 | controller.num_att_layers = cross_att_count
171 |
172 |
173 | class MyUNet2DConditionModel(UNet2DConditionModel):
174 | def forward(
175 | self,
176 | sample: torch.FloatTensor,
177 | timestep: Union[torch.Tensor, float, int],
178 | up_ft_indices,
179 | encoder_hidden_states: torch.Tensor,
180 | class_labels: Optional[torch.Tensor] = None,
181 | timestep_cond: Optional[torch.Tensor] = None,
182 | attention_mask: Optional[torch.Tensor] = None,
183 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
184 | ):
185 | r"""
186 | Args:
187 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
188 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
189 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
190 | cross_attention_kwargs (`dict`, *optional*):
191 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
192 | `self.processor` in
193 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
194 | """
195 | # By default samples have to be AT least a multiple of the overall upsampling factor.
196 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
197 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
198 | # on the fly if necessary.
199 | default_overall_up_factor = 2**self.num_upsamplers
200 |
201 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
202 | forward_upsample_size = False
203 | upsample_size = None
204 |
205 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
206 | # logger.info("Forward upsample size to force interpolation output size.")
207 | forward_upsample_size = True
208 |
209 | # prepare attention_mask
210 | if attention_mask is not None:
211 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
212 | attention_mask = attention_mask.unsqueeze(1)
213 |
214 | # 0. center input if necessary
215 | if self.config.center_input_sample:
216 | sample = 2 * sample - 1.0
217 |
218 | # 1. time
219 | timesteps = timestep
220 | if not torch.is_tensor(timesteps):
221 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
222 | # This would be a good case for the `match` statement (Python 3.10+)
223 | is_mps = sample.device.type == "mps"
224 | if isinstance(timestep, float):
225 | dtype = torch.float32 if is_mps else torch.float64
226 | else:
227 | dtype = torch.int32 if is_mps else torch.int64
228 | timesteps = torch.tensor(
229 | [timesteps], dtype=dtype, device=sample.device)
230 | elif len(timesteps.shape) == 0:
231 | timesteps = timesteps[None].to(sample.device)
232 |
233 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
234 | timesteps = timesteps.expand(sample.shape[0])
235 |
236 | t_emb = self.time_proj(timesteps)
237 |
238 | # timesteps does not contain any weights and will always return f32 tensors
239 | # but time_embedding might actually be running in fp16. so we need to cast here.
240 | # there might be better ways to encapsulate this.
241 | t_emb = t_emb.to(dtype=self.dtype)
242 |
243 | emb = self.time_embedding(t_emb, timestep_cond)
244 |
245 | if self.class_embedding is not None:
246 | if class_labels is None:
247 | raise ValueError(
248 | "class_labels should be provided when num_class_embeds > 0"
249 | )
250 |
251 | if self.config.class_embed_type == "timestep":
252 | class_labels = self.time_proj(class_labels)
253 |
254 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
255 | emb = emb + class_emb
256 |
257 | # 2. pre-process
258 | sample = self.conv_in(sample)
259 |
260 | # 3. down
261 | down_block_res_samples = (sample,)
262 | for downsample_block in self.down_blocks:
263 | if (
264 | hasattr(downsample_block, "has_cross_attention")
265 | and downsample_block.has_cross_attention
266 | ):
267 | sample, res_samples = downsample_block(
268 | hidden_states=sample,
269 | temb=emb,
270 | encoder_hidden_states=encoder_hidden_states,
271 | attention_mask=attention_mask,
272 | cross_attention_kwargs=cross_attention_kwargs,
273 | )
274 | else:
275 | sample, res_samples = downsample_block(
276 | hidden_states=sample, temb=emb)
277 |
278 | down_block_res_samples += res_samples
279 |
280 | # 4. mid
281 | if self.mid_block is not None:
282 | sample = self.mid_block(
283 | sample,
284 | emb,
285 | encoder_hidden_states=encoder_hidden_states,
286 | attention_mask=attention_mask,
287 | cross_attention_kwargs=cross_attention_kwargs,
288 | )
289 |
290 | # 5. up
291 | up_ft = {}
292 | for i, upsample_block in enumerate(self.up_blocks):
293 | # if i > np.max(up_ft_indices):
294 | # break
295 |
296 | is_final_block = i == len(self.up_blocks) - 1
297 |
298 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
299 | down_block_res_samples = down_block_res_samples[
300 | : -len(upsample_block.resnets)
301 | ]
302 |
303 | # if we have not reached the final block and need to forward the
304 | # upsample size, we do it here
305 | if not is_final_block and forward_upsample_size:
306 | upsample_size = down_block_res_samples[-1].shape[2:]
307 |
308 | if (
309 | hasattr(upsample_block, "has_cross_attention")
310 | and upsample_block.has_cross_attention
311 | ):
312 | sample = upsample_block(
313 | hidden_states=sample,
314 | temb=emb,
315 | res_hidden_states_tuple=res_samples,
316 | encoder_hidden_states=encoder_hidden_states,
317 | cross_attention_kwargs=cross_attention_kwargs,
318 | upsample_size=upsample_size,
319 | attention_mask=attention_mask,
320 | )
321 | else:
322 | sample = upsample_block(
323 | hidden_states=sample,
324 | temb=emb,
325 | res_hidden_states_tuple=res_samples,
326 | upsample_size=upsample_size,
327 | )
328 |
329 | if i in up_ft_indices:
330 | up_ft[i] = sample.detach()
331 |
332 | output = {}
333 | output["up_ft"] = up_ft
334 | return output
335 |
336 |
337 | class OneStepSDPipeline(StableDiffusionPipeline):
338 | @torch.no_grad()
339 | def __call__(
340 | self,
341 | img_tensor,
342 | t,
343 | up_ft_indices,
344 | negative_prompt: Optional[Union[str, List[str]]] = None,
345 | generator: Optional[Union[torch.Generator,
346 | List[torch.Generator]]] = None,
347 | prompt_embeds: Optional[torch.FloatTensor] = None,
348 | callback: Optional[Callable[[
349 | int, int, torch.FloatTensor], None]] = None,
350 | callback_steps: int = 1,
351 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
352 | ):
353 | device = self._execution_device
354 | latents = (
355 | self.vae.encode(img_tensor).latent_dist.sample()
356 | * self.vae.config.scaling_factor
357 | )
358 | t = torch.tensor(t, dtype=torch.long, device=device)
359 | noise = torch.randn_like(latents).to(device)
360 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
361 | unet_output = self.unet(
362 | latents_noisy,
363 | t,
364 | up_ft_indices,
365 | encoder_hidden_states=prompt_embeds,
366 | cross_attention_kwargs=cross_attention_kwargs,
367 | )
368 | return unet_output
369 |
370 |
371 | class SDFeaturizer(nn.Module):
372 | def __init__(self, sd_id='pretrained_models/stable-diffusion-2-1',
373 | load_local=True, ):
374 | super().__init__()
375 | # sd_id="stabilityai/stable-diffusion-2-1", load_local=False):
376 | unet = MyUNet2DConditionModel.from_pretrained(
377 | sd_id,
378 | subfolder="unet",
379 | # output_loading_info=True,
380 | local_files_only=load_local,
381 | low_cpu_mem_usage=True,
382 | use_safetensors=False,
383 | # torch_dtype=torch.float16,
384 | # device_map="auto",
385 | )
386 | onestep_pipe = OneStepSDPipeline.from_pretrained(
387 | sd_id,
388 | unet=unet,
389 | safety_checker=None,
390 | local_files_only=load_local,
391 | low_cpu_mem_usage=True,
392 | use_safetensors=False,
393 | # torch_dtype=torch.float16,
394 | # device_map="auto",
395 | )
396 | onestep_pipe.vae.decoder = None
397 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained(
398 | sd_id, subfolder="scheduler"
399 | )
400 | gc.collect()
401 |
402 | onestep_pipe = onestep_pipe.to("cuda")
403 |
404 | onestep_pipe.enable_attention_slicing()
405 | onestep_pipe.enable_xformers_memory_efficient_attention()
406 | self.pipe = onestep_pipe
407 |
408 | # register nn.module for ddp
409 | self.vae = self.pipe.vae
410 | self.unet = self.pipe.unet
411 |
412 | # freeze vae and unet
413 | for param in self.vae.parameters():
414 | param.requires_grad = False
415 | for param in self.unet.parameters():
416 | param.requires_grad = False
417 |
418 | @torch.no_grad()
419 | def forward(self, img_tensor, prompt='', t=261, up_ft_index=3, ensemble_size=8):
420 | """
421 | Args:
422 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
423 | prompt: the prompt to use, a string
424 | t: the time step to use, should be an int in the range of [0, 1000]
425 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
426 | ensemble_size: the number of repeated images used in the batch to extract features
427 | Return:
428 | unet_ft: a torch tensor in the shape of [1, c, h, w]
429 | """
430 | img_tensor = img_tensor.repeat(
431 | ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w
432 | prompt_embeds = self.pipe._encode_prompt(
433 | prompt=prompt,
434 | device="cuda",
435 | num_images_per_prompt=1,
436 | do_classifier_free_guidance=False,
437 | ) # [1, 77, dim]
438 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1)
439 | unet_ft_all = self.pipe(
440 | img_tensor=img_tensor,
441 | t=t,
442 | up_ft_indices=[up_ft_index],
443 | prompt_embeds=prompt_embeds,
444 | )
445 | unet_ft = unet_ft_all["up_ft"][up_ft_index] # ensem, c, h, w
446 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
447 | return unet_ft
448 | # index 0: 1280, 24, 24
449 | # index 1: 1280, 48, 48
450 | # index 2: 640, 96, 96
451 | # index 3: 320, 96,96
452 | @torch.no_grad()
453 | def forward_feature_extractor(self, uc, img_tensor, t=261, up_ft_index=[0, 1, 2, 3], ensemble_size=8):
454 | """
455 | Args:
456 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
457 | prompt: the prompt to use, a string
458 | t: the time step to use, should be an int in the range of [0, 1000]
459 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
460 | ensemble_size: the number of repeated images used in the batch to extract features
461 | Return:
462 | unet_ft: a torch tensor in the shape of [1, c, h, w]
463 | """
464 | batch_size = img_tensor.shape[0]
465 |
466 | img_tensor = img_tensor.unsqueeze(1).repeat(1, ensemble_size, 1, 1, 1)
467 |
468 | img_tensor = img_tensor.reshape(-1, *img_tensor.shape[2:])
469 |
470 | prompt_embeds = uc.repeat(
471 | img_tensor.shape[0], 1, 1).to(img_tensor.device)
472 | unet_ft_all = self.pipe(
473 | img_tensor=img_tensor,
474 | t=t,
475 | up_ft_indices=up_ft_index,
476 | prompt_embeds=prompt_embeds,
477 | )
478 | unet_ft = unet_ft_all["up_ft"] # ensem, c, h, w
479 |
480 | return unet_ft
481 |
482 |
483 | class FeatureExtractor(nn.Module):
484 | def __init__(self,
485 | sd_id='stabilityai/stable-diffusion-2-1', # 'pretrained_models/stable-diffusion-2-1',
486 | load_local=True,
487 | if_softmax=False,
488 | feature_index_cor=1,
489 | feature_index_matting=4,
490 | attention_res=32, # [16, 32],
491 | set_diag_to_one=True,
492 | time_steps=[0],
493 | extract_feature_inputted_to_layer=False,
494 | ensemble_size=8):
495 | super().__init__()
496 |
497 | self.dift_sd = SDFeaturizer(sd_id=sd_id, load_local=load_local)
498 | # register buffer for prompt embedding
499 | self.register_buffer("prompt_embeds", self.dift_sd.pipe._encode_prompt(
500 | prompt='',
501 | num_images_per_prompt=1,
502 | do_classifier_free_guidance=False,
503 | device="cuda",
504 | ))
505 | # free self.pipe.tokenizer and self.pipe.text_encoder
506 | del self.dift_sd.pipe.tokenizer
507 | del self.dift_sd.pipe.text_encoder
508 | gc.collect()
509 | torch.cuda.empty_cache()
510 | self.feature_index_cor = feature_index_cor
511 | self.feature_index_matting = feature_index_matting
512 | self.attention_res = attention_res
513 | self.set_diag_to_one = set_diag_to_one
514 | self.time_steps = time_steps
515 | self.extract_feature_inputted_to_layer = extract_feature_inputted_to_layer
516 | self.ensemble_size = ensemble_size
517 | self.register_attention_store(
518 | if_softmax=if_softmax, attention_res=attention_res)
519 |
520 |
521 | def register_attention_store(self, if_softmax=False, attention_res=[16, 32]):
522 | self.controller = AttentionStore(store_res=attention_res)
523 |
524 | register_attention_control(
525 | self.dift_sd.pipe, self.controller, if_softmax=if_softmax, ensemble_size=self.ensemble_size)
526 |
527 | def get_trainable_params(self):
528 | return []
529 |
530 | def get_reference_feature(self, images):
531 | self.controller.reset()
532 | batch_size = images.shape[0]
533 | features = self.dift_sd.forward_feature_extractor(
534 | self.prompt_embeds, images, t=self.time_steps[0], ensemble_size=self.ensemble_size) # b*e, c, h, w
535 |
536 | features = self.ensemble_feature(
537 | features, self.feature_index_cor, batch_size)
538 |
539 | return features.detach()
540 |
541 | def ensemble_feature(self, features, index, batch_size):
542 | if isinstance(index, int):
543 |
544 | features_ = features[index].reshape(
545 | batch_size, self.ensemble_size, *features[index].shape[1:])
546 | features_ = features_.mean(1, keepdim=False).detach()
547 | else:
548 | index = list(index)
549 | res = ['24','48','96']
550 | res = res[:len(index)]
551 | features_ = {}
552 | for i in range(len(index)):
553 | features_[res[i]] = features[index[i]].reshape(
554 | batch_size, self.ensemble_size, *features[index[i]].shape[1:])
555 | features_[res[i]] = features_[res[i]].mean(1, keepdim=False).detach()
556 | return features_
557 |
558 | def get_source_feature(self, images):
559 | # return {"ft": [B, C, H, W], "attn": [B, H, W, H*W]}
560 |
561 | self.controller.reset()
562 | torch.cuda.empty_cache()
563 | batch_size = images.shape[0]
564 |
565 | ft = self.dift_sd.forward_feature_extractor(
566 | self.prompt_embeds, images, t=self.time_steps[0], ensemble_size=self.ensemble_size) # b*e, c, h, w
567 |
568 |
569 | attention_maps = self.get_feature_attention(batch_size)
570 |
571 | output = {"ft_cor": self.ensemble_feature(ft, self.feature_index_cor, batch_size),
572 | "attn": attention_maps, 'ft_matting': self.ensemble_feature(ft, self.feature_index_matting, batch_size)}
573 | return output
574 |
575 | def get_feature_attention(self, batch_size):
576 |
577 | attention_maps = self.__aggregate_attention(
578 | from_where=["down", "mid", "up"], is_cross=False, batch_size=batch_size)
579 |
580 | for attn_map in attention_maps.keys():
581 | attention_maps[attn_map] = attention_maps[attn_map].permute(0, 2, 1).reshape(
582 | (batch_size, -1, int(attn_map), int(attn_map))) # [bs, h*w, h, w]
583 | attention_maps[attn_map] = attention_maps[attn_map].permute(0, 2, 3, 1) # [bs, h, w, h*w]
584 | return attention_maps
585 |
586 | def __aggregate_attention(self, from_where: List[str], is_cross: bool, batch_size: int):
587 | out = {}
588 | self.controller.between_steps()
589 | self.controller.cur_step=1
590 | attention_maps = self.controller.get_average_attention()
591 | for res in self.attention_res:
592 | out[str(res)] = self.__aggregate_attention_single_res(
593 | from_where, is_cross, batch_size, res, attention_maps)
594 | return out
595 |
596 | def __aggregate_attention_single_res(self, from_where: List[str], is_cross: bool, batch_size: int, res: int, attention_maps):
597 | out = []
598 | num_pixels = res ** 2
599 | for location in from_where:
600 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
601 | if item.shape[1] == num_pixels:
602 | cross_maps = item.reshape(
603 | batch_size, -1, res, res, item.shape[-1])
604 | out.append(cross_maps)
605 | out = torch.cat(out, dim=1)
606 | out = out.sum(1) / out.shape[1]
607 | out = out.reshape(batch_size, out.shape[-1], out.shape[-1])
608 |
609 | if self.set_diag_to_one:
610 | for o in out:
611 | o = o - torch.diag(torch.diag(o)) + \
612 | torch.eye(o.shape[0]).to(o.device)
613 | return out
614 |
--------------------------------------------------------------------------------
/icm/models/in_context_matting.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from pytorch_lightning.utilities.types import STEP_OUTPUT
3 | import torch
4 | import pytorch_lightning as pl
5 | from torch.optim.lr_scheduler import LambdaLR
6 |
7 | from icm.criterion.matting_criterion_eval import compute_mse_loss_torch, compute_sad_loss_torch
8 | from icm.util import instantiate_from_config
9 | from pytorch_lightning.utilities import rank_zero_only
10 | import os
11 | import cv2
12 | class InContextMatting(pl.LightningModule):
13 | '''
14 | In Context Matting Model
15 | consists of a feature extractor and a in-context decoder
16 | train model with lr, scheduler, loss
17 | '''
18 |
19 | def __init__(
20 | self,
21 | cfg_feature_extractor,
22 | cfg_in_context_decoder,
23 | cfg_loss_function,
24 | learning_rate,
25 | cfg_scheduler=None,
26 | **kwargs,
27 | ):
28 | super().__init__()
29 |
30 | self.feature_extractor = instantiate_from_config(
31 | cfg_feature_extractor)
32 | self.in_context_decoder = instantiate_from_config(
33 | cfg_in_context_decoder)
34 |
35 | self.loss_function = instantiate_from_config(cfg_loss_function)
36 |
37 | self.learning_rate = learning_rate
38 | self.cfg_scheduler = cfg_scheduler
39 |
40 | def forward(self, reference_images, guidance_on_reference_image, source_images):
41 |
42 | feature_of_reference_image = self.feature_extractor.get_reference_feature(
43 | reference_images)
44 |
45 | feature_of_source_image = self.feature_extractor.get_source_feature(
46 | source_images)
47 |
48 | reference = {'feature': feature_of_reference_image,
49 | 'guidance': guidance_on_reference_image}
50 |
51 | source = {'feature': feature_of_source_image, 'image': source_images}
52 |
53 | output, cross_map, self_map = self.in_context_decoder(source, reference)
54 |
55 | return output, cross_map, self_map
56 |
57 | def on_train_epoch_start(self):
58 | self.log("epoch", self.current_epoch, on_step=False,
59 | on_epoch=True, prog_bar=False, sync_dist=True)
60 |
61 | def training_step(self, batch, batch_idx):
62 | loss_dict, loss, _, _, _ = self.__shared_step(batch)
63 |
64 | self.__log_loss(loss_dict, loss, "train")
65 |
66 | # log learning rate
67 | self.log("lr", self.trainer.optimizers[0].param_groups[0]
68 | ["lr"], on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
69 |
70 | return loss
71 |
72 | def validation_step(self, batch, batch_idx):
73 | loss_dict, loss, preds, cross_map, self_map = self.__shared_step(batch)
74 |
75 | self.__log_loss(loss_dict, loss, "val")
76 | batch['cross_map'] = cross_map
77 | batch['self_map'] = self_map
78 | return preds, batch
79 |
80 | def __shared_step(self, batch):
81 | reference_images, guidance_on_reference_image, source_images, labels, trimaps = batch[
82 | "reference_image"], batch["guidance_on_reference_image"], batch["source_image"], batch["alpha"], batch["trimap"]
83 |
84 | outputs, cross_map, self_map = self(reference_images,
85 | guidance_on_reference_image, source_images)
86 |
87 | sample_map = torch.zeros_like(trimaps)
88 | sample_map[trimaps==0.5] = 1
89 |
90 | loss_dict = self.loss_function(sample_map, outputs, labels)
91 |
92 | loss = sum(loss_dict.values())
93 | if loss > 1e4 or torch.isnan(loss):
94 | raise ValueError(f"Loss explosion: {loss}")
95 | return loss_dict, loss, outputs, cross_map, self_map
96 |
97 | def __log_loss(self, loss_dict, loss, prefix):
98 | loss_dict = {
99 | f"{prefix}/{key}": loss_dict.get(key) for key in loss_dict}
100 | self.log_dict(loss_dict, on_step=True, on_epoch=True,
101 | prog_bar=False, sync_dist=True)
102 | self.log(f"{prefix}/loss", loss, on_step=True,
103 | on_epoch=True, prog_bar=True, sync_dist=True)
104 |
105 | def validation_step_end(self, outputs):
106 |
107 | preds, batch = outputs
108 | h, w = batch['alpha_shape']
109 |
110 |
111 | cross_map = batch['cross_map']
112 | self_map = batch['self_map']
113 | # resize cross_map and self_map to the same size as preds
114 | cross_map = torch.nn.functional.interpolate(
115 | cross_map, size=preds.shape[2:], mode='bilinear', align_corners=False)
116 | self_map = torch.nn.functional.interpolate(
117 | self_map, size=preds.shape[2:], mode='bilinear', align_corners=False)
118 |
119 | # normalize cross_map and self_map
120 | cross_map = (cross_map - cross_map.min()) / \
121 | (cross_map.max() - cross_map.min())
122 | self_map = (self_map - self_map.min()) / \
123 | (self_map.max() - self_map.min())
124 |
125 | cross_map = cross_map[0].squeeze()*255.0
126 | self_map = self_map[0].squeeze()*255.0
127 |
128 | # get one sample from batch
129 | pred = preds[0].squeeze()*255.0
130 | source_image = batch['source_image'][0]
131 | label = batch["alpha"][0].squeeze()*255.0
132 | trimap = batch["trimap"][0].squeeze()*255.0
133 | trimap[trimap == 127.5] = 128
134 | reference_image = batch["reference_image"][0]
135 | guidance_on_reference_image = batch["guidance_on_reference_image"][0]
136 | dataset_name = batch["dataset_name"][0]
137 | image_name = batch["image_name"][0].split('.')[0]
138 |
139 | # save pre to model.val_save_path
140 |
141 | # if self.val_save_path is not None:
142 | if hasattr(self, 'val_save_path'):
143 | os.makedirs(self.val_save_path, exist_ok=True)
144 | # resize preds to h,w
145 | pred_ = torch.nn.functional.interpolate(
146 | pred.unsqueeze(0).unsqueeze(0), size=(h, w), mode='bilinear', align_corners=False)
147 | pred_ = pred_.squeeze().cpu().numpy()
148 | pred_ = pred_.astype('uint8')
149 | cv2.imwrite(os.path.join(self.val_save_path, image_name+'.png'), pred_)
150 |
151 | masked_reference_image = reference_image*guidance_on_reference_image
152 | # self.__compute_and_log_mse_sad_of_one_sample(
153 | # pred, label, trimap, prefix="val")
154 |
155 | self.__log_image(
156 | source_image, masked_reference_image, pred, label, dataset_name, image_name, prefix='val', self_map=self_map, cross_map=cross_map)
157 |
158 |
159 | # def validation_step_end(self, outputs):
160 |
161 | # preds, batch = outputs
162 |
163 | # cross_map = batch['cross_map']
164 | # self_map = batch['self_map']
165 | # # resize cross_map and self_map to the same size as preds
166 | # cross_map = torch.nn.functional.interpolate(
167 | # cross_map, size=preds.shape[2:], mode='bilinear', align_corners=False)
168 | # self_map = torch.nn.functional.interpolate(
169 | # self_map, size=preds.shape[2:], mode='bilinear', align_corners=False)
170 |
171 | # # normalize cross_map and self_map
172 | # cross_map = (cross_map - cross_map.min()) / \
173 | # (cross_map.max() - cross_map.min())
174 | # self_map = (self_map - self_map.min()) / \
175 | # (self_map.max() - self_map.min())
176 |
177 | # cross_map = cross_map[0].squeeze()*255.0
178 | # self_map = self_map[0].squeeze()*255.0
179 |
180 | # # get one sample from batch
181 | # pred = preds[0].squeeze()*255.0
182 | # source_image = batch['source_image'][0]
183 | # label = batch["alpha"][0].squeeze()*255.0
184 | # trimap = batch["trimap"][0].squeeze()*255.0
185 | # trimap[trimap == 127.5] = 128
186 | # reference_image = batch["reference_image"][0]
187 | # guidance_on_reference_image = batch["guidance_on_reference_image"][0]
188 | # dataset_name = batch["dataset_name"][0]
189 | # image_name = batch["image_name"][0].split('.')[0]
190 |
191 | # masked_reference_image = reference_image*guidance_on_reference_image
192 |
193 | # self.__compute_and_log_mse_sad_of_one_sample(
194 | # pred, label, trimap, prefix="val")
195 |
196 | # self.__log_image(
197 | # source_image, masked_reference_image, pred, label, dataset_name, image_name, prefix='val', self_map=self_map, cross_map=cross_map)
198 |
199 | def __compute_and_log_mse_sad_of_one_sample(self, pred, label, trimap, prefix="val"):
200 | # compute loss for unknown pixels
201 | mse_loss_unknown_ = compute_mse_loss_torch(pred, label, trimap)
202 | sad_loss_unknown_ = compute_sad_loss_torch(pred, label, trimap)
203 |
204 | # compute loss for all pixels
205 | trimap = torch.ones_like(label)*128
206 | mse_loss_all_ = compute_mse_loss_torch(pred, label, trimap)
207 | sad_loss_all_ = compute_sad_loss_torch(pred, label, trimap)
208 |
209 | # log
210 | metrics_unknown = {f'{prefix}/mse_unknown': mse_loss_unknown_,
211 | f'{prefix}/sad_unknown': sad_loss_unknown_, }
212 |
213 | metrics_all = {f'{prefix}/mse_all': mse_loss_all_,
214 | f'{prefix}/sad_all': sad_loss_all_, }
215 |
216 | self.log_dict(metrics_unknown, on_step=False,
217 | on_epoch=True, prog_bar=False, sync_dist=True)
218 | self.log_dict(metrics_all, on_step=False,
219 | on_epoch=True, prog_bar=False, sync_dist=True)
220 |
221 | def __log_image(self, source_image, masked_reference_image, pred, label, dataset_name, image_name, prefix='val', self_map=None, cross_map=None):
222 | ########### log source_image, masked_reference_image, output and gt ###########
223 | # process image, masked_reference_image, pred, label
224 | source_image = self.__revert_normalize(source_image)
225 | masked_reference_image = self.__revert_normalize(
226 | masked_reference_image)
227 | pred = torch.stack((pred/255.0,)*3, axis=-1)
228 | label = torch.stack((label/255.0,)*3, axis=-1)
229 | self_map = torch.stack((self_map/255.0,)*3, axis=-1)
230 | cross_map = torch.stack((cross_map/255.0,)*3, axis=-1)
231 |
232 | # concat pred, masked_reference_image, label, source_image
233 | image_for_log = torch.stack(
234 | (source_image, masked_reference_image, label, pred, self_map, cross_map), axis=0)
235 |
236 | # log image
237 | self.logger.experiment.add_images(
238 | f'{prefix}-{dataset_name}/{image_name}', image_for_log, self.current_epoch, dataformats='NHWC')
239 |
240 | def __revert_normalize(self, image):
241 | # image: [C, H, W]
242 | image = image.permute(1, 2, 0)
243 | image = image * torch.tensor([0.229, 0.224, 0.225], device=self.device) + \
244 | torch.tensor([0.485, 0.456, 0.406], device=self.device)
245 | image = torch.clamp(image, 0, 1)
246 | return image
247 |
248 | def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
249 | torch.cuda.empty_cache()
250 | def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
251 | torch.cuda.empty_cache()
252 | def test_step(self, batch, batch_idx):
253 | loss_dict, loss, preds = self.__shared_step(batch)
254 |
255 | return loss_dict, loss, preds
256 |
257 | def configure_optimizers(self):
258 | params = self.__get_trainable_params()
259 | opt = torch.optim.Adam(params, lr=self.learning_rate)
260 |
261 | if self.cfg_scheduler is not None:
262 | scheduler = self.__get_scheduler(opt)
263 | return [opt], scheduler
264 | return opt
265 |
266 | def __get_trainable_params(self):
267 | params = []
268 | params = params + self.in_context_decoder.get_trainable_params() + \
269 | self.feature_extractor.get_trainable_params()
270 | return params
271 |
272 | def __get_scheduler(self, opt):
273 | scheduler = instantiate_from_config(self.cfg_scheduler)
274 | scheduler = [
275 | {
276 | "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
277 | "interval": "step",
278 | "frequency": 1,
279 | }
280 | ]
281 | return scheduler
282 |
283 | from pytorch_lightning.callbacks import ModelCheckpoint
284 |
285 | class ModifiedModelCheckpoint(ModelCheckpoint):
286 | def delete_frozen_params(self, ckpt):
287 | # delete params with requires_grad=False
288 | for k in list(ckpt["state_dict"].keys()):
289 | # remove ckpt['state_dict'][k] if 'feature_extractor' in k
290 | if "feature_extractor" in k:
291 | del ckpt["state_dict"][k]
292 | return ckpt
293 |
294 | def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
295 | super()._save_model(trainer, filepath)
296 |
297 | if trainer.is_global_zero:
298 | ckpt = torch.load(filepath)
299 | ckpt = self.delete_frozen_params(ckpt)
300 | torch.save(ckpt, filepath)
--------------------------------------------------------------------------------
/icm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch
3 | def instantiate_from_config(config):
4 | if not "target" in config:
5 | raise KeyError("Expected key `target` to instantiate.")
6 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
7 |
8 | def get_obj_from_str(string, reload=False):
9 | module, cls = string.rsplit(".", 1)
10 | if reload:
11 | module_imp = importlib.import_module(module)
12 | importlib.reload(module_imp)
13 | return getattr(importlib.import_module(module, package=None), cls)
14 |
15 | def instantiate_feature_extractor(cfg):
16 | model = instantiate_from_config(cfg)
17 | load_odise_params = cfg.get("load_odise_params", False)
18 | if load_odise_params:
19 | # load params
20 | params = torch.load('ckpt/odise_label_coco_50e-b67d2efc.pth')
21 | # gather the params with "backbone.feature_extractor." prefix
22 | params = {k.replace("backbone.feature_extractor.", ""): v for k, v in params['model'].items() if "backbone.feature_extractor." in k}
23 | model.load_state_dict(params, strict=False)
24 | '''
25 | alpha_cond
26 | alpha_cond_time_embed
27 | clip_project.positional_embedding
28 | clip_project.linear.weight
29 | clip_project.linear.bias
30 | time_embed_project.positional_embedding
31 | time_embed_project.linear.weight
32 | time_embed_project.linear.bias
33 | '''
34 |
35 | model.eval()
36 | return model
37 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | if __name__ == '__main__':
2 | import datetime
3 | import argparse
4 | from omegaconf import OmegaConf
5 |
6 | import os
7 |
8 | from icm.util import instantiate_from_config
9 | import torch
10 | from pytorch_lightning import Trainer, seed_everything
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument(
16 | "--experiment_name",
17 | type=str,
18 | default="in_context_matting",
19 | )
20 | parser.add_argument(
21 | "--debug",
22 | type=bool,
23 | default=False,
24 | )
25 | parser.add_argument(
26 | "--resume",
27 | type=str,
28 | default="",
29 | )
30 | parser.add_argument(
31 | "--fine_tune",
32 | type=bool,
33 | default=False,
34 | )
35 | parser.add_argument(
36 | "--config",
37 | type=str,
38 | default="",
39 | )
40 | parser.add_argument(
41 | "--logdir",
42 | type=str,
43 | default="logs",
44 | )
45 | parser.add_argument(
46 | "--seed",
47 | type=int,
48 | default=42,
49 | )
50 |
51 | args = parser.parse_args()
52 | return args
53 |
54 | import multiprocessing
55 | multiprocessing.set_start_method('spawn')
56 |
57 | args = parse_args()
58 | if args.resume:
59 | path = args.resume.split('checkpoints')[0]
60 | # get the folder of last version folder
61 | all_folder = os.listdir(path)
62 | all_folder = [os.path.join(path, folder) for folder in all_folder if 'version' in folder]
63 | all_folder.sort()
64 | last_version_folder = all_folder[-1]
65 | # get the hparams.yaml path
66 | hparams_path = os.path.join(last_version_folder, 'hparams.yaml')
67 | cfg = OmegaConf.load(hparams_path)
68 | else:
69 | cfg = OmegaConf.load(args.config)
70 |
71 | if args.fine_tune:
72 | cfg_ft = OmegaConf.load(args.config)
73 | # merge cfg and cfg_ft, cfg_ft will overwrite cfg
74 | cfg = OmegaConf.merge(cfg, cfg_ft)
75 |
76 | # set seed
77 | seed_everything(args.seed)
78 |
79 | """=== Init data ==="""
80 | cfg_data = cfg.get('data')
81 |
82 | data = instantiate_from_config(cfg_data)
83 |
84 | """=== Init model ==="""
85 | cfg_model = cfg.get('model')
86 |
87 | model = instantiate_from_config(cfg_model)
88 |
89 | """=== Init trainer ==="""
90 | cfg_trainer = cfg.get('trainer')
91 | # omegaconf to dict
92 | cfg_trainer = OmegaConf.to_container(cfg_trainer)
93 |
94 | if args.debug:
95 | cfg_trainer['limit_train_batches'] = 2
96 | # cfg_trainer['log_every_n_steps'] = 1
97 | cfg_trainer['limit_val_batches'] = 3
98 | # cfg_trainer['overfit_batches'] = 2
99 |
100 | # init logger
101 | cfg_logger = cfg_trainer.pop('cfg_logger')
102 |
103 | if args.resume:
104 | name = args.resume.split('/')[-3]
105 | else:
106 | name = datetime.datetime.now().strftime(
107 | "%Y-%m-%d_%H-%M-%S")+'-'+args.experiment_name
108 | cfg_logger['params']['save_dir'] = args.logdir
109 | cfg_logger['params']['name'] = name
110 | cfg_trainer['logger'] = instantiate_from_config(cfg_logger)
111 |
112 | # plugin
113 | cfg_plugin = cfg_trainer.pop('plugins')
114 | cfg_trainer['plugins'] = instantiate_from_config(cfg_plugin)
115 |
116 | # init callbacks
117 | cfg_callbacks = cfg_trainer.pop('cfg_callbacks')
118 | callbacks = []
119 | for callback_name in cfg_callbacks:
120 | if callback_name == 'modelcheckpoint':
121 | cfg_callbacks[callback_name]['params']['dirpath'] = os.path.join(
122 | args.logdir, name, 'checkpoints')
123 | callbacks.append(instantiate_from_config(cfg_callbacks[callback_name]))
124 | cfg_trainer['callbacks'] = callbacks
125 |
126 | if args.resume and not args.fine_tune:
127 | cfg_trainer['resume_from_checkpoint'] = args.resume
128 |
129 | if args.fine_tune:
130 | # load state_dict
131 | ckpt = torch.load(args.resume)
132 | model.load_state_dict(ckpt['state_dict'], strict=False)
133 | # init trainer
134 | trainer_opt = argparse.Namespace(**cfg_trainer)
135 | trainer = Trainer.from_argparse_args(trainer_opt)
136 |
137 | # save configs to log
138 | trainer.logger.log_hyperparams(cfg)
139 |
140 | """=== Start training ==="""
141 |
142 | trainer.fit(model, data)
143 |
--------------------------------------------------------------------------------