├── .gitattributes ├── .gitignore ├── README.md ├── config └── eval.yaml ├── demo └── src │ └── icon │ ├── arXiv-Paper.svg │ ├── bilibili-demo.svg │ ├── colab-badge.svg │ ├── license-MIT.svg │ ├── publication-Paper.svg │ └── youtube-demo.svg ├── environment.yml ├── eval.py ├── icm ├── criterion │ ├── __init__.py │ ├── loss_function.py │ └── matting_criterion_eval.py ├── data │ ├── __init__.py │ ├── data_generator.py │ ├── data_module.py │ └── image_file.py ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── decoder │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── bottleneck_block.py │ │ ├── detail_capture.py │ │ ├── in_context_correspondence.py │ │ └── in_context_decoder.py │ ├── feature_extractor │ │ ├── attention_controllers.py │ │ └── dift_sd.py │ └── in_context_matting.py └── util.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | icm/data/image_file.py merge=ours -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | !/.gitignore 3 | !/.gitattributes 4 | reference_code/ 5 | __pycache__/ 6 | logs/ 7 | ckpt/ 8 | pretrained_models/ 9 | datasets/ 10 | lightning_logs/ 11 | old_logs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

In-Context Matting [CVPR 2024, Highlight]

2 | 3 | 4 |

5 | 6 | 7 | 8 | 9 |

10 | 11 | 12 |

This is the official repository of the paper In-Context Matting.

13 | 14 |

Details of the model architecture and experimental results can be found in our homepage.

15 | 16 | ## TODO: 17 | - [x] Release code 18 | - [x] Release pre-trained models and instructions for inference 19 | - [x] Release ICM-57 dataset 20 | - [ ] Release training dataset and instructions for training 21 | 22 | ## Requirements 23 | We follow the environment setup of [Stable Diffusion Version 2](https://github.com/Stability-AI/StableDiffusion#requirements). 24 | 25 | ## Usage 26 | 27 | To evaluate the performance on the ICM-57 dataset using the `eval.py` script, follow these instructions: 28 | 29 | 1. **Download the Pretrained Model:** 30 | - Download the pretrained model from [this link](https://pan.baidu.com/s/1HPbRRE5ZtPRpOSocm9qOmA?pwd=BA1c). 31 | 32 | 2. **Prepare the dataset:** 33 | Ensure that your ICM-57 is ready. 34 | 35 | 3. **Run the Evaluation:** 36 | Use the following command to run the evaluation script. Replace the placeholders with the actual paths if they differ. 37 | 38 | ```bash 39 | python eval.py --checkpoint PATH_TO_MODEL --save_path results/ --config config/eval.yaml 40 | ``` 41 | 42 | ### Dataset 43 | **ICM-57** 44 | - Download link: [ICM-57 Dataset](https://pan.baidu.com/s/1ZJU_XHEVhIaVzGFPK_XCRg?pwd=BA1c) 45 | - **Installation Guide**: 46 | 1. After downloading, unzip the dataset into the `datasets/` directory of the project. 47 | 2. Ensure the structure of the dataset folder is as follows: 48 | ``` 49 | datasets/ICM57/ 50 | ├── image 51 | └── alpha 52 | ``` 53 | 54 | ### Acknowledgments 55 | 56 | We would like to express our gratitude to the developers and contributors of the [DIFT](https://github.com/Tsingularity/dift) and [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/) projects. Their shared resources and insights have significantly aided the development of our work. 57 | 58 | ## Statement 59 | 60 | 64 | 65 | This project is under the MIT license. For technical questions, please contact He Guo at [hguo01@hust.edu.cn](mailto:hguo01@hust.edu.cn). For commerial use, please contact Hao Lu at [hlu@hust.edu.cn](mailto:hlu@hust.edu.cn) 66 | -------------------------------------------------------------------------------- /config/eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: icm.models.in_context_matting.InContextMatting 3 | params: 4 | learning_rate: 0.0004 5 | cfg_loss_function: 6 | target: icm.criterion.loss_function.LossFunction2 7 | params: 8 | losses_seg: 9 | - known_smooth_l1_loss 10 | losses_matting: 11 | - unknown_l1_loss 12 | - known_l1_loss 13 | - loss_pha_laplacian 14 | - loss_gradient_penalty 15 | cfg_scheduler: 16 | target: icm.lr_scheduler.LambdaLinearScheduler 17 | params: 18 | warm_up_steps: 19 | - 250 20 | cycle_lengths: 21 | - 10000000000000 22 | f_start: 23 | - 1.0e-06 24 | f_max: 25 | - 1.0 26 | f_min: 27 | - 1.0 28 | cfg_feature_extractor: 29 | target: icm.models.feature_extractor.dift_sd.FeatureExtractor 30 | params: 31 | sd_id: stabilityai/stable-diffusion-2-1 32 | load_local: true 33 | if_softmax: 1 34 | feature_index_cor: 1 35 | feature_index_matting: 36 | - 0 37 | - 1 38 | attention_res: 39 | - 24 40 | - 48 41 | set_diag_to_one: false 42 | time_steps: 43 | - 200 44 | extract_feature_inputted_to_layer: false 45 | ensemble_size: 4 46 | cfg_in_context_decoder: 47 | target: icm.models.decoder.in_context_decoder.InContextDecoder 48 | params: 49 | freeze_in_context_fusion: false 50 | cfg_detail_decoder: 51 | target: icm.models.decoder.detail_capture.DetailCapture 52 | params: 53 | use_sigmoid: true 54 | ckpt: '' 55 | in_chans: 320 56 | img_chans: 3 57 | convstream_out: 58 | - 48 59 | - 96 60 | - 192 61 | fusion_out: 62 | - 256 63 | - 128 64 | - 64 65 | - 32 66 | cfg_in_context_fusion: 67 | target: icm.models.decoder.in_context_correspondence.SemiTrainingAttentionBlocks 68 | params: 69 | res_ratio: null 70 | pool_type: min 71 | upsample_mode: bicubic 72 | bottle_neck_dim: null 73 | use_norm: 1280 74 | in_ft_dim: 75 | - 1280 76 | - 1280 77 | in_attn_dim: 78 | - 576 79 | - 2304 80 | attn_out_dim: 256 81 | ft_out_dim: 82 | - 320 83 | - 320 84 | training_cross_attn: false 85 | data: 86 | target: icm.data.data_module.DataModuleFromConfig 87 | params: 88 | batch_size: 2 89 | batch_size_val: 1 90 | num_workers: 8 91 | shuffle_train: false 92 | validation: 93 | target: icm.data.data_generator.InContextDataset 94 | params: 95 | crop_size: 768 96 | phase: val 97 | norm_type: sd 98 | data: 99 | target: icm.data.image_file.ContextData 100 | params: 101 | ratio: 0 102 | dataset_name: 103 | - ICM57 104 | trainer: 105 | accelerator: ddp 106 | gpus: 1 107 | max_epochs: 1000 108 | auto_select_gpus: false 109 | num_sanity_val_steps: 0 110 | cfg_logger: 111 | target: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 112 | params: 113 | save_dir: logs 114 | default_hp_metric: false 115 | plugins: 116 | target: pytorch_lightning.plugins.DDPPlugin 117 | params: 118 | find_unused_parameters: false 119 | -------------------------------------------------------------------------------- /demo/src/icon/arXiv-Paper.svg: -------------------------------------------------------------------------------- 1 | arXiv: PaperarXivPaper -------------------------------------------------------------------------------- /demo/src/icon/bilibili-demo.svg: -------------------------------------------------------------------------------- 1 | bilibili: demobilibilidemo -------------------------------------------------------------------------------- /demo/src/icon/colab-badge.svg: -------------------------------------------------------------------------------- 1 | Open in ColabOpen in Colab 2 | -------------------------------------------------------------------------------- /demo/src/icon/license-MIT.svg: -------------------------------------------------------------------------------- 1 | license: MITlicenseMIT -------------------------------------------------------------------------------- /demo/src/icon/publication-Paper.svg: -------------------------------------------------------------------------------- 1 | publication: PaperpublicationPaper -------------------------------------------------------------------------------- /demo/src/icon/youtube-demo.svg: -------------------------------------------------------------------------------- 1 | youtube: demoyoutubedemo -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: icm 2 | channels: 3 | - xformers 4 | - pytorch 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - brotlipy=0.7.0=py310h7f8727e_1002 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2023.01.10=h06a4308_0 14 | - certifi=2023.5.7=py310h06a4308_0 15 | - cffi=1.15.1=py310h5eee18b_3 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - cryptography=39.0.1=py310h9ce1e76_0 18 | - cuda-cudart=11.7.99=0 19 | - cuda-cupti=11.7.101=0 20 | - cuda-libraries=11.7.1=0 21 | - cuda-nvrtc=11.7.99=0 22 | - cuda-nvtx=11.7.91=0 23 | - cuda-runtime=11.7.1=0 24 | - ffmpeg=4.3=hf484d3e_0 25 | - freetype=2.12.1=h4a9f257_0 26 | - giflib=5.2.1=h5eee18b_3 27 | - gmp=6.2.1=h295c915_3 28 | - gnutls=3.6.15=he1e5248_0 29 | - idna=3.4=py310h06a4308_0 30 | - intel-openmp=2023.1.0=hdb19cb5_46305 31 | - jpeg=9e=h5eee18b_1 32 | - lame=3.100=h7b6447c_0 33 | - lcms2=2.12=h3be6417_0 34 | - ld_impl_linux-64=2.38=h1181459_1 35 | - lerc=3.0=h295c915_0 36 | - libcublas=11.10.3.66=0 37 | - libcufft=10.7.2.124=h4fbf590_0 38 | - libcufile=1.6.1.9=0 39 | - libcurand=10.3.2.106=0 40 | - libcusolver=11.4.0.1=0 41 | - libcusparse=11.7.4.91=0 42 | - libdeflate=1.17=h5eee18b_0 43 | - libffi=3.4.4=h6a678d5_0 44 | - libgcc-ng=11.2.0=h1234567_1 45 | - libgomp=11.2.0=h1234567_1 46 | - libiconv=1.16=h7f8727e_2 47 | - libidn2=2.3.4=h5eee18b_0 48 | - libnpp=11.7.4.75=0 49 | - libnvjpeg=11.8.0.2=0 50 | - libpng=1.6.39=h5eee18b_0 51 | - libstdcxx-ng=11.2.0=h1234567_1 52 | - libtasn1=4.19.0=h5eee18b_0 53 | - libtiff=4.5.0=h6a678d5_2 54 | - libunistring=0.9.10=h27cfd23_0 55 | - libuuid=1.41.5=h5eee18b_0 56 | - libwebp=1.2.4=h11a3e52_1 57 | - libwebp-base=1.2.4=h5eee18b_1 58 | - lz4-c=1.9.4=h6a678d5_0 59 | - mkl=2023.1.0=h6d00ec8_46342 60 | - mkl-service=2.4.0=py310h5eee18b_1 61 | - mkl_fft=1.3.6=py310h1128e8f_1 62 | - mkl_random=1.2.2=py310h1128e8f_1 63 | - ncurses=6.4=h6a678d5_0 64 | - nettle=3.7.3=hbbd107a_1 65 | - numpy=1.24.3=py310h5f9d8c6_1 66 | - numpy-base=1.24.3=py310hb5e798b_1 67 | - openh264=2.1.1=h4ff587b_0 68 | - openssl=1.1.1t=h7f8727e_0 69 | - pillow=9.4.0=py310h6a678d5_0 70 | - pip=23.0.1=py310h06a4308_0 71 | - pycparser=2.21=pyhd3eb1b0_0 72 | - pyopenssl=23.0.0=py310h06a4308_0 73 | - pysocks=1.7.1=py310h06a4308_0 74 | - python=3.10.9=h7a1cb2a_2 75 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0 76 | - pytorch-cuda=11.7=h778d358_5 77 | - pytorch-mutex=1.0=cuda 78 | - readline=8.2=h5eee18b_0 79 | - requests=2.29.0=py310h06a4308_0 80 | - setuptools=67.8.0=py310h06a4308_0 81 | - sqlite=3.41.2=h5eee18b_0 82 | - tbb=2021.8.0=hdb19cb5_0 83 | - tk=8.6.12=h1ccaba5_0 84 | - torchaudio=0.13.1=py310_cu117 85 | - torchvision=0.14.1=py310_cu117 86 | - typing_extensions=4.5.0=py310h06a4308_0 87 | - tzdata=2023c=h04d1e81_0 88 | - urllib3=1.26.15=py310h06a4308_0 89 | - wheel=0.38.4=py310h06a4308_0 90 | - xformers=0.0.20=py310_cu11.7.1_pyt1.13.1 91 | - xz=5.4.2=h5eee18b_0 92 | - zlib=1.2.13=h5eee18b_0 93 | - zstd=1.5.5=hc292b87_0 94 | - pip: 95 | - accelerate==0.19.0 96 | - aiohttp==3.8.4 97 | - aiosignal==1.3.1 98 | - anyio==3.7.0 99 | - argon2-cffi==21.3.0 100 | - argon2-cffi-bindings==21.2.0 101 | - arrow==1.2.3 102 | - asttokens==2.2.1 103 | - async-lru==2.0.2 104 | - async-timeout==4.0.2 105 | - attrs==23.1.0 106 | - babel==2.12.1 107 | - backcall==0.2.0 108 | - beautifulsoup4==4.12.2 109 | - bleach==6.0.0 110 | - brotli==1.0.9 111 | - cmake==3.26.3 112 | - comm==0.1.3 113 | - contourpy==1.0.7 114 | - cycler==0.11.0 115 | - debugpy==1.6.7 116 | - decorator==5.1.1 117 | - defusedxml==0.7.1 118 | - diffusers==0.18.1 119 | - exceptiongroup==1.1.1 120 | - executing==1.2.0 121 | - fastjsonschema==2.17.1 122 | - filelock==3.12.0 123 | - fonttools==4.39.4 124 | - fqdn==1.5.1 125 | - frozenlist==1.3.3 126 | - fsspec==2023.5.0 127 | - gevent==22.10.2 128 | - geventhttpclient==2.0.2 129 | - greenlet==2.0.2 130 | - grpcio==1.54.2 131 | - huggingface-hub==0.14.1 132 | - importlib-metadata==6.6.0 133 | - ipykernel==6.23.1 134 | - ipympl==0.9.3 135 | - ipython==8.13.2 136 | - ipython-genutils==0.2.0 137 | - ipywidgets==8.0.6 138 | - isoduration==20.11.0 139 | - jedi==0.18.2 140 | - jinja2==3.1.2 141 | - json5==0.9.14 142 | - jsonpointer==2.3 143 | - jsonschema==4.17.3 144 | - jupyter-client==8.2.0 145 | - jupyter-core==5.3.0 146 | - jupyter-events==0.6.3 147 | - jupyter-lsp==2.2.0 148 | - jupyter-server 149 | - jupyter-server-terminals==0.4.4 150 | - jupyterlab==4.0.0 151 | - jupyterlab-pygments==0.2.2 152 | - jupyterlab-server==2.22.1 153 | - jupyterlab-widgets==3.0.7 154 | - kiwisolver==1.4.4 155 | - lit==16.0.5 156 | - markupsafe==2.1.2 157 | - matplotlib==3.7.1 158 | - matplotlib-inline==0.1.6 159 | - mistune==2.0.5 160 | - multidict==6.0.4 161 | - mypy-extensions==1.0.0 162 | - nbclient==0.8.0 163 | - nbconvert==7.4.0 164 | - nbformat==5.8.0 165 | - nest-asyncio==1.5.6 166 | - notebook-shim==0.2.3 167 | - overrides==7.3.1 168 | - packaging==23.1 169 | - pandocfilters==1.5.0 170 | - parso==0.8.3 171 | - pexpect==4.8.0 172 | - pickleshare==0.7.5 173 | - platformdirs==3.5.1 174 | - prometheus-client==0.17.0 175 | - prompt-toolkit==3.0.38 176 | - protobuf==3.20.3 177 | - psutil==5.9.5 178 | - ptyprocess==0.7.0 179 | - pure-eval==0.2.2 180 | - pygments==2.15.1 181 | - pyparsing==3.0.9 182 | - pyrsistent==0.19.3 183 | - python-dateutil==2.8.2 184 | - python-json-logger==2.0.7 185 | - python-rapidjson==1.10 186 | - pyyaml==6.0 187 | - regex==2023.5.5 188 | - rfc3339-validator==0.1.4 189 | - rfc3986-validator==0.1.1 190 | - send2trash==1.8.2 191 | - sh==1.14.3 192 | - six==1.16.0 193 | - sniffio==1.3.0 194 | - soupsieve==2.4.1 195 | - stack-data==0.6.2 196 | - terminado==0.17.1 197 | - tinycss2==1.2.1 198 | - tokenizers==0.13.3 199 | - tomli==2.0.1 200 | - tornado==6.3.2 201 | - tqdm==4.65.0 202 | - traitlets==5.9.0 203 | - transformers==4.29.2 204 | - triton==2.0.0.post1 205 | - tritonclient==2.33.0 206 | - typing-inspect==0.6.0 207 | - uri-template==1.2.0 208 | - wcwidth==0.2.6 209 | - webcolors==1.13 210 | - webencodings==0.5.1 211 | - websocket-client==1.5.2 212 | - widgetsnbextension==4.0.7 213 | - wrapt==1.15.0 214 | - yarl==1.9.2 215 | - zipp==3.15.0 216 | - zope-event==4.6 217 | - zope-interface==6.0 -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse 3 | from omegaconf import OmegaConf 4 | from icm.util import instantiate_from_config 5 | import torch 6 | from pytorch_lightning import Trainer, seed_everything 7 | import os 8 | from tqdm import tqdm 9 | 10 | def load_model_from_config(config, ckpt, verbose=False): 11 | print(f"Loading model from {ckpt}") 12 | pl_sd = torch.load(ckpt, map_location="cpu") 13 | if "global_step" in pl_sd: 14 | print(f"Global Step: {pl_sd['global_step']}") 15 | sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd 16 | model = instantiate_from_config(config) 17 | m, u = model.load_state_dict(sd, strict=False) 18 | if len(m) > 0 and verbose: 19 | print("missing keys:") 20 | print(m) 21 | if len(u) > 0 and verbose: 22 | print("unexpected keys:") 23 | print(u) 24 | 25 | # model.eval() 26 | return model 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument( 33 | "--checkpoint", 34 | type=str, 35 | default="", 36 | ) 37 | parser.add_argument( 38 | "--save_path", 39 | type=str, 40 | default="", 41 | ) 42 | parser.add_argument( 43 | "--config", 44 | type=str, 45 | default="", 46 | ) 47 | parser.add_argument( 48 | "--seed", 49 | type=int, 50 | default=42, 51 | ) 52 | 53 | args = parser.parse_args() 54 | return args 55 | 56 | 57 | if __name__ == '__main__': 58 | args = parse_args() 59 | # if args.checkpoint: 60 | # path = args.checkpoint.split('checkpoints')[0] 61 | # # get the folder of last version folder 62 | # all_folder = os.listdir(path) 63 | # all_folder = [os.path.join(path, folder) 64 | # for folder in all_folder if 'version' in folder] 65 | # all_folder.sort() 66 | # last_version_folder = all_folder[-1] 67 | # # get the hparams.yaml path 68 | # hparams_path = os.path.join(last_version_folder, 'hparams.yaml') 69 | # cfg = OmegaConf.load(hparams_path) 70 | # else: 71 | # raise ValueError('Please input the checkpoint path') 72 | 73 | # set seed 74 | seed_everything(args.seed) 75 | 76 | cfg = OmegaConf.load(args.config) 77 | 78 | """=== Init data ===""" 79 | 80 | cfg_data = cfg.get('data') 81 | 82 | data = instantiate_from_config(cfg_data) 83 | data.setup() 84 | 85 | """=== Init model ===""" 86 | cfg_model = cfg.get('model') 87 | 88 | # model = instantiate_from_config(cfg_model) 89 | model = load_model_from_config(cfg_model, args.checkpoint, verbose=True) 90 | 91 | """=== Start validation ===""" 92 | model.on_train_start() 93 | model.eval() 94 | model.cuda() 95 | # model.train() 96 | # loss_list = [] 97 | # for batch in tqdm(data._val_dataloader()): 98 | # # move tensor in batch to cuda 99 | # for key in batch: 100 | # if isinstance(batch[key], torch.Tensor): 101 | # batch[key] = batch[key].cuda() 102 | # output, loss = model.test_step(batch, None) 103 | # loss_list.append(loss.item()) 104 | # print('Validation loss: ', sum(loss_list)/len(loss_list)) 105 | # print('Validation loss: ', sum(loss_list)/len(loss_list)) 106 | # print('Finish validation') 107 | 108 | 109 | # init trainer for validation 110 | cfg_trainer = cfg.get('trainer') 111 | # set gpu = 1 112 | cfg_trainer.gpus = 1 113 | 114 | 115 | # omegaconf to dict 116 | cfg_trainer = OmegaConf.to_container(cfg_trainer) 117 | cfg_trainer.pop('cfg_callbacks') if 'cfg_callbacks' in cfg_trainer else None 118 | # init logger 119 | cfg_logger = cfg_trainer.pop('cfg_logger') if 'cfg_logger' in cfg_trainer else None 120 | cfg_logger['params']['save_dir'] = 'logs/' 121 | cfg_logger['params']['name'] = 'eval' 122 | cfg_trainer['logger'] = instantiate_from_config(cfg_logger) 123 | 124 | # plugin 125 | cfg_plugin = cfg_trainer.pop('plugins') if 'plugins' in cfg_trainer else None 126 | 127 | # init trainer 128 | trainer_opt = argparse.Namespace(**cfg_trainer) 129 | trainer = Trainer.from_argparse_args(trainer_opt) 130 | # init logger 131 | model.val_save_path = args.save_path 132 | trainer.validate(model, data.val_dataloader()) 133 | -------------------------------------------------------------------------------- /icm/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss_function import LossFunction 2 | -------------------------------------------------------------------------------- /icm/criterion/loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision.ops import focal_loss 6 | 7 | class LossFunction(nn.Module): 8 | ''' 9 | Loss function set 10 | losses=['unknown_l1_loss', 'known_l1_loss', 11 | 'loss_pha_laplacian', 'loss_gradient_penalty', 12 | 'smooth_l1_loss', 'cross_entropy_loss', 'focal_loss'] 13 | ''' 14 | def __init__(self, 15 | *, 16 | losses, 17 | ): 18 | super(LossFunction, self).__init__() 19 | self.losses = losses 20 | 21 | def loss_gradient_penalty(self, sample_map ,preds, targets): 22 | preds = preds['phas'] 23 | targets = targets['phas'] 24 | h,w = sample_map.shape[2:] 25 | if torch.sum(sample_map) == 0: 26 | scale = 0 27 | else: 28 | #sample_map for unknown area 29 | scale = sample_map.shape[0]*262144/torch.sum(sample_map) 30 | 31 | #gradient in x 32 | sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type()) 33 | delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1) 34 | delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1) 35 | 36 | #gradient in y 37 | sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type()) 38 | delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1) 39 | delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1) 40 | 41 | #loss 42 | loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \ 43 | F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \ 44 | 0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \ 45 | 0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale) 46 | 47 | return dict(loss_gradient_penalty=loss) 48 | 49 | def loss_pha_laplacian(self, preds, targets): 50 | assert 'phas' in preds and 'phas' in targets 51 | loss = laplacian_loss(preds['phas'], targets['phas']) 52 | 53 | return dict(loss_pha_laplacian=loss) 54 | 55 | def unknown_l1_loss(self, sample_map, preds, targets): 56 | h,w = sample_map.shape[2:] 57 | if torch.sum(sample_map) == 0: 58 | scale = 0 59 | else: 60 | #sample_map for unknown area 61 | scale = sample_map.shape[0]*262144/torch.sum(sample_map) 62 | 63 | # scale = 1 64 | 65 | loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale 66 | return dict(unknown_l1_loss=loss) 67 | 68 | def known_l1_loss(self, sample_map, preds, targets): 69 | new_sample_map = torch.zeros_like(sample_map) 70 | new_sample_map[sample_map==0] = 1 71 | h,w = sample_map.shape[2:] 72 | if torch.sum(new_sample_map) == 0: 73 | scale = 0 74 | else: 75 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map) 76 | # scale = 1 77 | 78 | loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale 79 | return dict(known_l1_loss=loss) 80 | 81 | def smooth_l1_loss(self, preds, targets): 82 | assert 'phas' in preds and 'phas' in targets 83 | loss = F.smooth_l1_loss(preds['phas'], targets['phas']) 84 | 85 | return dict(smooth_l1_loss=loss) 86 | 87 | def known_smooth_l1_loss(self, sample_map, preds, targets): 88 | new_sample_map = torch.zeros_like(sample_map) 89 | new_sample_map[sample_map==0] = 1 90 | h,w = sample_map.shape[2:] 91 | if torch.sum(new_sample_map) == 0: 92 | scale = 0 93 | else: 94 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map) 95 | # scale = 1 96 | 97 | loss = F.smooth_l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale 98 | return dict(known_l1_loss=loss) 99 | 100 | def cross_entropy_loss(self, preds, targets): 101 | assert 'phas' in preds and 'phas' in targets 102 | loss = F.binary_cross_entropy_with_logits(preds['phas'], targets['phas']) 103 | 104 | return dict(cross_entropy_loss=loss) 105 | 106 | def focal_loss(self, preds, targets): 107 | assert 'phas' in preds and 'phas' in targets 108 | loss = focal_loss.sigmoid_focal_loss(preds['phas'], targets['phas'], reduction='mean') 109 | 110 | return dict(focal_loss=loss) 111 | def forward(self, sample_map, preds, targets): 112 | 113 | preds = {'phas': preds} 114 | targets = {'phas': targets} 115 | losses = dict() 116 | for k in self.losses: 117 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty' or k=='known_smooth_l1_loss': 118 | losses.update(getattr(self, k)(sample_map, preds, targets)) 119 | else: 120 | losses.update(getattr(self, k)(preds, targets)) 121 | return losses 122 | 123 | class LossFunction2(nn.Module): 124 | ''' 125 | Loss function set 126 | losses=['unknown_l1_loss', 'known_l1_loss', 127 | 'loss_pha_laplacian', 'loss_gradient_penalty', 128 | 'smooth_l1_loss', 'cross_entropy_loss', 'focal_loss'] 129 | ''' 130 | def __init__(self, 131 | *, 132 | losses_seg = ['known_smooth_l1_loss'], 133 | losses_matting = ['unknown_l1_loss', 'known_l1_loss','loss_pha_laplacian', 'loss_gradient_penalty',], 134 | ): 135 | super(LossFunction2, self).__init__() 136 | self.losses_seg = losses_seg 137 | self.losses_matting = losses_matting 138 | 139 | def loss_gradient_penalty(self, sample_map ,preds, targets): 140 | preds = preds['phas'] 141 | targets = targets['phas'] 142 | h,w = sample_map.shape[2:] 143 | if torch.sum(sample_map) == 0: 144 | scale = 0 145 | else: 146 | #sample_map for unknown area 147 | scale = sample_map.shape[0]*262144/torch.sum(sample_map) 148 | 149 | #gradient in x 150 | sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type()) 151 | delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1) 152 | delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1) 153 | 154 | #gradient in y 155 | sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type()) 156 | delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1) 157 | delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1) 158 | 159 | #loss 160 | loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \ 161 | F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \ 162 | 0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \ 163 | 0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale) 164 | 165 | return dict(loss_gradient_penalty=loss) 166 | 167 | def loss_pha_laplacian(self, preds, targets): 168 | assert 'phas' in preds and 'phas' in targets 169 | loss = laplacian_loss(preds['phas'], targets['phas']) 170 | 171 | return dict(loss_pha_laplacian=loss) 172 | 173 | def unknown_l1_loss(self, sample_map, preds, targets): 174 | h,w = sample_map.shape[2:] 175 | if torch.sum(sample_map) == 0: 176 | scale = 0 177 | else: 178 | #sample_map for unknown area 179 | scale = sample_map.shape[0]*262144/torch.sum(sample_map) 180 | 181 | # scale = 1 182 | 183 | loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale 184 | return dict(unknown_l1_loss=loss) 185 | 186 | def known_l1_loss(self, sample_map, preds, targets): 187 | new_sample_map = torch.zeros_like(sample_map) 188 | new_sample_map[sample_map==0] = 1 189 | h,w = sample_map.shape[2:] 190 | if torch.sum(new_sample_map) == 0: 191 | scale = 0 192 | else: 193 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map) 194 | # scale = 1 195 | 196 | loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale 197 | return dict(known_l1_loss=loss) 198 | 199 | def smooth_l1_loss(self, preds, targets): 200 | assert 'phas' in preds and 'phas' in targets 201 | loss = F.smooth_l1_loss(preds['phas'], targets['phas']) 202 | 203 | return dict(smooth_l1_loss=loss) 204 | 205 | def known_smooth_l1_loss(self, sample_map, preds, targets): 206 | new_sample_map = torch.zeros_like(sample_map) 207 | new_sample_map[sample_map==0] = 1 208 | h,w = sample_map.shape[2:] 209 | if torch.sum(new_sample_map) == 0: 210 | scale = 0 211 | else: 212 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map) 213 | # scale = 1 214 | 215 | loss = F.smooth_l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale 216 | return dict(known_l1_loss=loss) 217 | 218 | def cross_entropy_loss(self, preds, targets): 219 | assert 'phas' in preds and 'phas' in targets 220 | loss = F.binary_cross_entropy_with_logits(preds['phas'], targets['phas']) 221 | 222 | return dict(cross_entropy_loss=loss) 223 | 224 | def focal_loss(self, preds, targets): 225 | assert 'phas' in preds and 'phas' in targets 226 | loss = focal_loss.sigmoid_focal_loss(preds['phas'], targets['phas'], reduction='mean') 227 | 228 | return dict(focal_loss=loss) 229 | def forward_single_sample(self, sample_map, preds, targets): 230 | # check if targets only have element 0 and 1 231 | if torch.all(targets == 0) or torch.all(targets == 1): 232 | 233 | preds = {'phas': preds} 234 | targets = {'phas': targets} 235 | losses = dict() 236 | for k in self.losses_seg: 237 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty' or k=='known_smooth_l1_loss': 238 | losses.update(getattr(self, k)(sample_map, preds, targets)) 239 | else: 240 | losses.update(getattr(self, k)(preds, targets)) 241 | return losses 242 | else: 243 | preds = {'phas': preds} 244 | targets = {'phas': targets} 245 | losses = dict() 246 | for k in self.losses_matting: 247 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty' or k=='known_smooth_l1_loss': 248 | losses.update(getattr(self, k)(sample_map, preds, targets)) 249 | else: 250 | losses.update(getattr(self, k)(preds, targets)) 251 | return losses 252 | 253 | def forward(self, sample_map, preds, targets): 254 | losses = dict() 255 | for i in range(preds.shape[0]): 256 | losses_ = self.forward_single_sample(sample_map[i].unsqueeze(0), preds[i].unsqueeze(0), targets[i].unsqueeze(0)) 257 | for k in losses_: 258 | if k in losses: 259 | losses[k] += losses_[k] 260 | else: 261 | losses[k] = losses_[k] 262 | return losses 263 | #-----------------Laplacian Loss-------------------------# 264 | def laplacian_loss(pred, true, max_levels=5): 265 | kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) 266 | pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) 267 | true_pyramid = laplacian_pyramid(true, kernel, max_levels) 268 | loss = 0 269 | for level in range(max_levels): 270 | loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) 271 | return loss / max_levels 272 | 273 | def laplacian_pyramid(img, kernel, max_levels): 274 | current = img 275 | pyramid = [] 276 | for _ in range(max_levels): 277 | current = crop_to_even_size(current) 278 | down = downsample(current, kernel) 279 | up = upsample(down, kernel) 280 | diff = current - up 281 | pyramid.append(diff) 282 | current = down 283 | return pyramid 284 | 285 | def gauss_kernel(device='cpu', dtype=torch.float32): 286 | kernel = torch.tensor([[1, 4, 6, 4, 1], 287 | [4, 16, 24, 16, 4], 288 | [6, 24, 36, 24, 6], 289 | [4, 16, 24, 16, 4], 290 | [1, 4, 6, 4, 1]], device=device, dtype=dtype) 291 | kernel /= 256 292 | kernel = kernel[None, None, :, :] 293 | return kernel 294 | 295 | def gauss_convolution(img, kernel): 296 | B, C, H, W = img.shape 297 | img = img.reshape(B * C, 1, H, W) 298 | img = F.pad(img, (2, 2, 2, 2), mode='reflect') 299 | img = F.conv2d(img, kernel) 300 | img = img.reshape(B, C, H, W) 301 | return img 302 | 303 | def downsample(img, kernel): 304 | img = gauss_convolution(img, kernel) 305 | img = img[:, :, ::2, ::2] 306 | return img 307 | 308 | def upsample(img, kernel): 309 | B, C, H, W = img.shape 310 | out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) 311 | out[:, :, ::2, ::2] = img * 4 312 | out = gauss_convolution(out, kernel) 313 | return out 314 | 315 | def crop_to_even_size(img): 316 | H, W = img.shape[2:] 317 | H = H - H % 2 318 | W = W - W % 2 319 | return img[:, :, :H, :W] -------------------------------------------------------------------------------- /icm/criterion/matting_criterion_eval.py: -------------------------------------------------------------------------------- 1 | import scipy.ndimage 2 | import numpy as np 3 | from skimage.measure import label 4 | import scipy.ndimage.morphology 5 | import torch 6 | 7 | def compute_mse_loss(pred, target, trimap): 8 | error_map = (pred - target) / 255.0 9 | loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8) 10 | 11 | return loss 12 | 13 | 14 | def compute_sad_loss(pred, target, trimap): 15 | error_map = np.abs((pred - target) / 255.0) 16 | loss = np.sum(error_map * (trimap == 128)) 17 | 18 | return loss / 1000, np.sum(trimap == 128) / 1000 19 | 20 | def gauss(x, sigma): 21 | y = np.exp(-x ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi)) 22 | return y 23 | 24 | 25 | def dgauss(x, sigma): 26 | y = -x * gauss(x, sigma) / (sigma ** 2) 27 | return y 28 | 29 | 30 | def gaussgradient(im, sigma): 31 | epsilon = 1e-2 32 | halfsize = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))).astype(int) 33 | size = 2 * halfsize + 1 34 | hx = np.zeros((size, size)) 35 | for i in range(0, size): 36 | for j in range(0, size): 37 | u = [i - halfsize, j - halfsize] 38 | hx[i, j] = gauss(u[0], sigma) * dgauss(u[1], sigma) 39 | 40 | hx = hx / np.sqrt(np.sum(np.abs(hx) * np.abs(hx))) 41 | hy = hx.transpose() 42 | 43 | gx = scipy.ndimage.convolve(im, hx, mode='nearest') 44 | gy = scipy.ndimage.convolve(im, hy, mode='nearest') 45 | 46 | return gx, gy 47 | 48 | def compute_gradient_loss(pred, target, trimap): 49 | 50 | pred = pred / 255.0 51 | target = target / 255.0 52 | 53 | pred_x, pred_y = gaussgradient(pred, 1.4) 54 | target_x, target_y = gaussgradient(target, 1.4) 55 | 56 | pred_amp = np.sqrt(pred_x ** 2 + pred_y ** 2) 57 | target_amp = np.sqrt(target_x ** 2 + target_y ** 2) 58 | 59 | error_map = (pred_amp - target_amp) ** 2 60 | loss = np.sum(error_map[trimap == 128]) 61 | 62 | return loss / 1000. 63 | 64 | 65 | def compute_connectivity_error(pred, target, trimap, step): 66 | pred = pred / 255.0 67 | target = target / 255.0 68 | h, w = pred.shape 69 | 70 | thresh_steps = list(np.arange(0, 1 + step, step)) 71 | l_map = np.ones_like(pred, dtype=float) * -1 72 | for i in range(1, len(thresh_steps)): 73 | pred_alpha_thresh = (pred >= thresh_steps[i]).astype(int) 74 | target_alpha_thresh = (target >= thresh_steps[i]).astype(int) 75 | 76 | omega = getLargestCC(pred_alpha_thresh * target_alpha_thresh).astype(int) 77 | flag = ((l_map == -1) & (omega == 0)).astype(int) 78 | l_map[flag == 1] = thresh_steps[i - 1] 79 | 80 | l_map[l_map == -1] = 1 81 | 82 | pred_d = pred - l_map 83 | target_d = target - l_map 84 | pred_phi = 1 - pred_d * (pred_d >= 0.15).astype(int) 85 | target_phi = 1 - target_d * (target_d >= 0.15).astype(int) 86 | loss = np.sum(np.abs(pred_phi - target_phi)[trimap == 128]) 87 | 88 | return loss / 1000. 89 | 90 | def getLargestCC(segmentation): 91 | labels = label(segmentation, connectivity=1) 92 | largestCC = labels == np.argmax(np.bincount(labels.flat)) 93 | return largestCC 94 | 95 | 96 | def compute_mse_loss_torch(pred, target, trimap): 97 | error_map = (pred - target) / 255.0 98 | # rewrite the loss with torch 99 | # loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8) 100 | loss = torch.sum((error_map ** 2) * (trimap == 128).float()) / (torch.sum(trimap == 128).float() + 1e-8) 101 | 102 | return loss 103 | 104 | 105 | def compute_sad_loss_torch(pred, target, trimap): 106 | # rewrite the map with torch 107 | # error_map = np.abs((pred - target) / 255.0) 108 | error_map = torch.abs((pred - target) / 255.0) 109 | # loss = np.sum(error_map * (trimap == 128)) 110 | loss = torch.sum(error_map * (trimap == 128).float()) 111 | 112 | return loss / 1000 113 | -------------------------------------------------------------------------------- /icm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiny-smart/in-context-matting/a58bcf4b948babcf5f1c2e1c41e2fa040bc53d1e/icm/data/__init__.py -------------------------------------------------------------------------------- /icm/data/data_generator.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import math 4 | import numbers 5 | import random 6 | import logging 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torch.nn import functional as F 12 | from torchvision import transforms 13 | 14 | from icm.util import instantiate_from_config 15 | from icm.data.image_file import get_dir_ext 16 | # one-hot or class, choice: [3, 1] 17 | TRIMAP_CHANNEL = 1 18 | 19 | RANDOM_INTERP = True 20 | 21 | CUTMASK_PROB = 0 22 | 23 | interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, 24 | cv2.INTER_CUBIC, cv2.INTER_LANCZOS4] 25 | 26 | 27 | def maybe_random_interp(cv2_interp): 28 | if RANDOM_INTERP: 29 | return np.random.choice(interp_list) 30 | else: 31 | return cv2_interp 32 | 33 | 34 | class ToTensor(object): 35 | """ 36 | Convert ndarrays in sample to Tensors with normalization. 37 | """ 38 | 39 | def __init__(self, phase="test", norm_type='imagenet'): 40 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 41 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 42 | self.phase = phase 43 | self.norm_type = norm_type 44 | 45 | def __call__(self, sample): 46 | # convert GBR images to RGB 47 | image, alpha, trimap, mask = sample['image'][:, :, ::- 48 | 1], sample['alpha'], sample['trimap'], sample['mask'] 49 | 50 | alpha[alpha < 0] = 0 51 | alpha[alpha > 1] = 1 52 | 53 | # swap color axis because 54 | # numpy image: H x W x C 55 | # torch image: C X H X W 56 | image = image.transpose((2, 0, 1)).astype(np.float32) 57 | alpha = np.expand_dims(alpha.astype(np.float32), axis=0) 58 | trimap[trimap < 85] = 0 59 | trimap[trimap >= 170] = 1 60 | trimap[trimap >= 85] = 0.5 61 | 62 | mask = np.expand_dims(mask.astype(np.float32), axis=0) 63 | 64 | if self.phase == "train": 65 | # convert GBR images to RGB 66 | fg = sample['fg'][:, :, ::- 67 | 1].transpose((2, 0, 1)).astype(np.float32) / 255. 68 | sample['fg'] = torch.from_numpy(fg).sub_(self.mean).div_(self.std) 69 | bg = sample['bg'][:, :, ::- 70 | 1].transpose((2, 0, 1)).astype(np.float32) / 255. 71 | sample['bg'] = torch.from_numpy(bg).sub_(self.mean).div_(self.std) 72 | # del sample['image_name'] 73 | 74 | sample['image'], sample['alpha'], sample['trimap'] = \ 75 | torch.from_numpy(image), torch.from_numpy( 76 | alpha), torch.from_numpy(trimap) 77 | 78 | if self.norm_type == 'imagenet': 79 | # normalize image 80 | sample['image'] /= 255. 81 | 82 | sample['image'] = sample['image'].sub_(self.mean).div_(self.std) 83 | elif self.norm_type == 'sd': 84 | sample['image'] = sample['image'].to(dtype=torch.float32) / 127.5 - 1.0 85 | else: 86 | raise NotImplementedError( 87 | "norm_type {} is not implemented".format(self.norm_type)) 88 | 89 | if TRIMAP_CHANNEL == 3: 90 | sample['trimap'] = F.one_hot( 91 | sample['trimap'], num_classes=3).permute(2, 0, 1).float() 92 | elif TRIMAP_CHANNEL == 1: 93 | sample['trimap'] = sample['trimap'][None, ...].float() 94 | else: 95 | raise NotImplementedError("TRIMAP_CHANNEL can only be 3 or 1") 96 | 97 | sample['mask'] = torch.from_numpy(mask).float() 98 | 99 | return sample 100 | 101 | 102 | class RandomAffine(object): 103 | """ 104 | Random affine translation 105 | """ 106 | 107 | def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0): 108 | if isinstance(degrees, numbers.Number): 109 | if degrees < 0: 110 | raise ValueError( 111 | "If degrees is a single number, it must be positive.") 112 | self.degrees = (-degrees, degrees) 113 | else: 114 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 115 | "degrees should be a list or tuple and it must be of length 2." 116 | self.degrees = degrees 117 | 118 | if translate is not None: 119 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 120 | "translate should be a list or tuple and it must be of length 2." 121 | for t in translate: 122 | if not (0.0 <= t <= 1.0): 123 | raise ValueError( 124 | "translation values should be between 0 and 1") 125 | self.translate = translate 126 | 127 | if scale is not None: 128 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 129 | "scale should be a list or tuple and it must be of length 2." 130 | for s in scale: 131 | if s <= 0: 132 | raise ValueError("scale values should be positive") 133 | self.scale = scale 134 | 135 | if shear is not None: 136 | if isinstance(shear, numbers.Number): 137 | if shear < 0: 138 | raise ValueError( 139 | "If shear is a single number, it must be positive.") 140 | self.shear = (-shear, shear) 141 | else: 142 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 143 | "shear should be a list or tuple and it must be of length 2." 144 | self.shear = shear 145 | else: 146 | self.shear = shear 147 | 148 | self.resample = resample 149 | self.fillcolor = fillcolor 150 | self.flip = flip 151 | 152 | @staticmethod 153 | def get_params(degrees, translate, scale_ranges, shears, flip, img_size): 154 | """Get parameters for affine transformation 155 | 156 | Returns: 157 | sequence: params to be passed to the affine transformation 158 | """ 159 | angle = random.uniform(degrees[0], degrees[1]) 160 | if translate is not None: 161 | max_dx = translate[0] * img_size[0] 162 | max_dy = translate[1] * img_size[1] 163 | translations = (np.round(random.uniform(-max_dx, max_dx)), 164 | np.round(random.uniform(-max_dy, max_dy))) 165 | else: 166 | translations = (0, 0) 167 | 168 | if scale_ranges is not None: 169 | scale = (random.uniform(scale_ranges[0], scale_ranges[1]), 170 | random.uniform(scale_ranges[0], scale_ranges[1])) 171 | else: 172 | scale = (1.0, 1.0) 173 | 174 | if shears is not None: 175 | shear = random.uniform(shears[0], shears[1]) 176 | else: 177 | shear = 0.0 178 | 179 | if flip is not None: 180 | flip = (np.random.rand(2) < flip).astype(np.int) * 2 - 1 181 | 182 | return angle, translations, scale, shear, flip 183 | 184 | def __call__(self, sample): 185 | fg, alpha = sample['fg'], sample['alpha'] 186 | rows, cols, ch = fg.shape 187 | if np.maximum(rows, cols) < 1024: 188 | params = self.get_params( 189 | (0, 0), self.translate, self.scale, self.shear, self.flip, fg.size) 190 | else: 191 | params = self.get_params( 192 | self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size) 193 | 194 | center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5) 195 | M = self._get_inverse_affine_matrix(center, *params) 196 | M = np.array(M).reshape((2, 3)) 197 | 198 | fg = cv2.warpAffine(fg, M, (cols, rows), 199 | flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP) 200 | alpha = cv2.warpAffine(alpha, M, (cols, rows), 201 | flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP) 202 | 203 | sample['fg'], sample['alpha'] = fg, alpha 204 | 205 | return sample 206 | 207 | @ staticmethod 208 | def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip): 209 | # Helper method to compute inverse matrix for affine transformation 210 | 211 | # As it is explained in PIL.Image.rotate 212 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 213 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] 214 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] 215 | # RSS is rotation with scale and shear matrix 216 | # It is different from the original function in torchvision 217 | # The order are changed to flip -> scale -> rotation -> shear 218 | # x and y have different scale factors 219 | # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y 0] 220 | # [ sin(a)*scale_x*f cos(a)*scale_y 0] 221 | # [ 0 0 1] 222 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 223 | 224 | angle = math.radians(angle) 225 | shear = math.radians(shear) 226 | scale_x = 1.0 / scale[0] * flip[0] 227 | scale_y = 1.0 / scale[1] * flip[1] 228 | 229 | # Inverted rotation matrix with scale and shear 230 | d = math.cos(angle + shear) * math.cos(angle) + \ 231 | math.sin(angle + shear) * math.sin(angle) 232 | matrix = [ 233 | math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0, 234 | -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0 235 | ] 236 | matrix = [m / d for m in matrix] 237 | 238 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 239 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + \ 240 | matrix[1] * (-center[1] - translate[1]) 241 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + \ 242 | matrix[4] * (-center[1] - translate[1]) 243 | 244 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 245 | matrix[2] += center[0] 246 | matrix[5] += center[1] 247 | 248 | return matrix 249 | 250 | 251 | class RandomJitter(object): 252 | """ 253 | Random change the hue of the image 254 | """ 255 | 256 | def __call__(self, sample): 257 | sample_ori = sample.copy() 258 | fg, alpha = sample['fg'], sample['alpha'] 259 | # if alpha is all 0 skip 260 | if np.all(alpha == 0): 261 | return sample_ori 262 | # convert to HSV space, convert to float32 image to keep precision during space conversion. 263 | fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV) 264 | # Hue noise 265 | hue_jitter = np.random.randint(-40, 40) 266 | fg[:, :, 0] = np.remainder( 267 | fg[:, :, 0].astype(np.float32) + hue_jitter, 360) 268 | # Saturation noise 269 | sat_bar = fg[:, :, 1][alpha > 0].mean() 270 | if np.isnan(sat_bar): 271 | return sample_ori 272 | sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10 273 | sat = fg[:, :, 1] 274 | sat = np.abs(sat + sat_jitter) 275 | sat[sat > 1] = 2 - sat[sat > 1] 276 | fg[:, :, 1] = sat 277 | # Value noise 278 | val_bar = fg[:, :, 2][alpha > 0].mean() 279 | if np.isnan(val_bar): 280 | return sample_ori 281 | val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10 282 | val = fg[:, :, 2] 283 | val = np.abs(val + val_jitter) 284 | val[val > 1] = 2 - val[val > 1] 285 | fg[:, :, 2] = val 286 | # convert back to BGR space 287 | fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR) 288 | sample['fg'] = fg*255 289 | 290 | return sample 291 | 292 | 293 | class RandomHorizontalFlip(object): 294 | """ 295 | Random flip image and label horizontally 296 | """ 297 | 298 | def __init__(self, prob=0.5): 299 | self.prob = prob 300 | 301 | def __call__(self, sample): 302 | fg, alpha = sample['fg'], sample['alpha'] 303 | if np.random.uniform(0, 1) < self.prob: 304 | fg = cv2.flip(fg, 1) 305 | alpha = cv2.flip(alpha, 1) 306 | sample['fg'], sample['alpha'] = fg, alpha 307 | 308 | return sample 309 | 310 | 311 | class RandomCrop(object): 312 | """ 313 | Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size' 314 | 315 | :param output_size (tuple or int): Desired output size. If int, square crop 316 | is made. 317 | """ 318 | 319 | def __init__(self, output_size): 320 | assert isinstance(output_size, (int, tuple)) 321 | if isinstance(output_size, int): 322 | self.output_size = (output_size, output_size) 323 | else: 324 | assert len(output_size) == 2 325 | self.output_size = output_size 326 | self.margin = output_size[0] // 2 327 | self.logger = logging.getLogger("Logger") 328 | 329 | def __call__(self, sample): 330 | fg, alpha, trimap, mask, name = sample['fg'], sample[ 331 | 'alpha'], sample['trimap'], sample['mask'], sample['image_name'] 332 | bg = sample['bg'] 333 | h, w = trimap.shape 334 | bg = cv2.resize( 335 | bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC)) 336 | if w < self.output_size[0]+1 or h < self.output_size[1]+1: 337 | ratio = 1.1*self.output_size[0] / \ 338 | h if h < w else 1.1*self.output_size[1]/w 339 | # self.logger.warning("Size of {} is {}.".format(name, (h, w))) 340 | while h < self.output_size[0]+1 or w < self.output_size[1]+1: 341 | fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), 342 | interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 343 | alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)), 344 | interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 345 | trimap = cv2.resize( 346 | trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST) 347 | bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), 348 | interpolation=maybe_random_interp(cv2.INTER_CUBIC)) 349 | mask = cv2.resize( 350 | mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST) 351 | h, w = trimap.shape 352 | small_trimap = cv2.resize( 353 | trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST) 354 | unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4, 355 | self.margin//4:(w-self.margin)//4] == 128))) 356 | unknown_num = len(unknown_list) 357 | if len(unknown_list) < 10: 358 | left_top = (np.random.randint( 359 | 0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1)) 360 | else: 361 | idx = np.random.randint(unknown_num) 362 | left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4) 363 | 364 | fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], 365 | left_top[1]:left_top[1]+self.output_size[1], :] 366 | alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], 367 | left_top[1]:left_top[1]+self.output_size[1]] 368 | bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], 369 | left_top[1]:left_top[1]+self.output_size[1], :] 370 | trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], 371 | left_top[1]:left_top[1]+self.output_size[1]] 372 | mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0], 373 | left_top[1]:left_top[1]+self.output_size[1]] 374 | 375 | if len(np.where(trimap == 128)[0]) == 0: 376 | self.logger.error("{} does not have enough unknown area for crop. Resized to target size." 377 | "left_top: {}".format(name, left_top)) 378 | fg_crop = cv2.resize( 379 | fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 380 | alpha_crop = cv2.resize( 381 | alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 382 | trimap_crop = cv2.resize( 383 | trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST) 384 | bg_crop = cv2.resize( 385 | bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC)) 386 | mask_crop = cv2.resize( 387 | mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST) 388 | 389 | sample.update({'fg': fg_crop, 'alpha': alpha_crop, 390 | 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop}) 391 | return sample 392 | 393 | 394 | class CropResize(object): 395 | # crop the image to square, and resize to target size 396 | def __init__(self, output_size): 397 | assert isinstance(output_size, (int, tuple)) 398 | if isinstance(output_size, int): 399 | self.output_size = (output_size, output_size) 400 | else: 401 | assert len(output_size) == 2 402 | self.output_size = output_size 403 | 404 | def __call__(self, sample): 405 | img, alpha, trimap, mask = sample['image'], sample['alpha'], sample['trimap'], sample['mask'] 406 | # crop the image to square, and resize to target size 407 | 408 | h, w = img.shape[:2] 409 | if h == w: 410 | img_crop = cv2.resize( 411 | img, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 412 | alpha_crop = cv2.resize( 413 | alpha, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 414 | trimap_crop = cv2.resize( 415 | trimap, self.output_size, interpolation=cv2.INTER_NEAREST) 416 | mask_crop = cv2.resize( 417 | mask, self.output_size, interpolation=cv2.INTER_NEAREST) 418 | elif h > w: 419 | margin = (h-w)//2 420 | img = img[margin:margin+w, :] 421 | alpha = alpha[margin:margin+w, :] 422 | trimap = trimap[margin:margin+w, :] 423 | mask = mask[margin:margin+w, :] 424 | img_crop = cv2.resize( 425 | img, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 426 | alpha_crop = cv2.resize( 427 | alpha, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 428 | trimap_crop = cv2.resize( 429 | trimap, self.output_size, interpolation=cv2.INTER_NEAREST) 430 | mask_crop = cv2.resize( 431 | mask, self.output_size, interpolation=cv2.INTER_NEAREST) 432 | else: 433 | margin = (w-h)//2 434 | img = img[:, margin:margin+h] 435 | alpha = alpha[:, margin:margin+h] 436 | trimap = trimap[:, margin:margin+h] 437 | mask = mask[:, margin:margin+h] 438 | img_crop = cv2.resize( 439 | img, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 440 | alpha_crop = cv2.resize( 441 | alpha, self.output_size, interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 442 | trimap_crop = cv2.resize( 443 | trimap, self.output_size, interpolation=cv2.INTER_NEAREST) 444 | mask_crop = cv2.resize( 445 | mask, self.output_size, interpolation=cv2.INTER_NEAREST) 446 | sample.update({'image': img_crop, 'alpha': alpha_crop, 447 | 'trimap': trimap_crop, 'mask': mask_crop}) 448 | return sample 449 | 450 | 451 | class OriginScale(object): 452 | def __call__(self, sample): 453 | h, w = sample["alpha_shape"] 454 | 455 | if h % 32 == 0 and w % 32 == 0: 456 | return sample 457 | 458 | target_h = 32 * ((h - 1) // 32 + 1) 459 | target_w = 32 * ((w - 1) // 32 + 1) 460 | pad_h = target_h - h 461 | pad_w = target_w - w 462 | 463 | padded_image = np.pad( 464 | sample['image'], ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") 465 | padded_trimap = np.pad( 466 | sample['trimap'], ((0, pad_h), (0, pad_w)), mode="reflect") 467 | padded_mask = np.pad( 468 | sample['mask'], ((0, pad_h), (0, pad_w)), mode="reflect") 469 | 470 | sample['image'] = padded_image 471 | sample['trimap'] = padded_trimap 472 | sample['mask'] = padded_mask 473 | 474 | return sample 475 | 476 | 477 | class GenMask(object): 478 | def __init__(self): 479 | self.erosion_kernels = [None] + [cv2.getStructuringElement( 480 | cv2.MORPH_ELLIPSE, (size, size)) for size in range(1, 30)] 481 | 482 | def __call__(self, sample): 483 | alpha_ori = sample['alpha'] 484 | h, w = alpha_ori.shape 485 | 486 | max_kernel_size = 30 487 | alpha = cv2.resize(alpha_ori, (640, 640), 488 | interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 489 | 490 | # generate trimap 491 | fg_mask = (alpha + 1e-5).astype(np.int).astype(np.uint8) 492 | bg_mask = (1 - alpha + 1e-5).astype(np.int).astype(np.uint8) 493 | fg_mask = cv2.erode( 494 | fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 495 | bg_mask = cv2.erode( 496 | bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 497 | 498 | fg_width = np.random.randint(1, 30) 499 | bg_width = np.random.randint(1, 30) 500 | fg_mask = (alpha + 1e-5).astype(np.int).astype(np.uint8) 501 | bg_mask = (1 - alpha + 1e-5).astype(np.int).astype(np.uint8) 502 | fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width]) 503 | bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width]) 504 | 505 | trimap = np.ones_like(alpha) * 128 506 | trimap[fg_mask == 1] = 255 507 | trimap[bg_mask == 1] = 0 508 | 509 | trimap = cv2.resize(trimap, (w, h), interpolation=cv2.INTER_NEAREST) 510 | sample['trimap'] = trimap 511 | 512 | # generate mask 513 | low = 0.01 514 | high = 1.0 515 | thres = random.random() * (high - low) + low 516 | seg_mask = (alpha >= thres).astype(np.int).astype(np.uint8) 517 | random_num = random.randint(0, 3) 518 | if random_num == 0: 519 | seg_mask = cv2.erode( 520 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 521 | elif random_num == 1: 522 | seg_mask = cv2.dilate( 523 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 524 | elif random_num == 2: 525 | seg_mask = cv2.erode( 526 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 527 | seg_mask = cv2.dilate( 528 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 529 | elif random_num == 3: 530 | seg_mask = cv2.dilate( 531 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 532 | seg_mask = cv2.erode( 533 | seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)]) 534 | 535 | seg_mask = cv2.resize( 536 | seg_mask, (w, h), interpolation=cv2.INTER_NEAREST) 537 | sample['mask'] = seg_mask 538 | 539 | return sample 540 | 541 | 542 | class Composite(object): 543 | def __call__(self, sample): 544 | fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha'] 545 | alpha[alpha < 0] = 0 546 | alpha[alpha > 1] = 1 547 | fg[fg < 0] = 0 548 | fg[fg > 255] = 255 549 | bg[bg < 0] = 0 550 | bg[bg > 255] = 255 551 | 552 | image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None]) 553 | sample['image'] = image 554 | return sample 555 | 556 | 557 | class CutMask(object): 558 | def __init__(self, perturb_prob=0): 559 | self.perturb_prob = perturb_prob 560 | 561 | def __call__(self, sample): 562 | if np.random.rand() < self.perturb_prob: 563 | return sample 564 | 565 | mask = sample['mask'] # H x W, trimap 0--255, segmask 0--1, alpha 0--1 566 | h, w = mask.shape 567 | perturb_size_h, perturb_size_w = random.randint( 568 | h // 4, h // 2), random.randint(w // 4, w // 2) 569 | x = random.randint(0, h - perturb_size_h) 570 | y = random.randint(0, w - perturb_size_w) 571 | x1 = random.randint(0, h - perturb_size_h) 572 | y1 = random.randint(0, w - perturb_size_w) 573 | 574 | mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1 + 575 | perturb_size_h, y1:y1+perturb_size_w].copy() 576 | 577 | sample['mask'] = mask 578 | return sample 579 | 580 | 581 | class DataGenerator(Dataset): 582 | def __init__(self, data, crop_size=512, phase="train"): 583 | self.phase = phase 584 | self.crop_size = crop_size 585 | self.alpha = data.alpha 586 | 587 | if self.phase == "train": 588 | self.fg = data.fg 589 | self.bg = data.bg 590 | self.merged = [] 591 | self.trimap = [] 592 | 593 | else: 594 | self.fg = [] 595 | self.bg = [] 596 | self.merged = data.merged 597 | self.trimap = data.trimap 598 | 599 | train_trans = [ 600 | RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5), 601 | GenMask(), 602 | CutMask(perturb_prob=CUTMASK_PROB), 603 | RandomCrop((self.crop_size, self.crop_size)), 604 | RandomJitter(), 605 | Composite(), 606 | ToTensor(phase="train")] 607 | 608 | test_trans = [OriginScale(), ToTensor()] 609 | 610 | self.transform = { 611 | 'train': 612 | transforms.Compose(train_trans), 613 | 'val': 614 | transforms.Compose([ 615 | OriginScale(), 616 | ToTensor() 617 | ]), 618 | 'test': 619 | transforms.Compose(test_trans) 620 | }[phase] 621 | 622 | self.fg_num = len(self.fg) 623 | 624 | def __getitem__(self, idx): 625 | if self.phase == "train": 626 | fg = cv2.imread(self.fg[idx % self.fg_num]) 627 | alpha = cv2.imread( 628 | self.alpha[idx % self.fg_num], 0).astype(np.float32)/255 629 | bg = cv2.imread(self.bg[idx], 1) 630 | 631 | fg, alpha = self._composite_fg(fg, alpha, idx) 632 | 633 | image_name = os.path.split(self.fg[idx % self.fg_num])[-1] 634 | sample = {'fg': fg, 'alpha': alpha, 635 | 'bg': bg, 'image_name': image_name} 636 | 637 | else: 638 | image = cv2.imread(self.merged[idx]) 639 | alpha = cv2.imread(self.alpha[idx], 0)/255. 640 | trimap = cv2.imread(self.trimap[idx], 0) 641 | mask = (trimap >= 170).astype(np.float32) 642 | image_name = os.path.split(self.merged[idx])[-1] 643 | 644 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 645 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape} 646 | 647 | sample = self.transform(sample) 648 | 649 | return sample 650 | 651 | def _composite_fg(self, fg, alpha, idx): 652 | 653 | if np.random.rand() < 0.5: 654 | idx2 = np.random.randint(self.fg_num) + idx 655 | fg2 = cv2.imread(self.fg[idx2 % self.fg_num]) 656 | alpha2 = cv2.imread( 657 | self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255. 658 | h, w = alpha.shape 659 | fg2 = cv2.resize( 660 | fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 661 | alpha2 = cv2.resize( 662 | alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 663 | 664 | alpha_tmp = 1 - (1 - alpha) * (1 - alpha2) 665 | if np.any(alpha_tmp < 1): 666 | fg = fg.astype( 667 | np.float32) * alpha[:, :, None] + fg2.astype(np.float32) * (1 - alpha[:, :, None]) 668 | # The overlap of two 50% transparency should be 25% 669 | alpha = alpha_tmp 670 | fg = fg.astype(np.uint8) 671 | 672 | if np.random.rand() < 0.25: 673 | fg = cv2.resize( 674 | fg, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 675 | alpha = cv2.resize( 676 | alpha, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST)) 677 | 678 | return fg, alpha 679 | 680 | def __len__(self): 681 | if self.phase == "train": 682 | return len(self.bg) 683 | else: 684 | return len(self.alpha) 685 | 686 | 687 | class MultiDataGeneratorDoubleSet(Dataset): 688 | # divide a dataset into train set and validation set 689 | def __init__(self, data, crop_size=1024, phase="train",norm_type='imagenet'): 690 | self.phase = phase 691 | self.crop_size = crop_size 692 | data = instantiate_from_config(data) 693 | 694 | if self.phase == "train": 695 | self.alpha = data.alpha_train 696 | self.merged = data.merged_train 697 | self.trimap = data.trimap_train 698 | 699 | elif self.phase == "val": 700 | self.alpha = data.alpha_val 701 | self.merged = data.merged_val 702 | self.trimap = data.trimap_val 703 | 704 | train_trans = [ 705 | # RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5), 706 | 707 | # CutMask(perturb_prob=CUTMASK_PROB), 708 | CropResize((self.crop_size, self.crop_size)), 709 | # RandomJitter(), 710 | ToTensor(phase="val",norm_type=norm_type)] 711 | 712 | # val_trans = [ OriginScale(), ToTensor() ] 713 | val_trans = [CropResize( 714 | (self.crop_size, self.crop_size)), ToTensor(phase="val",norm_type=norm_type)] 715 | 716 | self.transform = { 717 | 'train': 718 | transforms.Compose(train_trans), 719 | 720 | 'val': 721 | transforms.Compose(val_trans) 722 | }[phase] 723 | 724 | self.alpha_num = len(self.alpha) 725 | 726 | def __getitem__(self, idx): 727 | 728 | image = cv2.imread(self.merged[idx]) 729 | alpha = cv2.imread(self.alpha[idx], 0)/255. 730 | trimap = cv2.imread(self.trimap[idx], 0).astype(np.float32) 731 | mask = (trimap >= 170).astype(np.float32) 732 | image_name = os.path.split(self.merged[idx])[-1] 733 | 734 | dataset_name = self.get_dataset_name(image_name) 735 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 736 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name} 737 | 738 | sample = self.transform(sample) 739 | 740 | return sample 741 | 742 | def __len__(self): 743 | return len(self.alpha) 744 | 745 | def get_dataset_name(self, image_name): 746 | image_name = image_name.split('.')[0] 747 | if image_name.startswith('o_'): 748 | return 'AIM' 749 | elif image_name.endswith('_o') or image_name.endswith('_5k'): 750 | return 'PPM' 751 | elif image_name.startswith('m_'): 752 | return 'AM2k' 753 | elif image_name.endswith('_input'): 754 | return 'RWP636' 755 | elif image_name.startswith('p_'): 756 | return 'P3M' 757 | 758 | else: 759 | # raise ValueError('image_name {} not recognized'.format(image_name)) 760 | return 'RM1k' 761 | 762 | class ContextDataset(Dataset): 763 | # divide a dataset into train set and validation set 764 | def __init__(self, data, crop_size=1024, phase="train",norm_type='imagenet'): 765 | self.phase = phase 766 | self.crop_size = crop_size 767 | data = instantiate_from_config(data) 768 | 769 | if self.phase == "train": 770 | self.dataset = data.dataset_train 771 | self.image_class_dict = data.image_class_dict_train 772 | 773 | elif self.phase == "val": 774 | self.dataset = data.dataset_val 775 | self.image_class_dict = data.image_class_dict_val 776 | 777 | # dict to list 778 | for key, value in self.image_class_dict.items(): 779 | self.image_class_dict[key] = list(value.items()) 780 | self.dataset = list(self.dataset.items()) 781 | 782 | train_trans = [ 783 | # RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5), 784 | 785 | # CutMask(perturb_prob=CUTMASK_PROB), 786 | CropResize((self.crop_size, self.crop_size)), 787 | # RandomJitter(), 788 | ToTensor(phase="val",norm_type=norm_type)] 789 | 790 | # val_trans = [ OriginScale(), ToTensor() ] 791 | val_trans = [CropResize( 792 | (self.crop_size, self.crop_size)), ToTensor(phase="val",norm_type=norm_type)] 793 | 794 | self.transform = { 795 | 'train': 796 | transforms.Compose(train_trans), 797 | 798 | 'val': 799 | transforms.Compose(val_trans) 800 | }[phase] 801 | 802 | def __getitem__(self, idx): 803 | cv2.setNumThreads(0) 804 | 805 | image_name, image_info = self.dataset[idx] 806 | 807 | # get image sample 808 | dataset_name = image_info['dataset_name'] 809 | image_sample = self.get_sample(image_name, dataset_name) 810 | 811 | # get context image 812 | class_name = str( 813 | image_info['class'])+'-'+str(image_info['sub_class'])+'-'+str(image_info['HalfOrFull']) 814 | (context_image_name, context_dataset_name) = self.image_class_dict[class_name][np.random.randint( 815 | len(self.image_class_dict[class_name]))] 816 | context_image_sample = self.get_sample( 817 | context_image_name, context_dataset_name) 818 | 819 | # merge image and context 820 | image_sample['context_image'] = context_image_sample['image'] 821 | image_sample['context_guidance'] = context_image_sample['alpha'] 822 | image_sample['context_image_name'] = context_image_sample['image_name'] 823 | 824 | return image_sample 825 | 826 | def __len__(self): 827 | return len(self.dataset) 828 | 829 | def get_sample(self, image_name, dataset_name): 830 | cv2.setNumThreads(0) 831 | image_dir, label_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext = get_dir_ext( 832 | dataset_name) 833 | image_path = os.path.join(image_dir, image_name + merged_ext) if 'open-images' not in dataset_name else os.path.join( 834 | image_dir, image_name.split('_')[0] + merged_ext) 835 | label_path = os.path.join(label_dir, image_name + alpha_ext) 836 | trimap_path = os.path.join(trimap_dir, image_name + trimap_ext) 837 | 838 | image = cv2.imread(image_path) 839 | alpha = cv2.imread(label_path, 0)/255. 840 | trimap = cv2.imread(trimap_path, 0).astype(np.float32) 841 | mask = (trimap >= 170).astype(np.float32) 842 | image_name = os.path.split(image_path)[-1] 843 | 844 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 845 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name} 846 | 847 | sample = self.transform(sample) 848 | return sample 849 | def get_sample_example(self, image_dir, mask_dir, img_list, mask_list, index): 850 | image = cv2.imread(os.path.join(image_dir, img_list[index])) 851 | alpha = cv2.imread(os.path.join(mask_dir, mask_list[index]), 0)/255. 852 | 853 | # resize alpha to image size 854 | alpha = cv2.resize(alpha, (image.shape[1], image.shape[0])) 855 | 856 | # unused 857 | trimap = cv2.imread(os.path.join(mask_dir, mask_list[index]))/255. 858 | mask = (trimap >= 170).astype(np.float32) 859 | image_name = '' 860 | dataset_name = '' 861 | 862 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 863 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name} 864 | 865 | sample = self.transform(sample) 866 | return sample['image'], sample['alpha'] 867 | 868 | class InContextDataset(Dataset): 869 | # divide a dataset into train set and validation set 870 | def __init__(self, data, crop_size=1024, phase="train",norm_type='imagenet'): 871 | self.phase = phase 872 | self.crop_size = crop_size 873 | data = instantiate_from_config(data) 874 | 875 | if self.phase == "train": 876 | self.dataset = data.dataset_train 877 | self.image_class_dict = data.image_class_dict_train 878 | 879 | elif self.phase == "val": 880 | self.dataset = data.dataset_val 881 | self.image_class_dict = data.image_class_dict_val 882 | 883 | # dict to list 884 | for key, value in self.image_class_dict.items(): 885 | self.image_class_dict[key] = list(value.items()) 886 | self.dataset = list(self.dataset.items()) 887 | 888 | train_trans = [ 889 | # RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5), 890 | 891 | # CutMask(perturb_prob=CUTMASK_PROB), 892 | CropResize((self.crop_size, self.crop_size)), 893 | # RandomJitter(), 894 | ToTensor(phase="val",norm_type=norm_type)] 895 | 896 | # val_trans = [ OriginScale(), ToTensor() ] 897 | val_trans = [CropResize( 898 | (self.crop_size, self.crop_size)), ToTensor(phase="val",norm_type=norm_type)] 899 | 900 | self.transform = { 901 | 'train': 902 | transforms.Compose(train_trans), 903 | 904 | 'val': 905 | transforms.Compose(val_trans) 906 | }[phase] 907 | 908 | def __getitem__(self, idx): 909 | cv2.setNumThreads(0) 910 | 911 | image_name, image_info = self.dataset[idx] 912 | 913 | # get image sample 914 | dataset_name = image_info['dataset_name'] 915 | image_sample = self.get_sample(image_name, dataset_name) 916 | 917 | # get context image 918 | class_name = str( 919 | image_info['class'])+'-'+str(image_info['sub_class'])+'-'+str(image_info['HalfOrFull']) 920 | 921 | context_set = self.image_class_dict[class_name] 922 | if len(context_set) > 2: 923 | # delet image_name from context_set (dict) 924 | context_set = [x for x in context_set if x[0] != image_name] 925 | 926 | (reference_image_name, context_dataset_name) = context_set[np.random.randint( 927 | len(context_set))] 928 | reference_image_sample = self.get_sample( 929 | reference_image_name, context_dataset_name) 930 | 931 | # merge image and context 932 | image_sample['reference_image'] = reference_image_sample['source_image'] 933 | image_sample['guidance_on_reference_image'] = reference_image_sample['alpha'] 934 | image_sample['reference_image_name'] = reference_image_sample['image_name'] 935 | 936 | return image_sample 937 | 938 | def __len__(self): 939 | return len(self.dataset) 940 | 941 | def get_sample(self, image_name, dataset_name): 942 | cv2.setNumThreads(0) 943 | image_dir, label_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext = get_dir_ext( 944 | dataset_name) 945 | image_path = os.path.join(image_dir, image_name + merged_ext) if 'open-images' not in dataset_name else os.path.join( 946 | image_dir, image_name.split('_')[0] + merged_ext) 947 | label_path = os.path.join(label_dir, image_name + alpha_ext) 948 | trimap_path = os.path.join(trimap_dir, image_name + trimap_ext) 949 | 950 | image = cv2.imread(image_path) 951 | alpha = cv2.imread(label_path, 0)/255. 952 | trimap = cv2.imread(trimap_path, 0).astype(np.float32) 953 | mask = (trimap >= 170).astype(np.float32) 954 | image_name = os.path.split(image_path)[-1] 955 | 956 | if 'open-images' in dataset_name: 957 | # 量化alpha to 0,1 958 | alpha[alpha < 0.5] = 0 959 | alpha[alpha >= 0.5] = 1 960 | 961 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 962 | 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape, 'dataset_name': dataset_name} 963 | 964 | sample = self.transform(sample) 965 | 966 | # modify 'image' to 'source_image' 967 | sample['source_image'] = sample['image'] 968 | del sample['image'] 969 | 970 | return sample 971 | 972 | -------------------------------------------------------------------------------- /icm/data/data_module.py: -------------------------------------------------------------------------------- 1 | 2 | import pytorch_lightning as pl 3 | from icm.util import instantiate_from_config 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | import torch 7 | def worker_init_fn(worker_id): 8 | # set numpy random seed with torch.randint so each worker has a different seed 9 | np.random.seed(torch.randint(0, 2**32 - 1, size=(1,)).item()) 10 | 11 | class DataModuleFromConfig(pl.LightningDataModule): 12 | def __init__(self, train=None, validation=None, test=None, predict=None, num_workers=None, 13 | batch_size=None, shuffle_train=False,batch_size_val=None): 14 | super().__init__() 15 | self.batch_size = batch_size 16 | self.batch_size_val = batch_size_val 17 | self.dataset_configs = dict() 18 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 19 | self.shuffle_train = shuffle_train 20 | # If a dataset is passed, add it to the dataset configs and create a corresponding dataloader method 21 | if train is not None: 22 | self.dataset_configs["train"] = train 23 | self.train_dataloader = self._train_dataloader 24 | if validation is not None: 25 | self.dataset_configs["validation"] = validation 26 | self.val_dataloader = self._val_dataloader 27 | 28 | # for debugging 29 | # self.setup() 30 | 31 | 32 | def setup(self, stage=None): 33 | # Instantiate datasets from the dataset configs 34 | self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) 35 | 36 | def _train_dataloader(self): 37 | return DataLoader(self.datasets["train"], 38 | batch_size=self.batch_size, 39 | num_workers=self.num_workers, 40 | shuffle=self.shuffle_train, 41 | worker_init_fn=worker_init_fn,) 42 | 43 | def _val_dataloader(self): 44 | 45 | return DataLoader(self.datasets["validation"], 46 | batch_size=self.batch_size if self.batch_size_val is None else self.batch_size_val, 47 | num_workers=self.num_workers, 48 | shuffle=True, 49 | worker_init_fn=worker_init_fn, 50 | ) 51 | 52 | def prepare_data(self): 53 | return super().prepare_data() 54 | -------------------------------------------------------------------------------- /icm/data/image_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import glob 4 | import functools 5 | import numpy as np 6 | 7 | 8 | class ImageFile(object): 9 | def __init__(self, phase='train'): 10 | self.phase = phase 11 | self.rng = np.random.RandomState(0) 12 | 13 | def _get_valid_names(self, *dirs, shuffle=True): 14 | # Extract valid names 15 | name_sets = [self._get_name_set(d) for d in dirs] 16 | 17 | # Reduce 18 | def _join_and(a, b): 19 | return a & b 20 | 21 | valid_names = list(functools.reduce(_join_and, name_sets)) 22 | if shuffle: 23 | self.rng.shuffle(valid_names) 24 | 25 | return valid_names 26 | 27 | @staticmethod 28 | def _get_name_set(dir_name): 29 | path_list = glob.glob(os.path.join(dir_name, '*')) 30 | name_set = set() 31 | for path in path_list: 32 | name = os.path.basename(path) 33 | name = os.path.splitext(name)[0] 34 | name_set.add(name) 35 | return name_set 36 | 37 | @staticmethod 38 | def _list_abspath(data_dir, ext, data_list): 39 | return [os.path.join(data_dir, name + ext) 40 | for name in data_list] 41 | 42 | 43 | class ImageFileTrain(ImageFile): 44 | def __init__(self, 45 | alpha_dir="train_alpha", 46 | fg_dir="train_fg", 47 | bg_dir="train_bg", 48 | alpha_ext=".jpg", 49 | fg_ext=".jpg", 50 | bg_ext=".jpg"): 51 | super(ImageFileTrain, self).__init__(phase="train") 52 | 53 | self.alpha_dir = alpha_dir 54 | self.fg_dir = fg_dir 55 | self.bg_dir = bg_dir 56 | self.alpha_ext = alpha_ext 57 | self.fg_ext = fg_ext 58 | self.bg_ext = bg_ext 59 | 60 | self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir) 61 | self.valid_bg_list = [os.path.splitext( 62 | name)[0] for name in os.listdir(self.bg_dir)] 63 | 64 | self.alpha = self._list_abspath( 65 | self.alpha_dir, self.alpha_ext, self.valid_fg_list) 66 | self.fg = self._list_abspath( 67 | self.fg_dir, self.fg_ext, self.valid_fg_list) 68 | self.bg = self._list_abspath( 69 | self.bg_dir, self.bg_ext, self.valid_bg_list) 70 | 71 | def __len__(self): 72 | return len(self.alpha) 73 | 74 | 75 | class ImageFileTest(ImageFile): 76 | def __init__(self, 77 | alpha_dir="test_alpha", 78 | merged_dir="test_merged", 79 | trimap_dir="test_trimap", 80 | alpha_ext=".png", 81 | merged_ext=".png", 82 | trimap_ext=".png"): 83 | super(ImageFileTest, self).__init__(phase="test") 84 | 85 | self.alpha_dir = alpha_dir 86 | self.merged_dir = merged_dir 87 | self.trimap_dir = trimap_dir 88 | self.alpha_ext = alpha_ext 89 | self.merged_ext = merged_ext 90 | self.trimap_ext = trimap_ext 91 | 92 | self.valid_image_list = self._get_valid_names( 93 | self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False) 94 | 95 | self.alpha = self._list_abspath( 96 | self.alpha_dir, self.alpha_ext, self.valid_image_list) 97 | self.merged = self._list_abspath( 98 | self.merged_dir, self.merged_ext, self.valid_image_list) 99 | self.trimap = self._list_abspath( 100 | self.trimap_dir, self.trimap_ext, self.valid_image_list) 101 | 102 | def __len__(self): 103 | return len(self.alpha) 104 | 105 | 106 | dataset = {'AIM', 'PPM', 'AM2k_train', 'AM2k_val', 107 | 'RWP636', 'P3M_val_np', 'P3M_train', 'P3M_val_p'} 108 | # o_ _o or _5k m_ m_ _input.jpg p_ p_ p_ 109 | 110 | 111 | def get_dir_ext(dataset): 112 | 113 | # assert dataset in ['ICM57', 'ICM']: 114 | image_dir = './datasets/ICM57/image' 115 | label_dir = './datasets/ICM57/alpha' 116 | trimap_dir = './datasets/ICM57/trimap' 117 | 118 | merged_ext = '.jpg' 119 | alpha_ext = '.png' 120 | trimap_ext = '.png' 121 | return image_dir, label_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext 122 | 123 | class MultiImageFile(object): 124 | def __init__(self): 125 | 126 | self.rng = np.random.RandomState(1) 127 | 128 | def _get_valid_names(self, *dirs, shuffle=True): 129 | # Extract valid names 130 | name_sets = [self._get_name_set(d) for d in dirs] 131 | 132 | # Reduce 133 | def _join_and(a, b): 134 | return a & b 135 | 136 | valid_names = list(functools.reduce(_join_and, name_sets)) 137 | 138 | # ensure the order is the same for both training and validation 139 | if shuffle: 140 | valid_names.sort() 141 | self.rng.shuffle(valid_names) 142 | 143 | return valid_names 144 | 145 | @staticmethod 146 | def _get_name_set(dir_name): 147 | path_list = glob.glob(os.path.join(dir_name, '*')) 148 | name_set = set() 149 | for path in path_list: 150 | name = os.path.basename(path) 151 | name = os.path.splitext(name)[0] 152 | name_set.add(name) 153 | return name_set 154 | 155 | @staticmethod 156 | def _list_abspath(data_dir, ext, data_list): 157 | return [os.path.join(data_dir, name + ext) 158 | for name in data_list] 159 | 160 | 161 | class MultiImageFileDoubleSet(MultiImageFile): 162 | def __init__(self, ratio=0.9, dataset_name=['AIM', 'PPM', 'AM2k_train', 'AM2k_val', 'RWP636', 'P3M_val_np']): 163 | 164 | super(MultiImageFileDoubleSet, self).__init__() 165 | 166 | self.alpha_train = [] 167 | self.merged_train = [] 168 | self.trimap_train = [] 169 | self.alpha_val = [] 170 | self.merged_val = [] 171 | self.trimap_val = [] 172 | 173 | for dataset_name_ in dataset_name: 174 | merged_dir, alpha_dir, trimap_dir, merged_ext, alpha_ext, trimap_ext = get_dir_ext( 175 | dataset_name_) 176 | valid_image_list = self._get_valid_names( 177 | alpha_dir, merged_dir, trimap_dir) 178 | 179 | alpha = self._list_abspath(alpha_dir, alpha_ext, valid_image_list) 180 | merged = self._list_abspath( 181 | merged_dir, merged_ext, valid_image_list) 182 | trimap = self._list_abspath( 183 | trimap_dir, trimap_ext, valid_image_list) 184 | 185 | alpha_train, alpha_val = self._split(alpha, ratio) 186 | merged_train, merged_val = self._split(merged, ratio) 187 | trimap_train, trimap_val = self._split(trimap, ratio) 188 | 189 | self.alpha_train.extend(alpha_train) 190 | self.merged_train.extend(merged_train) 191 | self.trimap_train.extend(trimap_train) 192 | self.alpha_val.extend(alpha_val) 193 | self.merged_val.extend(merged_val) 194 | self.trimap_val.extend(trimap_val) 195 | 196 | def _split(self, data_list, ratio): 197 | num = len(data_list) 198 | split = int(num * ratio) 199 | return data_list[:split], data_list[split:] 200 | 201 | 202 | class ContextData(): 203 | ''' 204 | dataset_name: corresponding to the dataset file in /datasets 205 | 206 | return: 207 | dataset: dict2list, key: image_name, 208 | value: {"dataset_name": "AIM", "class": "animal", 209 | "sub_class": null, "HalfOrFull": "half", "TransparentOrOpaque": "SO"} 210 | 211 | image_class_dict: dict, key: class_name, (class-sub_class) 212 | value: dict, key: image_name, value: dataset_name 213 | ''' 214 | 215 | def __init__(self, ratio=0.9, dataset_name=['PPM', 'AM2k', 'RWP636', 'P3M_val_np']): 216 | dataset = {} 217 | for dataset_name_ in dataset_name: 218 | json_dir = os.path.join('datasets', dataset_name_+'.json') 219 | # read json file and append to dataset 220 | with open(json_dir) as f: 221 | new_data = json.load(f) 222 | # filter out the items if each element in "instance_area_ratio" list >0.1 223 | # check if "instance_area_ratio" exists 224 | if 'instance_area_ratio' in new_data[list(new_data.keys())[0]].keys(): 225 | new_data_ = {} 226 | for k, v in new_data.items(): 227 | if min(v['instance_area_ratio']) < 0.3: 228 | x = min(v['instance_area_ratio']) 229 | r = np.random.rand() 230 | if 100*(0.6-x)*x/9 > r**0.2: 231 | new_data_[k] = v 232 | else: 233 | new_data_[k] = v 234 | new_data = new_data_ 235 | dataset.update(new_data) 236 | 237 | # shuffle dataset with seed 238 | self.rng = np.random.RandomState(1) 239 | dataset_list = list(dataset.items()) 240 | dataset_list.sort() 241 | self.rng.shuffle(dataset_list) 242 | 243 | # split dataset into train and val 244 | dataset_train, dataset_val = self._split(dataset_list, ratio) 245 | 246 | # get image_class_dict 247 | image_class_dict_train = self.get_image_class_dict(dataset_train) 248 | image_class_dict_val = self.get_image_class_dict(dataset_val) 249 | 250 | self.image_class_dict_train = image_class_dict_train 251 | self.image_class_dict_val = image_class_dict_val 252 | self.dataset_train = dataset_train 253 | self.dataset_val = dataset_val 254 | 255 | def _split(self, data_list, ratio): 256 | num = len(data_list) 257 | split = int(num * ratio) 258 | 259 | dataset_train = dict(data_list[:split]) 260 | dataset_val = dict(data_list[split:]) 261 | return dataset_train, dataset_val 262 | 263 | def get_image_class_dict(self, dataset): 264 | image_class_dict = {} 265 | for image_name, image_info in dataset.items(): 266 | class_name = str( 267 | image_info['class'])+'-'+str(image_info['sub_class'])+'-'+str(image_info['HalfOrFull']) 268 | if class_name not in image_class_dict.keys(): 269 | image_class_dict[class_name] = {} 270 | image_class_dict[class_name][image_name] = image_info['dataset_name'] 271 | else: 272 | image_class_dict[class_name][image_name] = image_info['dataset_name'] 273 | return image_class_dict 274 | 275 | 276 | if __name__ == "__main__": 277 | 278 | test = MultiImageFileDoubleSet() 279 | print(0) 280 | -------------------------------------------------------------------------------- /icm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /icm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiny-smart/in-context-matting/a58bcf4b948babcf5f1c2e1c41e2fa040bc53d1e/icm/models/__init__.py -------------------------------------------------------------------------------- /icm/models/decoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiny-smart/in-context-matting/a58bcf4b948babcf5f1c2e1c41e2fa040bc53d1e/icm/models/decoder/__init__.py -------------------------------------------------------------------------------- /icm/models/decoder/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | import os 14 | import numpy as np 15 | from PIL import Image 16 | 17 | 18 | class Attention(nn.Module): 19 | """ 20 | An attention layer that allows for downscaling the size of the embedding 21 | after projection to queries, keys, and values. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | embedding_dim: int, 27 | num_heads: int, 28 | downsample_rate: int = 1, 29 | ) -> None: 30 | super().__init__() 31 | self.embedding_dim = embedding_dim 32 | self.internal_dim = embedding_dim // downsample_rate 33 | self.num_heads = num_heads 34 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 35 | 36 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 37 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 38 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 39 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 40 | 41 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 42 | b, n, c = x.shape 43 | x = x.reshape(b, n, num_heads, c // num_heads) 44 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 45 | 46 | def _recombine_heads(self, x: Tensor) -> Tensor: 47 | b, n_heads, n_tokens, c_per_head = x.shape 48 | x = x.transpose(1, 2) 49 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 50 | 51 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 52 | # Input projections 53 | q = self.q_proj(q) 54 | k = self.k_proj(k) 55 | v = self.v_proj(v) 56 | 57 | # Separate into heads 58 | q = self._separate_heads(q, self.num_heads) 59 | k = self._separate_heads(k, self.num_heads) 60 | v = self._separate_heads(v, self.num_heads) 61 | 62 | # Attention 63 | _, _, _, c_per_head = q.shape 64 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 65 | attn = attn / math.sqrt(c_per_head) 66 | attn = torch.softmax(attn, dim=-1) # [b,head,token,token] 67 | 68 | # save_attn_map(attn) 69 | 70 | # Get output 71 | out = attn @ v 72 | out = self._recombine_heads(out) 73 | out = self.out_proj(out) 74 | 75 | return out 76 | 77 | class MLPBlock(nn.Module): 78 | def __init__( 79 | self, 80 | embedding_dim: int, 81 | mlp_dim: int, 82 | act: Type[nn.Module] = nn.GELU, 83 | ) -> None: 84 | super().__init__() 85 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 86 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 87 | self.act = act() 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | return self.lin2(self.act(self.lin1(x))) 91 | 92 | 93 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 94 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 95 | class LayerNorm2d(nn.Module): 96 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 97 | super().__init__() 98 | self.weight = nn.Parameter(torch.ones(num_channels)) 99 | self.bias = nn.Parameter(torch.zeros(num_channels)) 100 | self.eps = eps 101 | 102 | def forward(self, x: torch.Tensor) -> torch.Tensor: 103 | u = x.mean(1, keepdim=True) 104 | s = (x - u).pow(2).mean(1, keepdim=True) 105 | x = (x - u) / torch.sqrt(s + self.eps) 106 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 107 | return x 108 | -------------------------------------------------------------------------------- /icm/models/decoder/bottleneck_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | def c2_msra_fill(module: nn.Module) -> None: 6 | """ 7 | Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. 8 | Also initializes `module.bias` to 0. 9 | 10 | Args: 11 | module (torch.nn.Module): module to initialize. 12 | """ 13 | # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Module, Tensor]`. 14 | nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 15 | if module.bias is not None: 16 | # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, 17 | # torch.Tensor]`. 18 | nn.init.constant_(module.bias, 0) 19 | 20 | class Conv2d(torch.nn.Conv2d): 21 | """ 22 | A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. 23 | """ 24 | 25 | def __init__(self, *args, **kwargs): 26 | """ 27 | Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: 28 | 29 | Args: 30 | norm (nn.Module, optional): a normalization layer 31 | activation (callable(Tensor) -> Tensor): a callable activation function 32 | 33 | It assumes that norm layer is used before activation. 34 | """ 35 | norm = kwargs.pop("norm", None) 36 | activation = kwargs.pop("activation", None) 37 | super().__init__(*args, **kwargs) 38 | 39 | self.norm = norm 40 | self.activation = activation 41 | 42 | def forward(self, x): 43 | # torchscript does not support SyncBatchNorm yet 44 | # https://github.com/pytorch/pytorch/issues/40507 45 | # and we skip these codes in torchscript since: 46 | # 1. currently we only support torchscript in evaluation mode 47 | # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or 48 | # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. 49 | if not torch.jit.is_scripting(): 50 | if x.numel() == 0 and self.training: 51 | # https://github.com/pytorch/pytorch/issues/12013 52 | assert not isinstance( 53 | self.norm, torch.nn.SyncBatchNorm 54 | ), "SyncBatchNorm does not support empty inputs!" 55 | 56 | x = F.conv2d( 57 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups 58 | ) 59 | if self.norm is not None: 60 | x = self.norm(x) 61 | if self.activation is not None: 62 | x = self.activation(x) 63 | return x 64 | 65 | def get_norm(norm, out_channels): 66 | """ 67 | Args: 68 | norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; 69 | or a callable that takes a channel number and returns 70 | the normalization layer as a nn.Module. 71 | 72 | Returns: 73 | nn.Module or None: the normalization layer 74 | """ 75 | if norm is None: 76 | return None 77 | if isinstance(norm, str): 78 | if len(norm) == 0: 79 | return None 80 | norm = { 81 | "BN": nn.BatchNorm2d, 82 | "GN": lambda channels: nn.GroupNorm(32, channels), 83 | }[norm] 84 | return norm(out_channels) 85 | 86 | class CNNBlockBase(nn.Module): 87 | """ 88 | A CNN block is assumed to have input channels, output channels and a stride. 89 | The input and output of `forward()` method must be NCHW tensors. 90 | The method can perform arbitrary computation but must match the given 91 | channels and stride specification. 92 | 93 | Attribute: 94 | in_channels (int): 95 | out_channels (int): 96 | stride (int): 97 | """ 98 | 99 | def __init__(self, in_channels, out_channels, stride): 100 | """ 101 | The `__init__` method of any subclass should also contain these arguments. 102 | 103 | Args: 104 | in_channels (int): 105 | out_channels (int): 106 | stride (int): 107 | """ 108 | super().__init__() 109 | self.in_channels = in_channels 110 | self.out_channels = out_channels 111 | self.stride = stride 112 | 113 | def freeze(self): 114 | """ 115 | Make this block not trainable. 116 | This method sets all parameters to `requires_grad=False`, 117 | and convert all BatchNorm layers to FrozenBatchNorm 118 | 119 | Returns: 120 | the block itself 121 | """ 122 | for p in self.parameters(): 123 | p.requires_grad = False 124 | return self 125 | 126 | class BottleneckBlock(CNNBlockBase): 127 | """ 128 | The standard bottleneck residual block used by ResNet-50, 101 and 152 129 | defined in :paper:`ResNet`. It contains 3 conv layers with kernels 130 | 1x1, 3x3, 1x1, and a projection shortcut if needed. 131 | """ 132 | 133 | def __init__( 134 | self, 135 | in_channels, 136 | out_channels, 137 | *, 138 | bottleneck_channels, 139 | stride=1, 140 | num_groups=1, 141 | norm="GN", 142 | stride_in_1x1=False, 143 | dilation=1, 144 | ): 145 | """ 146 | Args: 147 | bottleneck_channels (int): number of output channels for the 3x3 148 | "bottleneck" conv layers. 149 | num_groups (int): number of groups for the 3x3 conv layer. 150 | norm (str or callable): normalization for all conv layers. 151 | See :func:`layers.get_norm` for supported format. 152 | stride_in_1x1 (bool): when stride>1, whether to put stride in the 153 | first 1x1 convolution or the bottleneck 3x3 convolution. 154 | dilation (int): the dilation rate of the 3x3 conv layer. 155 | """ 156 | super().__init__(in_channels, out_channels, stride) 157 | 158 | if in_channels != out_channels: 159 | self.shortcut = Conv2d( 160 | in_channels, 161 | out_channels, 162 | kernel_size=1, 163 | stride=stride, 164 | bias=False, 165 | norm=get_norm(norm, out_channels), 166 | ) 167 | else: 168 | self.shortcut = None 169 | 170 | # The original MSRA ResNet models have stride in the first 1x1 conv 171 | # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have 172 | # stride in the 3x3 conv 173 | stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) 174 | 175 | self.conv1 = Conv2d( 176 | in_channels, 177 | bottleneck_channels, 178 | kernel_size=1, 179 | stride=stride_1x1, 180 | bias=False, 181 | norm=get_norm(norm, bottleneck_channels), 182 | ) 183 | 184 | self.conv2 = Conv2d( 185 | bottleneck_channels, 186 | bottleneck_channels, 187 | kernel_size=3, 188 | stride=stride_3x3, 189 | padding=1 * dilation, 190 | bias=False, 191 | groups=num_groups, 192 | dilation=dilation, 193 | norm=get_norm(norm, bottleneck_channels), 194 | ) 195 | 196 | self.conv3 = Conv2d( 197 | bottleneck_channels, 198 | out_channels, 199 | kernel_size=1, 200 | bias=False, 201 | norm=get_norm(norm, out_channels), 202 | ) 203 | 204 | for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: 205 | if layer is not None: # shortcut can be None 206 | c2_msra_fill(layer) 207 | 208 | # Zero-initialize the last normalization in each residual branch, 209 | # so that at the beginning, the residual branch starts with zeros, 210 | # and each residual block behaves like an identity. 211 | # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": 212 | # "For BN layers, the learnable scaling coefficient γ is initialized 213 | # to be 1, except for each residual block's last BN 214 | # where γ is initialized to be 0." 215 | 216 | # nn.init.constant_(self.conv3.norm.weight, 0) 217 | # TODO this somehow hurts performance when training GN models from scratch. 218 | # Add it as an option when we need to use this code to train a backbone. 219 | 220 | def forward(self, x): 221 | out = self.conv1(x) 222 | out = F.relu_(out) 223 | 224 | out = self.conv2(out) 225 | out = F.relu_(out) 226 | 227 | out = self.conv3(out) 228 | 229 | if self.shortcut is not None: 230 | shortcut = self.shortcut(x) 231 | else: 232 | shortcut = x 233 | 234 | out += shortcut 235 | out = F.relu_(out) 236 | return out -------------------------------------------------------------------------------- /icm/models/decoder/detail_capture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Basic_Conv3x3(nn.Module): 6 | """ 7 | Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. 8 | """ 9 | def __init__( 10 | self, 11 | in_chans, 12 | out_chans, 13 | stride=1, 14 | padding=1, 15 | ): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) 18 | self.bn = nn.BatchNorm2d(out_chans) 19 | self.relu = nn.ReLU(True) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | x = self.bn(x) 24 | x = self.relu(x) 25 | 26 | return x 27 | 28 | 29 | class Basic_Conv3x3_attn(nn.Module): 30 | """ 31 | Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. 32 | """ 33 | def __init__( 34 | self, 35 | in_chans, 36 | out_chans, 37 | res = False, 38 | stride=1, 39 | padding=1, 40 | ): 41 | super().__init__() 42 | self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) 43 | self.ln = nn.LayerNorm(in_chans, elementwise_affine=True) 44 | self.relu = nn.ReLU(True) 45 | 46 | def forward(self, x): 47 | x = self.ln(x) 48 | x = x.permute(0, 3, 1, 2) 49 | x = self.relu(x) 50 | x = self.conv(x) 51 | 52 | return x 53 | 54 | # class Basic_Conv3x3_attn(nn.Module): 55 | # """ 56 | # Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. 57 | # """ 58 | # def __init__( 59 | # self, 60 | # in_chans, 61 | # out_chans, 62 | # res = False, 63 | # stride=1, 64 | # padding=1, 65 | # ): 66 | # super().__init__() 67 | # self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) 68 | # self.ln = nn.LayerNorm([in_chans, res, res], elementwise_affine=True) 69 | # self.relu = nn.ReLU(True) 70 | 71 | # def forward(self, x): 72 | # x = x.permute(0, 3, 1, 2) 73 | # x = self.ln(x) 74 | # x = self.relu(x) 75 | # x = self.conv(x) 76 | 77 | # return x 78 | 79 | class ConvStream(nn.Module): 80 | """ 81 | Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. 82 | """ 83 | def __init__( 84 | self, 85 | in_chans = 4, 86 | out_chans = [48, 96, 192], 87 | ): 88 | super().__init__() 89 | self.convs = nn.ModuleList() 90 | 91 | self.conv_chans = out_chans 92 | self.conv_chans.insert(0, in_chans) 93 | 94 | for i in range(len(self.conv_chans)-1): 95 | in_chan_ = self.conv_chans[i] 96 | out_chan_ = self.conv_chans[i+1] 97 | self.convs.append( 98 | Basic_Conv3x3(in_chan_, out_chan_, stride=2) 99 | ) 100 | 101 | def forward(self, x): 102 | out_dict = {'D0': x} 103 | for i in range(len(self.convs)): 104 | x = self.convs[i](x) 105 | check = self.convs[i] 106 | name_ = 'D'+str(i+1) 107 | out_dict[name_] = x 108 | 109 | return out_dict 110 | 111 | class Fusion_Block(nn.Module): 112 | """ 113 | Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. 114 | """ 115 | def __init__( 116 | self, 117 | in_chans, 118 | out_chans, 119 | ): 120 | super().__init__() 121 | self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) 122 | 123 | def forward(self, x, D): 124 | F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 125 | out = torch.cat([D, F_up], dim=1) 126 | out = self.conv(out) 127 | 128 | return out 129 | 130 | class Matting_Head(nn.Module): 131 | """ 132 | Simple Matting Head, containing only conv3x3 and conv1x1 layers. 133 | """ 134 | def __init__( 135 | self, 136 | in_chans = 32, 137 | mid_chans = 16, 138 | ): 139 | super().__init__() 140 | self.matting_convs = nn.Sequential( 141 | nn.Conv2d(in_chans, mid_chans, 3, 1, 1), 142 | nn.BatchNorm2d(mid_chans), 143 | nn.ReLU(True), 144 | nn.Conv2d(mid_chans, 1, 1, 1, 0) 145 | ) 146 | 147 | def forward(self, x): 148 | x = self.matting_convs(x) 149 | 150 | return x 151 | 152 | # TODO: implement groupnorm and ws. In ODISE, bs=2, they work well; when bs = 1, mse loss will be nan, why? 153 | class DetailCapture(nn.Module): 154 | """ 155 | Simple and Lightweight Detail Capture Module for ViT Matting. 156 | """ 157 | def __init__( 158 | self, 159 | in_chans = 384, 160 | img_chans=4, 161 | convstream_out = [48, 96, 192], 162 | fusion_out = [256, 128, 64, 32], 163 | ckpt=None, 164 | use_sigmoid = True, 165 | ): 166 | super().__init__() 167 | assert len(fusion_out) == len(convstream_out) + 1 168 | 169 | self.convstream = ConvStream(in_chans = img_chans) 170 | self.conv_chans = self.convstream.conv_chans 171 | 172 | self.fusion_blks = nn.ModuleList() 173 | self.fus_channs = fusion_out.copy() 174 | self.fus_channs.insert(0, in_chans) 175 | for i in range(len(self.fus_channs)-1): 176 | self.fusion_blks.append( 177 | Fusion_Block( 178 | in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], 179 | out_chans = self.fus_channs[i+1], 180 | ) 181 | ) 182 | 183 | self.matting_head = Matting_Head( 184 | in_chans = fusion_out[-1], 185 | ) 186 | 187 | if ckpt != None and ckpt != '': 188 | self.load_state_dict(ckpt['state_dict'], strict=False) 189 | print('load detail capture ckpt from', ckpt['path']) 190 | 191 | self.use_sigmoid = use_sigmoid 192 | self.img_chans = img_chans 193 | def forward(self, features, images): 194 | 195 | if isinstance(features, dict): 196 | 197 | trimap = features['trimap'] 198 | features = features['feature'] 199 | if self.img_chans ==4: 200 | images = torch.cat([images, trimap], dim=1) 201 | 202 | detail_features = self.convstream(images) 203 | # D0 2 4 512 512 204 | # D1 2 48 256 256 205 | # D2 2 96 128 128 206 | # D3 2 192 64 64 207 | for i in range(len(self.fusion_blks)): 208 | d_name_ = 'D'+str(len(self.fusion_blks)-i-1) 209 | features = self.fusion_blks[i](features, detail_features[d_name_]) 210 | 211 | if self.use_sigmoid: 212 | phas = torch.sigmoid(self.matting_head(features)) 213 | else: 214 | phas = self.matting_head(features) 215 | return phas 216 | 217 | def get_trainable_params(self): 218 | return list(self.parameters()) 219 | 220 | class MaskDecoder(nn.Module): 221 | ''' 222 | use trans-conv to decode mask 223 | ''' 224 | def __init__( 225 | self, 226 | in_chans = 384, 227 | ): 228 | super().__init__() 229 | self.output_upscaling = nn.Sequential( 230 | nn.ConvTranspose2d(in_chans, in_chans // 4, kernel_size=2, stride=2), 231 | # LayerNorm2d(in_chans // 4), 232 | nn.BatchNorm2d(in_chans // 4), 233 | nn.ReLU(), 234 | nn.ConvTranspose2d(in_chans // 4, in_chans // 8, kernel_size=2, stride=2), 235 | nn.BatchNorm2d(in_chans // 8), 236 | nn.ReLU(), 237 | ) 238 | 239 | self.matting_head = Matting_Head( 240 | in_chans = in_chans // 8, 241 | ) 242 | 243 | def forward(self, x, images): 244 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 245 | x = self.output_upscaling(x) 246 | x = self.matting_head(x) 247 | x = torch.sigmoid(x) 248 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 249 | return x -------------------------------------------------------------------------------- /icm/models/decoder/in_context_correspondence.py: -------------------------------------------------------------------------------- 1 | 2 | from einops import rearrange 3 | from torch import einsum 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from icm.models.decoder.bottleneck_block import BottleneckBlock 9 | 10 | from icm.models.decoder.detail_capture import Basic_Conv3x3, Basic_Conv3x3_attn, Fusion_Block 11 | import math 12 | from icm.models.decoder.attention import Attention, MLPBlock 13 | 14 | 15 | class OneWayAttentionBlock(nn.Module): 16 | def __init__( 17 | self, 18 | dim, 19 | n_heads, 20 | d_head, 21 | mlp_dim_rate, 22 | ): 23 | super().__init__() 24 | 25 | self.attn = Attention(dim, n_heads, downsample_rate=dim//d_head) 26 | 27 | self.norm1 = nn.LayerNorm(dim) 28 | 29 | self.mlp = MLPBlock(dim, int(dim*mlp_dim_rate)) 30 | 31 | self.norm2 = nn.LayerNorm(dim) 32 | 33 | def forward(self, q, context_all): 34 | output = [] 35 | for i in range(len(q)): 36 | x = q[i].unsqueeze(0) 37 | context = context_all[i].unsqueeze(0) 38 | x = self.norm1(x) 39 | context = self.norm1(context) 40 | x = self.attn(q=x, k=context, v=context) + x 41 | x = self.norm2(x) 42 | x = self.mlp(x) + x 43 | output.append(x.squeeze(0)) 44 | 45 | return output 46 | 47 | 48 | def compute_correspondence_matrix(source_feature, ref_feature): 49 | """ 50 | Compute correspondence matrix between source and reference features. 51 | Args: 52 | source_feature: [B, C, H, W] 53 | ref_feature: [B, C, H, W] 54 | Returns: 55 | correspondence_matrix: [B, H*W, H*W] 56 | """ 57 | # [B, C, H, W] -> [B, H, W, C] 58 | source_feature = source_feature.permute(0, 2, 3, 1) 59 | ref_feature = ref_feature.permute(0, 2, 3, 1) 60 | 61 | # [B, H, W, C] -> [B, H*W, C] 62 | source_feature = torch.reshape( 63 | source_feature, [source_feature.shape[0], -1, source_feature.shape[-1]]) 64 | ref_feature = torch.reshape( 65 | ref_feature, [ref_feature.shape[0], -1, ref_feature.shape[-1]]) 66 | 67 | # norm 68 | source_feature = F.normalize(source_feature, p=2, dim=-1) 69 | ref_feature = F.normalize(ref_feature, p=2, dim=-1) 70 | 71 | # cosine similarity 72 | cos_sim = torch.matmul( 73 | source_feature, ref_feature.transpose(1, 2)) # [B, H*W, H*W] 74 | 75 | return cos_sim 76 | 77 | 78 | def maskpooling(mask, res): 79 | ''' 80 | Mask pooling to reduce the resolution of mask 81 | Input: 82 | mask: [B, 1, H, W] 83 | res: resolution 84 | Output: [B, 1, res, res] 85 | ''' 86 | # mask[mask < 1] = 0 87 | mask[mask > 0] = 1 88 | mask = -1 * mask 89 | kernel_size = mask.shape[2] // res 90 | mask = F.max_pool2d(mask, kernel_size=kernel_size, 91 | stride=kernel_size, padding=0) 92 | mask = -1*mask 93 | return mask 94 | 95 | 96 | def dilate(bin_img, ksize=5): 97 | pad = (ksize - 1) // 2 98 | bin_img = F.pad(bin_img, pad=[pad, pad, pad, pad], mode='reflect') 99 | out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0) 100 | return out 101 | 102 | 103 | def erode(bin_img, ksize=5): 104 | out = 1 - dilate(1 - bin_img, ksize) 105 | return out 106 | 107 | 108 | def generate_trimap(mask, erode_kernel_size=10, dilate_kernel_size=10): 109 | eroded = erode(mask, erode_kernel_size) 110 | dilated = dilate(mask, dilate_kernel_size) 111 | trimap = torch.zeros_like(mask) 112 | trimap[dilated == 1] = 0.5 113 | trimap[eroded == 1] = 1 114 | return trimap 115 | 116 | 117 | def calculate_attention_score_(mask, attention_map, score_type): 118 | ''' 119 | Calculate the attention score of the attention map 120 | mask: [B, H, W] value:0 or 1 121 | attention_map: [B, H*W, H, W] 122 | ''' 123 | B, H, W = mask.shape 124 | mask_pos = mask.repeat(1, attention_map.shape[1], 1, 1) # [B, H*W, H, W] 125 | score_pos = torch.sum(attention_map * mask_pos, dim=(2, 3)) # [B, H*W] 126 | score_pos = score_pos / torch.sum(mask_pos, dim=(2, 3)) # [B, H*W] 127 | 128 | mask_neg = 1 - mask_pos 129 | score_neg = torch.sum(attention_map * mask_neg, dim=(2, 3)) 130 | score_neg = score_neg / torch.sum(mask_neg, dim=(2, 3)) 131 | 132 | assert score_type in ['classification', 'softmax', 'ratio'] 133 | 134 | if score_type == 'classification': 135 | score = torch.zeros_like(score_pos) 136 | score[score_pos > score_neg] = 1 137 | 138 | return score.reshape(B, H, W) 139 | 140 | 141 | def refine_mask_by_attention(mask, attention_maps, iterations=10, score_type='classification'): 142 | ''' 143 | mask: [B, H, W] 144 | attention_maps: [B, H*W, H, W] 145 | ''' 146 | assert mask.shape[1] == attention_maps.shape[2] 147 | for i in range(iterations): 148 | score = calculate_attention_score_( 149 | mask, attention_maps, score_type=score_type) # [B, H, W] 150 | 151 | if score.equal(mask): 152 | # print("iteration: ", i, "score equal to mask") 153 | break 154 | else: 155 | mask = score 156 | 157 | assert i != iterations - 1 158 | return mask 159 | 160 | 161 | class InContextCorrespondence(nn.Module): 162 | ''' 163 | one implementation of in_context_fusion 164 | 165 | forward(feature_of_reference_image, feature_of_source_image, guidance_on_reference_image) 166 | ''' 167 | 168 | def __init__(self, 169 | use_bottle_neck=False, 170 | in_dim=1280, 171 | bottle_neck_dim=512, 172 | refine_with_attention=False, 173 | ): 174 | super().__init__() 175 | self.use_bottle_neck = use_bottle_neck 176 | self.refine_with_attention = refine_with_attention 177 | 178 | def forward(self, feature_of_reference_image, ft_attn_of_source_image, guidance_on_reference_image): 179 | ''' 180 | feature_of_reference_image: [B, C, H, W] 181 | ft_attn_of_source_image: {"ft": [B, C, H, W], "attn": [B, H_1, W_1, H_1*W_1]} 182 | guidance_on_reference_image: [B, 1, H_2, W_2] 183 | ''' 184 | 185 | # get source_image h,w 186 | h, w = guidance_on_reference_image.shape[-2:] 187 | h_attn, w_attn = ft_attn_of_source_image['attn'].shape[-3:-1] 188 | 189 | feature_of_source_image = ft_attn_of_source_image['ft'] 190 | attention_map_of_source_image = ft_attn_of_source_image['attn'] 191 | 192 | cos_sim = compute_correspondence_matrix( 193 | feature_of_source_image, feature_of_reference_image) 194 | 195 | # 获得cos_sim每一行的最大值的索引 196 | index = torch.argmax(cos_sim, dim=-1) # 1*1024 197 | 198 | mask_ref = maskpooling(guidance_on_reference_image, 199 | h_attn) 200 | 201 | mask_ref = mask_ref.reshape(mask_ref.shape[0], -1) # 1*1024 202 | 203 | new_index = torch.gather(mask_ref, 1, index) # 1*1024 204 | res = int(new_index.shape[-1]**0.5) 205 | new_index = new_index.reshape( 206 | new_index.shape[0], res, res).unsqueeze(1) 207 | 208 | # resize mask_result to 512*512 209 | mask_result = new_index 210 | 211 | if self.refine_with_attention: 212 | mask_result = refine_mask_by_attention( 213 | mask_result, attention_map_of_source_image, iterations=10, score_type='classification') 214 | 215 | mask_result = F.interpolate( 216 | mask_result.float(), size=(h, w), mode='bilinear') 217 | 218 | # get trimap 219 | 220 | pesudo_trimap = generate_trimap( 221 | mask_result, erode_kernel_size=self.kernel_size, dilate_kernel_size=self.kernel_size) 222 | 223 | output = {} 224 | output['trimap'] = pesudo_trimap 225 | output['feature'] = feature_of_source_image 226 | output['mask'] = mask_result 227 | 228 | return output 229 | 230 | 231 | class TrainingFreeAttention(nn.Module): 232 | def __init__(self, res_ratio=4, pool_type='average', temp_softmax=1, use_scale=False, upsample_mode='bilinear', use_norm=False) -> None: 233 | super().__init__() 234 | self.res_ratio = res_ratio 235 | self.pool_type = pool_type 236 | self.temp_softmax = temp_softmax 237 | self.use_scale = use_scale 238 | self.upsample_mode = upsample_mode 239 | if use_norm: 240 | self.norm = nn.LayerNorm(use_norm, elementwise_affine=True) 241 | else: 242 | self.idt = nn.Identity() 243 | 244 | def forward(self, features, features_ref, roi_mask,): 245 | # roi_mask: [B, 1, H, W] 246 | # features: [B, C, h, w] 247 | # features_ref: [B, C, h, w] 248 | B, _, H, W = roi_mask.shape 249 | if self.res_ratio == None: 250 | H_attn, W_attn = features.shape[2], features.shape[3] 251 | else: 252 | H_attn = H//self.res_ratio 253 | W_attn = W//self.res_ratio 254 | features, features_ref = self.resize_input_to_res( 255 | features, features_ref, (H, W)) # [H//res_ratio, W//res_ratio] 256 | 257 | # List, len = B, each element: [C_q, dim], dim = H//res_ratio * W//res_ratio 258 | features_ref = self.get_roi_features(features_ref, roi_mask) 259 | 260 | features = features.reshape( 261 | B, -1, features.shape[2] * features.shape[3]).permute(0, 2, 1) # [B, C, dim] 262 | # List, len = B, each element: [C_q, C] 263 | attn_output = self.compute_attention(features, features_ref) 264 | 265 | # List, len = B, each element: [C_q, H_attn, W_attn] 266 | attn_output = self.reshape_attn_output(attn_output, (H_attn, W_attn)) 267 | 268 | return attn_output 269 | 270 | def resize_input_to_res(self, features, features_ref, size): 271 | # features: [B, C, h, w] 272 | # features_ref: [B, C, h, w] 273 | H, W = size 274 | target_H, target_W = H//self.res_ratio, W//self.res_ratio 275 | features = F.interpolate(features, size=( 276 | target_H, target_W), mode=self.upsample_mode) 277 | features_ref = F.interpolate(features_ref, size=( 278 | target_H, target_W), mode=self.upsample_mode) 279 | return features, features_ref 280 | 281 | def get_roi_features(self, feature, mask): 282 | ''' 283 | get feature tokens by maskpool 284 | feature: [B, C, h, w] 285 | mask: [B, 1, H, W] [0,1] 286 | return: List, len = B, each element: [token_num, C] 287 | ''' 288 | 289 | # assert mask only has elements 0 and 1 290 | assert torch.all(torch.logical_or(mask == 0, mask == 1)) 291 | # assert mask.max() == 1 and mask.min() == 0 292 | 293 | B, _, H, W = mask.shape 294 | h, w = feature.shape[2:] 295 | 296 | output = [] 297 | for i in range(B): 298 | mask_ = mask[i] 299 | feature_ = feature[i] 300 | feature_ = self.maskpool(feature_, mask_) 301 | output.append(feature_) 302 | return output 303 | 304 | def maskpool(self, feature, mask): 305 | ''' 306 | get feature tokens by maskpool 307 | feature: [C, h, w] 308 | mask: [1, H, W] [0,1] 309 | return: [token_num, C] 310 | ''' 311 | kernel_size = mask.shape[1] // feature.shape[1] if self.res_ratio == None else self.res_ratio 312 | if self.pool_type == 'max': 313 | mask = F.max_pool2d(mask, kernel_size=kernel_size, 314 | stride=kernel_size, padding=0) 315 | elif self.pool_type == 'average': 316 | mask = F.avg_pool2d(mask, kernel_size=kernel_size, 317 | stride=kernel_size, padding=0) 318 | elif self.pool_type == 'min': 319 | mask = -1*mask 320 | mask = F.max_pool2d(mask, kernel_size=kernel_size, 321 | stride=kernel_size, padding=0) 322 | mask = -1*mask 323 | else: 324 | raise NotImplementedError 325 | 326 | # element-wise multiplication mask and feature 327 | feature = feature * mask 328 | 329 | index = (mask > 0).reshape(1, -1).squeeze() 330 | feature = feature.reshape(feature.shape[0], -1).permute(1, 0) 331 | 332 | feature = feature[index] 333 | return feature 334 | 335 | def compute_attention(self, features, features_ref): 336 | ''' 337 | features: [B, C, dim] 338 | features_ref: List, len = B, each element: [C_q, dim] 339 | return: List, len = B, each element: [C_q, C] 340 | ''' 341 | output = [] 342 | for i in range(len(features_ref)): 343 | feature_ref = features_ref[i] 344 | feature = features[i] 345 | feature = self.compute_attention_single(feature, feature_ref) 346 | output.append(feature) 347 | return output 348 | 349 | def compute_attention_single(self, feature, feature_ref): 350 | ''' 351 | compute attention with softmax 352 | feature: [C, dim] 353 | feature_ref: [C_q, dim] 354 | return: [C_q, C] 355 | ''' 356 | scale = feature.shape[-1]**-0.5 if self.use_scale else 1.0 357 | feature = self.norm(feature) if hasattr(self, 'norm') else feature 358 | feature_ref = self.norm(feature_ref) if hasattr( 359 | self, 'norm') else feature_ref 360 | sim = einsum('i d, j d -> i j', feature_ref, feature)*scale 361 | sim = sim/self.temp_softmax 362 | sim = sim.softmax(dim=-1) 363 | return sim 364 | 365 | def reshape_attn_output(self, attn_output, attn_size): 366 | ''' 367 | attn_output: List, len = B, each element: [C_q, C] 368 | return: List, len = B, each element: [C_q, H_attn, W_attn] 369 | ''' 370 | # attn_output[0].shape[1] sqrt to get H_attn, W_attn 371 | H_attn, W_attn = attn_size 372 | 373 | output = [] 374 | for i in range(len(attn_output)): 375 | attn_output_ = attn_output[i] 376 | attn_output_ = attn_output_.reshape( 377 | attn_output_.shape[0], H_attn, W_attn) 378 | output.append(attn_output_) 379 | return output 380 | 381 | 382 | class TrainingCrossAttention(nn.Module): 383 | def __init__(self, res_ratio=4, pool_type='average', temp_softmax=1, use_scale=False, upsample_mode='bilinear', use_norm=False, dim=1280, 384 | n_heads=4, 385 | d_head=320, 386 | mlp_dim_rate=0.5,) -> None: 387 | super().__init__() 388 | self.res_ratio = res_ratio 389 | self.pool_type = pool_type 390 | self.temp_softmax = temp_softmax 391 | self.use_scale = use_scale 392 | self.upsample_mode = upsample_mode 393 | if use_norm: 394 | self.norm = nn.LayerNorm(use_norm, elementwise_affine=True) 395 | else: 396 | self.idt = nn.Identity() 397 | 398 | self.attn_module = OneWayAttentionBlock( 399 | dim, 400 | n_heads, 401 | d_head, 402 | mlp_dim_rate, 403 | ) 404 | 405 | def forward(self, features, features_ref, roi_mask,): 406 | # roi_mask: [B, 1, H, W] 407 | # features: [B, C, h, w] 408 | # features_ref: [B, C, h, w] 409 | B, _, H, W = roi_mask.shape 410 | if self.res_ratio == None: 411 | H_attn, W_attn = features.shape[2], features.shape[3] 412 | else: 413 | H_attn = H//self.res_ratio 414 | W_attn = W//self.res_ratio 415 | features, features_ref = self.resize_input_to_res( 416 | features, features_ref, (H, W)) # [H//res_ratio, W//res_ratio] 417 | 418 | # List, len = B, each element: [C_q, dim], dim = H//res_ratio * W//res_ratio 419 | features_ref = self.get_roi_features(features_ref, roi_mask) 420 | 421 | features = features.reshape( 422 | B, -1, features.shape[2] * features.shape[3]).permute(0, 2, 1) # [B, C, dim] 423 | # List, len = B, each element: [C_q, C] 424 | 425 | features_ref = self.attn_module(features_ref, features) 426 | 427 | attn_output = self.compute_attention(features, features_ref) 428 | 429 | # List, len = B, each element: [C_q, H_attn, W_attn] 430 | attn_output = self.reshape_attn_output(attn_output, (H_attn, W_attn)) 431 | 432 | return attn_output 433 | 434 | def resize_input_to_res(self, features, features_ref, size): 435 | # features: [B, C, h, w] 436 | # features_ref: [B, C, h, w] 437 | H, W = size 438 | target_H, target_W = H//self.res_ratio, W//self.res_ratio 439 | features = F.interpolate(features, size=( 440 | target_H, target_W), mode=self.upsample_mode) 441 | features_ref = F.interpolate(features_ref, size=( 442 | target_H, target_W), mode=self.upsample_mode) 443 | return features, features_ref 444 | 445 | def get_roi_features(self, feature, mask): 446 | ''' 447 | get feature tokens by maskpool 448 | feature: [B, C, h, w] 449 | mask: [B, 1, H, W] [0,1] 450 | return: List, len = B, each element: [token_num, C] 451 | ''' 452 | 453 | # assert mask only has elements 0 and 1 454 | assert torch.all(torch.logical_or(mask == 0, mask == 1)) 455 | # assert mask.max() == 1 and mask.min() == 0 456 | 457 | B, _, H, W = mask.shape 458 | h, w = feature.shape[2:] 459 | 460 | output = [] 461 | for i in range(B): 462 | mask_ = mask[i] 463 | feature_ = feature[i] 464 | feature_ = self.maskpool(feature_, mask_) 465 | output.append(feature_) 466 | return output 467 | 468 | def maskpool(self, feature, mask): 469 | ''' 470 | get feature tokens by maskpool 471 | feature: [C, h, w] 472 | mask: [1, H, W] [0,1] 473 | return: [token_num, C] 474 | ''' 475 | kernel_size = mask.shape[1] // feature.shape[1] if self.res_ratio == None else self.res_ratio 476 | if self.pool_type == 'max': 477 | mask = F.max_pool2d(mask, kernel_size=kernel_size, 478 | stride=kernel_size, padding=0) 479 | elif self.pool_type == 'average': 480 | mask = F.avg_pool2d(mask, kernel_size=kernel_size, 481 | stride=kernel_size, padding=0) 482 | elif self.pool_type == 'min': 483 | mask = -1*mask 484 | mask = F.max_pool2d(mask, kernel_size=kernel_size, 485 | stride=kernel_size, padding=0) 486 | mask = -1*mask 487 | else: 488 | raise NotImplementedError 489 | 490 | # element-wise multiplication mask and feature 491 | feature = feature * mask 492 | 493 | index = (mask > 0).reshape(1, -1).squeeze() 494 | feature = feature.reshape(feature.shape[0], -1).permute(1, 0) 495 | 496 | feature = feature[index] 497 | return feature 498 | 499 | def compute_attention(self, features, features_ref): 500 | ''' 501 | features: [B, C, dim] 502 | features_ref: List, len = B, each element: [C_q, dim] 503 | return: List, len = B, each element: [C_q, C] 504 | ''' 505 | output = [] 506 | for i in range(len(features_ref)): 507 | feature_ref = features_ref[i] 508 | feature = features[i] 509 | feature = self.compute_attention_single(feature, feature_ref) 510 | output.append(feature) 511 | return output 512 | 513 | def compute_attention_single(self, feature, feature_ref): 514 | ''' 515 | compute attention with softmax 516 | feature: [C, dim] 517 | feature_ref: [C_q, dim] 518 | return: [C_q, C] 519 | ''' 520 | scale = feature.shape[-1]**-0.5 if self.use_scale else 1.0 521 | feature = self.norm(feature) if hasattr(self, 'norm') else feature 522 | feature_ref = self.norm(feature_ref) if hasattr( 523 | self, 'norm') else feature_ref 524 | sim = einsum('i d, j d -> i j', feature_ref, feature)*scale 525 | sim = sim/self.temp_softmax 526 | sim = sim.softmax(dim=-1) 527 | return sim 528 | 529 | def reshape_attn_output(self, attn_output, attn_size): 530 | ''' 531 | attn_output: List, len = B, each element: [C_q, C] 532 | return: List, len = B, each element: [C_q, H_attn, W_attn] 533 | ''' 534 | # attn_output[0].shape[1] sqrt to get H_attn, W_attn 535 | H_attn, W_attn = attn_size 536 | 537 | output = [] 538 | for i in range(len(attn_output)): 539 | attn_output_ = attn_output[i] 540 | attn_output_ = attn_output_.reshape( 541 | attn_output_.shape[0], H_attn, W_attn) 542 | output.append(attn_output_) 543 | return output 544 | 545 | 546 | class TrainingFreeAttentionBlocks(nn.Module): 547 | ''' 548 | one implementation of in_context_fusion 549 | 550 | forward(feature_of_reference_image, feature_of_source_image, guidance_on_reference_image) 551 | ''' 552 | 553 | def __init__(self, 554 | res_ratio=8, 555 | pool_type='min', 556 | temp_softmax=1000, 557 | use_scale=False, 558 | upsample_mode='bicubic', 559 | bottle_neck_dim=None, 560 | use_norm=False, 561 | 562 | ): 563 | super().__init__() 564 | 565 | self.attn_module = TrainingFreeAttention(res_ratio=res_ratio, 566 | pool_type=pool_type, 567 | temp_softmax=temp_softmax, 568 | use_scale=use_scale, 569 | upsample_mode=upsample_mode, 570 | use_norm=use_norm,) 571 | 572 | def forward(self, feature_of_reference_image, ft_attn_of_source_image, guidance_on_reference_image): 573 | ''' 574 | feature_of_reference_image: [B, C, H, W] 575 | ft_attn_of_source_image: {"ft_cor": [B, C, H, W], "attn": {'24':[B, H_1, W_1, H_1*W_1],} "ft_matting": [B, C, H, W]} 576 | guidance_on_reference_image: [B, 1, H_2, W_2] 577 | ''' 578 | # assert feature_of_reference_image.shape[0] == 1 579 | # get source_image h,w 580 | h, w = guidance_on_reference_image.shape[-2:] 581 | 582 | features_cor = ft_attn_of_source_image['ft_cor'] 583 | features_matting = ft_attn_of_source_image['ft_matting'] 584 | features_ref = feature_of_reference_image 585 | 586 | guidance_on_reference_image[guidance_on_reference_image > 0.5] = 1 587 | guidance_on_reference_image[guidance_on_reference_image <= 0.5] = 0 588 | attn_output = self.attn_module( 589 | features_cor, features_ref, guidance_on_reference_image) 590 | 591 | attn_output = [attn_output_.sum(dim=0).unsqueeze( 592 | 0).unsqueeze(0) for attn_output_ in attn_output] 593 | attn_output = torch.cat(attn_output, dim=0) 594 | 595 | self_attn_output = self.training_free_self_attention( 596 | attn_output, ft_attn_of_source_image['attn']) 597 | 598 | # resize 599 | self_attn_output = F.interpolate( 600 | self_attn_output, size=(h, w), mode='bilinear') 601 | 602 | output = {} 603 | output['trimap'] = self_attn_output 604 | output['feature'] = features_matting 605 | output['mask'] = attn_output 606 | 607 | return output 608 | 609 | def training_free_self_attention(self, x, self_attn_maps): 610 | ''' 611 | Compute self-attention using the attention maps. 612 | 613 | Parameters: 614 | x (torch.Tensor): The input tensor. Shape: [B, 1, H, W] 615 | self_attn_maps (torch.Tensor): The attention maps. Shape: {'24': [B, H1, W1, H1*W1]} 616 | 617 | Returns: 618 | torch.Tensor: The result of the self-attention computation. 619 | ''' 620 | 621 | # Original dimensions of x 622 | # Assuming x's shape is [B, 1, H, W] based on your comment 623 | B, _, H, W = x.shape 624 | 625 | # Dimensions of the attention maps 626 | assert len(self_attn_maps) == 1 627 | # get only one value in dict 628 | self_attn_maps = list(self_attn_maps.values())[0] 629 | _, H1, W1, _ = self_attn_maps.shape 630 | 631 | # Resize x to match the spatial dimensions of the attention maps 632 | # You might need align_corners depending on your version of PyTorch 633 | x = F.interpolate(x, size=(H1, W1), mode='bilinear', 634 | align_corners=True) 635 | 636 | # Reshape the attention maps and x for matrix multiplication 637 | # Reshaping from [B, H1, W1, H1*W1] to [B, H1*W1, H1*W1] 638 | self_attn_maps = self_attn_maps.view(B, H1 * W1, H1 * W1) 639 | # Reshaping from [B, 1, H1, W1] to [B, 1, H1*W1] 640 | x = x.view(B, 1, H1 * W1) 641 | 642 | # Apply the self-attention mechanism 643 | # Matrix multiplication between the attention maps and the input feature map 644 | # This step essentially computes the weighted sum of feature vectors in the input, 645 | # where the weights are defined by the attention maps. 646 | # Multiplying with the transpose to get shape [B, 1, H1*W1] 647 | out = torch.matmul(x, self_attn_maps.transpose(1, 2)) 648 | 649 | # Reshape the output tensor to the original spatial dimensions 650 | out = out.view(B, 1, H1, W1) # Reshaping back to spatial dimensions 651 | 652 | # # Resize the output back to the input's original dimensions (if necessary) 653 | # out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=True) 654 | 655 | return out 656 | 657 | 658 | class SemiTrainingAttentionBlocks(nn.Module): 659 | ''' 660 | one implementation of in_context_fusion 661 | 662 | forward(feature_of_reference_image, feature_of_source_image, guidance_on_reference_image) 663 | ''' 664 | 665 | def __init__(self, 666 | res_ratio=8, 667 | pool_type='min', 668 | upsample_mode='bicubic', 669 | bottle_neck_dim=None, 670 | use_norm=False, 671 | in_ft_dim=[1280, 960], 672 | in_attn_dim=[24**2, 48**2], 673 | attn_out_dim=256, 674 | ft_out_dim=512, 675 | training_cross_attn=False, 676 | ): 677 | super().__init__() 678 | if training_cross_attn: 679 | self.attn_module = TrainingCrossAttention( 680 | res_ratio=res_ratio, 681 | pool_type=pool_type, 682 | temp_softmax=1, 683 | use_scale=True, 684 | upsample_mode=upsample_mode, 685 | use_norm=use_norm, 686 | ) 687 | else: 688 | self.attn_module = TrainingFreeAttention(res_ratio=res_ratio, 689 | pool_type=pool_type, 690 | temp_softmax=1, 691 | use_scale=True, 692 | upsample_mode=upsample_mode, 693 | use_norm=use_norm,) 694 | 695 | # init module list for attn, with basic 3*3 conv_attn 696 | self.attn_module_list = nn.ModuleList() 697 | self.ft_attn_module_list = nn.ModuleList() 698 | for i in range(len(in_attn_dim)): 699 | self.attn_module_list.append(Basic_Conv3x3_attn( 700 | in_attn_dim[i], attn_out_dim, int(math.sqrt(in_attn_dim[i])))) 701 | self.ft_attn_module_list.append(Basic_Conv3x3( 702 | ft_out_dim[i] + attn_out_dim, ft_out_dim[i])) 703 | # init module list for ft, with basic 3*3 conv 704 | self.ft_module_list = nn.ModuleList() 705 | for i in range(len(in_ft_dim)): 706 | self.ft_module_list.append( 707 | Basic_Conv3x3(in_ft_dim[i], ft_out_dim[i])) 708 | 709 | ft_out_dim_ = [2*d for d in ft_out_dim] 710 | self.fusion = MultiScaleFeatureFusion(ft_out_dim_, ft_out_dim) 711 | 712 | def forward(self, feature_of_reference_image, ft_attn_of_source_image, guidance_on_reference_image): 713 | ''' 714 | feature_of_reference_image: [B, C, H, W] 715 | ft_attn_of_source_image: {"ft_cor": [B, C, H, W], "attn": [B, H_1, W_1, H_1*W_1], "ft_matting": {'24':[B, C, H, W]} } 716 | guidance_on_reference_image: [B, 1, H_2, W_2] 717 | ''' 718 | # assert feature_of_reference_image.shape[0] == 1 719 | # get source_image h,w 720 | h, w = guidance_on_reference_image.shape[-2:] 721 | 722 | features_cor = ft_attn_of_source_image['ft_cor'] 723 | features_matting = ft_attn_of_source_image['ft_matting'] 724 | features_ref = feature_of_reference_image 725 | 726 | guidance_on_reference_image[guidance_on_reference_image > 0.5] = 1 727 | guidance_on_reference_image[guidance_on_reference_image <= 0.5] = 0 728 | attn_output = self.attn_module( 729 | features_cor, features_ref, guidance_on_reference_image) 730 | 731 | attn_output = [attn_output_.sum(dim=0).unsqueeze( 732 | 0).unsqueeze(0) for attn_output_ in attn_output] 733 | attn_output = torch.cat(attn_output, dim=0) 734 | 735 | self_attn_output = self.training_free_self_attention( 736 | attn_output, ft_attn_of_source_image['attn']) 737 | 738 | # concat attn and ft_matting 739 | 740 | attn_ft_matting = {} 741 | for i, key in enumerate(features_matting.keys()): 742 | if key in self_attn_output.keys(): 743 | features_matting[key] = self.ft_module_list[i]( 744 | features_matting[key]) 745 | attn_ft_matting[key] = torch.cat( 746 | [features_matting[key], self_attn_output[key]], dim=1) 747 | 748 | attn_ft_matting[key] = self.ft_attn_module_list[i]( 749 | attn_ft_matting[key]) 750 | 751 | else: 752 | attn_ft_matting[key] = self.ft_module_list[i]( 753 | features_matting[key]) 754 | 755 | # forward in multi-scale fusion block 756 | attn_ft_matting = self.fusion(attn_ft_matting) 757 | 758 | att_look = [] 759 | # resize and average self_attn_output 760 | for i, key in enumerate(self_attn_output.keys()): 761 | att__ = F.interpolate( 762 | self_attn_output[key].mean(dim=1).unsqueeze(1), size=(h, w), mode='bilinear') 763 | att_look.append(att__) 764 | att_look = torch.cat(att_look, dim=1) 765 | att_look = att_look.mean(dim=1).unsqueeze(1) 766 | 767 | output = {} 768 | 769 | output['trimap'] = att_look 770 | output['feature'] = attn_ft_matting 771 | output['mask'] = attn_output 772 | 773 | return output 774 | 775 | def training_free_self_attention(self, x, self_attn_maps): 776 | ''' 777 | Compute weighted attn maps using the attention maps. 778 | 779 | Parameters: 780 | x (torch.Tensor): The input tensor. Shape: [B, 1, H, W] 781 | self_attn_maps (torch.Tensor): The attention maps. Shape: {'24':[B, H1, W1, H1*W1], '48':[B, H2, W2, H2*W2]} 782 | 783 | Returns: 784 | torch.Tensor: The result of the attention computation. {'24':[B, 1, H1*W1, H1, W1], '48':[B, 1, H2*W2, H2, W2]} 785 | ''' 786 | 787 | # Original dimensions of x 788 | # Assuming x's shape is [B, 1, H, W] based on your comment 789 | B, _, H, W = x.shape 790 | out = {} 791 | for i, key in enumerate(self_attn_maps.keys()): 792 | # Dimensions of the attention maps 793 | _, H1, W1, _ = self_attn_maps[key].shape 794 | 795 | # Resize x to match the spatial dimensions of the attention maps 796 | # You might need align_corners depending on your version of PyTorch 797 | x_ = F.interpolate(x, size=(H1, W1), mode='bilinear', 798 | align_corners=True) 799 | 800 | # Reshape the attention maps and x for matrix multiplication 801 | # Reshaping from [B, H1, W1, H1*W1] to [B, H1*W1, H1*W1] 802 | self_attn_map_ = self_attn_maps[key].view( 803 | B, H1 * W1, H1 * W1).transpose(1, 2) 804 | # Reshaping from [B, 1, H1, W1] to [B, 1, H1*W1] 805 | x_ = x_.reshape(B, H1 * W1, 1) 806 | 807 | # propagate , element wise multiplication x_ and self_attn_maps 808 | x_ = x_ * self_attn_map_ 809 | x_ = x_.reshape(B, H1 * W1, H1, W1) 810 | x_ = x_.permute(0, 2, 3, 1) 811 | x_ = self.attn_module_list[i](x_) 812 | out[key] = x_ 813 | 814 | return out 815 | 816 | 817 | class MultiScaleFeatureFusion(nn.Module): 818 | ''' 819 | N conv layers or bottleneck blocks to compress the feature dimension 820 | 821 | M conv layers and upsampling to fusion the features 822 | 823 | ''' 824 | 825 | def __init__(self, 826 | in_feature_dim=[], 827 | out_feature_dim=[], 828 | use_bottleneck=False) -> None: 829 | super().__init__() 830 | assert len(in_feature_dim) == len(out_feature_dim) 831 | # init module list 832 | self.module_list = nn.ModuleList() 833 | for i in range(len(in_feature_dim)-1): 834 | self.module_list.append(Fusion_Block( 835 | in_feature_dim[i], out_feature_dim[i])) 836 | 837 | def forward(self, features): 838 | # features: {'32': tensor, '16': tensor, '8': tensor} 839 | 840 | key_list = list(features.keys()) 841 | ft = features[key_list[0]] 842 | for i in range(len(key_list)-1): 843 | ft = self.module_list[i](ft, features[key_list[i+1]]) 844 | 845 | return ft 846 | -------------------------------------------------------------------------------- /icm/models/decoder/in_context_decoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from icm.util import instantiate_from_config 3 | 4 | class InContextDecoder(nn.Module): 5 | ''' 6 | InContextDecoder is the decoder of InContextMatting. 7 | 8 | in-context decoder: 9 | 10 | list get_trainable_params() 11 | 12 | forward(source, reference) 13 | reference = {'feature': feature_of_reference_image, 14 | 'guidance': guidance_on_reference_image} 15 | 16 | source = {'feature': feature_of_source_image, 'image': source_images} 17 | 18 | ''' 19 | 20 | def __init__(self, 21 | cfg_detail_decoder, 22 | cfg_in_context_fusion, 23 | freeze_in_context_fusion=False, 24 | ): 25 | super().__init__() 26 | 27 | self.in_context_fusion = instantiate_from_config( 28 | cfg_in_context_fusion) 29 | self.detail_decoder = instantiate_from_config(cfg_detail_decoder) 30 | 31 | self.freeze_in_context_fusion = freeze_in_context_fusion 32 | if freeze_in_context_fusion: 33 | self.__freeze_in_context_fusion() 34 | 35 | def forward(self, source, reference): 36 | feature_of_reference_image = reference['feature'] 37 | guidance_on_reference_image = reference['guidance'] 38 | 39 | feature_of_source_image = source['feature'] 40 | source_images = source['image'] 41 | 42 | features = self.in_context_fusion( 43 | feature_of_reference_image, feature_of_source_image, guidance_on_reference_image) 44 | 45 | output = self.detail_decoder(features, source_images) 46 | 47 | return output, features['mask'], features['trimap'] 48 | 49 | def get_trainable_params(self): 50 | params = [] 51 | params = params + list(self.detail_decoder.parameters()) 52 | if not self.freeze_in_context_fusion: 53 | params = params + list(self.in_context_fusion.parameters()) 54 | return params 55 | 56 | def __freeze_in_context_fusion(self): 57 | for param in self.in_context_fusion.parameters(): 58 | param.requires_grad = False 59 | -------------------------------------------------------------------------------- /icm/models/feature_extractor/attention_controllers.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Tuple, List, Callable, Dict 2 | import torch 3 | import torch.nn.functional as nnf 4 | import numpy as np 5 | import abc 6 | 7 | 8 | LOW_RESOURCE = False 9 | 10 | 11 | class AttentionControl(abc.ABC): 12 | 13 | def step_callback(self, x_t): 14 | return x_t 15 | 16 | def between_steps(self): 17 | return 18 | 19 | @property 20 | def num_uncond_att_layers(self): 21 | return self.num_att_layers if LOW_RESOURCE else 0 22 | 23 | @abc.abstractmethod 24 | def forward (self, attn, is_cross: bool, place_in_unet: str): 25 | raise NotImplementedError 26 | 27 | def __call__(self, attn, is_cross: bool, place_in_unet: str, ensemble_size=1, token_batch_size=1): 28 | if self.cur_att_layer >= self.num_uncond_att_layers: 29 | if LOW_RESOURCE: 30 | attn = self.forward(attn, is_cross, place_in_unet, ensemble_size, token_batch_size) 31 | else: 32 | h = attn.shape[0] 33 | # attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 34 | attn = self.forward(attn, is_cross, place_in_unet, ensemble_size, token_batch_size) 35 | self.cur_att_layer += 1 36 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 37 | self.cur_att_layer = 0 38 | self.cur_step += 1 39 | self.between_steps() 40 | return attn 41 | 42 | def reset(self): 43 | self.cur_step = 0 44 | self.cur_att_layer = 0 45 | 46 | def __init__(self): 47 | self.cur_step = 0 48 | self.num_att_layers = -1 49 | self.cur_att_layer = 0 50 | 51 | class EmptyControl(AttentionControl): 52 | 53 | def forward (self, attn, is_cross: bool, place_in_unet: str): 54 | return attn 55 | 56 | 57 | class AttentionStore(AttentionControl): 58 | 59 | @staticmethod 60 | def get_empty_store(): 61 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 62 | "down_self": [], "mid_self": [], "up_self": []} 63 | 64 | def forward(self, attn, is_cross: bool, place_in_unet: str, ensemble_size=1, token_batch_size=1): 65 | num_head = attn.shape[0]//token_batch_size 66 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 67 | if self.store_res is not None: 68 | if attn.shape[1] in self.store_res and (is_cross is False): 69 | attn = attn.reshape(-1, ensemble_size, *attn.shape[1:]) 70 | attn = attn.mean(dim=1) 71 | attn = attn.reshape(-1,num_head , *attn.shape[1:]) 72 | attn = attn.mean(dim=1) 73 | self.step_store[key].append(attn) 74 | elif attn.shape[1] <= 48 ** 2 and (is_cross is False): # avoid memory overhead 75 | attn = attn.reshape(-1, ensemble_size, *attn.shape[1:]) 76 | attn = attn.mean(dim=1) 77 | attn = attn.reshape(-1,num_head , *attn.shape[1:]) 78 | attn = attn.mean(dim=1) 79 | self.step_store[key].append(attn) 80 | 81 | torch.cuda.empty_cache() 82 | 83 | def between_steps(self): 84 | if len(self.attention_store) == 0: 85 | self.attention_store = self.step_store 86 | else: 87 | for key in self.attention_store: 88 | for i in range(len(self.attention_store[key])): 89 | self.attention_store[key][i] += self.step_store[key][i] 90 | self.step_store = self.get_empty_store() 91 | 92 | def get_average_attention(self): 93 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} 94 | return average_attention 95 | 96 | 97 | def reset(self): 98 | super(AttentionStore, self).reset() 99 | del self.step_store 100 | torch.cuda.empty_cache() 101 | self.step_store = self.get_empty_store() 102 | self.attention_store = {} 103 | 104 | def __init__(self,store_res = None): 105 | super(AttentionStore, self).__init__() 106 | self.step_store = self.get_empty_store() 107 | self.attention_store = {} 108 | 109 | store_res = [store_res] if isinstance(store_res, int) else list(store_res) 110 | self.store_res = [] 111 | for res in store_res: 112 | self.store_res.append(res**2) 113 | 114 | -------------------------------------------------------------------------------- /icm/models/feature_extractor/dift_sd.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import torch 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from typing import Any, Callable, Dict, List, Optional, Union 7 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 8 | from diffusers import DDIMScheduler 9 | import gc 10 | from PIL import Image 11 | 12 | from icm.models.feature_extractor.attention_controllers import AttentionStore 13 | import xformers 14 | 15 | 16 | def register_attention_control(model, controller, if_softmax=True, ensemble_size=1): 17 | def ca_forward(self, place_in_unet, att_opt_b): 18 | 19 | class MyXFormersAttnProcessor: 20 | r""" 21 | Processor for implementing memory efficient attention using xFormers. 22 | 23 | Args: 24 | attention_op (`Callable`, *optional*, defaults to `None`): 25 | The base 26 | [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to 27 | use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best 28 | operator. 29 | """ 30 | 31 | def __init__(self, attention_op=None): 32 | self.attention_op = attention_op 33 | 34 | def __call__( 35 | self, 36 | attn, 37 | hidden_states: torch.FloatTensor, 38 | encoder_hidden_states=None, 39 | attention_mask=None, 40 | temb=None, 41 | ): 42 | residual = hidden_states 43 | 44 | if attn.spatial_norm is not None: 45 | hidden_states = attn.spatial_norm(hidden_states, temb) 46 | 47 | input_ndim = hidden_states.ndim 48 | 49 | if input_ndim == 4: 50 | batch_size, channel, height, width = hidden_states.shape 51 | hidden_states = hidden_states.view( 52 | batch_size, channel, height * width).transpose(1, 2) 53 | 54 | batch_size, key_tokens, _ = ( 55 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 56 | ) 57 | 58 | attention_mask = attn.prepare_attention_mask( 59 | attention_mask, key_tokens, batch_size) 60 | if attention_mask is not None: 61 | # expand our mask's singleton query_tokens dimension: 62 | # [batch*heads, 1, key_tokens] -> 63 | # [batch*heads, query_tokens, key_tokens] 64 | # so that it can be added as a bias onto the attention scores that xformers computes: 65 | # [batch*heads, query_tokens, key_tokens] 66 | # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. 67 | _, query_tokens, _ = hidden_states.shape 68 | attention_mask = attention_mask.expand( 69 | -1, query_tokens, -1) 70 | 71 | if attn.group_norm is not None: 72 | hidden_states = attn.group_norm( 73 | hidden_states.transpose(1, 2)).transpose(1, 2) 74 | 75 | query = attn.to_q(hidden_states) 76 | 77 | is_cross = False if encoder_hidden_states is None else True 78 | 79 | if encoder_hidden_states is None: 80 | encoder_hidden_states = hidden_states 81 | elif attn.norm_cross: 82 | encoder_hidden_states = attn.norm_encoder_hidden_states( 83 | encoder_hidden_states) 84 | 85 | key = attn.to_k(encoder_hidden_states) 86 | value = attn.to_v(encoder_hidden_states) 87 | 88 | query = attn.head_to_batch_dim(query).contiguous() 89 | key = attn.head_to_batch_dim(key).contiguous() 90 | value = attn.head_to_batch_dim(value).contiguous() 91 | 92 | # controller 93 | if query.shape[1] in controller.store_res: 94 | sim = torch.einsum('b i d, b j d -> b i j', 95 | query, key) * attn.scale 96 | 97 | if if_softmax: 98 | sim = sim / if_softmax 99 | my_attn = sim.softmax(dim=-1).detach() 100 | del sim 101 | else: 102 | my_attn = sim.detach() 103 | 104 | controller(my_attn, is_cross, place_in_unet, ensemble_size, batch_size) 105 | 106 | # end controller 107 | 108 | hidden_states = xformers.ops.memory_efficient_attention( 109 | query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale 110 | ) 111 | hidden_states = hidden_states.to(query.dtype) 112 | hidden_states = attn.batch_to_head_dim(hidden_states) 113 | 114 | # linear proj 115 | hidden_states = attn.to_out[0](hidden_states) 116 | # dropout 117 | hidden_states = attn.to_out[1](hidden_states) 118 | 119 | if input_ndim == 4: 120 | hidden_states = hidden_states.transpose( 121 | -1, -2).reshape(batch_size, channel, height, width) 122 | 123 | if attn.residual_connection: 124 | hidden_states = hidden_states + residual 125 | 126 | hidden_states = hidden_states / attn.rescale_output_factor 127 | 128 | return hidden_states 129 | 130 | return MyXFormersAttnProcessor(att_opt_b) 131 | 132 | class DummyController: 133 | 134 | def __call__(self, *args): 135 | return args[0] 136 | 137 | def __init__(self): 138 | self.num_att_layers = 0 139 | 140 | if controller is None: 141 | controller = DummyController() 142 | 143 | def register_recr(net_, count, place_in_unet): 144 | if net_.__class__.__name__ == 'Attention': 145 | net_.processor = ca_forward( 146 | net_, place_in_unet, net_.processor.attention_op) 147 | return count + 1 148 | elif hasattr(net_, 'children'): 149 | for net__ in net_.children(): 150 | count = register_recr(net__, count, place_in_unet) 151 | return count 152 | 153 | cross_att_count = 0 154 | # sub_nets = model.unet.named_children() 155 | sub_nets = model.unet.named_children() 156 | # for net in sub_nets: 157 | # if "down" in net[0]: 158 | # cross_att_count += register_recr(net[1], 0, "down") 159 | # elif "up" in net[0]: 160 | # cross_att_count += register_recr(net[1], 0, "up") 161 | # elif "mid" in net[0]: 162 | # cross_att_count += register_recr(net[1], 0, "mid") 163 | for net in sub_nets: 164 | if "down_blocks" in net[0]: 165 | cross_att_count += register_recr(net[1], 0, "down") 166 | elif "up_blocks" in net[0]: 167 | cross_att_count += register_recr(net[1], 0, "up") 168 | elif "mid_block" in net[0]: 169 | cross_att_count += register_recr(net[1], 0, "mid") 170 | controller.num_att_layers = cross_att_count 171 | 172 | 173 | class MyUNet2DConditionModel(UNet2DConditionModel): 174 | def forward( 175 | self, 176 | sample: torch.FloatTensor, 177 | timestep: Union[torch.Tensor, float, int], 178 | up_ft_indices, 179 | encoder_hidden_states: torch.Tensor, 180 | class_labels: Optional[torch.Tensor] = None, 181 | timestep_cond: Optional[torch.Tensor] = None, 182 | attention_mask: Optional[torch.Tensor] = None, 183 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 184 | ): 185 | r""" 186 | Args: 187 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 188 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 189 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 190 | cross_attention_kwargs (`dict`, *optional*): 191 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 192 | `self.processor` in 193 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 194 | """ 195 | # By default samples have to be AT least a multiple of the overall upsampling factor. 196 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 197 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 198 | # on the fly if necessary. 199 | default_overall_up_factor = 2**self.num_upsamplers 200 | 201 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 202 | forward_upsample_size = False 203 | upsample_size = None 204 | 205 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 206 | # logger.info("Forward upsample size to force interpolation output size.") 207 | forward_upsample_size = True 208 | 209 | # prepare attention_mask 210 | if attention_mask is not None: 211 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 212 | attention_mask = attention_mask.unsqueeze(1) 213 | 214 | # 0. center input if necessary 215 | if self.config.center_input_sample: 216 | sample = 2 * sample - 1.0 217 | 218 | # 1. time 219 | timesteps = timestep 220 | if not torch.is_tensor(timesteps): 221 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 222 | # This would be a good case for the `match` statement (Python 3.10+) 223 | is_mps = sample.device.type == "mps" 224 | if isinstance(timestep, float): 225 | dtype = torch.float32 if is_mps else torch.float64 226 | else: 227 | dtype = torch.int32 if is_mps else torch.int64 228 | timesteps = torch.tensor( 229 | [timesteps], dtype=dtype, device=sample.device) 230 | elif len(timesteps.shape) == 0: 231 | timesteps = timesteps[None].to(sample.device) 232 | 233 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 234 | timesteps = timesteps.expand(sample.shape[0]) 235 | 236 | t_emb = self.time_proj(timesteps) 237 | 238 | # timesteps does not contain any weights and will always return f32 tensors 239 | # but time_embedding might actually be running in fp16. so we need to cast here. 240 | # there might be better ways to encapsulate this. 241 | t_emb = t_emb.to(dtype=self.dtype) 242 | 243 | emb = self.time_embedding(t_emb, timestep_cond) 244 | 245 | if self.class_embedding is not None: 246 | if class_labels is None: 247 | raise ValueError( 248 | "class_labels should be provided when num_class_embeds > 0" 249 | ) 250 | 251 | if self.config.class_embed_type == "timestep": 252 | class_labels = self.time_proj(class_labels) 253 | 254 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 255 | emb = emb + class_emb 256 | 257 | # 2. pre-process 258 | sample = self.conv_in(sample) 259 | 260 | # 3. down 261 | down_block_res_samples = (sample,) 262 | for downsample_block in self.down_blocks: 263 | if ( 264 | hasattr(downsample_block, "has_cross_attention") 265 | and downsample_block.has_cross_attention 266 | ): 267 | sample, res_samples = downsample_block( 268 | hidden_states=sample, 269 | temb=emb, 270 | encoder_hidden_states=encoder_hidden_states, 271 | attention_mask=attention_mask, 272 | cross_attention_kwargs=cross_attention_kwargs, 273 | ) 274 | else: 275 | sample, res_samples = downsample_block( 276 | hidden_states=sample, temb=emb) 277 | 278 | down_block_res_samples += res_samples 279 | 280 | # 4. mid 281 | if self.mid_block is not None: 282 | sample = self.mid_block( 283 | sample, 284 | emb, 285 | encoder_hidden_states=encoder_hidden_states, 286 | attention_mask=attention_mask, 287 | cross_attention_kwargs=cross_attention_kwargs, 288 | ) 289 | 290 | # 5. up 291 | up_ft = {} 292 | for i, upsample_block in enumerate(self.up_blocks): 293 | # if i > np.max(up_ft_indices): 294 | # break 295 | 296 | is_final_block = i == len(self.up_blocks) - 1 297 | 298 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 299 | down_block_res_samples = down_block_res_samples[ 300 | : -len(upsample_block.resnets) 301 | ] 302 | 303 | # if we have not reached the final block and need to forward the 304 | # upsample size, we do it here 305 | if not is_final_block and forward_upsample_size: 306 | upsample_size = down_block_res_samples[-1].shape[2:] 307 | 308 | if ( 309 | hasattr(upsample_block, "has_cross_attention") 310 | and upsample_block.has_cross_attention 311 | ): 312 | sample = upsample_block( 313 | hidden_states=sample, 314 | temb=emb, 315 | res_hidden_states_tuple=res_samples, 316 | encoder_hidden_states=encoder_hidden_states, 317 | cross_attention_kwargs=cross_attention_kwargs, 318 | upsample_size=upsample_size, 319 | attention_mask=attention_mask, 320 | ) 321 | else: 322 | sample = upsample_block( 323 | hidden_states=sample, 324 | temb=emb, 325 | res_hidden_states_tuple=res_samples, 326 | upsample_size=upsample_size, 327 | ) 328 | 329 | if i in up_ft_indices: 330 | up_ft[i] = sample.detach() 331 | 332 | output = {} 333 | output["up_ft"] = up_ft 334 | return output 335 | 336 | 337 | class OneStepSDPipeline(StableDiffusionPipeline): 338 | @torch.no_grad() 339 | def __call__( 340 | self, 341 | img_tensor, 342 | t, 343 | up_ft_indices, 344 | negative_prompt: Optional[Union[str, List[str]]] = None, 345 | generator: Optional[Union[torch.Generator, 346 | List[torch.Generator]]] = None, 347 | prompt_embeds: Optional[torch.FloatTensor] = None, 348 | callback: Optional[Callable[[ 349 | int, int, torch.FloatTensor], None]] = None, 350 | callback_steps: int = 1, 351 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 352 | ): 353 | device = self._execution_device 354 | latents = ( 355 | self.vae.encode(img_tensor).latent_dist.sample() 356 | * self.vae.config.scaling_factor 357 | ) 358 | t = torch.tensor(t, dtype=torch.long, device=device) 359 | noise = torch.randn_like(latents).to(device) 360 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 361 | unet_output = self.unet( 362 | latents_noisy, 363 | t, 364 | up_ft_indices, 365 | encoder_hidden_states=prompt_embeds, 366 | cross_attention_kwargs=cross_attention_kwargs, 367 | ) 368 | return unet_output 369 | 370 | 371 | class SDFeaturizer(nn.Module): 372 | def __init__(self, sd_id='pretrained_models/stable-diffusion-2-1', 373 | load_local=True, ): 374 | super().__init__() 375 | # sd_id="stabilityai/stable-diffusion-2-1", load_local=False): 376 | unet = MyUNet2DConditionModel.from_pretrained( 377 | sd_id, 378 | subfolder="unet", 379 | # output_loading_info=True, 380 | local_files_only=load_local, 381 | low_cpu_mem_usage=True, 382 | use_safetensors=False, 383 | # torch_dtype=torch.float16, 384 | # device_map="auto", 385 | ) 386 | onestep_pipe = OneStepSDPipeline.from_pretrained( 387 | sd_id, 388 | unet=unet, 389 | safety_checker=None, 390 | local_files_only=load_local, 391 | low_cpu_mem_usage=True, 392 | use_safetensors=False, 393 | # torch_dtype=torch.float16, 394 | # device_map="auto", 395 | ) 396 | onestep_pipe.vae.decoder = None 397 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained( 398 | sd_id, subfolder="scheduler" 399 | ) 400 | gc.collect() 401 | 402 | onestep_pipe = onestep_pipe.to("cuda") 403 | 404 | onestep_pipe.enable_attention_slicing() 405 | onestep_pipe.enable_xformers_memory_efficient_attention() 406 | self.pipe = onestep_pipe 407 | 408 | # register nn.module for ddp 409 | self.vae = self.pipe.vae 410 | self.unet = self.pipe.unet 411 | 412 | # freeze vae and unet 413 | for param in self.vae.parameters(): 414 | param.requires_grad = False 415 | for param in self.unet.parameters(): 416 | param.requires_grad = False 417 | 418 | @torch.no_grad() 419 | def forward(self, img_tensor, prompt='', t=261, up_ft_index=3, ensemble_size=8): 420 | """ 421 | Args: 422 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W] 423 | prompt: the prompt to use, a string 424 | t: the time step to use, should be an int in the range of [0, 1000] 425 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3] 426 | ensemble_size: the number of repeated images used in the batch to extract features 427 | Return: 428 | unet_ft: a torch tensor in the shape of [1, c, h, w] 429 | """ 430 | img_tensor = img_tensor.repeat( 431 | ensemble_size, 1, 1, 1).cuda() # ensem, c, h, w 432 | prompt_embeds = self.pipe._encode_prompt( 433 | prompt=prompt, 434 | device="cuda", 435 | num_images_per_prompt=1, 436 | do_classifier_free_guidance=False, 437 | ) # [1, 77, dim] 438 | prompt_embeds = prompt_embeds.repeat(ensemble_size, 1, 1) 439 | unet_ft_all = self.pipe( 440 | img_tensor=img_tensor, 441 | t=t, 442 | up_ft_indices=[up_ft_index], 443 | prompt_embeds=prompt_embeds, 444 | ) 445 | unet_ft = unet_ft_all["up_ft"][up_ft_index] # ensem, c, h, w 446 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w 447 | return unet_ft 448 | # index 0: 1280, 24, 24 449 | # index 1: 1280, 48, 48 450 | # index 2: 640, 96, 96 451 | # index 3: 320, 96,96 452 | @torch.no_grad() 453 | def forward_feature_extractor(self, uc, img_tensor, t=261, up_ft_index=[0, 1, 2, 3], ensemble_size=8): 454 | """ 455 | Args: 456 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W] 457 | prompt: the prompt to use, a string 458 | t: the time step to use, should be an int in the range of [0, 1000] 459 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3] 460 | ensemble_size: the number of repeated images used in the batch to extract features 461 | Return: 462 | unet_ft: a torch tensor in the shape of [1, c, h, w] 463 | """ 464 | batch_size = img_tensor.shape[0] 465 | 466 | img_tensor = img_tensor.unsqueeze(1).repeat(1, ensemble_size, 1, 1, 1) 467 | 468 | img_tensor = img_tensor.reshape(-1, *img_tensor.shape[2:]) 469 | 470 | prompt_embeds = uc.repeat( 471 | img_tensor.shape[0], 1, 1).to(img_tensor.device) 472 | unet_ft_all = self.pipe( 473 | img_tensor=img_tensor, 474 | t=t, 475 | up_ft_indices=up_ft_index, 476 | prompt_embeds=prompt_embeds, 477 | ) 478 | unet_ft = unet_ft_all["up_ft"] # ensem, c, h, w 479 | 480 | return unet_ft 481 | 482 | 483 | class FeatureExtractor(nn.Module): 484 | def __init__(self, 485 | sd_id='stabilityai/stable-diffusion-2-1', # 'pretrained_models/stable-diffusion-2-1', 486 | load_local=True, 487 | if_softmax=False, 488 | feature_index_cor=1, 489 | feature_index_matting=4, 490 | attention_res=32, # [16, 32], 491 | set_diag_to_one=True, 492 | time_steps=[0], 493 | extract_feature_inputted_to_layer=False, 494 | ensemble_size=8): 495 | super().__init__() 496 | 497 | self.dift_sd = SDFeaturizer(sd_id=sd_id, load_local=load_local) 498 | # register buffer for prompt embedding 499 | self.register_buffer("prompt_embeds", self.dift_sd.pipe._encode_prompt( 500 | prompt='', 501 | num_images_per_prompt=1, 502 | do_classifier_free_guidance=False, 503 | device="cuda", 504 | )) 505 | # free self.pipe.tokenizer and self.pipe.text_encoder 506 | del self.dift_sd.pipe.tokenizer 507 | del self.dift_sd.pipe.text_encoder 508 | gc.collect() 509 | torch.cuda.empty_cache() 510 | self.feature_index_cor = feature_index_cor 511 | self.feature_index_matting = feature_index_matting 512 | self.attention_res = attention_res 513 | self.set_diag_to_one = set_diag_to_one 514 | self.time_steps = time_steps 515 | self.extract_feature_inputted_to_layer = extract_feature_inputted_to_layer 516 | self.ensemble_size = ensemble_size 517 | self.register_attention_store( 518 | if_softmax=if_softmax, attention_res=attention_res) 519 | 520 | 521 | def register_attention_store(self, if_softmax=False, attention_res=[16, 32]): 522 | self.controller = AttentionStore(store_res=attention_res) 523 | 524 | register_attention_control( 525 | self.dift_sd.pipe, self.controller, if_softmax=if_softmax, ensemble_size=self.ensemble_size) 526 | 527 | def get_trainable_params(self): 528 | return [] 529 | 530 | def get_reference_feature(self, images): 531 | self.controller.reset() 532 | batch_size = images.shape[0] 533 | features = self.dift_sd.forward_feature_extractor( 534 | self.prompt_embeds, images, t=self.time_steps[0], ensemble_size=self.ensemble_size) # b*e, c, h, w 535 | 536 | features = self.ensemble_feature( 537 | features, self.feature_index_cor, batch_size) 538 | 539 | return features.detach() 540 | 541 | def ensemble_feature(self, features, index, batch_size): 542 | if isinstance(index, int): 543 | 544 | features_ = features[index].reshape( 545 | batch_size, self.ensemble_size, *features[index].shape[1:]) 546 | features_ = features_.mean(1, keepdim=False).detach() 547 | else: 548 | index = list(index) 549 | res = ['24','48','96'] 550 | res = res[:len(index)] 551 | features_ = {} 552 | for i in range(len(index)): 553 | features_[res[i]] = features[index[i]].reshape( 554 | batch_size, self.ensemble_size, *features[index[i]].shape[1:]) 555 | features_[res[i]] = features_[res[i]].mean(1, keepdim=False).detach() 556 | return features_ 557 | 558 | def get_source_feature(self, images): 559 | # return {"ft": [B, C, H, W], "attn": [B, H, W, H*W]} 560 | 561 | self.controller.reset() 562 | torch.cuda.empty_cache() 563 | batch_size = images.shape[0] 564 | 565 | ft = self.dift_sd.forward_feature_extractor( 566 | self.prompt_embeds, images, t=self.time_steps[0], ensemble_size=self.ensemble_size) # b*e, c, h, w 567 | 568 | 569 | attention_maps = self.get_feature_attention(batch_size) 570 | 571 | output = {"ft_cor": self.ensemble_feature(ft, self.feature_index_cor, batch_size), 572 | "attn": attention_maps, 'ft_matting': self.ensemble_feature(ft, self.feature_index_matting, batch_size)} 573 | return output 574 | 575 | def get_feature_attention(self, batch_size): 576 | 577 | attention_maps = self.__aggregate_attention( 578 | from_where=["down", "mid", "up"], is_cross=False, batch_size=batch_size) 579 | 580 | for attn_map in attention_maps.keys(): 581 | attention_maps[attn_map] = attention_maps[attn_map].permute(0, 2, 1).reshape( 582 | (batch_size, -1, int(attn_map), int(attn_map))) # [bs, h*w, h, w] 583 | attention_maps[attn_map] = attention_maps[attn_map].permute(0, 2, 3, 1) # [bs, h, w, h*w] 584 | return attention_maps 585 | 586 | def __aggregate_attention(self, from_where: List[str], is_cross: bool, batch_size: int): 587 | out = {} 588 | self.controller.between_steps() 589 | self.controller.cur_step=1 590 | attention_maps = self.controller.get_average_attention() 591 | for res in self.attention_res: 592 | out[str(res)] = self.__aggregate_attention_single_res( 593 | from_where, is_cross, batch_size, res, attention_maps) 594 | return out 595 | 596 | def __aggregate_attention_single_res(self, from_where: List[str], is_cross: bool, batch_size: int, res: int, attention_maps): 597 | out = [] 598 | num_pixels = res ** 2 599 | for location in from_where: 600 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 601 | if item.shape[1] == num_pixels: 602 | cross_maps = item.reshape( 603 | batch_size, -1, res, res, item.shape[-1]) 604 | out.append(cross_maps) 605 | out = torch.cat(out, dim=1) 606 | out = out.sum(1) / out.shape[1] 607 | out = out.reshape(batch_size, out.shape[-1], out.shape[-1]) 608 | 609 | if self.set_diag_to_one: 610 | for o in out: 611 | o = o - torch.diag(torch.diag(o)) + \ 612 | torch.eye(o.shape[0]).to(o.device) 613 | return out 614 | -------------------------------------------------------------------------------- /icm/models/in_context_matting.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pytorch_lightning.utilities.types import STEP_OUTPUT 3 | import torch 4 | import pytorch_lightning as pl 5 | from torch.optim.lr_scheduler import LambdaLR 6 | 7 | from icm.criterion.matting_criterion_eval import compute_mse_loss_torch, compute_sad_loss_torch 8 | from icm.util import instantiate_from_config 9 | from pytorch_lightning.utilities import rank_zero_only 10 | import os 11 | import cv2 12 | class InContextMatting(pl.LightningModule): 13 | ''' 14 | In Context Matting Model 15 | consists of a feature extractor and a in-context decoder 16 | train model with lr, scheduler, loss 17 | ''' 18 | 19 | def __init__( 20 | self, 21 | cfg_feature_extractor, 22 | cfg_in_context_decoder, 23 | cfg_loss_function, 24 | learning_rate, 25 | cfg_scheduler=None, 26 | **kwargs, 27 | ): 28 | super().__init__() 29 | 30 | self.feature_extractor = instantiate_from_config( 31 | cfg_feature_extractor) 32 | self.in_context_decoder = instantiate_from_config( 33 | cfg_in_context_decoder) 34 | 35 | self.loss_function = instantiate_from_config(cfg_loss_function) 36 | 37 | self.learning_rate = learning_rate 38 | self.cfg_scheduler = cfg_scheduler 39 | 40 | def forward(self, reference_images, guidance_on_reference_image, source_images): 41 | 42 | feature_of_reference_image = self.feature_extractor.get_reference_feature( 43 | reference_images) 44 | 45 | feature_of_source_image = self.feature_extractor.get_source_feature( 46 | source_images) 47 | 48 | reference = {'feature': feature_of_reference_image, 49 | 'guidance': guidance_on_reference_image} 50 | 51 | source = {'feature': feature_of_source_image, 'image': source_images} 52 | 53 | output, cross_map, self_map = self.in_context_decoder(source, reference) 54 | 55 | return output, cross_map, self_map 56 | 57 | def on_train_epoch_start(self): 58 | self.log("epoch", self.current_epoch, on_step=False, 59 | on_epoch=True, prog_bar=False, sync_dist=True) 60 | 61 | def training_step(self, batch, batch_idx): 62 | loss_dict, loss, _, _, _ = self.__shared_step(batch) 63 | 64 | self.__log_loss(loss_dict, loss, "train") 65 | 66 | # log learning rate 67 | self.log("lr", self.trainer.optimizers[0].param_groups[0] 68 | ["lr"], on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) 69 | 70 | return loss 71 | 72 | def validation_step(self, batch, batch_idx): 73 | loss_dict, loss, preds, cross_map, self_map = self.__shared_step(batch) 74 | 75 | self.__log_loss(loss_dict, loss, "val") 76 | batch['cross_map'] = cross_map 77 | batch['self_map'] = self_map 78 | return preds, batch 79 | 80 | def __shared_step(self, batch): 81 | reference_images, guidance_on_reference_image, source_images, labels, trimaps = batch[ 82 | "reference_image"], batch["guidance_on_reference_image"], batch["source_image"], batch["alpha"], batch["trimap"] 83 | 84 | outputs, cross_map, self_map = self(reference_images, 85 | guidance_on_reference_image, source_images) 86 | 87 | sample_map = torch.zeros_like(trimaps) 88 | sample_map[trimaps==0.5] = 1 89 | 90 | loss_dict = self.loss_function(sample_map, outputs, labels) 91 | 92 | loss = sum(loss_dict.values()) 93 | if loss > 1e4 or torch.isnan(loss): 94 | raise ValueError(f"Loss explosion: {loss}") 95 | return loss_dict, loss, outputs, cross_map, self_map 96 | 97 | def __log_loss(self, loss_dict, loss, prefix): 98 | loss_dict = { 99 | f"{prefix}/{key}": loss_dict.get(key) for key in loss_dict} 100 | self.log_dict(loss_dict, on_step=True, on_epoch=True, 101 | prog_bar=False, sync_dist=True) 102 | self.log(f"{prefix}/loss", loss, on_step=True, 103 | on_epoch=True, prog_bar=True, sync_dist=True) 104 | 105 | def validation_step_end(self, outputs): 106 | 107 | preds, batch = outputs 108 | h, w = batch['alpha_shape'] 109 | 110 | 111 | cross_map = batch['cross_map'] 112 | self_map = batch['self_map'] 113 | # resize cross_map and self_map to the same size as preds 114 | cross_map = torch.nn.functional.interpolate( 115 | cross_map, size=preds.shape[2:], mode='bilinear', align_corners=False) 116 | self_map = torch.nn.functional.interpolate( 117 | self_map, size=preds.shape[2:], mode='bilinear', align_corners=False) 118 | 119 | # normalize cross_map and self_map 120 | cross_map = (cross_map - cross_map.min()) / \ 121 | (cross_map.max() - cross_map.min()) 122 | self_map = (self_map - self_map.min()) / \ 123 | (self_map.max() - self_map.min()) 124 | 125 | cross_map = cross_map[0].squeeze()*255.0 126 | self_map = self_map[0].squeeze()*255.0 127 | 128 | # get one sample from batch 129 | pred = preds[0].squeeze()*255.0 130 | source_image = batch['source_image'][0] 131 | label = batch["alpha"][0].squeeze()*255.0 132 | trimap = batch["trimap"][0].squeeze()*255.0 133 | trimap[trimap == 127.5] = 128 134 | reference_image = batch["reference_image"][0] 135 | guidance_on_reference_image = batch["guidance_on_reference_image"][0] 136 | dataset_name = batch["dataset_name"][0] 137 | image_name = batch["image_name"][0].split('.')[0] 138 | 139 | # save pre to model.val_save_path 140 | 141 | # if self.val_save_path is not None: 142 | if hasattr(self, 'val_save_path'): 143 | os.makedirs(self.val_save_path, exist_ok=True) 144 | # resize preds to h,w 145 | pred_ = torch.nn.functional.interpolate( 146 | pred.unsqueeze(0).unsqueeze(0), size=(h, w), mode='bilinear', align_corners=False) 147 | pred_ = pred_.squeeze().cpu().numpy() 148 | pred_ = pred_.astype('uint8') 149 | cv2.imwrite(os.path.join(self.val_save_path, image_name+'.png'), pred_) 150 | 151 | masked_reference_image = reference_image*guidance_on_reference_image 152 | # self.__compute_and_log_mse_sad_of_one_sample( 153 | # pred, label, trimap, prefix="val") 154 | 155 | self.__log_image( 156 | source_image, masked_reference_image, pred, label, dataset_name, image_name, prefix='val', self_map=self_map, cross_map=cross_map) 157 | 158 | 159 | # def validation_step_end(self, outputs): 160 | 161 | # preds, batch = outputs 162 | 163 | # cross_map = batch['cross_map'] 164 | # self_map = batch['self_map'] 165 | # # resize cross_map and self_map to the same size as preds 166 | # cross_map = torch.nn.functional.interpolate( 167 | # cross_map, size=preds.shape[2:], mode='bilinear', align_corners=False) 168 | # self_map = torch.nn.functional.interpolate( 169 | # self_map, size=preds.shape[2:], mode='bilinear', align_corners=False) 170 | 171 | # # normalize cross_map and self_map 172 | # cross_map = (cross_map - cross_map.min()) / \ 173 | # (cross_map.max() - cross_map.min()) 174 | # self_map = (self_map - self_map.min()) / \ 175 | # (self_map.max() - self_map.min()) 176 | 177 | # cross_map = cross_map[0].squeeze()*255.0 178 | # self_map = self_map[0].squeeze()*255.0 179 | 180 | # # get one sample from batch 181 | # pred = preds[0].squeeze()*255.0 182 | # source_image = batch['source_image'][0] 183 | # label = batch["alpha"][0].squeeze()*255.0 184 | # trimap = batch["trimap"][0].squeeze()*255.0 185 | # trimap[trimap == 127.5] = 128 186 | # reference_image = batch["reference_image"][0] 187 | # guidance_on_reference_image = batch["guidance_on_reference_image"][0] 188 | # dataset_name = batch["dataset_name"][0] 189 | # image_name = batch["image_name"][0].split('.')[0] 190 | 191 | # masked_reference_image = reference_image*guidance_on_reference_image 192 | 193 | # self.__compute_and_log_mse_sad_of_one_sample( 194 | # pred, label, trimap, prefix="val") 195 | 196 | # self.__log_image( 197 | # source_image, masked_reference_image, pred, label, dataset_name, image_name, prefix='val', self_map=self_map, cross_map=cross_map) 198 | 199 | def __compute_and_log_mse_sad_of_one_sample(self, pred, label, trimap, prefix="val"): 200 | # compute loss for unknown pixels 201 | mse_loss_unknown_ = compute_mse_loss_torch(pred, label, trimap) 202 | sad_loss_unknown_ = compute_sad_loss_torch(pred, label, trimap) 203 | 204 | # compute loss for all pixels 205 | trimap = torch.ones_like(label)*128 206 | mse_loss_all_ = compute_mse_loss_torch(pred, label, trimap) 207 | sad_loss_all_ = compute_sad_loss_torch(pred, label, trimap) 208 | 209 | # log 210 | metrics_unknown = {f'{prefix}/mse_unknown': mse_loss_unknown_, 211 | f'{prefix}/sad_unknown': sad_loss_unknown_, } 212 | 213 | metrics_all = {f'{prefix}/mse_all': mse_loss_all_, 214 | f'{prefix}/sad_all': sad_loss_all_, } 215 | 216 | self.log_dict(metrics_unknown, on_step=False, 217 | on_epoch=True, prog_bar=False, sync_dist=True) 218 | self.log_dict(metrics_all, on_step=False, 219 | on_epoch=True, prog_bar=False, sync_dist=True) 220 | 221 | def __log_image(self, source_image, masked_reference_image, pred, label, dataset_name, image_name, prefix='val', self_map=None, cross_map=None): 222 | ########### log source_image, masked_reference_image, output and gt ########### 223 | # process image, masked_reference_image, pred, label 224 | source_image = self.__revert_normalize(source_image) 225 | masked_reference_image = self.__revert_normalize( 226 | masked_reference_image) 227 | pred = torch.stack((pred/255.0,)*3, axis=-1) 228 | label = torch.stack((label/255.0,)*3, axis=-1) 229 | self_map = torch.stack((self_map/255.0,)*3, axis=-1) 230 | cross_map = torch.stack((cross_map/255.0,)*3, axis=-1) 231 | 232 | # concat pred, masked_reference_image, label, source_image 233 | image_for_log = torch.stack( 234 | (source_image, masked_reference_image, label, pred, self_map, cross_map), axis=0) 235 | 236 | # log image 237 | self.logger.experiment.add_images( 238 | f'{prefix}-{dataset_name}/{image_name}', image_for_log, self.current_epoch, dataformats='NHWC') 239 | 240 | def __revert_normalize(self, image): 241 | # image: [C, H, W] 242 | image = image.permute(1, 2, 0) 243 | image = image * torch.tensor([0.229, 0.224, 0.225], device=self.device) + \ 244 | torch.tensor([0.485, 0.456, 0.406], device=self.device) 245 | image = torch.clamp(image, 0, 1) 246 | return image 247 | 248 | def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): 249 | torch.cuda.empty_cache() 250 | def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: 251 | torch.cuda.empty_cache() 252 | def test_step(self, batch, batch_idx): 253 | loss_dict, loss, preds = self.__shared_step(batch) 254 | 255 | return loss_dict, loss, preds 256 | 257 | def configure_optimizers(self): 258 | params = self.__get_trainable_params() 259 | opt = torch.optim.Adam(params, lr=self.learning_rate) 260 | 261 | if self.cfg_scheduler is not None: 262 | scheduler = self.__get_scheduler(opt) 263 | return [opt], scheduler 264 | return opt 265 | 266 | def __get_trainable_params(self): 267 | params = [] 268 | params = params + self.in_context_decoder.get_trainable_params() + \ 269 | self.feature_extractor.get_trainable_params() 270 | return params 271 | 272 | def __get_scheduler(self, opt): 273 | scheduler = instantiate_from_config(self.cfg_scheduler) 274 | scheduler = [ 275 | { 276 | "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), 277 | "interval": "step", 278 | "frequency": 1, 279 | } 280 | ] 281 | return scheduler 282 | 283 | from pytorch_lightning.callbacks import ModelCheckpoint 284 | 285 | class ModifiedModelCheckpoint(ModelCheckpoint): 286 | def delete_frozen_params(self, ckpt): 287 | # delete params with requires_grad=False 288 | for k in list(ckpt["state_dict"].keys()): 289 | # remove ckpt['state_dict'][k] if 'feature_extractor' in k 290 | if "feature_extractor" in k: 291 | del ckpt["state_dict"][k] 292 | return ckpt 293 | 294 | def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None: 295 | super()._save_model(trainer, filepath) 296 | 297 | if trainer.is_global_zero: 298 | ckpt = torch.load(filepath) 299 | ckpt = self.delete_frozen_params(ckpt) 300 | torch.save(ckpt, filepath) -------------------------------------------------------------------------------- /icm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | def instantiate_from_config(config): 4 | if not "target" in config: 5 | raise KeyError("Expected key `target` to instantiate.") 6 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 7 | 8 | def get_obj_from_str(string, reload=False): 9 | module, cls = string.rsplit(".", 1) 10 | if reload: 11 | module_imp = importlib.import_module(module) 12 | importlib.reload(module_imp) 13 | return getattr(importlib.import_module(module, package=None), cls) 14 | 15 | def instantiate_feature_extractor(cfg): 16 | model = instantiate_from_config(cfg) 17 | load_odise_params = cfg.get("load_odise_params", False) 18 | if load_odise_params: 19 | # load params 20 | params = torch.load('ckpt/odise_label_coco_50e-b67d2efc.pth') 21 | # gather the params with "backbone.feature_extractor." prefix 22 | params = {k.replace("backbone.feature_extractor.", ""): v for k, v in params['model'].items() if "backbone.feature_extractor." in k} 23 | model.load_state_dict(params, strict=False) 24 | ''' 25 | alpha_cond 26 | alpha_cond_time_embed 27 | clip_project.positional_embedding 28 | clip_project.linear.weight 29 | clip_project.linear.bias 30 | time_embed_project.positional_embedding 31 | time_embed_project.linear.weight 32 | time_embed_project.linear.bias 33 | ''' 34 | 35 | model.eval() 36 | return model 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | import datetime 3 | import argparse 4 | from omegaconf import OmegaConf 5 | 6 | import os 7 | 8 | from icm.util import instantiate_from_config 9 | import torch 10 | from pytorch_lightning import Trainer, seed_everything 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument( 16 | "--experiment_name", 17 | type=str, 18 | default="in_context_matting", 19 | ) 20 | parser.add_argument( 21 | "--debug", 22 | type=bool, 23 | default=False, 24 | ) 25 | parser.add_argument( 26 | "--resume", 27 | type=str, 28 | default="", 29 | ) 30 | parser.add_argument( 31 | "--fine_tune", 32 | type=bool, 33 | default=False, 34 | ) 35 | parser.add_argument( 36 | "--config", 37 | type=str, 38 | default="", 39 | ) 40 | parser.add_argument( 41 | "--logdir", 42 | type=str, 43 | default="logs", 44 | ) 45 | parser.add_argument( 46 | "--seed", 47 | type=int, 48 | default=42, 49 | ) 50 | 51 | args = parser.parse_args() 52 | return args 53 | 54 | import multiprocessing 55 | multiprocessing.set_start_method('spawn') 56 | 57 | args = parse_args() 58 | if args.resume: 59 | path = args.resume.split('checkpoints')[0] 60 | # get the folder of last version folder 61 | all_folder = os.listdir(path) 62 | all_folder = [os.path.join(path, folder) for folder in all_folder if 'version' in folder] 63 | all_folder.sort() 64 | last_version_folder = all_folder[-1] 65 | # get the hparams.yaml path 66 | hparams_path = os.path.join(last_version_folder, 'hparams.yaml') 67 | cfg = OmegaConf.load(hparams_path) 68 | else: 69 | cfg = OmegaConf.load(args.config) 70 | 71 | if args.fine_tune: 72 | cfg_ft = OmegaConf.load(args.config) 73 | # merge cfg and cfg_ft, cfg_ft will overwrite cfg 74 | cfg = OmegaConf.merge(cfg, cfg_ft) 75 | 76 | # set seed 77 | seed_everything(args.seed) 78 | 79 | """=== Init data ===""" 80 | cfg_data = cfg.get('data') 81 | 82 | data = instantiate_from_config(cfg_data) 83 | 84 | """=== Init model ===""" 85 | cfg_model = cfg.get('model') 86 | 87 | model = instantiate_from_config(cfg_model) 88 | 89 | """=== Init trainer ===""" 90 | cfg_trainer = cfg.get('trainer') 91 | # omegaconf to dict 92 | cfg_trainer = OmegaConf.to_container(cfg_trainer) 93 | 94 | if args.debug: 95 | cfg_trainer['limit_train_batches'] = 2 96 | # cfg_trainer['log_every_n_steps'] = 1 97 | cfg_trainer['limit_val_batches'] = 3 98 | # cfg_trainer['overfit_batches'] = 2 99 | 100 | # init logger 101 | cfg_logger = cfg_trainer.pop('cfg_logger') 102 | 103 | if args.resume: 104 | name = args.resume.split('/')[-3] 105 | else: 106 | name = datetime.datetime.now().strftime( 107 | "%Y-%m-%d_%H-%M-%S")+'-'+args.experiment_name 108 | cfg_logger['params']['save_dir'] = args.logdir 109 | cfg_logger['params']['name'] = name 110 | cfg_trainer['logger'] = instantiate_from_config(cfg_logger) 111 | 112 | # plugin 113 | cfg_plugin = cfg_trainer.pop('plugins') 114 | cfg_trainer['plugins'] = instantiate_from_config(cfg_plugin) 115 | 116 | # init callbacks 117 | cfg_callbacks = cfg_trainer.pop('cfg_callbacks') 118 | callbacks = [] 119 | for callback_name in cfg_callbacks: 120 | if callback_name == 'modelcheckpoint': 121 | cfg_callbacks[callback_name]['params']['dirpath'] = os.path.join( 122 | args.logdir, name, 'checkpoints') 123 | callbacks.append(instantiate_from_config(cfg_callbacks[callback_name])) 124 | cfg_trainer['callbacks'] = callbacks 125 | 126 | if args.resume and not args.fine_tune: 127 | cfg_trainer['resume_from_checkpoint'] = args.resume 128 | 129 | if args.fine_tune: 130 | # load state_dict 131 | ckpt = torch.load(args.resume) 132 | model.load_state_dict(ckpt['state_dict'], strict=False) 133 | # init trainer 134 | trainer_opt = argparse.Namespace(**cfg_trainer) 135 | trainer = Trainer.from_argparse_args(trainer_opt) 136 | 137 | # save configs to log 138 | trainer.logger.log_hyperparams(cfg) 139 | 140 | """=== Start training ===""" 141 | 142 | trainer.fit(model, data) 143 | --------------------------------------------------------------------------------