├── data
├── __init__.py
└── dim_dataset.py
├── modeling
├── meta_arch
│ ├── __init__.py
│ └── mematte.py
├── decoder
│ ├── __init__.py
│ └── detail_capture.py
├── criterion
│ ├── __init__.py
│ └── matting_criterion.py
├── backbone
│ ├── __init__.py
│ ├── backbone.py
│ ├── utils.py
│ ├── vit_teacher.py
│ └── vit.py
└── __init__.py
├── engine
├── __init__.py
└── mattingtrainer.py
├── requirements.txt
├── configs
├── common
│ ├── scheduler.py
│ ├── train.py
│ ├── dataloader.py
│ ├── optimizer.py
│ └── model.py
├── MEMatte_S_topk0.25_win_global_long.py
└── MEMatte_B_topk0.25_win_global_long.py
├── pretrained
└── preprocess.py
├── .gitignore
├── utils
└── logger.py
├── inference.py
├── main.py
└── README.md
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dim_dataset import *
--------------------------------------------------------------------------------
/modeling/meta_arch/__init__.py:
--------------------------------------------------------------------------------
1 | from .mematte import MEMatte
--------------------------------------------------------------------------------
/engine/__init__.py:
--------------------------------------------------------------------------------
1 | from .mattingtrainer import MattingTrainer
2 |
--------------------------------------------------------------------------------
/modeling/decoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .detail_capture import Detail_Capture
--------------------------------------------------------------------------------
/modeling/criterion/__init__.py:
--------------------------------------------------------------------------------
1 | from .matting_criterion import MattingCriterion
--------------------------------------------------------------------------------
/modeling/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .backbone import *
2 | from .vit import *
3 | from .vit_teacher import ViT_Teacher
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .backbone import *
2 | from .criterion import *
3 | from .decoder import *
4 | from .meta_arch import *
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.0
2 | torchvision
3 | tensorboard
4 | timm==0.5.4
5 | opencv-python==4.5.3.56
6 | setuptools==58.2.0
7 | easydict
8 | wget
9 | scikit-image
10 | fairscale
--------------------------------------------------------------------------------
/configs/common/scheduler.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import LazyCall as L
2 | from detectron2.solver import WarmupParamScheduler
3 | from fvcore.common.param_scheduler import MultiStepParamScheduler
4 |
5 | lr_multiplier = L(WarmupParamScheduler)(
6 | scheduler=L(MultiStepParamScheduler)(
7 | values=[1.0, 0.1, 0.01],
8 | milestones=[96778, 103579],
9 | num_updates=100,
10 | ),
11 | warmup_length=250 / 100,
12 | warmup_factor=0.001,
13 | )
--------------------------------------------------------------------------------
/configs/common/train.py:
--------------------------------------------------------------------------------
1 | train = dict(
2 | output_dir="./output",
3 | init_checkpoint="",
4 | max_iter=90000,
5 | amp=dict(enabled=False), # options for Automatic Mixed Precision
6 | ddp=dict( # options for DistributedDataParallel
7 | broadcast_buffers=True,
8 | find_unused_parameters=False,
9 | fp16_compression=True,
10 | ),
11 | checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer
12 | eval_period=5000,
13 | # log_period=20,
14 | log_loss_period = 10,
15 | log_image_period = 2000,
16 | device="cuda"
17 | # ...
18 | )
--------------------------------------------------------------------------------
/pretrained/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | def add_teacher(model, name):
3 | new_model = {}
4 | for k in model.keys():
5 | new_model[k] = model[k]
6 | if 'backbone' in k:
7 | new_model['teacher_'+k] = model[k]
8 | print(f'teacher_{k}')
9 |
10 | torch.save(new_model, name + '.pth')
11 |
12 | if __name__ == "__main__":
13 | # Downloading the official checkpoint of ViTMatte, and then process the checkpoint with the script.
14 | # ViTMatte_S_Com.pth: https://drive.google.com/file/d/12VKhSwE_miF9lWQQCgK7mv83rJIls3Xe/view
15 | # ViTMatte_B_Com.pth: https://drive.google.com/file/d/1mOO5MMU4kwhNX96AlfpwjAoMM4V5w3k-/view?pli=1
16 | teacher_model = torch.load('ViTMatte_S_Com.pth')['model']
17 | add_teacher(teacher_model, 'ViTMatte_S_Com_with_teacher')
--------------------------------------------------------------------------------
/configs/common/dataloader.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from torch.utils.data import DataLoader
3 | from detectron2.config import LazyCall as L
4 | from torch.utils.data.distributed import DistributedSampler
5 |
6 | from data import ImageFileTrain, DataGenerator
7 |
8 | #Dataloader
9 | train_dataset = L(DataGenerator)(
10 | data = L(ImageFileTrain)(
11 | alpha_dir='/opt/data/private/lyh/Datasets/AdobeImageMatting/Train/alpha',
12 | fg_dir='/opt/data/private/lyh/Datasets/AdobeImageMatting/Train/fg',
13 | bg_dir='/opt/data/private/lyh/Datasets/coco2014/raw/train2014',
14 | root='/opt/data/private/lyh/Datasets/AdobeImageMatting'
15 | ),
16 | phase = 'train'
17 | )
18 | #
19 | dataloader = OmegaConf.create()
20 | dataloader.train = L(DataLoader)(
21 | dataset = train_dataset,
22 | batch_size=15,
23 | shuffle=False,
24 | num_workers=4,
25 | pin_memory=True,
26 | sampler=L(DistributedSampler)(
27 | dataset = train_dataset,
28 | ),
29 | drop_last=True
30 | )
--------------------------------------------------------------------------------
/configs/common/optimizer.py:
--------------------------------------------------------------------------------
1 | from detectron2 import model_zoo
2 | from functools import partial
3 |
4 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
5 | """
6 | Calculate lr decay rate for different ViT blocks.
7 | Args:
8 | name (string): parameter name.
9 | lr_decay_rate (float): base lr decay rate.
10 | num_layers (int): number of ViT blocks.
11 |
12 | Returns:
13 | lr decay rate for the given parameter.
14 | """
15 | layer_id = num_layers + 1
16 | if name.startswith("backbone"):
17 | if ".pos_embed" in name or ".patch_embed" in name:
18 | layer_id = 0
19 | elif ".blocks." in name and ".residual." not in name:
20 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
21 | return lr_decay_rate ** (num_layers + 1 - layer_id)
22 |
23 | # Optimizer
24 | optimizer = model_zoo.get_config("common/optim.py").AdamW
25 | optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.65)
26 | optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
--------------------------------------------------------------------------------
/configs/MEMatte_S_topk0.25_win_global_long.py:
--------------------------------------------------------------------------------
1 | from .common.train import train
2 | from .common.model import model
3 | from .common.optimizer import optimizer
4 | from .common.scheduler import lr_multiplier
5 | from .common.dataloader import dataloader, train_dataset
6 |
7 |
8 | train.max_iter = int(43100 / 16 / 2 * 30)
9 | train.checkpointer.period = int(43100 / 16 / 2 * 2)
10 |
11 | model.backbone.use_rel_pos = True
12 | model.backbone.topk = 0.25
13 | model.backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention
14 | model.backbone.multi_score = True
15 | model.distill = True
16 |
17 | model.teacher_backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention
18 |
19 | optimizer.lr=5e-4
20 | lr_multiplier.scheduler.values=[1.0, 0.1, 0.05]
21 | lr_multiplier.scheduler.milestones=[int(43100 / 16 / 2 * 6), int(43100 / 16 / 2 * 26)]
22 | lr_multiplier.scheduler.num_updates = train.max_iter
23 | lr_multiplier.warmup_length = 250 / train.max_iter
24 |
25 | train.init_checkpoint = './pretrained/ViTMatte_S_Com_with_teacher.pth'
26 | train.output_dir = './output_of_train/MEMatte_S_topk0.25_win_global_long'
27 |
28 | dataloader.train.batch_size=16
29 | dataloader.train.num_workers=2
30 |
--------------------------------------------------------------------------------
/configs/MEMatte_B_topk0.25_win_global_long.py:
--------------------------------------------------------------------------------
1 | from .common.train import train
2 | from .common.model import model
3 | from .common.optimizer import optimizer
4 | from .common.scheduler import lr_multiplier
5 | from .common.dataloader import dataloader, train_dataset
6 |
7 | model.backbone.embed_dim = 768
8 | model.backbone.num_heads = 12
9 | model.decoder.in_chans = 768
10 |
11 | model.teacher_backbone.embed_dim = 768
12 | model.teacher_backbone.num_heads = 12
13 |
14 | train.max_iter = int(43100 / 10 / 4 * 60)
15 | train.checkpointer.period = int(43100 / 10 / 4 * 4)
16 |
17 | model.backbone.use_rel_pos = True
18 | model.backbone.topk = 0.25
19 | model.backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention
20 | model.backbone.multi_score = True
21 | model.distill = True
22 |
23 | model.teacher_backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention
24 |
25 | optimizer.lr=5e-4
26 | lr_multiplier.scheduler.values=[1.0, 0.1, 0.05]
27 | lr_multiplier.scheduler.milestones=[int(43100 / 10 / 4 * 30), int(43100 / 10 / 4 * 52)]
28 | lr_multiplier.scheduler.num_updates = train.max_iter
29 | lr_multiplier.warmup_length = 250 / train.max_iter
30 |
31 | train.init_checkpoint = './pretrained/ViTMatte_B_Com_with_teacher.pth'
32 | train.output_dir = './output_of_train/MEMatte_B_topk0.25_win_global_long'
33 |
34 | dataloader.train.batch_size=10
35 | dataloader.train.num_workers=4
36 |
--------------------------------------------------------------------------------
/configs/common/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from functools import partial
3 | from detectron2.config import LazyCall as L
4 | from modeling import MEMatte, MattingCriterion, Detail_Capture, ViT, ViT_Teacher
5 |
6 | # Base
7 | embed_dim, num_heads = 384, 6
8 |
9 | model = L(MEMatte)(
10 | teacher_backbone = L(ViT_Teacher)(
11 | in_chans=4,
12 | img_size=512,
13 | patch_size=16,
14 | embed_dim=embed_dim,
15 | depth=12,
16 | num_heads=num_heads,
17 | drop_path_rate=0,
18 | window_size=14,
19 | mlp_ratio=4,
20 | qkv_bias=True,
21 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
22 | window_block_indexes=[],
23 | residual_block_indexes=[2, 5, 8, 11],
24 | use_rel_pos=True,
25 | out_feature="last_feat",
26 | ),
27 | backbone = L(ViT)( # Single-scale ViT backbone
28 | in_chans=4,
29 | img_size=512,
30 | patch_size=16,
31 | embed_dim=embed_dim,
32 | depth=12,
33 | num_heads=num_heads,
34 | drop_path_rate=0,
35 | window_size=14,
36 | mlp_ratio=4,
37 | qkv_bias=True,
38 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
39 | window_block_indexes=[
40 | # # 2, 5, 8 11 for global attention
41 | # 0,
42 | # 1,
43 | # 3,
44 | # 4,
45 | # 6,
46 | # 7,
47 | # 9,
48 | # 10,
49 | ],
50 | residual_block_indexes=[2, 5, 8, 11],
51 | use_rel_pos=True,
52 | out_feature="last_feat",
53 | topk = 1,
54 | ),
55 | criterion=L(MattingCriterion)(
56 | losses = ['unknown_l1_loss', 'known_l1_loss', 'loss_pha_laplacian', 'loss_gradient_penalty']
57 | ),
58 | pixel_mean = [123.675 / 255., 116.280 / 255., 103.530 / 255.],
59 | pixel_std = [58.395 / 255., 57.120 / 255., 57.375 / 255.],
60 | input_format = "RGB",
61 | size_divisibility=32,
62 | decoder=L(Detail_Capture)(),
63 | distill = True,
64 | distill_loss_ratio = 1,
65 | token_loss_ratio = 1,
66 | )
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /*.sh
2 | change_submit.py
3 | cluster_submit.yaml
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | work_dirs
15 | test
16 | val
17 | .Python
18 | build/
19 | ckpts/
20 | ckpts
21 | test/
22 | val/
23 | work_dirs/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | pip-wheel-metadata/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 | db.sqlite3-journal
75 |
76 | # Flask stuff:
77 | instance/
78 | .webassets-cache
79 |
80 | # Scrapy stuff:
81 | .scrapy
82 |
83 | # Sphinx documentation
84 | docs/_build/
85 |
86 | # PyBuilder
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # IPython
93 | profile_default/
94 | ipython_config.py
95 |
96 | # pyenv
97 | .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
107 | __pypackages__/
108 |
109 | # Celery stuff
110 | celerybeat-schedule
111 | celerybeat.pid
112 |
113 | # SageMath parsed files
114 | *.sage.py
115 |
116 | # Environments
117 | .env
118 | .venv
119 | env/
120 | venv/
121 | ENV/
122 | env.bak/
123 | venv.bak/
124 |
125 | # Spyder project settings
126 | .spyderproject
127 | .spyproject
128 |
129 | # Rope project settings
130 | .ropeproject
131 |
132 | # mkdocs documentation
133 | /site
134 |
135 | # mypy
136 | .mypy_cache/
137 | .dmypy.json
138 | dmypy.json
139 |
140 | # Pyre type checker
141 | .pyre/
142 |
143 | cluster.sh
144 |
145 | *.pth
146 |
147 | predAlpha
148 | evaluation_log.txt
--------------------------------------------------------------------------------
/modeling/backbone/backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Dict
4 | import torch.nn as nn
5 |
6 | from detectron2.layers import ShapeSpec
7 |
8 | __all__ = ["Backbone"]
9 |
10 |
11 | class Backbone(nn.Module, metaclass=ABCMeta):
12 | """
13 | Abstract base class for network backbones.
14 | """
15 |
16 | def __init__(self):
17 | """
18 | The `__init__` method of any subclass can specify its own set of arguments.
19 | """
20 | super().__init__()
21 |
22 | @abstractmethod
23 | def forward(self):
24 | """
25 | Subclasses must override this method, but adhere to the same return type.
26 |
27 | Returns:
28 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
29 | """
30 | pass
31 |
32 | @property
33 | def size_divisibility(self) -> int:
34 | """
35 | Some backbones require the input height and width to be divisible by a
36 | specific integer. This is typically true for encoder / decoder type networks
37 | with lateral connection (e.g., FPN) for which feature maps need to match
38 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
39 | input size divisibility is required.
40 | """
41 | return 0
42 |
43 | @property
44 | def padding_constraints(self) -> Dict[str, int]:
45 | """
46 | This property is a generalization of size_divisibility. Some backbones and training
47 | recipes require specific padding constraints, such as enforcing divisibility by a specific
48 | integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
49 | in :paper:vitdet). `padding_constraints` contains these optional items like:
50 | {
51 | "size_divisibility": int,
52 | "square_size": int,
53 | # Future options are possible
54 | }
55 | `size_divisibility` will read from here if presented and `square_size` indicates the
56 | square padding size if `square_size` > 0.
57 |
58 | TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
59 | could be generalized as TypedDict (Python 3.8+) to support more types in the future.
60 | """
61 | return {}
62 |
63 | def output_shape(self):
64 | """
65 | Returns:
66 | dict[str->ShapeSpec]
67 | """
68 | # this is a backward-compatible default
69 | return {
70 | name: ShapeSpec(
71 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
72 | )
73 | for name in self._out_features
74 | }
75 |
--------------------------------------------------------------------------------
/engine/mattingtrainer.py:
--------------------------------------------------------------------------------
1 | from detectron2.engine import AMPTrainer
2 | import torch
3 | import time
4 | import detectron2.utils.comm as comm
5 | import logging
6 |
7 | from detectron2.utils.events import EventWriter, get_event_storage
8 |
9 | def cycle(iterable):
10 | while True:
11 | for x in iterable:
12 | yield x
13 |
14 | class MattingTrainer(AMPTrainer):
15 | def __init__(self, model, data_loader, optimizer, grad_scaler=None, log_image_period = 2000):
16 | super().__init__(model, data_loader, optimizer, grad_scaler=None)
17 | self.data_loader_iter = iter(cycle(self.data_loader))
18 | self.log_image_period = log_image_period
19 |
20 |
21 | def run_step(self):
22 | """
23 | Implement the AMP training logic.
24 | """
25 | assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
26 | assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
27 | from torch.cuda.amp import autocast
28 |
29 | #matting pass
30 | start = time.perf_counter()
31 | data = next(self.data_loader_iter)
32 | data_time = time.perf_counter() - start
33 |
34 | with autocast():
35 | loss_dict, output_images, out_pred_prob = self.model(data)
36 | if isinstance(loss_dict, torch.Tensor):
37 | losses = loss_dict
38 | loss_dict = {"total_loss": loss_dict}
39 | else:
40 | losses = sum(loss_dict.values())
41 |
42 | self.optimizer.zero_grad()
43 | self.grad_scaler.scale(losses).backward()
44 |
45 | if self.iter % 20 == 0:
46 | self.write_ratios(out_pred_prob)
47 | self._write_metrics(loss_dict, data_time)
48 | self._write_images(output_images, data)
49 |
50 | self.grad_scaler.step(self.optimizer)
51 | self.grad_scaler.update()
52 |
53 | def write_ratios(self, out_pred_prob):
54 | storage = get_event_storage()
55 | for i in range(len(out_pred_prob)):
56 | storage.put_scalar(f"{i}_block_token_ratio", out_pred_prob[i].sum() / out_pred_prob[i].numel(), cur_iter = self.iter)
57 | storage.put_scalar("total_token_ratio", sum([p.sum() / p.numel() for p in out_pred_prob]) / len(out_pred_prob), cur_iter = self.iter)
58 |
59 | def _write_images(self, output_images: torch.Tensor, data: torch.Tensor, iter: int = None):
60 | logger = logging.getLogger(__name__)
61 | iter = self.iter if iter is None else iter
62 | if (iter + 1) % self.log_image_period == 0:
63 | try:
64 | MattingTrainer.write_images(output_images, data, iter)
65 | except Exception:
66 | logger.exception("Exception in writing images: ")
67 | raise
68 |
69 | @staticmethod
70 | def write_images(output_images: torch.Tensor, data: torch.Tensor, cur_iter:int = None):
71 | # output_images = output_images.detach().cpu()
72 | if comm.is_main_process():
73 | storage = get_event_storage()
74 | storage.put_image("fg", data["fg"])
75 | storage.put_image("alpha_gt", data["alpha"])
76 | storage.put_image("bg", data["bg"])
77 | storage.put_image("trimap", data["trimap"])
78 | storage.put_image("image", data["image"])
79 | # storage._block_ratio = (block_ratio, storage.iter)
80 | for key in output_images.keys():
81 | storage.put_image(key, output_images[key])
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 |
2 | from detectron2.engine.train_loop import HookBase
3 | from detectron2.utils.events import EventWriter, get_event_storage
4 | from functools import cached_property
5 | from typing import Optional
6 | from detectron2.utils.file_io import PathManager
7 | from detectron2.engine.defaults import CommonMetricPrinter, TensorboardXWriter, JSONWriter
8 |
9 | import os
10 | import torch
11 |
12 | def power_default_writers(output_dir: str, max_iter: Optional[int] = None):
13 | """
14 | Build a list of :class:`EventWriter` to be used.
15 | It now consists of a :class:`CommonMetricPrinter`,
16 | :class:`TensorboardXWriter` and :class:`JSONWriter`.
17 |
18 | Args:
19 | output_dir: directory to store JSON metrics and tensorboard events
20 | max_iter: the total number of iterations
21 |
22 | Returns:
23 | list[EventWriter]: a list of :class:`EventWriter` objects.
24 | """
25 | PathManager.mkdirs(output_dir)
26 | return [
27 | # It may not always print what you want to see, since it prints "common" metrics only.
28 | CommonMetricPrinter(max_iter),
29 | JSONWriter(os.path.join(output_dir, "metrics.json")),
30 | PowerTensorboardXWriter(output_dir),
31 | ]
32 |
33 |
34 | class PowerTensorboardXWriter(TensorboardXWriter):
35 | """
36 | Write all scalars to a tensorboard file.
37 | Compared to the offical code, we replace the add_image with add_images.
38 | """
39 |
40 | def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
41 | """
42 | Args:
43 | log_dir (str): the directory to save the output events
44 | window_size (int): the scalars will be median-smoothed by this window size
45 |
46 | kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
47 | """
48 | super().__init__(log_dir, window_size, **kwargs)
49 |
50 | def write(self):
51 | storage = get_event_storage()
52 | new_last_write = self._last_write
53 | for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
54 | if iter > self._last_write:
55 | self._writer.add_scalar(k, v, iter)
56 | new_last_write = max(new_last_write, iter)
57 | self._last_write = new_last_write
58 |
59 | # storage.put_{image,histogram} is only meant to be used by
60 | # tensorboard writer. So we access its internal fields directly from here.
61 | if len(storage._vis_data) >= 1:
62 | for img_name, img, step_num in storage._vis_data:
63 | self._writer.add_images(img_name, img, step_num)
64 | # Storage stores all image data and rely on this writer to clear them.
65 | # As a result it assumes only one writer will use its image data.
66 | # An alternative design is to let storage store limited recent
67 | # data (e.g. only the most recent image) that all writers can access.
68 | # In that case a writer may not see all image data if its period is long.
69 | storage.clear_images()
70 |
71 | # if len(storage._block_ratio) == 2:
72 | # block_ratio, step_num = storage._block_ratio
73 | # self._writer.add_text("block_token_ratio", block_ratio, step_num)
74 | # storage._block_ratio = ()
75 |
76 | if len(storage._histograms) >= 1:
77 | for params in storage._histograms:
78 | self._writer.add_histogram_raw(**params)
79 | storage.clear_histograms()
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from PIL import Image
4 | from tqdm import tqdm
5 | from torch.utils.data import Dataset, DataLoader
6 | from torchvision.transforms import functional as F
7 | from os.path import join as opj
8 | from detectron2.checkpoint import DetectionCheckpointer
9 | from detectron2.config import LazyConfig, instantiate
10 | from detectron2.engine import default_argument_parser
11 |
12 | import warnings
13 | warnings.filterwarnings('ignore')
14 |
15 | #Dataset and Dataloader
16 | def collate_fn(batched_inputs):
17 | rets = dict()
18 | for k in batched_inputs[0].keys():
19 | rets[k] = torch.stack([_[k] for _ in batched_inputs])
20 | return rets
21 |
22 | class Composition_1k(Dataset):
23 | def __init__(self, data_dir, finished_list = None):
24 | self.data_dir = data_dir
25 | if "AIM" in data_dir:
26 | self.file_names = sorted(os.listdir(opj(self.data_dir, 'original')))
27 | else:
28 | self.file_names = sorted(os.listdir(opj(self.data_dir, 'merged')), reverse=True)
29 |
30 | # self.file_names = list(set(self.file_names).difference(set(finished_list))) # difference
31 |
32 | def __len__(self):
33 | return len(self.file_names)
34 |
35 | def __getitem__(self, idx):
36 | if "AIM" in self.data_dir:
37 | tris = Image.open(opj(self.data_dir, 'trimap', self.file_names[idx].replace('jpg','png')))
38 | imgs = Image.open(opj(self.data_dir, 'original', self.file_names[idx]))
39 | else:
40 | tris = Image.open(opj(self.data_dir, 'trimaps', self.file_names[idx].replace('jpeg','png').replace('jpg','png')))
41 | imgs = Image.open(opj(self.data_dir, 'merged', self.file_names[idx]))
42 |
43 | sample = {}
44 |
45 | sample['trimap'] = F.to_tensor(tris)[0:1, :, :]
46 | sample['image'] = F.to_tensor(imgs)
47 | sample['image_name'] = self.file_names[idx]
48 | return sample
49 |
50 |
51 | #model and output
52 | def matting_inference(
53 | config_dir='',
54 | checkpoint_dir='',
55 | inference_dir='',
56 | data_dir='',
57 | rank=None,
58 | max_number_token = 18500,
59 | ):
60 |
61 | #initializing model
62 | cfg = LazyConfig.load(config_dir)
63 | cfg.model.teacher_backbone = None
64 | cfg.model.backbone.max_number_token = max_number_token
65 | model = instantiate(cfg.model)
66 | model.to(cfg.train.device if rank is None else rank)
67 | model.eval()
68 | DetectionCheckpointer(model).load(checkpoint_dir)
69 |
70 | #initializing dataset
71 | composition_1k_dataloader = DataLoader(
72 | dataset = Composition_1k(
73 | data_dir = data_dir,
74 | ),
75 | shuffle = False,
76 | batch_size = 1,
77 | )
78 |
79 | #inferencing
80 | os.makedirs(inference_dir, exist_ok=True)
81 |
82 | for data in tqdm(composition_1k_dataloader):
83 | with torch.no_grad():
84 | for k in data.keys():
85 | if k == 'image_name':
86 | continue
87 | else:
88 | data[k].to(model.device)
89 |
90 | output, _, _ = model(data, patch_decoder=True)
91 | output = output['phas'].flatten(0, 2)
92 | trimap = data['trimap'].squeeze(0).squeeze(0)
93 | output[trimap == 0] = 0
94 | output[trimap == 1] = 1
95 | output = F.to_pil_image(output)
96 | output.save(opj(inference_dir, data['image_name'][0].replace('.jpg', '.png')))
97 | torch.cuda.empty_cache()
98 |
99 | if __name__ == '__main__':
100 | #add argument we need:
101 | parser = default_argument_parser()
102 | parser.add_argument('--config-dir', type=str, required=True)
103 | parser.add_argument('--checkpoint-dir', type=str, required=True)
104 | parser.add_argument('--inference-dir', type=str, required=True)
105 | parser.add_argument('--data-dir', type=str, required=True)
106 | parser.add_argument('--max-number-token', type=int, required=True, default=18500)
107 |
108 | args = parser.parse_args()
109 | matting_inference(
110 | config_dir = args.config_dir,
111 | checkpoint_dir = args.checkpoint_dir,
112 | inference_dir = args.inference_dir,
113 | data_dir = args.data_dir,
114 | max_number_token = args.max_number_token
115 | )
116 |
--------------------------------------------------------------------------------
/modeling/decoder/detail_capture.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class Basic_Conv3x3(nn.Module):
6 | """
7 | Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
8 | """
9 | def __init__(
10 | self,
11 | in_chans,
12 | out_chans,
13 | stride=2,
14 | padding=1,
15 | ):
16 | super().__init__()
17 | self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
18 | self.bn = nn.BatchNorm2d(out_chans)
19 | self.relu = nn.ReLU(True)
20 |
21 | def forward(self, x):
22 | x = self.conv(x)
23 | x = self.bn(x)
24 | x = self.relu(x)
25 |
26 | return x
27 |
28 | class ConvStream(nn.Module):
29 | """
30 | Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
31 | """
32 | def __init__(
33 | self,
34 | in_chans = 4,
35 | out_chans = [48, 96, 192],
36 | ):
37 | super().__init__()
38 | self.convs = nn.ModuleList()
39 |
40 | self.conv_chans = out_chans.copy()
41 | self.conv_chans.insert(0, in_chans)
42 |
43 | for i in range(len(self.conv_chans)-1):
44 | in_chan_ = self.conv_chans[i]
45 | out_chan_ = self.conv_chans[i+1]
46 | self.convs.append(
47 | Basic_Conv3x3(in_chan_, out_chan_)
48 | )
49 |
50 | def forward(self, x):
51 | out_dict = {'D0': x}
52 | for i in range(len(self.convs)):
53 | x = self.convs[i](x)
54 | name_ = 'D'+str(i+1)
55 | out_dict[name_] = x
56 |
57 | return out_dict
58 |
59 | class Fusion_Block(nn.Module):
60 | """
61 | Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
62 | """
63 | def __init__(
64 | self,
65 | in_chans,
66 | out_chans,
67 | ):
68 | super().__init__()
69 | self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
70 |
71 | def forward(self, x, D):
72 | F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
73 | out = torch.cat([D, F_up], dim=1)
74 | out = self.conv(out)
75 |
76 | return out
77 |
78 | class Matting_Head(nn.Module):
79 | """
80 | Simple Matting Head, containing only conv3x3 and conv1x1 layers.
81 | """
82 | def __init__(
83 | self,
84 | in_chans = 32,
85 | mid_chans = 16,
86 | ):
87 | super().__init__()
88 | self.matting_convs = nn.Sequential(
89 | nn.Conv2d(in_chans, mid_chans, 3, 1, 1),
90 | nn.BatchNorm2d(mid_chans),
91 | nn.ReLU(True),
92 | nn.Conv2d(mid_chans, 1, 1, 1, 0)
93 | )
94 |
95 | def forward(self, x):
96 | x = self.matting_convs(x)
97 |
98 | return x
99 |
100 | class Detail_Capture(nn.Module):
101 | """
102 | Simple and Lightweight Detail Capture Module for ViT Matting.
103 | """
104 | def __init__(
105 | self,
106 | in_chans = 384,
107 | img_chans=4,
108 | convstream_out = [48, 96, 192],
109 | fusion_out = [256, 128, 64, 32],
110 | ):
111 | super().__init__()
112 | assert len(fusion_out) == len(convstream_out) + 1
113 |
114 | self.convstream = ConvStream(in_chans = img_chans)
115 | self.conv_chans = self.convstream.conv_chans
116 |
117 | self.fusion_blks = nn.ModuleList()
118 | self.fus_channs = fusion_out.copy()
119 | self.fus_channs.insert(0, in_chans)
120 | for i in range(len(self.fus_channs)-1):
121 | self.fusion_blks.append(
122 | Fusion_Block(
123 | in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
124 | out_chans = self.fus_channs[i+1],
125 | )
126 | )
127 |
128 | self.matting_head = Matting_Head(
129 | in_chans = fusion_out[-1],
130 | )
131 |
132 | def forward(self, features, images):
133 | detail_features = self.convstream(images)
134 | for i in range(len(self.fusion_blks)):
135 | d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
136 | features = self.fusion_blks[i](features, detail_features[d_name_])
137 |
138 | phas = torch.sigmoid(self.matting_head(features))
139 |
140 | return {'phas': phas}
141 |
142 | if __name__ == '__main__':
143 | detail_capture = Detail_Capture()
144 | features = torch.rand(2, 384, 32, 32)
145 | inputs = torch.randn(2, 4, 512, 512)
146 | print(detail_capture)
147 | output = detail_capture(features, inputs)
148 | print(output.shape)
--------------------------------------------------------------------------------
/modeling/criterion/matting_criterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class MattingCriterion(nn.Module):
6 | def __init__(self,
7 | *,
8 | losses,
9 | ):
10 | super(MattingCriterion, self).__init__()
11 | self.losses = losses
12 |
13 | def loss_gradient_penalty(self, sample_map ,preds, targets):
14 | preds = preds['phas']
15 | targets = targets['phas']
16 |
17 | #sample_map for unknown area
18 | scale = sample_map.shape[0]*262144/torch.sum(sample_map)
19 |
20 | #gradient in x
21 | sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type())
22 | delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1)
23 | delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1)
24 |
25 | #gradient in y
26 | sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type())
27 | delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1)
28 | delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1)
29 |
30 | #loss
31 | loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \
32 | F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \
33 | 0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \
34 | 0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale)
35 |
36 | return dict(loss_gradient_penalty=loss)
37 |
38 | def loss_pha_laplacian(self, preds, targets):
39 | assert 'phas' in preds and 'phas' in targets
40 | loss = laplacian_loss(preds['phas'], targets['phas'])
41 |
42 | return dict(loss_pha_laplacian=loss)
43 |
44 | def unknown_l1_loss(self, sample_map, preds, targets):
45 |
46 | scale = sample_map.shape[0]*262144/torch.sum(sample_map)
47 | # scale = 1
48 |
49 | loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale
50 | return dict(unknown_l1_loss=loss)
51 |
52 | def known_l1_loss(self, sample_map, preds, targets):
53 | new_sample_map = torch.zeros_like(sample_map)
54 | new_sample_map[sample_map==0] = 1
55 |
56 | if torch.sum(new_sample_map) == 0:
57 | scale = 0
58 | else:
59 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map)
60 | # scale = 1
61 |
62 | loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale
63 | return dict(known_l1_loss=loss)
64 |
65 |
66 | def forward(self, sample_map, preds, targets):
67 | losses = dict()
68 | for k in self.losses:
69 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty':
70 | losses.update(getattr(self, k)(sample_map, preds, targets))
71 | else:
72 | losses.update(getattr(self, k)(preds, targets))
73 | return losses
74 |
75 |
76 | #-----------------Laplacian Loss-------------------------#
77 | def laplacian_loss(pred, true, max_levels=5):
78 | kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
79 | pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
80 | true_pyramid = laplacian_pyramid(true, kernel, max_levels)
81 | loss = 0
82 | for level in range(max_levels):
83 | loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
84 | return loss / max_levels
85 |
86 | def laplacian_pyramid(img, kernel, max_levels):
87 | current = img
88 | pyramid = []
89 | for _ in range(max_levels):
90 | current = crop_to_even_size(current)
91 | down = downsample(current, kernel)
92 | up = upsample(down, kernel)
93 | diff = current - up
94 | pyramid.append(diff)
95 | current = down
96 | return pyramid
97 |
98 | def gauss_kernel(device='cpu', dtype=torch.float32):
99 | kernel = torch.tensor([[1, 4, 6, 4, 1],
100 | [4, 16, 24, 16, 4],
101 | [6, 24, 36, 24, 6],
102 | [4, 16, 24, 16, 4],
103 | [1, 4, 6, 4, 1]], device=device, dtype=dtype)
104 | kernel /= 256
105 | kernel = kernel[None, None, :, :]
106 | return kernel
107 |
108 | def gauss_convolution(img, kernel):
109 | B, C, H, W = img.shape
110 | img = img.reshape(B * C, 1, H, W)
111 | img = F.pad(img, (2, 2, 2, 2), mode='reflect')
112 | img = F.conv2d(img, kernel)
113 | img = img.reshape(B, C, H, W)
114 | return img
115 |
116 | def downsample(img, kernel):
117 | img = gauss_convolution(img, kernel)
118 | img = img[:, :, ::2, ::2]
119 | return img
120 |
121 | def upsample(img, kernel):
122 | B, C, H, W = img.shape
123 | out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
124 | out[:, :, ::2, ::2] = img * 4
125 | out = gauss_convolution(out, kernel)
126 | return out
127 |
128 | def crop_to_even_size(img):
129 | H, W = img.shape[2:]
130 | H = H - H % 2
131 | W = W - W % 2
132 | return img[:, :, :H, :W]
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | """
4 | Training script using the new "LazyConfig" python config files.
5 |
6 | This scripts reads a given python config file and runs the training or evaluation.
7 | It can be used to train any models or dataset as long as they can be
8 | instantiated by the recursive construction defined in the given config file.
9 |
10 | Besides lazy construction of models, dataloader, etc., this scripts expects a
11 | few common configuration parameters currently defined in "configs/common/train.py".
12 | To add more complicated training logic, you can easily add other configs
13 | in the config file and implement a new train_net.py to handle them.
14 | """
15 | import logging
16 |
17 | from detectron2.checkpoint import DetectionCheckpointer
18 | from detectron2.config import LazyConfig, instantiate
19 | from detectron2.engine import (
20 | AMPTrainer,
21 | SimpleTrainer,
22 | default_argument_parser,
23 | default_setup,
24 | default_writers,
25 | hooks,
26 | launch,
27 | )
28 | from detectron2.engine.defaults import create_ddp_model
29 | from detectron2.evaluation import inference_on_dataset, print_csv_format
30 | from detectron2.utils import comm
31 | from utils.logger import power_default_writers
32 |
33 | from engine import MattingTrainer
34 |
35 | #running without warnings
36 | import warnings
37 | warnings.filterwarnings('ignore')
38 |
39 | logger = logging.getLogger("detectron2")
40 |
41 |
42 | def do_test(cfg, model):
43 | if "evaluator" in cfg.dataloader:
44 | ret = inference_on_dataset(
45 | model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator)
46 | )
47 | print_csv_format(ret)
48 | return ret
49 |
50 |
51 | def do_train(args, cfg):
52 | """
53 | Args:
54 | cfg: an object with the following attributes:
55 | model: instantiate to a module
56 | dataloader.{train,test}: instantiate to dataloaders
57 | dataloader.evaluator: instantiate to evaluator for test set
58 | optimizer: instantaite to an optimizer
59 | lr_multiplier: instantiate to a fvcore scheduler
60 | train: other misc config defined in `configs/common/train.py`, including:
61 | output_dir (str)
62 | init_checkpoint (str)
63 | amp.enabled (bool)
64 | max_iter (int)
65 | eval_period, log_period (int)
66 | device (str)
67 | checkpointer (dict)
68 | ddp (dict)
69 | """
70 | model = instantiate(cfg.model)
71 | logger = logging.getLogger("detectron2")
72 | logger.info("Model:\n{}".format(model))
73 | model.to(cfg.train.device)
74 |
75 | for name, param in model.named_parameters():
76 | if "teacher" in name:
77 | param.requires_grad = False
78 |
79 | cfg.optimizer.params.model = model
80 | optim = instantiate(cfg.optimizer)
81 |
82 | train_dataset = instantiate(cfg.train_dataset)
83 | cfg.dataloader.train.dataset = train_dataset
84 | cfg.dataloader.train.sampler.dataset = train_dataset
85 | train_loader = instantiate(cfg.dataloader.train)
86 |
87 | model = create_ddp_model(model, **cfg.train.ddp)
88 | trainer = MattingTrainer(
89 | model = model,
90 | data_loader = train_loader,
91 | optimizer = optim,
92 | log_image_period = cfg.train.log_image_period
93 | )
94 | checkpointer = DetectionCheckpointer(
95 | model,
96 | cfg.train.output_dir,
97 | trainer=trainer,
98 | )
99 | trainer.register_hooks(
100 | [
101 | hooks.IterationTimer(),
102 | hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
103 | hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
104 | if comm.is_main_process()
105 | else None,
106 | hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
107 | hooks.PeriodicWriter(
108 | power_default_writers(cfg.train.output_dir, cfg.train.max_iter),
109 | period=cfg.train.log_loss_period,
110 | )
111 | if comm.is_main_process()
112 | else None,
113 | ]
114 | )
115 |
116 | checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
117 | if args.resume and checkpointer.has_checkpoint():
118 | # The checkpoint stores the training iteration that just finished, thus we start
119 | # at the next iteration
120 | start_iter = trainer.iter + 1
121 | else:
122 | start_iter = 0
123 | trainer.train(start_iter, cfg.train.max_iter)
124 |
125 |
126 | def main(args):
127 | cfg = LazyConfig.load(args.config_file)
128 | cfg = LazyConfig.apply_overrides(cfg, args.opts)
129 | default_setup(cfg, args)
130 |
131 | if args.eval_only:
132 | model = instantiate(cfg.model)
133 | model.to(cfg.train.device)
134 | model = create_ddp_model(model)
135 | DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
136 | print(do_test(cfg, model))
137 | else:
138 | do_train(args, cfg)
139 |
140 |
141 | if __name__ == "__main__":
142 | args = default_argument_parser().parse_args()
143 | launch(
144 | main,
145 | args.num_gpus,
146 | num_machines=args.num_machines,
147 | machine_rank=args.machine_rank,
148 | dist_url=args.dist_url,
149 | args=(args,),
150 | )
151 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Memory Efficient Matting with Adaptive Token Routing
4 |
5 | Yiheng Lin, Yihan Hu, Chenyi Zhang, Ting Liu, Xiaochao Qu, Luoqi Liu, Yao Zhao, Yunchao Wei
6 |
7 | Institute of Information Science, Beijing Jiaotong University
8 | Visual Intelligence + X International Joint Laboratory of the Ministry of Education
9 | Pengcheng Laboratory, Shenzhen, China
10 | MT Lab, Meitu Inc
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## 📮 News
26 | - [2025.07] Release our interactive image matting model - [MattePro](https://github.com/ChenyiZhang007/MattePro).
27 |
28 | ## Introduction
29 | Transformer-based models have recently achieved outstanding performance in image matting. However, their application to high-resolution images remains challenging due to the quadratic complexity of global self-attention. To address this issue, we propose MEMatte, a memory-efficient matting framework for processing high-resolution images. MEMatte incorporates a router before each global attention block, directing informative tokens to the global attention while routing other tokens to a Lightweight Token Refinement Module (LTRM). Specifically, the router employs a local-global strategy to predict the routing probability of each token, and the LTRM utilizes efficient modules to simulate global attention. Additionally, we introduce a Batch-constrained Adaptive Token Routing (BATR) mechanism, which allows each router to dynamically route tokens based on image content and the stages of attention block in the network.
30 |
31 | ## Dataset
32 | Our proposed ultra high-resolution image matting datasets:
33 | [`huggingface: dafbgd/UHRIM`](https://huggingface.co/datasets/dafbgd/UHRIM)
34 |
35 |
36 | ## Quick Installation
37 | Run the following command to install required packages.
38 | ```
39 | pip install -r requirements.txt
40 | ```
41 | Install [detectron2](https://github.com/facebookresearch/detectron2) please following its [document](https://detectron2.readthedocs.io/en/latest/), you can also run following command
42 | ```
43 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
44 | ```
45 |
46 | ## Results
47 | Quantitative Results on [Composition-1k](https://paperswithcode.com/dataset/composition-1k)
48 | | Model | SAD | MSE | Grad | Conn | checkpoints |
49 | | ---------- | ----- | --- | ---- | ----- | ----------- |
50 | | MEMatte-ViTS | 21.90 | 3.37 | 7.43 | 16.77 | [GoogleDrive](https://drive.google.com/file/d/122p3sdhJVb7vg4IXELeC9C3HEG9Mlh5z/view?usp=sharing) |
51 | | MEMatte-ViTB | 21.06 | 3.11 | 6.70 | 15.71 | [GoogleDrive](https://drive.google.com/file/d/1NOV64zMSFtoKPASqvEvxQKI_PRY9m5IA/view?usp=sharing) |
52 |
53 | We also train a robust model for real-world images AIM-500 using mixed data:
54 | | Model | SAD | MSE | Grad | Conn | checkpoints |
55 | | ---------- | ----- | --- | ---- | ----- | ----------- |
56 | | MEMatte-ViTS | 13.90 | 11.17 | 10.94 | 12.78 | [GoogleDrive](https://drive.google.com/file/d/1R5NbgIpOudKjvLz1V9M9SxXr1ovAmu3u/view?usp=drive_link) |
57 |
58 | ## Train
59 | 1. Download the official checkpoints of ViTMatte ([ViTMatte_S_Com.pth](https://drive.google.com/file/d/12VKhSwE_miF9lWQQCgK7mv83rJIls3Xe/view), [ViTMatte_B_Com.pth](https://drive.google.com/file/d/1mOO5MMU4kwhNX96AlfpwjAoMM4V5w3k-/view?pli=1)), and then process the checkpoints using `pretrained/preprocess.py`.
60 | 2. Set `train.init_checkpoint` in the configs to specify the processed checkpoint.
61 | 3. Train the model with the following command:
62 | ```
63 | python main.py \
64 | --config-file configs/MEMatte_S_topk0.25_win_global_long.py \
65 | --num-gpus 2
66 | ```
67 |
68 |
69 | ## Inference
70 | ```
71 | python inference.py \
72 | --config-dir ./configs/CONFIG.py \
73 | --checkpoint-dir ./CHECKPOINT_PATH \
74 | --inference-dir ./SAVE_DIR \
75 | --data-dir /DataDir \
76 | --max-number-token Max_number_token
77 | ```
78 | For example:
79 | ```
80 | python inference.py \
81 | --config-dir ./configs/MEMatte_S_topk0.25_win_global_long.py \
82 | --checkpoint-dir ./checkpoints/MEMatte_ViTS_DIM.pth \
83 | --inference-dir ./predAlpha/test_aim500 \
84 | --data-dir ./Datasets/AIM-500 \
85 | --max-number-token 18000
86 | # Reducing the maximum number of tokens lowers memory usage.
87 | ```
88 |
89 | ## ToDo
90 | - [x] release UHRIM dataset
91 | - [x] release code and checkpoint
92 |
93 | ## Citation
94 | If you have any questions, please feel free to open an issue. If you find our method or dataset helpful, we would appreciate it if you could give our project a star ⭐️ on GitHub and cite our paper:
95 | ```bibtex
96 | @inproceedings{lin2025memory,
97 | title={Memory Efficient Matting with Adaptive Token Routing},
98 | author={Lin, Yiheng and Hu, Yihan and Zhang, Chenyi and Liu, Ting and Qu, Xiaochao and Liu, Luoqi and Zhao, Yao and Wei, Yunchao},
99 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
100 | volume={39},
101 | number={5},
102 | pages={5298--5306},
103 | year={2025}
104 | }
105 | ```
106 |
107 | ## License
108 | The code is released under the MIT License. It is a short, permissive software license. Basically, you can do whatever you want as long as you include the original copyright and license notice in any copy of the software/source.
109 |
110 | ## Acknowledgement
111 | Our project is developed based on [ViTMatte](https://github.com/hustvl/ViTMatte), [DynamicViT](https://github.com/raoyongming/DynamicViT), [Matteformer](https://github.com/webtoon/matteformer), [ToMe](https://github.com/facebookresearch/ToMe), [EViT](https://github.com/youweiliang/evit). Thanks for their wonderful work!
112 |
113 |
--------------------------------------------------------------------------------
/modeling/meta_arch/mematte.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 | import os
6 |
7 | from detectron2.structures import ImageList
8 |
9 | class MEMatte(nn.Module):
10 | def __init__(self,
11 | *,
12 | teacher_backbone,
13 | backbone,
14 | criterion,
15 | pixel_mean,
16 | pixel_std,
17 | input_format,
18 | size_divisibility,
19 | decoder,
20 | distill = True,
21 | distill_loss_ratio = 1.,
22 | token_loss_ratio = 1.,
23 | balance_loss = "MSE",
24 | ):
25 | super(MEMatte, self).__init__()
26 | self.teacher_backbone = teacher_backbone
27 | self.backbone = backbone
28 | self.criterion = criterion
29 | self.input_format = input_format
30 | self.balance_loss = balance_loss
31 | self.size_divisibility = size_divisibility
32 | self.decoder = decoder
33 | self.distill = distill
34 | self.distill_loss_ratio = distill_loss_ratio
35 | self.token_loss_ratio = token_loss_ratio
36 | self.register_buffer(
37 | "pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False
38 | )
39 | self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
40 | assert (
41 | self.pixel_mean.shape == self.pixel_std.shape
42 | ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
43 |
44 | @property
45 | def device(self):
46 | return self.pixel_mean.device
47 |
48 | def forward(self, batched_inputs, patch_decoder=True):
49 | images, targets, H, W = self.preprocess_inputs(batched_inputs)
50 |
51 | if self.training:
52 | if self.distill == True:
53 | self.teacher_backbone.eval()
54 | features, out_pred_prob = self.backbone(images)
55 | teacher_features = self.teacher_backbone(images)
56 | distill_loss = F.mse_loss(features, teacher_features)
57 | else:
58 | features, out_pred_prob = self.backbone(images)
59 | outputs = self.decoder(features, images)
60 | assert targets is not None
61 | trimap = images[:, 3:4]
62 | sample_map = torch.zeros_like(trimap)
63 | sample_map[trimap==0.5] = 1 # 2*B*dim*(2*ratio*hw*dim + ratio*hw*ratio*hw)
64 | losses = self.criterion(sample_map ,outputs, targets)
65 | if self.distill:
66 | losses['distill_loss'] = distill_loss * self.distill_loss_ratio
67 | total_ratio = sum([p.sum() / p.numel() for p in out_pred_prob]) / len(out_pred_prob)
68 | losses['mse_ratio_loss'] = self.token_loss_ratio * F.mse_loss(total_ratio, torch.tensor(self.backbone.topk).cuda())
69 |
70 | tensorboard_images = dict()
71 | tensorboard_images['pred_alpha'] = outputs['phas']
72 | return losses, tensorboard_images, out_pred_prob
73 | else:
74 | features, out_pred_prob, out_hard_keep_decision = self.backbone(images)
75 | if patch_decoder:
76 | outputs = self.patch_inference(features=features, images=images)
77 | else:
78 | outputs = self.decoder(features, images)
79 |
80 | outputs['phas'] = outputs['phas'][:,:,:H,:W]
81 |
82 | return outputs, out_pred_prob, out_hard_keep_decision
83 |
84 | """测试flops"""
85 | # features = self.backbone(images)
86 | # # outputs = self.decoder(features, images)
87 | # # outputs['phas'] = outputs['phas'][:,:,:H,:W]
88 | # return features
89 |
90 | """测试decoder flops"""
91 |
92 |
93 |
94 |
95 | def patch_inference(self, features, images):
96 | patch_size = 512
97 | overlap = 64
98 | image_size = patch_size + 2 * overlap
99 | feature_patch_size = patch_size // 16
100 | feature_overlap = overlap // 16
101 | features_size = feature_patch_size + 2 * feature_overlap
102 | B, C, H, W = images.shape
103 | pad_h = (patch_size - H % patch_size) % patch_size
104 | pad_w = (patch_size - W % patch_size) % patch_size
105 | pad_images = F.pad(images.permute(0,2,3,1), (0,0,0,pad_w,0,pad_h)).permute(0,3,1,2)
106 | _, _, pad_H, pad_W = pad_images.shape
107 |
108 | _, _, H_fea, W_fea = features.shape
109 | pad_fea_h = (feature_patch_size - H_fea % feature_patch_size) % feature_patch_size
110 | pad_fea_w = (feature_patch_size - W_fea % feature_patch_size) % feature_patch_size
111 | pad_features = F.pad(features.permute(0,2,3,1), (0,0,0,pad_fea_w,0,pad_fea_h)).permute(0,3,1,2)
112 | _, _, pad_fea_H, pad_fea_W = pad_features.shape
113 |
114 | h_patch_num = pad_images.shape[2] // patch_size
115 | w_patch_num = pad_images.shape[3] // patch_size
116 |
117 | outputs = torch.zeros_like(pad_images[:,0:1,:,:])
118 |
119 | for i in range(h_patch_num):
120 | for j in range(w_patch_num):
121 | start_top = i * patch_size
122 | end_bottom = start_top + patch_size
123 | start_left = j*patch_size
124 | end_right = start_left + patch_size
125 | coor_top = start_top if (start_top - overlap) < 0 else (start_top - overlap)
126 | coor_bottom = end_bottom if (end_bottom + overlap) > pad_H else (end_bottom + overlap)
127 | coor_left = start_left if (start_left - overlap) < 0 else (start_left - overlap)
128 | coor_right = end_right if (end_right + overlap) > pad_W else (end_right + overlap)
129 | selected_images = pad_images[:,:,coor_top:coor_bottom, coor_left:coor_right]
130 |
131 | fea_start_top = i * feature_patch_size
132 | fea_end_bottom = fea_start_top + feature_patch_size
133 | fea_start_left = j*feature_patch_size
134 | fea_end_right = fea_start_left + feature_patch_size
135 | coor_top_fea = fea_start_top if (fea_start_top - feature_overlap) < 0 else (fea_start_top - feature_overlap)
136 | coor_bottom_fea = fea_end_bottom if (fea_end_bottom + feature_overlap) > pad_fea_H else (fea_end_bottom + feature_overlap)
137 | coor_left_fea = fea_start_left if (fea_start_left - feature_overlap) < 0 else (fea_start_left - feature_overlap)
138 | coor_right_fea = fea_end_right if (fea_end_right + feature_overlap) > pad_fea_W else (fea_end_right + feature_overlap)
139 | selected_fea = pad_features[:,:,coor_top_fea:coor_bottom_fea, coor_left_fea:coor_right_fea]
140 |
141 |
142 |
143 | outputs_patch = self.decoder(selected_fea, selected_images)
144 |
145 | coor_top = start_top if (start_top - overlap) < 0 else (coor_top + overlap)
146 | coor_bottom = coor_top + patch_size
147 | coor_left = start_left if (start_left - overlap) < 0 else (coor_left + overlap)
148 | coor_right = coor_left + patch_size
149 |
150 | coor_out_top = 0 if (start_top - overlap) < 0 else overlap
151 | coor_out_bottom = coor_out_top + patch_size
152 | coor_out_left = 0 if (start_left - overlap) < 0 else overlap
153 | coor_out_right = coor_out_left + patch_size
154 |
155 | outputs[:, :, coor_top:coor_bottom, coor_left:coor_right] = outputs_patch['phas'][:,:,coor_out_top:coor_out_bottom,coor_out_left:coor_out_right]
156 |
157 | outputs = outputs[:,:,:H, :W]
158 | return {'phas':outputs}
159 |
160 | def preprocess_inputs(self, batched_inputs):
161 | """
162 | Normalize, pad and batch the input images.
163 | """
164 | images = batched_inputs["image"].to(self.device)
165 | trimap = batched_inputs['trimap'].to(self.device)
166 | images = (images - self.pixel_mean) / self.pixel_std
167 |
168 | if 'fg' in batched_inputs.keys():
169 | trimap[trimap < 85] = 0
170 | trimap[trimap >= 170] = 1
171 | trimap[trimap >= 85] = 0.5
172 |
173 | images = torch.cat((images, trimap), dim=1)
174 |
175 | B, C, H, W = images.shape
176 | if images.shape[-1]%32!=0 or images.shape[-2]%32!=0:
177 | new_H = (32-images.shape[-2]%32) + H
178 | new_W = (32-images.shape[-1]%32) + W
179 | new_images = torch.zeros((images.shape[0], images.shape[1], new_H, new_W)).to(self.device)
180 | new_images[:,:,:H,:W] = images[:,:,:,:]
181 | del images
182 | images = new_images
183 |
184 | if "alpha" in batched_inputs:
185 | phas = batched_inputs["alpha"].to(self.device)
186 | else:
187 | phas = None
188 |
189 | return images, dict(phas=phas), H, W
190 |
191 |
--------------------------------------------------------------------------------
/modeling/backbone/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange, repeat, pack, unpack
8 | from functools import partial
9 | from collections import namedtuple
10 | from torch import Tensor, nn, einsum
11 | import numpy as np
12 | from torch.cuda.amp import autocast
13 |
14 | __all__ = [
15 | "window_partition",
16 | "window_unpartition",
17 | "add_decomposed_rel_pos",
18 | "get_abs_pos",
19 | "PatchEmbed",
20 | "Router",
21 | ]
22 |
23 | RouterReturn = namedtuple('RouterReturn', ['indices', 'scores', 'routed_tokens', 'routed_mask'])
24 |
25 | def exists(val):
26 | return val is not None
27 |
28 | def default(val, d):
29 | return val if exists(val) else d
30 |
31 | def divisible_by(numer, denom):
32 | return (numer % denom) == 0
33 |
34 | def pack_one(t, pattern):
35 | return pack([t], pattern)
36 |
37 | def unpack_one(t, ps, pattern):
38 | return unpack(t, ps, pattern)[0]
39 |
40 | def pad_to_multiple(tensor, multiple, dim=-1, value=0):
41 | seq_len = tensor.shape[dim]
42 | m = seq_len / multiple
43 | if m.is_integer():
44 | return tensor, seq_len
45 |
46 | remainder = math.ceil(m) * multiple - seq_len
47 | pad_offset = (0,) * (-1 - dim) * 2
48 | padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value = value)
49 | return padded_tensor, seq_len
50 |
51 | def batched_gather(x, indices):
52 | batch_range = create_batch_range(indices, indices.ndim - 1)
53 | return x[batch_range, indices]
54 |
55 | def identity(t):
56 | return t
57 |
58 | def l2norm(t):
59 | return F.normalize(t, dim = -1)
60 |
61 | # tensor helpers
62 |
63 | def create_batch_range(t, right_pad_dims = 1):
64 | b, device = t.shape[0], t.device
65 | batch_range = torch.arange(b, device = device)
66 | pad_dims = ((1,) * right_pad_dims)
67 | return batch_range.reshape(-1, *pad_dims)
68 |
69 |
70 | def window_partition(x, window_size):
71 | """
72 | Partition into non-overlapping windows with padding if needed.
73 | Args:
74 | x (tensor): input tokens with [B, H, W, C].
75 | window_size (int): window size.
76 |
77 | Returns:
78 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
79 | (Hp, Wp): padded height and width before partition
80 | """
81 | B, H, W, C = x.shape
82 |
83 | pad_h = (window_size - H % window_size) % window_size
84 | pad_w = (window_size - W % window_size) % window_size
85 | if pad_h > 0 or pad_w > 0:
86 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
87 | Hp, Wp = H + pad_h, W + pad_w
88 |
89 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
90 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
91 | return windows, (Hp, Wp)
92 |
93 |
94 | def window_unpartition(windows, window_size, pad_hw, hw):
95 | """
96 | Window unpartition into original sequences and removing padding.
97 | Args:
98 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
99 | window_size (int): window size.
100 | pad_hw (Tuple): padded height and width (Hp, Wp).
101 | hw (Tuple): original height and width (H, W) before padding.
102 |
103 | Returns:
104 | x: unpartitioned sequences with [B, H, W, C].
105 | """
106 | Hp, Wp = pad_hw
107 | H, W = hw
108 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
109 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
110 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
111 |
112 | if Hp > H or Wp > W:
113 | x = x[:, :H, :W, :].contiguous()
114 | return x
115 |
116 |
117 | def get_rel_pos(q_size, k_size, rel_pos):
118 | """
119 | Get relative positional embeddings according to the relative positions of
120 | query and key sizes.
121 | Args:
122 | q_size (int): size of query q.
123 | k_size (int): size of key k.
124 | rel_pos (Tensor): relative position embeddings (L, C).
125 |
126 | Returns:
127 | Extracted positional embeddings according to relative positions.
128 | """
129 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
130 | # Interpolate rel pos if needed.
131 | if rel_pos.shape[0] != max_rel_dist:
132 | # Interpolate rel pos.
133 | rel_pos_resized = F.interpolate(
134 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
135 | size=max_rel_dist,
136 | mode="linear",
137 | )
138 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
139 | else:
140 | rel_pos_resized = rel_pos
141 |
142 | # Scale the coords with short length if shapes for q and k are different.
143 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
144 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
145 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
146 |
147 | return rel_pos_resized[relative_coords.long()]
148 |
149 |
150 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
151 | """
152 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
153 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
154 | Args:
155 | attn (Tensor): attention map.
156 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
157 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
158 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
159 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
160 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
161 |
162 | Returns:
163 | attn (Tensor): attention map with added relative positional embeddings.
164 | """
165 | q_h, q_w = q_size # 80, 120
166 | k_h, k_w = k_size # 80, 120
167 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) # torch.Size([80, 80, 64])
168 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) # torch.Size([120, 120, 64])
169 |
170 | B, _, dim = q.shape
171 | r_q = q.reshape(B, q_h, q_w, dim)
172 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) # torch.Size([6, 80, 120, 80])
173 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) # torch.Size([6, 80, 120, 120])
174 |
175 | attn = ( # 2048*2048这里会爆, rel_h: [6, 128, 128, 128], rel_w:[6, 128, 128, 128], None扩充后都乘128:6*128*128*128*128*4 = 6,442,450,944
176 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
177 | ).view(B, q_h * q_w, k_h * k_w)
178 |
179 | return attn
180 |
181 |
182 | def get_abs_pos(abs_pos, has_cls_token, hw):
183 | """
184 | Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
185 | dimension for the original embeddings.
186 | Args:
187 | abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
188 | has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
189 | hw (Tuple): size of input image tokens.
190 |
191 | Returns:
192 | Absolute positional embeddings after processing with shape (1, H, W, C)
193 | """
194 | h, w = hw
195 | if has_cls_token:
196 | abs_pos = abs_pos[:, 1:]
197 | xy_num = abs_pos.shape[1]
198 | size = int(math.sqrt(xy_num))
199 | assert size * size == xy_num
200 |
201 | if size != h or size != w:
202 | new_abs_pos = F.interpolate(
203 | abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
204 | size=(h, w),
205 | mode="bicubic",
206 | align_corners=False,
207 | )
208 |
209 | return new_abs_pos.permute(0, 2, 3, 1)
210 | else:
211 | return abs_pos.reshape(1, h, w, -1)
212 |
213 |
214 | class PatchEmbed(nn.Module):
215 | """
216 | Image to Patch Embedding.
217 | """
218 |
219 | def __init__(
220 | self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
221 | ):
222 | """
223 | Args:
224 | kernel_size (Tuple): kernel size of the projection layer.
225 | stride (Tuple): stride of the projection layer.
226 | padding (Tuple): padding size of the projection layer.
227 | in_chans (int): Number of input image channels.
228 | embed_dim (int): embed_dim (int): Patch embedding dimension.
229 | """
230 | super().__init__()
231 |
232 | self.proj = nn.Conv2d(
233 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
234 | )
235 |
236 | def forward(self, x):
237 | x = self.proj(x)
238 | # B C H W -> B H W C
239 | x = x.permute(0, 2, 3, 1)
240 | return x
241 |
242 | class Router(nn.Module):
243 | def __init__(self, embed_dim=384):
244 | super().__init__()
245 | self.in_conv = nn.Sequential(
246 | nn.LayerNorm(embed_dim),
247 | nn.Linear(embed_dim, embed_dim),
248 | nn.GELU()
249 | )
250 |
251 | self.out_conv = nn.Sequential(
252 | nn.Linear(embed_dim, embed_dim // 2),
253 | nn.GELU(),
254 | nn.Linear(embed_dim // 2, embed_dim // 4),
255 | nn.GELU(),
256 | nn.Linear(embed_dim // 4, 2),
257 | nn.LogSoftmax(dim=-1)
258 | )
259 |
260 | def forward(self, x, policy=None):
261 | x = self.in_conv(x) # 16, 1024, 384
262 | B, N, C = x.size()
263 | local_x = x[:,:, :C//2] # 16, 1024, 192
264 | global_x = (x[:,:, C//2:]).sum(dim=1, keepdim=True) / N # 16, 1, 192
265 | x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1) # 16, 1024, 384
266 | return self.out_conv(x) # 16, 1024, 2
267 |
--------------------------------------------------------------------------------
/modeling/backbone/vit_teacher.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import fvcore.nn.weight_init as weight_init
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | from detectron2.layers import CNNBlockBase, Conv2d, get_norm
8 | from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
9 | from fairscale.nn.checkpoint import checkpoint_wrapper
10 | from timm.models.layers import DropPath, Mlp, trunc_normal_
11 | from .backbone import Backbone
12 | from .utils import (
13 | PatchEmbed,
14 | add_decomposed_rel_pos,
15 | get_abs_pos,
16 | window_partition,
17 | window_unpartition,
18 | )
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | __all__ = ["ViT"]
24 |
25 |
26 | class Attention(nn.Module):
27 | """Multi-head Attention block with relative position embeddings."""
28 |
29 | def __init__(
30 | self,
31 | dim,
32 | num_heads=8,
33 | qkv_bias=True,
34 | use_rel_pos=False,
35 | rel_pos_zero_init=True,
36 | input_size=None,
37 | ):
38 | """
39 | Args:
40 | dim (int): Number of input channels.
41 | num_heads (int): Number of attention heads.
42 | qkv_bias (bool: If True, add a learnable bias to query, key, value.
43 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
44 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
45 | input_size (int or None): Input resolution for calculating the relative positional
46 | parameter size.
47 | """
48 | super().__init__()
49 | self.num_heads = num_heads
50 | head_dim = dim // num_heads
51 | self.scale = head_dim**-0.5
52 |
53 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
54 | self.proj = nn.Linear(dim, dim)
55 |
56 | self.use_rel_pos = use_rel_pos
57 | if self.use_rel_pos:
58 | # initialize relative positional embeddings
59 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
60 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
61 |
62 | if not rel_pos_zero_init:
63 | trunc_normal_(self.rel_pos_h, std=0.02)
64 | trunc_normal_(self.rel_pos_w, std=0.02)
65 |
66 | def forward(self, x):
67 | B, H, W, _ = x.shape
68 | # qkv with shape (3, B, nHead, H * W, C)
69 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
70 | # q, k, v with shape (B * nHead, H * W, C)
71 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
72 |
73 | attn = (q * self.scale) @ k.transpose(-2, -1)
74 |
75 | if self.use_rel_pos:
76 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
77 |
78 | attn = attn.softmax(dim=-1)
79 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
80 | x = self.proj(x)
81 |
82 | return x
83 |
84 | class LayerNorm(nn.Module):
85 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
86 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
87 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
88 | with shape (batch_size, channels, height, width).
89 | """
90 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
91 | super().__init__()
92 | self.weight = nn.Parameter(torch.ones(normalized_shape))
93 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
94 | self.eps = eps
95 | self.data_format = data_format
96 | if self.data_format not in ["channels_last", "channels_first"]:
97 | raise NotImplementedError
98 | self.normalized_shape = (normalized_shape, )
99 |
100 | def forward(self, x):
101 | if self.data_format == "channels_last":
102 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
103 | elif self.data_format == "channels_first":
104 | u = x.mean(1, keepdim=True)
105 | s = (x - u).pow(2).mean(1, keepdim=True)
106 | x = (x - u) / torch.sqrt(s + self.eps)
107 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
108 | return x
109 |
110 | class ResBottleneckBlock(CNNBlockBase):
111 | """
112 | The standard bottleneck residual block without the last activation layer.
113 | It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
114 | """
115 |
116 | def __init__(
117 | self,
118 | in_channels,
119 | out_channels,
120 | bottleneck_channels,
121 | norm="LN",
122 | act_layer=nn.GELU,
123 | conv_kernels=3,
124 | conv_paddings=1,
125 | ):
126 | """
127 | Args:
128 | in_channels (int): Number of input channels.
129 | out_channels (int): Number of output channels.
130 | bottleneck_channels (int): number of output channels for the 3x3
131 | "bottleneck" conv layers.
132 | norm (str or callable): normalization for all conv layers.
133 | See :func:`layers.get_norm` for supported format.
134 | act_layer (callable): activation for all conv layers.
135 | """
136 | super().__init__(in_channels, out_channels, 1)
137 |
138 | self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
139 | self.norm1 = get_norm(norm, bottleneck_channels)
140 | self.act1 = act_layer()
141 |
142 | self.conv2 = Conv2d(
143 | bottleneck_channels,
144 | bottleneck_channels,
145 | conv_kernels,
146 | padding=conv_paddings,
147 | bias=False,
148 | )
149 | self.norm2 = get_norm(norm, bottleneck_channels)
150 | self.act2 = act_layer()
151 |
152 | self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
153 | self.norm3 = get_norm(norm, out_channels)
154 |
155 | for layer in [self.conv1, self.conv2, self.conv3]:
156 | weight_init.c2_msra_fill(layer)
157 | for layer in [self.norm1, self.norm2]:
158 | layer.weight.data.fill_(1.0)
159 | layer.bias.data.zero_()
160 | # zero init last norm layer.
161 | self.norm3.weight.data.zero_()
162 | self.norm3.bias.data.zero_()
163 |
164 | def forward(self, x):
165 | out = x
166 | for layer in self.children():
167 | out = layer(out)
168 |
169 | out = x + out
170 | return out
171 |
172 |
173 | class Block(nn.Module):
174 | """Transformer blocks with support of window attention and residual propagation blocks"""
175 |
176 | def __init__(
177 | self,
178 | dim,
179 | num_heads,
180 | mlp_ratio=4.0,
181 | qkv_bias=True,
182 | drop_path=0.0,
183 | norm_layer=nn.LayerNorm,
184 | act_layer=nn.GELU,
185 | use_rel_pos=False,
186 | rel_pos_zero_init=True,
187 | window_size=0,
188 | use_cc_attn = False,
189 | use_residual_block=False,
190 | use_convnext_block=False,
191 | input_size=None,
192 | res_conv_kernel_size=3,
193 | res_conv_padding=1,
194 | ):
195 | """
196 | Args:
197 | dim (int): Number of input channels.
198 | num_heads (int): Number of attention heads in each ViT block.
199 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
200 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
201 | drop_path (float): Stochastic depth rate.
202 | norm_layer (nn.Module): Normalization layer.
203 | act_layer (nn.Module): Activation layer.
204 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
205 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
206 | window_size (int): Window size for window attention blocks. If it equals 0, then not
207 | use window attention.
208 | use_residual_block (bool): If True, use a residual block after the MLP block.
209 | input_size (int or None): Input resolution for calculating the relative positional
210 | parameter size.
211 | """
212 | super().__init__()
213 | self.norm1 = norm_layer(dim)
214 | self.attn = Attention(
215 | dim,
216 | num_heads=num_heads,
217 | qkv_bias=qkv_bias,
218 | use_rel_pos=use_rel_pos,
219 | rel_pos_zero_init=rel_pos_zero_init,
220 | input_size=input_size if window_size == 0 else (window_size, window_size),
221 | )
222 |
223 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
224 | self.norm2 = norm_layer(dim)
225 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
226 |
227 | self.window_size = window_size
228 |
229 | self.use_residual_block = use_residual_block
230 | if use_residual_block:
231 | # Use a residual block with bottleneck channel as dim // 2
232 | self.residual = ResBottleneckBlock(
233 | in_channels=dim,
234 | out_channels=dim,
235 | bottleneck_channels=dim // 2,
236 | norm="LN",
237 | act_layer=act_layer,
238 | conv_kernels=res_conv_kernel_size,
239 | conv_paddings=res_conv_padding,
240 | )
241 | self.use_convnext_block = use_convnext_block
242 | if use_convnext_block:
243 | self.convnext = ConvNextBlock(dim = dim)
244 |
245 | if use_cc_attn:
246 | self.attn = CrissCrossAttention(dim)
247 |
248 |
249 | def forward(self, x):
250 | shortcut = x
251 | x = self.norm1(x)
252 | # Window partition
253 | if self.window_size > 0:
254 | H, W = x.shape[1], x.shape[2]
255 | x, pad_hw = window_partition(x, self.window_size)
256 |
257 | x = self.attn(x)
258 |
259 | # Reverse window partition
260 | if self.window_size > 0:
261 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
262 |
263 | x = shortcut + self.drop_path(x)
264 | x = x + self.drop_path(self.mlp(self.norm2(x)))
265 |
266 | if self.use_residual_block:
267 | x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
268 | if self.use_convnext_block:
269 | x = self.convnext(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
270 |
271 | return x
272 |
273 |
274 | class ViT_Teacher(Backbone):
275 | """
276 | This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
277 | "Exploring Plain Vision Transformer Backbones for Object Detection",
278 | https://arxiv.org/abs/2203.16527
279 | """
280 |
281 | def __init__(
282 | self,
283 | img_size=1024,
284 | patch_size=16,
285 | in_chans=3,
286 | embed_dim=768,
287 | depth=12,
288 | num_heads=12,
289 | mlp_ratio=4.0,
290 | qkv_bias=True,
291 | drop_path_rate=0.0,
292 | norm_layer=nn.LayerNorm,
293 | act_layer=nn.GELU,
294 | use_abs_pos=True,
295 | use_rel_pos=False,
296 | rel_pos_zero_init=True,
297 | window_size=0,
298 | window_block_indexes=(),
299 | residual_block_indexes=(),
300 | use_act_checkpoint=False,
301 | pretrain_img_size=224,
302 | pretrain_use_cls_token=True,
303 | out_feature="last_feat",
304 | res_conv_kernel_size=3,
305 | res_conv_padding=1,
306 | ):
307 | """
308 | Args:
309 | img_size (int): Input image size.
310 | patch_size (int): Patch size.
311 | in_chans (int): Number of input image channels.
312 | embed_dim (int): Patch embedding dimension.
313 | depth (int): Depth of ViT.
314 | num_heads (int): Number of attention heads in each ViT block.
315 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
316 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
317 | drop_path_rate (float): Stochastic depth rate.
318 | norm_layer (nn.Module): Normalization layer.
319 | act_layer (nn.Module): Activation layer.
320 | use_abs_pos (bool): If True, use absolute positional embeddings.
321 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
322 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
323 | window_size (int): Window size for window attention blocks.
324 | window_block_indexes (list): Indexes for blocks using window attention.
325 | residual_block_indexes (list): Indexes for blocks using conv propagation.
326 | use_act_checkpoint (bool): If True, use activation checkpointing.
327 | pretrain_img_size (int): input image size for pretraining models.
328 | pretrain_use_cls_token (bool): If True, pretrainig models use class token.
329 | out_feature (str): name of the feature from the last block.
330 | """
331 | super().__init__()
332 | self.pretrain_use_cls_token = pretrain_use_cls_token
333 |
334 | self.patch_embed = PatchEmbed(
335 | kernel_size=(patch_size, patch_size),
336 | stride=(patch_size, patch_size),
337 | in_chans=in_chans,
338 | embed_dim=embed_dim,
339 | )
340 |
341 | if use_abs_pos:
342 | # Initialize absolute positional embedding with pretrain image size.
343 | num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
344 | num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
345 | self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
346 | else:
347 | self.pos_embed = None
348 |
349 | # stochastic depth decay rule
350 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
351 |
352 | self.blocks = nn.ModuleList()
353 | for i in range(depth):
354 | block = Block(
355 | dim=embed_dim,
356 | num_heads=num_heads,
357 | mlp_ratio=mlp_ratio,
358 | qkv_bias=qkv_bias,
359 | drop_path=dpr[i],
360 | norm_layer=norm_layer,
361 | act_layer=act_layer,
362 | use_rel_pos=use_rel_pos,
363 | rel_pos_zero_init=rel_pos_zero_init,
364 | window_size=window_size if i in window_block_indexes else 0,
365 | use_residual_block=i in residual_block_indexes,
366 | input_size=(img_size // patch_size, img_size // patch_size),
367 | res_conv_kernel_size=res_conv_kernel_size,
368 | res_conv_padding=res_conv_padding,
369 | )
370 | if use_act_checkpoint:
371 | block = checkpoint_wrapper(block)
372 | self.blocks.append(block)
373 |
374 | self._out_feature_channels = {out_feature: embed_dim}
375 | self._out_feature_strides = {out_feature: patch_size}
376 | self._out_features = [out_feature]
377 |
378 | if self.pos_embed is not None:
379 | trunc_normal_(self.pos_embed, std=0.02)
380 |
381 | self.apply(self._init_weights)
382 |
383 | def _init_weights(self, m):
384 | if isinstance(m, nn.Linear):
385 | trunc_normal_(m.weight, std=0.02)
386 | if isinstance(m, nn.Linear) and m.bias is not None:
387 | nn.init.constant_(m.bias, 0)
388 | elif isinstance(m, nn.LayerNorm):
389 | nn.init.constant_(m.bias, 0)
390 | nn.init.constant_(m.weight, 1.0)
391 |
392 | def forward(self, x):
393 | x = self.patch_embed(x)
394 | if self.pos_embed is not None:
395 | x = x + get_abs_pos(
396 | self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
397 | )
398 |
399 |
400 | for blk in self.blocks: # torch.Size([14, 32, 32, 384])
401 | x = blk(x)
402 |
403 | outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
404 |
405 | return outputs['last_feat']
--------------------------------------------------------------------------------
/modeling/backbone/vit.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import fvcore.nn.weight_init as weight_init
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | from detectron2.layers import CNNBlockBase, Conv2d, get_norm
8 | from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
9 | from fairscale.nn.checkpoint import checkpoint_wrapper
10 | from timm.models.layers import DropPath, Mlp, trunc_normal_
11 | from functools import partial
12 | from .backbone import Backbone
13 | from .utils import (
14 | PatchEmbed,
15 | add_decomposed_rel_pos,
16 | get_abs_pos,
17 | window_partition,
18 | window_unpartition,
19 | Router
20 | )
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | __all__ = ["ViT"]
26 |
27 |
28 | class Attention(nn.Module):
29 | """Multi-head Attention block with relative position embeddings."""
30 |
31 | def __init__(
32 | self,
33 | dim,
34 | num_heads=8,
35 | qkv_bias=True,
36 | use_rel_pos=False,
37 | rel_pos_zero_init=True,
38 | input_size=None,
39 | ):
40 | """
41 | Args:
42 | dim (int): Number of input channels.
43 | num_heads (int): Number of attention heads.
44 | qkv_bias (bool: If True, add a learnable bias to query, key, value.
45 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
46 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
47 | input_size (int or None): Input resolution for calculating the relative positional
48 | parameter size.
49 | """
50 | super().__init__()
51 | self.num_heads = num_heads
52 | head_dim = dim // num_heads
53 | self.scale = head_dim**-0.5
54 |
55 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
56 | self.proj = nn.Linear(dim, dim)
57 |
58 | self.use_rel_pos = use_rel_pos
59 | if self.use_rel_pos:
60 | # initialize relative positional embeddings
61 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
62 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
63 |
64 | if not rel_pos_zero_init:
65 | trunc_normal_(self.rel_pos_h, std=0.02)
66 | trunc_normal_(self.rel_pos_w, std=0.02)
67 |
68 | def softmax_with_policy(self, attn, policy, eps=1e-6):
69 | B, N, _ = policy.size() # 128, 197, 1
70 | B, H, N, N = attn.size()
71 | attn_policy = policy.reshape(B, 1, 1, N) # * policy.reshape(B, 1, N, 1)
72 | eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N)
73 | attn_policy = attn_policy + (1.0 - attn_policy) * eye # 目的是将对角线上的token计算attention
74 | max_att = torch.max(attn, dim=-1, keepdim=True)[0]
75 | attn = attn - max_att
76 | # attn = attn.exp_() * attn_policy
77 | # return attn / attn.sum(dim=-1, keepdim=True)
78 |
79 | # for stable training
80 | attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32)
81 | attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps)
82 | return attn.type_as(max_att)
83 |
84 |
85 | def forward(self, x, policy, H, W): # total FLOPs: 4 * B * hw * dim * dim + 2 * B * hw * hw * dim
86 | if x.ndim == 4:
87 | B, H, W, _ = x.shape
88 | N = H*W
89 | else:
90 | B, N, _ = x.shape
91 |
92 |
93 | # qkv with shape (3, B, nHead, H * W, C) self.qkv.flops: b * hw * dim * dim * 3
94 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 6.341787648 reshape和permute没有FLOPs
95 | # q, k, v with shape (B * nHead, H * W, C)
96 | q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0)
97 |
98 | attn = (q * self.scale) @ k.transpose(-2, -1) # 5.637144576 (B * hw * hw * dim) 14 * 1024*1024*384
99 |
100 | if self.use_rel_pos:
101 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
102 |
103 | if policy is None:
104 | attn = attn.softmax(dim=-1)
105 | else:
106 | attn = self.softmax_with_policy(attn.reshape(B, self.num_heads, N, N), policy).reshape(B*self.num_heads, N, N)
107 |
108 | # # 5.637144576 (B * hw * hw * dim)
109 | if x.ndim == 4:
110 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
111 | else:
112 | x = (attn @ v).view(B, self.num_heads, N, -1).permute(0, 2, 1, 3).reshape(B, N, -1)
113 |
114 | x = self.proj(x) # 2.113929216 (B * hw * dim * dim) 14 * 1024 * 384 * 384
115 |
116 | return x
117 |
118 | class LayerNorm(nn.Module):
119 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
120 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
121 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
122 | with shape (batch_size, channels, height, width).
123 | """
124 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
125 | super().__init__()
126 | self.weight = nn.Parameter(torch.ones(normalized_shape))
127 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
128 | self.eps = eps
129 | self.data_format = data_format
130 | if self.data_format not in ["channels_last", "channels_first"]:
131 | raise NotImplementedError
132 | self.normalized_shape = (normalized_shape, )
133 |
134 | def forward(self, x):
135 | if self.data_format == "channels_last":
136 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
137 | elif self.data_format == "channels_first":
138 | u = x.mean(1, keepdim=True)
139 | s = (x - u).pow(2).mean(1, keepdim=True)
140 | x = (x - u) / torch.sqrt(s + self.eps)
141 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
142 | return x
143 |
144 | class ResBottleneckBlock(CNNBlockBase):
145 | """
146 | The standard bottleneck residual block without the last activation layer.
147 | It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
148 | """
149 |
150 | def __init__(
151 | self,
152 | in_channels,
153 | out_channels,
154 | bottleneck_channels,
155 | norm="LN",
156 | act_layer=nn.GELU,
157 | conv_kernels=3,
158 | conv_paddings=1,
159 | ):
160 | """
161 | Args:
162 | in_channels (int): Number of input channels.
163 | out_channels (int): Number of output channels.
164 | bottleneck_channels (int): number of output channels for the 3x3
165 | "bottleneck" conv layers.
166 | norm (str or callable): normalization for all conv layers.
167 | See :func:`layers.get_norm` for supported format.
168 | act_layer (callable): activation for all conv layers.
169 | """
170 | super().__init__(in_channels, out_channels, 1)
171 |
172 | self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
173 | self.norm1 = get_norm(norm, bottleneck_channels)
174 | self.act1 = act_layer()
175 |
176 | self.conv2 = Conv2d(
177 | bottleneck_channels,
178 | bottleneck_channels,
179 | conv_kernels,
180 | padding=conv_paddings,
181 | bias=False,
182 | )
183 | self.norm2 = get_norm(norm, bottleneck_channels)
184 | self.act2 = act_layer()
185 |
186 | self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
187 | self.norm3 = get_norm(norm, out_channels)
188 |
189 | for layer in [self.conv1, self.conv2, self.conv3]:
190 | weight_init.c2_msra_fill(layer)
191 | for layer in [self.norm1, self.norm2]:
192 | layer.weight.data.fill_(1.0)
193 | layer.bias.data.zero_()
194 | # zero init last norm layer.
195 | self.norm3.weight.data.zero_()
196 | self.norm3.bias.data.zero_()
197 |
198 | def forward(self, x):
199 | out = x
200 | for layer in self.children():
201 | out = layer(out)
202 |
203 | out = x + out
204 | return out
205 |
206 | class ECA(nn.Module):
207 | def __init__(self, channels, b=1, gamma=2):
208 | super(ECA, self).__init__()
209 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
210 | self.channels = channels
211 | self.b = b
212 | self.gamma = gamma
213 | self.conv = nn.Conv1d(
214 | 1,
215 | 1,
216 | kernel_size=self.kernel_size(),
217 | padding=(self.kernel_size() - 1) // 2,
218 | bias=False,
219 | )
220 | self.sigmoid = nn.Sigmoid()
221 |
222 | def kernel_size(self):
223 | k = int(abs((math.log2(self.channels) / self.gamma) + self.b / self.gamma))
224 | out = k if k % 2 else k + 1
225 | return out
226 |
227 | def forward(self, x):
228 |
229 | # feature descriptor on the global spatial information
230 | y = self.avg_pool(x)
231 |
232 | # Two different branches of ECA module
233 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
234 |
235 | # Multi-scale information fusion
236 | y = self.sigmoid(y)
237 |
238 | return x * y.expand_as(x)
239 |
240 |
241 |
242 | class LTRM(nn.Module):
243 | """Efficient Block to replace the Original Transformer Block"""
244 | def __init__(self, dim, expand_ratio = 2, kernel_size = 5):
245 | super(LTRM, self).__init__()
246 | self.fc1 = nn.Linear(dim, dim*expand_ratio)
247 | self.act1 = nn.GELU()
248 | self.dwconv = nn.Conv2d(dim*expand_ratio, dim*expand_ratio, kernel_size=(kernel_size, kernel_size), groups=dim*expand_ratio, padding=(kernel_size//2, kernel_size//2))
249 | self.act2 = nn.GELU()
250 | self.fc2 = nn.Linear(dim*expand_ratio, dim)
251 | self.eca = ECA(dim)
252 |
253 | def forward(self, x, prev_msa=None):
254 | x = self.act1(self.fc1(x))
255 | x = self.act2(self.dwconv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1))
256 | x = self.fc2(x)
257 | y = x + self.eca(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
258 | return y
259 |
260 |
261 |
262 |
263 |
264 | class EfficientBlock_ALR(nn.ModuleList): # from rething attention AAAI
265 | def __init__(self, model_dimension=128):
266 | super(EfficientBlock_ALR, self).__init__()
267 | # self.sentence_length=sentence_length
268 | self.model_dimension = model_dimension
269 | self.width = self.model_dimension
270 | self.layers=list()
271 | widths=[1,2,1]
272 | self.depth=len(widths)-1
273 | self.layers=nn.ModuleList()
274 | for i in range(self.depth):
275 | self.layers.extend([nn.LayerNorm(self.width * widths[i]),nn.Linear(self.width * widths[i], self.width * widths[i+1])])
276 | if(i 0.0 else nn.Identity()
338 | self.norm2 = norm_layer(dim)
339 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
340 |
341 | self.window_size = window_size
342 | self.use_efficient_block = use_efficient_block
343 | if use_efficient_block:
344 | self.efficient_block = LTRM(dim = dim)
345 |
346 | self.use_residual_block = use_residual_block
347 | if use_residual_block:
348 | # Use a residual block with bottleneck channel as dim // 2
349 | self.residual = ResBottleneckBlock(
350 | in_channels=dim,
351 | out_channels=dim,
352 | bottleneck_channels=dim // 2,
353 | norm="LN",
354 | act_layer=act_layer,
355 | conv_kernels=res_conv_kernel_size,
356 | conv_paddings=res_conv_padding,
357 | )
358 | self.use_convnext_block = use_convnext_block
359 | if use_convnext_block:
360 | self.convnext = ConvNextBlock(dim = dim)
361 |
362 | if use_cc_attn:
363 | self.attn = CrissCrossAttention(dim)
364 |
365 | def msa_forward(self, x, policy, H, W):
366 | # shortcut = x
367 | x = self.norm1(x) # 0.27525
368 | # Window partition
369 | if self.window_size > 0:
370 | H, W = x.shape[1], x.shape[2]
371 | x, pad_hw = window_partition(x, self.window_size)
372 |
373 | x = self.attn(x, policy, H = H, W = W)
374 |
375 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
376 | else:
377 | x = self.attn(x, policy, H = H, W = W)
378 |
379 |
380 | x = self.drop_path(x)
381 |
382 | return x
383 |
384 | def mlp_forward(self, x):
385 | return self.drop_path(self.mlp(self.norm2(x)))
386 |
387 | def forward(self, x, policy=None):
388 | B, H, W, C = x.shape
389 | N = H * W
390 | # B, N, C = x.shape
391 | shortcut = x
392 |
393 | if self.use_efficient_block:
394 | if self.training:
395 | fast_msa = self.efficient_block(x)
396 | slow_msa = self.msa_forward(x, policy, H, W).reshape(B, N, C)
397 | # slow_msa = slow_msa * (policy + (1. - policy).detach())
398 | # msa = torch.where(policy.bool(), slow_msa, fast_msa.reshape(B, N, C)).reshape(B, H, W, C)
399 | msa = (slow_msa * policy + fast_msa.reshape(B, -1, C) * (1. - policy)).reshape(B, H, W, C)
400 | else:
401 | msa = self.efficient_block(x)
402 | selected_indices = policy.squeeze(-1).bool()
403 | # if True in selected_indices:
404 | if torch.any(selected_indices == 1):
405 | selected_x = x.reshape(B, -1, C)[selected_indices].unsqueeze(0)
406 | slow_msa = self.msa_forward(selected_x, policy=None, H = H, W = W)
407 | msa.masked_scatter_(selected_indices.reshape(B, H, W, 1), slow_msa)
408 | else:
409 | msa = self.msa_forward(x, policy, H, W).reshape(B, H, W, C)
410 | x = shortcut + msa
411 | x = x + self.mlp_forward(x)
412 |
413 | if self.use_residual_block:
414 | x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
415 | if self.use_convnext_block:
416 | x = self.convnext(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
417 |
418 | return x
419 |
420 |
421 | class ViT(Backbone):
422 | """
423 | This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
424 | "Exploring Plain Vision Transformer Backbones for Object Detection",
425 | https://arxiv.org/abs/2203.16527
426 | """
427 |
428 | def __init__(
429 | self,
430 | img_size=1024,
431 | patch_size=16,
432 | in_chans=3,
433 | embed_dim=768,
434 | depth=12,
435 | num_heads=12,
436 | mlp_ratio=4.0,
437 | qkv_bias=True,
438 | drop_path_rate=0.0,
439 | norm_layer=nn.LayerNorm,
440 | act_layer=nn.GELU,
441 | use_abs_pos=True,
442 | use_rel_pos=False,
443 | rel_pos_zero_init=True,
444 | window_size=0,
445 | window_block_indexes=(),
446 | residual_block_indexes=(),
447 | use_act_checkpoint=False,
448 | pretrain_img_size=224,
449 | pretrain_use_cls_token=True,
450 | out_feature="last_feat",
451 | res_conv_kernel_size=3,
452 | res_conv_padding=1,
453 | topk = 1.,
454 | multi_score = False,
455 | router_module = "Ours",
456 | skip_all_block = False,
457 | max_number_token = 18500,
458 | ):
459 | """
460 | Args:
461 | img_size (int): Input image size.
462 | patch_size (int): Patch size.
463 | in_chans (int): Number of input image channels.
464 | embed_dim (int): Patch embedding dimension.
465 | depth (int): Depth of ViT.
466 | num_heads (int): Number of attention heads in each ViT block.
467 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
468 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
469 | drop_path_rate (float): Stochastic depth rate.
470 | norm_layer (nn.Module): Normalization layer.
471 | act_layer (nn.Module): Activation layer.
472 | use_abs_pos (bool): If True, use absolute positional embeddings.
473 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
474 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
475 | window_size (int): Window size for window attention blocks.
476 | window_block_indexes (list): Indexes for blocks using window attention.
477 | residual_block_indexes (list): Indexes for blocks using conv propagation.
478 | use_act_checkpoint (bool): If True, use activation checkpointing.
479 | pretrain_img_size (int): input image size for pretraining models.
480 | pretrain_use_cls_token (bool): If True, pretrainig models use class token.
481 | out_feature (str): name of the feature from the last block.
482 | """
483 | super().__init__()
484 |
485 | self.pretrain_use_cls_token = pretrain_use_cls_token
486 | self.topk = topk
487 | self.multi_score = multi_score
488 | self.skip_all_block = skip_all_block
489 | self.window_block_indexes = window_block_indexes
490 | self.max_number_token = max_number_token
491 | self.router_module = router_module
492 | self.embed_dim = embed_dim
493 | self.patch_embed = PatchEmbed(
494 | kernel_size=(patch_size, patch_size),
495 | stride=(patch_size, patch_size),
496 | in_chans=in_chans,
497 | embed_dim=embed_dim,
498 | )
499 |
500 | if use_abs_pos:
501 | # Initialize absolute positional embedding with pretrain image size.
502 | num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
503 | num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
504 | self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
505 | else:
506 | self.pos_embed = None
507 |
508 | # stochastic depth decay rule
509 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
510 |
511 | self.blocks = nn.ModuleList()
512 | for i in range(depth):
513 | block = Block(
514 | dim=embed_dim,
515 | num_heads=num_heads,
516 | mlp_ratio=mlp_ratio,
517 | qkv_bias=qkv_bias,
518 | drop_path=dpr[i],
519 | norm_layer=norm_layer,
520 | act_layer=act_layer,
521 | use_rel_pos=True if (i in window_block_indexes and use_rel_pos == True) else False,
522 | rel_pos_zero_init=rel_pos_zero_init,
523 | window_size=window_size if i in window_block_indexes else 0,
524 | use_residual_block=i in residual_block_indexes,
525 | input_size=(img_size // patch_size, img_size // patch_size),
526 | res_conv_kernel_size=res_conv_kernel_size,
527 | res_conv_padding=res_conv_padding,
528 | use_efficient_block=False if (i == 0 or i in window_block_indexes) else True,
529 | )
530 | if use_act_checkpoint:
531 | block = checkpoint_wrapper(block)
532 | self.blocks.append(block)
533 |
534 | self.routers = nn.ModuleList()
535 | for i in range(depth - len(window_block_indexes)):
536 | router = Router(embed_dim = embed_dim)
537 | self.routers.append(router)
538 |
539 |
540 | self._out_feature_channels = {out_feature: embed_dim}
541 | self._out_feature_strides = {out_feature: patch_size}
542 | self._out_features = [out_feature]
543 |
544 | if self.pos_embed is not None:
545 | trunc_normal_(self.pos_embed, std=0.02)
546 |
547 | self.apply(self._init_weights)
548 |
549 | def _init_weights(self, m):
550 | if isinstance(m, nn.Linear):
551 | trunc_normal_(m.weight, std=0.02)
552 | if isinstance(m, nn.Linear) and m.bias is not None:
553 | nn.init.constant_(m.bias, 0)
554 | elif isinstance(m, nn.LayerNorm):
555 | nn.init.constant_(m.bias, 0)
556 | nn.init.constant_(m.weight, 1.0)
557 |
558 | def forward(self, x):
559 | x = self.patch_embed(x)
560 | if self.pos_embed is not None:
561 | x = x + get_abs_pos(
562 | self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
563 | )
564 |
565 | B, H, W, C = x.shape
566 | N = H * W
567 | prev_decision = torch.ones(B, N, 1, dtype = x.dtype, device = x.device)
568 | out_pred_prob = []
569 | out_hard_keep_decision = []
570 | count = 0
571 |
572 | for i, blk in enumerate(self.blocks): # x.shape = torch.Size([16, 32, 32, 384])
573 | if i in self.window_block_indexes or i == 0:
574 | x = blk(x, policy = None)
575 | continue
576 | pred_score = self.routers[count](x.reshape(B, -1, C), prev_decision).reshape(B, -1, 2) # B, N, 2
577 | count = count + 1
578 | if self.training:
579 | hard_keep_decision = F.gumbel_softmax(pred_score, hard = True)[:, :, 0:1]
580 | out_pred_prob.append(hard_keep_decision.reshape(B, H*W))
581 | x = blk(x, policy = hard_keep_decision)
582 | else:
583 | """gumbel_softmax"""
584 | # hard_keep_decision = F.gumbel_softmax(pred_score, hard = True)[:, :, 0:1] # torch.Size([1, N, 1])
585 |
586 | """argmax"""
587 | hard_keep_decision = torch.zeros_like(pred_score[..., :1])
588 | hard_keep_decision[pred_score[..., 0] > pred_score[..., 1]] = 1.
589 |
590 | if hard_keep_decision.sum().item() > self.max_number_token:
591 | print("=================================================================")
592 | print(f"Decrease the number of tokens in Global attention: {self.max_number_token}")
593 | print("=================================================================")
594 | _, sort_index = torch.sort((pred_score[..., 0] - pred_score[..., 1]), descending=True)
595 | sort_index = sort_index[:, :self.max_number_token]
596 | hard_keep_decision = torch.zeros_like(pred_score[..., :1])
597 | hard_keep_decision[:, sort_index.squeeze(0), :] = 1.
598 |
599 | # threshold = torch.quantile(pred_score[..., :1], 0.85)
600 | # hard_keep_decision = (pred_score[..., :1] >= threshold).float()
601 |
602 | # out_pred_prob.append(pred_score[..., 0].reshape(B, x.shape[1], x.shape[2]))
603 |
604 | out_pred_prob.append(pred_score.reshape(B, x.shape[1], x.shape[2], 2))
605 | out_hard_keep_decision.append(hard_keep_decision.reshape(B, x.shape[1], x.shape[2], 1))
606 | x = blk(x, policy = hard_keep_decision)
607 |
608 | outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
609 |
610 | if self.training:
611 | return outputs['last_feat'], out_pred_prob
612 | else:
613 | return outputs['last_feat'], out_pred_prob, out_hard_keep_decision
614 |
615 |
616 | if __name__ == '__main__':
617 | model = ViT(
618 | in_chans=4,
619 | img_size=512,
620 | patch_size=16,
621 | embed_dim=384,
622 | depth=12,
623 | num_heads=6,
624 | drop_path_rate=0,
625 | window_size=0,
626 | mlp_ratio=4,
627 | qkv_bias=True,
628 | norm_layer = partial(nn.LayerNorm, eps=1e-6),
629 | window_block_indexes=[
630 | # 2, 5, 8 11 for global attention
631 | # 0,
632 | # 1,
633 | # 3,
634 | # 4,
635 | # 6,
636 | # 7,
637 | # 9,
638 | # 10,
639 | ],
640 | residual_block_indexes=[2, 5, 8, 11],
641 | use_rel_pos=True,
642 | out_feature="last_feat"
643 | )
644 | print(model)
645 |
646 | out, prob = model(torch.ones(2, 4, 512, 512)) # sum([p.sum() / p.numel() for p in prob]) / len(prob)
647 | print(out.shape)
--------------------------------------------------------------------------------
/data/dim_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Dataloader to process Adobe Image Matting Dataset.
3 |
4 | From GCA_Matting(https://github.com/Yaoyi-Li/GCA-Matting/tree/master/dataloader)
5 | '''
6 | import os
7 | import glob
8 | import logging
9 | import os.path as osp
10 | import functools
11 | import numpy as np
12 | import torch
13 | import cv2
14 | import math
15 | import numbers
16 | import random
17 | import pickle
18 | from torch.utils.data import Dataset, DataLoader
19 | from torch.nn import functional as F
20 | from torchvision import transforms
21 | from easydict import EasyDict
22 | from detectron2.utils.logger import setup_logger
23 | from detectron2.utils import comm
24 |
25 | # Base default config
26 | CONFIG = EasyDict({})
27 |
28 | # Model config
29 | CONFIG.model = EasyDict({})
30 | # one-hot or class, choice: [3, 1]
31 | CONFIG.model.trimap_channel = 1
32 |
33 | # Dataloader config
34 | CONFIG.data = EasyDict({})
35 | # feed forward image size (untested)
36 | CONFIG.data.crop_size = 512
37 | # composition of two foregrounds, affine transform, crop and HSV jitter
38 | CONFIG.data.cutmask_prob = 0.25
39 | CONFIG.data.augmentation = True
40 | CONFIG.data.random_interp = True
41 |
42 | class Prefetcher():
43 | """
44 | Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
45 | """
46 | def __init__(self, loader):
47 | self.orig_loader = loader
48 | self.stream = torch.cuda.Stream()
49 | self.next_sample = None
50 |
51 | def preload(self):
52 | try:
53 | self.next_sample = next(self.loader)
54 | except StopIteration:
55 | self.next_sample = None
56 | return
57 |
58 | with torch.cuda.stream(self.stream):
59 | for key, value in self.next_sample.items():
60 | if isinstance(value, torch.Tensor):
61 | self.next_sample[key] = value.cuda(non_blocking=True)
62 |
63 | def __next__(self):
64 | torch.cuda.current_stream().wait_stream(self.stream)
65 | sample = self.next_sample
66 | if sample is not None:
67 | for key, value in sample.items():
68 | if isinstance(value, torch.Tensor):
69 | sample[key].record_stream(torch.cuda.current_stream())
70 | self.preload()
71 | else:
72 | # throw stop exception if there is no more data to perform as a default dataloader
73 | raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); "
74 | "data = next(iterator)`")
75 | return sample
76 |
77 | def __iter__(self):
78 | self.loader = iter(self.orig_loader)
79 | self.preload()
80 | return self
81 |
82 |
83 | class ImageFile(object):
84 | def __init__(self, phase='train'):
85 | self.phase = phase
86 | self.rng = np.random.RandomState(0)
87 |
88 | def _get_valid_names(self, *dirs, shuffle=True):
89 | name_sets = [self._get_name_set(d) for d in dirs]
90 |
91 | def _join_and(a, b):
92 | return a & b
93 |
94 | valid_names = list(functools.reduce(_join_and, name_sets))
95 | if shuffle:
96 | self.rng.shuffle(valid_names)
97 |
98 | return valid_names
99 |
100 | @staticmethod
101 | def _get_name_set(dir_name):
102 | path_list = glob.glob(os.path.join(dir_name, '*'))
103 | name_set = set()
104 | for path in path_list:
105 | name = os.path.basename(path)
106 | name = os.path.splitext(name)[0]
107 | name_set.add(name)
108 | return name_set
109 |
110 | @staticmethod
111 | def _list_abspath(data_dir, ext, data_list):
112 | return [os.path.join(data_dir, name + ext)
113 | for name in data_list]
114 |
115 | class ImageFileTrain(ImageFile):
116 | def __init__(self,
117 | alpha_dir="train_alpha",
118 | fg_dir="train_fg",
119 | bg_dir="train_bg",
120 | alpha_ext=".jpg",
121 | fg_ext=".jpg",
122 | bg_ext=".jpg",
123 | root='',
124 | ):
125 | super(ImageFileTrain, self).__init__(phase="train")
126 |
127 | self.alpha_dir = alpha_dir
128 | self.fg_dir = fg_dir
129 | self.bg_dir = bg_dir
130 | self.alpha_ext = alpha_ext
131 | self.fg_ext = fg_ext
132 | self.bg_ext = bg_ext
133 | logger = setup_logger(name=__name__)
134 |
135 | self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir)
136 | self.valid_bg_list = [os.path.splitext(name)[0] for name in os.listdir(self.bg_dir)]
137 |
138 | self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_fg_list)
139 | self.fg = self._list_abspath(self.fg_dir, self.fg_ext, self.valid_fg_list)
140 | self.bg = self._list_abspath(self.bg_dir, self.bg_ext, self.valid_bg_list)
141 |
142 |
143 |
144 | def __len__(self):
145 | return len(self.alpha)
146 |
147 |
148 | class ImageFileTest(ImageFile):
149 | def __init__(self,
150 | alpha_dir="test_alpha",
151 | merged_dir="test_merged",
152 | trimap_dir="test_trimap",
153 | alpha_ext=".png",
154 | merged_ext=".png",
155 | trimap_ext=".png"):
156 | super(ImageFileTest, self).__init__(phase="test")
157 |
158 | self.alpha_dir = alpha_dir
159 | self.merged_dir = merged_dir
160 | self.trimap_dir = trimap_dir
161 | self.alpha_ext = alpha_ext
162 | self.merged_ext = merged_ext
163 | self.trimap_ext = trimap_ext
164 |
165 | self.valid_image_list = self._get_valid_names(self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False)
166 |
167 | self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_image_list)
168 | self.merged = self._list_abspath(self.merged_dir, self.merged_ext, self.valid_image_list)
169 | self.trimap = self._list_abspath(self.trimap_dir, self.trimap_ext, self.valid_image_list)
170 |
171 | def __len__(self):
172 | return len(self.alpha)
173 |
174 | interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
175 |
176 |
177 | def maybe_random_interp(cv2_interp):
178 | if CONFIG.data.random_interp:
179 | return np.random.choice(interp_list)
180 | else:
181 | return cv2_interp
182 |
183 |
184 | class ToTensor(object):
185 | """
186 | Convert ndarrays in sample to Tensors with normalization.
187 | """
188 | def __init__(self, phase="test"):
189 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
190 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
191 | self.phase = phase
192 |
193 | def __call__(self, sample):
194 | image, alpha, trimap, mask = sample['image'][:,:,::-1], sample['alpha'], sample['trimap'], sample['mask']
195 |
196 | alpha[alpha < 0 ] = 0
197 | alpha[alpha > 1] = 1
198 |
199 | image = image.transpose((2, 0, 1)).astype(np.float32)
200 | alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
201 |
202 | mask = np.expand_dims(mask.astype(np.float32), axis=0)
203 |
204 | image /= 255.
205 |
206 | if self.phase == "train":
207 | fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
208 | sample['fg'] = torch.from_numpy(fg)
209 | bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
210 | sample['bg'] = torch.from_numpy(bg)
211 |
212 | sample['image'], sample['alpha'], sample['trimap'] = \
213 | torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long)
214 | sample['image'] = sample['image']
215 |
216 | if CONFIG.model.trimap_channel == 3:
217 | sample['trimap'] = F.one_hot(sample['trimap'], num_classes=3).permute(2,0,1).float()
218 | elif CONFIG.model.trimap_channel == 1:
219 | sample['trimap'] = sample['trimap'][None,...].float()
220 | else:
221 | raise NotImplementedError("CONFIG.model.trimap_channel can only be 3 or 1")
222 |
223 | sample['mask'] = torch.from_numpy(mask).float()
224 |
225 | return sample
226 |
227 |
228 | class RandomAffine(object):
229 | """
230 | Random affine translation
231 | """
232 | def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
233 | if isinstance(degrees, numbers.Number):
234 | if degrees < 0:
235 | raise ValueError("If degrees is a single number, it must be positive.")
236 | self.degrees = (-degrees, degrees)
237 | else:
238 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
239 | "degrees should be a list or tuple and it must be of length 2."
240 | self.degrees = degrees
241 |
242 | if translate is not None:
243 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
244 | "translate should be a list or tuple and it must be of length 2."
245 | for t in translate:
246 | if not (0.0 <= t <= 1.0):
247 | raise ValueError("translation values should be between 0 and 1")
248 | self.translate = translate
249 |
250 | if scale is not None:
251 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
252 | "scale should be a list or tuple and it must be of length 2."
253 | for s in scale:
254 | if s <= 0:
255 | raise ValueError("scale values should be positive")
256 | self.scale = scale
257 |
258 | if shear is not None:
259 | if isinstance(shear, numbers.Number):
260 | if shear < 0:
261 | raise ValueError("If shear is a single number, it must be positive.")
262 | self.shear = (-shear, shear)
263 | else:
264 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
265 | "shear should be a list or tuple and it must be of length 2."
266 | self.shear = shear
267 | else:
268 | self.shear = shear
269 |
270 | self.resample = resample
271 | self.fillcolor = fillcolor
272 | self.flip = flip
273 |
274 | @staticmethod
275 | def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
276 | """Get parameters for affine transformation
277 |
278 | Returns:
279 | sequence: params to be passed to the affine transformation
280 | """
281 | angle = random.uniform(degrees[0], degrees[1])
282 | if translate is not None:
283 | max_dx = translate[0] * img_size[0]
284 | max_dy = translate[1] * img_size[1]
285 | translations = (np.round(random.uniform(-max_dx, max_dx)),
286 | np.round(random.uniform(-max_dy, max_dy)))
287 | else:
288 | translations = (0, 0)
289 |
290 | if scale_ranges is not None:
291 | scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
292 | random.uniform(scale_ranges[0], scale_ranges[1]))
293 | else:
294 | scale = (1.0, 1.0)
295 |
296 | if shears is not None:
297 | shear = random.uniform(shears[0], shears[1])
298 | else:
299 | shear = 0.0
300 |
301 | if flip is not None:
302 | flip = (np.random.rand(2) < flip).astype(int) * 2 - 1
303 |
304 | return angle, translations, scale, shear, flip
305 |
306 | def __call__(self, sample):
307 | fg, alpha = sample['fg'], sample['alpha']
308 | rows, cols, ch = fg.shape
309 | if np.maximum(rows, cols) < 1024:
310 | params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
311 | else:
312 | params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
313 |
314 | center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
315 | M = self._get_inverse_affine_matrix(center, *params)
316 | M = np.array(M).reshape((2, 3))
317 |
318 | fg = cv2.warpAffine(fg, M, (cols, rows),
319 | flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
320 | alpha = cv2.warpAffine(alpha, M, (cols, rows),
321 | flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
322 |
323 | sample['fg'], sample['alpha'] = fg, alpha
324 |
325 | return sample
326 |
327 |
328 | @ staticmethod
329 | def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
330 |
331 | angle = math.radians(angle)
332 | shear = math.radians(shear)
333 | scale_x = 1.0 / scale[0] * flip[0]
334 | scale_y = 1.0 / scale[1] * flip[1]
335 |
336 | # Inverted rotation matrix with scale and shear
337 | d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
338 | matrix = [
339 | math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
340 | -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
341 | ]
342 | matrix = [m / d for m in matrix]
343 |
344 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
345 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
346 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
347 |
348 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1
349 | matrix[2] += center[0]
350 | matrix[5] += center[1]
351 |
352 | return matrix
353 |
354 |
355 | class RandomJitter(object):
356 | """
357 | Random change the hue of the image
358 | """
359 |
360 | def __call__(self, sample):
361 | sample_ori = sample.copy()
362 | fg, alpha = sample['fg'], sample['alpha']
363 | # if alpha is all 0 skip
364 | if np.all(alpha==0):
365 | return sample_ori
366 | # convert to HSV space, convert to float32 image to keep precision during space conversion.
367 | fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
368 | # Hue noise
369 | hue_jitter = np.random.randint(-40, 40)
370 | fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360)
371 | # Saturation noise
372 | sat_bar = fg[:, :, 1][alpha > 0].mean()
373 | if np.isnan(sat_bar):
374 | return sample_ori
375 | sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
376 | sat = fg[:, :, 1]
377 | sat = np.abs(sat + sat_jitter)
378 | sat[sat>1] = 2 - sat[sat>1]
379 | fg[:, :, 1] = sat
380 | # Value noise
381 | val_bar = fg[:, :, 2][alpha > 0].mean()
382 | if np.isnan(val_bar):
383 | return sample_ori
384 | val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
385 | val = fg[:, :, 2]
386 | val = np.abs(val + val_jitter)
387 | val[val>1] = 2 - val[val>1]
388 | fg[:, :, 2] = val
389 | # convert back to BGR space
390 | fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR)
391 | sample['fg'] = fg*255
392 |
393 | return sample
394 |
395 |
396 | class RandomHorizontalFlip(object):
397 | """
398 | Random flip image and label horizontally
399 | """
400 | def __init__(self, prob=0.5):
401 | self.prob = prob
402 | def __call__(self, sample):
403 | fg, alpha = sample['fg'], sample['alpha']
404 | if np.random.uniform(0, 1) < self.prob:
405 | fg = cv2.flip(fg, 1)
406 | alpha = cv2.flip(alpha, 1)
407 | sample['fg'], sample['alpha'] = fg, alpha
408 |
409 | return sample
410 |
411 |
412 | class RandomCrop(object):
413 | """
414 | Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
415 |
416 | :param output_size (tuple or int): Desired output size. If int, square crop
417 | is made.
418 | """
419 |
420 | def __init__(self, output_size=( CONFIG.data.crop_size, CONFIG.data.crop_size)):
421 | assert isinstance(output_size, (int, tuple))
422 | if isinstance(output_size, int):
423 | self.output_size = (output_size, output_size)
424 | else:
425 | assert len(output_size) == 2
426 | self.output_size = output_size
427 | self.margin = output_size[0] // 2
428 | self.logger = logging.getLogger("Logger")
429 |
430 | def __call__(self, sample):
431 | fg, alpha, trimap, mask, name = sample['fg'], sample['alpha'], sample['trimap'], sample['mask'], sample['image_name']
432 | bg = sample['bg']
433 | h, w = trimap.shape
434 | bg = cv2.resize(bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
435 | if w < self.output_size[0]+1 or h < self.output_size[1]+1:
436 | ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w
437 | # self.logger.warning("Size of {} is {}.".format(name, (h, w)))
438 | while h < self.output_size[0]+1 or w < self.output_size[1]+1:
439 | fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
440 | alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
441 | interpolation=maybe_random_interp(cv2.INTER_NEAREST))
442 | trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
443 | bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
444 | mask = cv2.resize(mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
445 | h, w = trimap.shape
446 | small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
447 | unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
448 | self.margin//4:(w-self.margin)//4] == 128)))
449 | unknown_num = len(unknown_list)
450 | if len(unknown_list) < 10:
451 | left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
452 | else:
453 | idx = np.random.randint(unknown_num)
454 | left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
455 |
456 | fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
457 | alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
458 | bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
459 | trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
460 | mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
461 |
462 | if len(np.where(trimap==128)[0]) == 0:
463 | self.logger.error("{} does not have enough unknown area for crop. Resized to target size."
464 | "left_top: {}".format(name, left_top))
465 | fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
466 | alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
467 | trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
468 | bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC))
469 | mask_crop = cv2.resize(mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
470 |
471 | sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop})
472 | return sample
473 |
474 |
475 | class OriginScale(object):
476 | def __call__(self, sample):
477 | h, w = sample["alpha_shape"]
478 |
479 | if h % 32 == 0 and w % 32 == 0:
480 | return sample
481 |
482 | target_h = 32 * ((h - 1) // 32 + 1)
483 | target_w = 32 * ((w - 1) // 32 + 1)
484 | pad_h = target_h - h
485 | pad_w = target_w - w
486 |
487 | padded_image = np.pad(sample['image'], ((0,pad_h), (0, pad_w), (0,0)), mode="reflect")
488 | padded_trimap = np.pad(sample['trimap'], ((0,pad_h), (0, pad_w)), mode="reflect")
489 | padded_mask = np.pad(sample['mask'], ((0,pad_h), (0, pad_w)), mode="reflect")
490 |
491 | sample['image'] = padded_image
492 | sample['trimap'] = padded_trimap
493 | sample['mask'] = padded_mask
494 |
495 | return sample
496 |
497 |
498 | class GenMask(object):
499 | def __init__(self):
500 | self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)]
501 |
502 | def __call__(self, sample):
503 | alpha_ori = sample['alpha']
504 | h, w = alpha_ori.shape
505 |
506 | max_kernel_size = 30
507 | alpha = cv2.resize(alpha_ori, (640,640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
508 |
509 | ### generate trimap
510 | fg_mask = (alpha + 1e-5).astype(int).astype(np.uint8)
511 | bg_mask = (1 - alpha + 1e-5).astype(int).astype(np.uint8)
512 | fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
513 | bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
514 |
515 | fg_width = np.random.randint(1, 30)
516 | bg_width = np.random.randint(1, 30)
517 | fg_mask = (alpha + 1e-5).astype(int).astype(np.uint8)
518 | bg_mask = (1 - alpha + 1e-5).astype(int).astype(np.uint8)
519 | fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width])
520 | bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width])
521 |
522 | trimap = np.ones_like(alpha) * 128
523 | trimap[fg_mask == 1] = 255
524 | trimap[bg_mask == 1] = 0
525 |
526 | trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
527 | sample['trimap'] = trimap
528 |
529 | ### generate mask
530 | low = 0.01
531 | high = 1.0
532 | thres = random.random() * (high - low) + low
533 | seg_mask = (alpha >= thres).astype(int).astype(np.uint8)
534 | random_num = random.randint(0,3)
535 | if random_num == 0:
536 | seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
537 | elif random_num == 1:
538 | seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
539 | elif random_num == 2:
540 | seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
541 | seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
542 | elif random_num == 3:
543 | seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
544 | seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
545 |
546 | seg_mask = cv2.resize(seg_mask, (w,h), interpolation=cv2.INTER_NEAREST)
547 | sample['mask'] = seg_mask
548 |
549 | return sample
550 |
551 |
552 | class Composite(object):
553 | def __call__(self, sample):
554 | fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
555 | alpha[alpha < 0 ] = 0
556 | alpha[alpha > 1] = 1
557 | fg[fg < 0 ] = 0
558 | fg[fg > 255] = 255
559 | bg[bg < 0 ] = 0
560 | bg[bg > 255] = 255
561 |
562 | image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None])
563 | sample['image'] = image
564 | return sample
565 |
566 |
567 | class CutMask(object):
568 | def __init__(self, perturb_prob = 0):
569 | self.perturb_prob = perturb_prob
570 |
571 | def __call__(self, sample):
572 | if np.random.rand() < self.perturb_prob:
573 | return sample
574 |
575 | mask = sample['mask'] # H x W, trimap 0--255, segmask 0--1, alpha 0--1
576 | h, w = mask.shape
577 | perturb_size_h, perturb_size_w = random.randint(h // 4, h // 2), random.randint(w // 4, w // 2)
578 | x = random.randint(0, h - perturb_size_h)
579 | y = random.randint(0, w - perturb_size_w)
580 | x1 = random.randint(0, h - perturb_size_h)
581 | y1 = random.randint(0, w - perturb_size_w)
582 |
583 | mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1+perturb_size_h, y1:y1+perturb_size_w].copy()
584 |
585 | sample['mask'] = mask
586 | return sample
587 |
588 |
589 | class DataGenerator(Dataset):
590 | def __init__(self, data, phase="train"):
591 | self.phase = phase
592 | self.crop_size = CONFIG.data.crop_size
593 | self.alpha = data.alpha
594 |
595 | if self.phase == "train":
596 | self.fg = data.fg
597 | self.bg = data.bg
598 | self.merged = []
599 | self.trimap = []
600 |
601 | else:
602 | self.fg = []
603 | self.bg = []
604 | self.merged = data.merged
605 | self.trimap = data.trimap
606 |
607 | train_trans = [
608 | RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
609 | GenMask(),
610 | CutMask(perturb_prob=CONFIG.data.cutmask_prob),
611 | RandomCrop((self.crop_size, self.crop_size)),
612 | RandomJitter(),
613 | Composite(),
614 | ToTensor(phase="train") ]
615 |
616 | test_trans = [ OriginScale(), ToTensor() ]
617 |
618 | self.transform = {
619 | 'train':
620 | transforms.Compose(train_trans),
621 | 'val':
622 | transforms.Compose([
623 | OriginScale(),
624 | ToTensor()
625 | ]),
626 | 'test':
627 | transforms.Compose(test_trans)
628 | }[phase]
629 |
630 | self.fg_num = len(self.fg)
631 |
632 | def __getitem__(self, idx):
633 | if self.phase == "train":
634 | fg = cv2.imread(self.fg[idx % self.fg_num])
635 | alpha = cv2.imread(self.alpha[idx % self.fg_num], 0).astype(np.float32)/255
636 | bg = cv2.imread(self.bg[idx], 1)
637 |
638 | fg, alpha = self._composite_fg(fg, alpha, idx)
639 |
640 | image_name = os.path.split(self.fg[idx % self.fg_num])[-1]
641 | sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name}
642 |
643 | else:
644 | image = cv2.imread(self.merged[idx])
645 | alpha = cv2.imread(self.alpha[idx], 0)/255.
646 | trimap = cv2.imread(self.trimap[idx], 0)
647 | mask = (trimap >= 170).astype(np.float32)
648 | image_name = os.path.split(self.merged[idx])[-1]
649 |
650 | sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape}
651 |
652 | sample = self.transform(sample)
653 |
654 | return sample
655 |
656 | def _composite_fg(self, fg, alpha, idx):
657 |
658 | if np.random.rand() < 0.5:
659 | idx2 = np.random.randint(self.fg_num) + idx
660 | fg2 = cv2.imread(self.fg[idx2 % self.fg_num])
661 | alpha2 = cv2.imread(self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255.
662 | h, w = alpha.shape
663 | fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
664 | alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
665 |
666 | alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
667 | if np.any(alpha_tmp < 1):
668 | fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
669 | # The overlap of two 50% transparency should be 25%
670 | alpha = alpha_tmp
671 | fg = fg.astype(np.uint8)
672 |
673 | if np.random.rand() < 0.25:
674 | fg = cv2.resize(fg, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
675 | alpha = cv2.resize(alpha, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
676 |
677 | return fg, alpha
678 |
679 | def __len__(self):
680 | if self.phase == "train":
681 | return len(self.bg)
682 | else:
683 | return len(self.alpha)
--------------------------------------------------------------------------------