├── data ├── __init__.py └── dim_dataset.py ├── modeling ├── meta_arch │ ├── __init__.py │ └── mematte.py ├── decoder │ ├── __init__.py │ └── detail_capture.py ├── criterion │ ├── __init__.py │ └── matting_criterion.py ├── backbone │ ├── __init__.py │ ├── backbone.py │ ├── utils.py │ ├── vit_teacher.py │ └── vit.py └── __init__.py ├── engine ├── __init__.py └── mattingtrainer.py ├── requirements.txt ├── configs ├── common │ ├── scheduler.py │ ├── train.py │ ├── dataloader.py │ ├── optimizer.py │ └── model.py ├── MEMatte_S_topk0.25_win_global_long.py └── MEMatte_B_topk0.25_win_global_long.py ├── pretrained └── preprocess.py ├── .gitignore ├── utils └── logger.py ├── inference.py ├── main.py └── README.md /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dim_dataset import * -------------------------------------------------------------------------------- /modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .mematte import MEMatte -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .mattingtrainer import MattingTrainer 2 | -------------------------------------------------------------------------------- /modeling/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .detail_capture import Detail_Capture -------------------------------------------------------------------------------- /modeling/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | from .matting_criterion import MattingCriterion -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import * 2 | from .vit import * 3 | from .vit_teacher import ViT_Teacher -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import * 2 | from .criterion import * 3 | from .decoder import * 4 | from .meta_arch import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision 3 | tensorboard 4 | timm==0.5.4 5 | opencv-python==4.5.3.56 6 | setuptools==58.2.0 7 | easydict 8 | wget 9 | scikit-image 10 | fairscale -------------------------------------------------------------------------------- /configs/common/scheduler.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import LazyCall as L 2 | from detectron2.solver import WarmupParamScheduler 3 | from fvcore.common.param_scheduler import MultiStepParamScheduler 4 | 5 | lr_multiplier = L(WarmupParamScheduler)( 6 | scheduler=L(MultiStepParamScheduler)( 7 | values=[1.0, 0.1, 0.01], 8 | milestones=[96778, 103579], 9 | num_updates=100, 10 | ), 11 | warmup_length=250 / 100, 12 | warmup_factor=0.001, 13 | ) -------------------------------------------------------------------------------- /configs/common/train.py: -------------------------------------------------------------------------------- 1 | train = dict( 2 | output_dir="./output", 3 | init_checkpoint="", 4 | max_iter=90000, 5 | amp=dict(enabled=False), # options for Automatic Mixed Precision 6 | ddp=dict( # options for DistributedDataParallel 7 | broadcast_buffers=True, 8 | find_unused_parameters=False, 9 | fp16_compression=True, 10 | ), 11 | checkpointer=dict(period=5000, max_to_keep=100), # options for PeriodicCheckpointer 12 | eval_period=5000, 13 | # log_period=20, 14 | log_loss_period = 10, 15 | log_image_period = 2000, 16 | device="cuda" 17 | # ... 18 | ) -------------------------------------------------------------------------------- /pretrained/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def add_teacher(model, name): 3 | new_model = {} 4 | for k in model.keys(): 5 | new_model[k] = model[k] 6 | if 'backbone' in k: 7 | new_model['teacher_'+k] = model[k] 8 | print(f'teacher_{k}') 9 | 10 | torch.save(new_model, name + '.pth') 11 | 12 | if __name__ == "__main__": 13 | # Downloading the official checkpoint of ViTMatte, and then process the checkpoint with the script. 14 | # ViTMatte_S_Com.pth: https://drive.google.com/file/d/12VKhSwE_miF9lWQQCgK7mv83rJIls3Xe/view 15 | # ViTMatte_B_Com.pth: https://drive.google.com/file/d/1mOO5MMU4kwhNX96AlfpwjAoMM4V5w3k-/view?pli=1 16 | teacher_model = torch.load('ViTMatte_S_Com.pth')['model'] 17 | add_teacher(teacher_model, 'ViTMatte_S_Com_with_teacher') -------------------------------------------------------------------------------- /configs/common/dataloader.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from torch.utils.data import DataLoader 3 | from detectron2.config import LazyCall as L 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | from data import ImageFileTrain, DataGenerator 7 | 8 | #Dataloader 9 | train_dataset = L(DataGenerator)( 10 | data = L(ImageFileTrain)( 11 | alpha_dir='/opt/data/private/lyh/Datasets/AdobeImageMatting/Train/alpha', 12 | fg_dir='/opt/data/private/lyh/Datasets/AdobeImageMatting/Train/fg', 13 | bg_dir='/opt/data/private/lyh/Datasets/coco2014/raw/train2014', 14 | root='/opt/data/private/lyh/Datasets/AdobeImageMatting' 15 | ), 16 | phase = 'train' 17 | ) 18 | # 19 | dataloader = OmegaConf.create() 20 | dataloader.train = L(DataLoader)( 21 | dataset = train_dataset, 22 | batch_size=15, 23 | shuffle=False, 24 | num_workers=4, 25 | pin_memory=True, 26 | sampler=L(DistributedSampler)( 27 | dataset = train_dataset, 28 | ), 29 | drop_last=True 30 | ) -------------------------------------------------------------------------------- /configs/common/optimizer.py: -------------------------------------------------------------------------------- 1 | from detectron2 import model_zoo 2 | from functools import partial 3 | 4 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): 5 | """ 6 | Calculate lr decay rate for different ViT blocks. 7 | Args: 8 | name (string): parameter name. 9 | lr_decay_rate (float): base lr decay rate. 10 | num_layers (int): number of ViT blocks. 11 | 12 | Returns: 13 | lr decay rate for the given parameter. 14 | """ 15 | layer_id = num_layers + 1 16 | if name.startswith("backbone"): 17 | if ".pos_embed" in name or ".patch_embed" in name: 18 | layer_id = 0 19 | elif ".blocks." in name and ".residual." not in name: 20 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 21 | return lr_decay_rate ** (num_layers + 1 - layer_id) 22 | 23 | # Optimizer 24 | optimizer = model_zoo.get_config("common/optim.py").AdamW 25 | optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.65) 26 | optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}} -------------------------------------------------------------------------------- /configs/MEMatte_S_topk0.25_win_global_long.py: -------------------------------------------------------------------------------- 1 | from .common.train import train 2 | from .common.model import model 3 | from .common.optimizer import optimizer 4 | from .common.scheduler import lr_multiplier 5 | from .common.dataloader import dataloader, train_dataset 6 | 7 | 8 | train.max_iter = int(43100 / 16 / 2 * 30) 9 | train.checkpointer.period = int(43100 / 16 / 2 * 2) 10 | 11 | model.backbone.use_rel_pos = True 12 | model.backbone.topk = 0.25 13 | model.backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention 14 | model.backbone.multi_score = True 15 | model.distill = True 16 | 17 | model.teacher_backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention 18 | 19 | optimizer.lr=5e-4 20 | lr_multiplier.scheduler.values=[1.0, 0.1, 0.05] 21 | lr_multiplier.scheduler.milestones=[int(43100 / 16 / 2 * 6), int(43100 / 16 / 2 * 26)] 22 | lr_multiplier.scheduler.num_updates = train.max_iter 23 | lr_multiplier.warmup_length = 250 / train.max_iter 24 | 25 | train.init_checkpoint = './pretrained/ViTMatte_S_Com_with_teacher.pth' 26 | train.output_dir = './output_of_train/MEMatte_S_topk0.25_win_global_long' 27 | 28 | dataloader.train.batch_size=16 29 | dataloader.train.num_workers=2 30 | -------------------------------------------------------------------------------- /configs/MEMatte_B_topk0.25_win_global_long.py: -------------------------------------------------------------------------------- 1 | from .common.train import train 2 | from .common.model import model 3 | from .common.optimizer import optimizer 4 | from .common.scheduler import lr_multiplier 5 | from .common.dataloader import dataloader, train_dataset 6 | 7 | model.backbone.embed_dim = 768 8 | model.backbone.num_heads = 12 9 | model.decoder.in_chans = 768 10 | 11 | model.teacher_backbone.embed_dim = 768 12 | model.teacher_backbone.num_heads = 12 13 | 14 | train.max_iter = int(43100 / 10 / 4 * 60) 15 | train.checkpointer.period = int(43100 / 10 / 4 * 4) 16 | 17 | model.backbone.use_rel_pos = True 18 | model.backbone.topk = 0.25 19 | model.backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention 20 | model.backbone.multi_score = True 21 | model.distill = True 22 | 23 | model.teacher_backbone.window_block_indexes=[0,1,3,4,6,7,9,10,] # 2, 5, 8 11 for global attention 24 | 25 | optimizer.lr=5e-4 26 | lr_multiplier.scheduler.values=[1.0, 0.1, 0.05] 27 | lr_multiplier.scheduler.milestones=[int(43100 / 10 / 4 * 30), int(43100 / 10 / 4 * 52)] 28 | lr_multiplier.scheduler.num_updates = train.max_iter 29 | lr_multiplier.warmup_length = 250 / train.max_iter 30 | 31 | train.init_checkpoint = './pretrained/ViTMatte_B_Com_with_teacher.pth' 32 | train.output_dir = './output_of_train/MEMatte_B_topk0.25_win_global_long' 33 | 34 | dataloader.train.batch_size=10 35 | dataloader.train.num_workers=4 36 | -------------------------------------------------------------------------------- /configs/common/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from functools import partial 3 | from detectron2.config import LazyCall as L 4 | from modeling import MEMatte, MattingCriterion, Detail_Capture, ViT, ViT_Teacher 5 | 6 | # Base 7 | embed_dim, num_heads = 384, 6 8 | 9 | model = L(MEMatte)( 10 | teacher_backbone = L(ViT_Teacher)( 11 | in_chans=4, 12 | img_size=512, 13 | patch_size=16, 14 | embed_dim=embed_dim, 15 | depth=12, 16 | num_heads=num_heads, 17 | drop_path_rate=0, 18 | window_size=14, 19 | mlp_ratio=4, 20 | qkv_bias=True, 21 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 22 | window_block_indexes=[], 23 | residual_block_indexes=[2, 5, 8, 11], 24 | use_rel_pos=True, 25 | out_feature="last_feat", 26 | ), 27 | backbone = L(ViT)( # Single-scale ViT backbone 28 | in_chans=4, 29 | img_size=512, 30 | patch_size=16, 31 | embed_dim=embed_dim, 32 | depth=12, 33 | num_heads=num_heads, 34 | drop_path_rate=0, 35 | window_size=14, 36 | mlp_ratio=4, 37 | qkv_bias=True, 38 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 39 | window_block_indexes=[ 40 | # # 2, 5, 8 11 for global attention 41 | # 0, 42 | # 1, 43 | # 3, 44 | # 4, 45 | # 6, 46 | # 7, 47 | # 9, 48 | # 10, 49 | ], 50 | residual_block_indexes=[2, 5, 8, 11], 51 | use_rel_pos=True, 52 | out_feature="last_feat", 53 | topk = 1, 54 | ), 55 | criterion=L(MattingCriterion)( 56 | losses = ['unknown_l1_loss', 'known_l1_loss', 'loss_pha_laplacian', 'loss_gradient_penalty'] 57 | ), 58 | pixel_mean = [123.675 / 255., 116.280 / 255., 103.530 / 255.], 59 | pixel_std = [58.395 / 255., 57.120 / 255., 57.375 / 255.], 60 | input_format = "RGB", 61 | size_divisibility=32, 62 | decoder=L(Detail_Capture)(), 63 | distill = True, 64 | distill_loss_ratio = 1, 65 | token_loss_ratio = 1, 66 | ) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /*.sh 2 | change_submit.py 3 | cluster_submit.yaml 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | work_dirs 15 | test 16 | val 17 | .Python 18 | build/ 19 | ckpts/ 20 | ckpts 21 | test/ 22 | val/ 23 | work_dirs/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | cluster.sh 144 | 145 | *.pth 146 | 147 | predAlpha 148 | evaluation_log.txt -------------------------------------------------------------------------------- /modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Dict 4 | import torch.nn as nn 5 | 6 | from detectron2.layers import ShapeSpec 7 | 8 | __all__ = ["Backbone"] 9 | 10 | 11 | class Backbone(nn.Module, metaclass=ABCMeta): 12 | """ 13 | Abstract base class for network backbones. 14 | """ 15 | 16 | def __init__(self): 17 | """ 18 | The `__init__` method of any subclass can specify its own set of arguments. 19 | """ 20 | super().__init__() 21 | 22 | @abstractmethod 23 | def forward(self): 24 | """ 25 | Subclasses must override this method, but adhere to the same return type. 26 | 27 | Returns: 28 | dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor 29 | """ 30 | pass 31 | 32 | @property 33 | def size_divisibility(self) -> int: 34 | """ 35 | Some backbones require the input height and width to be divisible by a 36 | specific integer. This is typically true for encoder / decoder type networks 37 | with lateral connection (e.g., FPN) for which feature maps need to match 38 | dimension in the "bottom up" and "top down" paths. Set to 0 if no specific 39 | input size divisibility is required. 40 | """ 41 | return 0 42 | 43 | @property 44 | def padding_constraints(self) -> Dict[str, int]: 45 | """ 46 | This property is a generalization of size_divisibility. Some backbones and training 47 | recipes require specific padding constraints, such as enforcing divisibility by a specific 48 | integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter 49 | in :paper:vitdet). `padding_constraints` contains these optional items like: 50 | { 51 | "size_divisibility": int, 52 | "square_size": int, 53 | # Future options are possible 54 | } 55 | `size_divisibility` will read from here if presented and `square_size` indicates the 56 | square padding size if `square_size` > 0. 57 | 58 | TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints 59 | could be generalized as TypedDict (Python 3.8+) to support more types in the future. 60 | """ 61 | return {} 62 | 63 | def output_shape(self): 64 | """ 65 | Returns: 66 | dict[str->ShapeSpec] 67 | """ 68 | # this is a backward-compatible default 69 | return { 70 | name: ShapeSpec( 71 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 72 | ) 73 | for name in self._out_features 74 | } 75 | -------------------------------------------------------------------------------- /engine/mattingtrainer.py: -------------------------------------------------------------------------------- 1 | from detectron2.engine import AMPTrainer 2 | import torch 3 | import time 4 | import detectron2.utils.comm as comm 5 | import logging 6 | 7 | from detectron2.utils.events import EventWriter, get_event_storage 8 | 9 | def cycle(iterable): 10 | while True: 11 | for x in iterable: 12 | yield x 13 | 14 | class MattingTrainer(AMPTrainer): 15 | def __init__(self, model, data_loader, optimizer, grad_scaler=None, log_image_period = 2000): 16 | super().__init__(model, data_loader, optimizer, grad_scaler=None) 17 | self.data_loader_iter = iter(cycle(self.data_loader)) 18 | self.log_image_period = log_image_period 19 | 20 | 21 | def run_step(self): 22 | """ 23 | Implement the AMP training logic. 24 | """ 25 | assert self.model.training, "[AMPTrainer] model was changed to eval mode!" 26 | assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" 27 | from torch.cuda.amp import autocast 28 | 29 | #matting pass 30 | start = time.perf_counter() 31 | data = next(self.data_loader_iter) 32 | data_time = time.perf_counter() - start 33 | 34 | with autocast(): 35 | loss_dict, output_images, out_pred_prob = self.model(data) 36 | if isinstance(loss_dict, torch.Tensor): 37 | losses = loss_dict 38 | loss_dict = {"total_loss": loss_dict} 39 | else: 40 | losses = sum(loss_dict.values()) 41 | 42 | self.optimizer.zero_grad() 43 | self.grad_scaler.scale(losses).backward() 44 | 45 | if self.iter % 20 == 0: 46 | self.write_ratios(out_pred_prob) 47 | self._write_metrics(loss_dict, data_time) 48 | self._write_images(output_images, data) 49 | 50 | self.grad_scaler.step(self.optimizer) 51 | self.grad_scaler.update() 52 | 53 | def write_ratios(self, out_pred_prob): 54 | storage = get_event_storage() 55 | for i in range(len(out_pred_prob)): 56 | storage.put_scalar(f"{i}_block_token_ratio", out_pred_prob[i].sum() / out_pred_prob[i].numel(), cur_iter = self.iter) 57 | storage.put_scalar("total_token_ratio", sum([p.sum() / p.numel() for p in out_pred_prob]) / len(out_pred_prob), cur_iter = self.iter) 58 | 59 | def _write_images(self, output_images: torch.Tensor, data: torch.Tensor, iter: int = None): 60 | logger = logging.getLogger(__name__) 61 | iter = self.iter if iter is None else iter 62 | if (iter + 1) % self.log_image_period == 0: 63 | try: 64 | MattingTrainer.write_images(output_images, data, iter) 65 | except Exception: 66 | logger.exception("Exception in writing images: ") 67 | raise 68 | 69 | @staticmethod 70 | def write_images(output_images: torch.Tensor, data: torch.Tensor, cur_iter:int = None): 71 | # output_images = output_images.detach().cpu() 72 | if comm.is_main_process(): 73 | storage = get_event_storage() 74 | storage.put_image("fg", data["fg"]) 75 | storage.put_image("alpha_gt", data["alpha"]) 76 | storage.put_image("bg", data["bg"]) 77 | storage.put_image("trimap", data["trimap"]) 78 | storage.put_image("image", data["image"]) 79 | # storage._block_ratio = (block_ratio, storage.iter) 80 | for key in output_images.keys(): 81 | storage.put_image(key, output_images[key]) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | from detectron2.engine.train_loop import HookBase 3 | from detectron2.utils.events import EventWriter, get_event_storage 4 | from functools import cached_property 5 | from typing import Optional 6 | from detectron2.utils.file_io import PathManager 7 | from detectron2.engine.defaults import CommonMetricPrinter, TensorboardXWriter, JSONWriter 8 | 9 | import os 10 | import torch 11 | 12 | def power_default_writers(output_dir: str, max_iter: Optional[int] = None): 13 | """ 14 | Build a list of :class:`EventWriter` to be used. 15 | It now consists of a :class:`CommonMetricPrinter`, 16 | :class:`TensorboardXWriter` and :class:`JSONWriter`. 17 | 18 | Args: 19 | output_dir: directory to store JSON metrics and tensorboard events 20 | max_iter: the total number of iterations 21 | 22 | Returns: 23 | list[EventWriter]: a list of :class:`EventWriter` objects. 24 | """ 25 | PathManager.mkdirs(output_dir) 26 | return [ 27 | # It may not always print what you want to see, since it prints "common" metrics only. 28 | CommonMetricPrinter(max_iter), 29 | JSONWriter(os.path.join(output_dir, "metrics.json")), 30 | PowerTensorboardXWriter(output_dir), 31 | ] 32 | 33 | 34 | class PowerTensorboardXWriter(TensorboardXWriter): 35 | """ 36 | Write all scalars to a tensorboard file. 37 | Compared to the offical code, we replace the add_image with add_images. 38 | """ 39 | 40 | def __init__(self, log_dir: str, window_size: int = 20, **kwargs): 41 | """ 42 | Args: 43 | log_dir (str): the directory to save the output events 44 | window_size (int): the scalars will be median-smoothed by this window size 45 | 46 | kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` 47 | """ 48 | super().__init__(log_dir, window_size, **kwargs) 49 | 50 | def write(self): 51 | storage = get_event_storage() 52 | new_last_write = self._last_write 53 | for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items(): 54 | if iter > self._last_write: 55 | self._writer.add_scalar(k, v, iter) 56 | new_last_write = max(new_last_write, iter) 57 | self._last_write = new_last_write 58 | 59 | # storage.put_{image,histogram} is only meant to be used by 60 | # tensorboard writer. So we access its internal fields directly from here. 61 | if len(storage._vis_data) >= 1: 62 | for img_name, img, step_num in storage._vis_data: 63 | self._writer.add_images(img_name, img, step_num) 64 | # Storage stores all image data and rely on this writer to clear them. 65 | # As a result it assumes only one writer will use its image data. 66 | # An alternative design is to let storage store limited recent 67 | # data (e.g. only the most recent image) that all writers can access. 68 | # In that case a writer may not see all image data if its period is long. 69 | storage.clear_images() 70 | 71 | # if len(storage._block_ratio) == 2: 72 | # block_ratio, step_num = storage._block_ratio 73 | # self._writer.add_text("block_token_ratio", block_ratio, step_num) 74 | # storage._block_ratio = () 75 | 76 | if len(storage._histograms) >= 1: 77 | for params in storage._histograms: 78 | self._writer.add_histogram_raw(**params) 79 | storage.clear_histograms() -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from tqdm import tqdm 5 | from torch.utils.data import Dataset, DataLoader 6 | from torchvision.transforms import functional as F 7 | from os.path import join as opj 8 | from detectron2.checkpoint import DetectionCheckpointer 9 | from detectron2.config import LazyConfig, instantiate 10 | from detectron2.engine import default_argument_parser 11 | 12 | import warnings 13 | warnings.filterwarnings('ignore') 14 | 15 | #Dataset and Dataloader 16 | def collate_fn(batched_inputs): 17 | rets = dict() 18 | for k in batched_inputs[0].keys(): 19 | rets[k] = torch.stack([_[k] for _ in batched_inputs]) 20 | return rets 21 | 22 | class Composition_1k(Dataset): 23 | def __init__(self, data_dir, finished_list = None): 24 | self.data_dir = data_dir 25 | if "AIM" in data_dir: 26 | self.file_names = sorted(os.listdir(opj(self.data_dir, 'original'))) 27 | else: 28 | self.file_names = sorted(os.listdir(opj(self.data_dir, 'merged')), reverse=True) 29 | 30 | # self.file_names = list(set(self.file_names).difference(set(finished_list))) # difference 31 | 32 | def __len__(self): 33 | return len(self.file_names) 34 | 35 | def __getitem__(self, idx): 36 | if "AIM" in self.data_dir: 37 | tris = Image.open(opj(self.data_dir, 'trimap', self.file_names[idx].replace('jpg','png'))) 38 | imgs = Image.open(opj(self.data_dir, 'original', self.file_names[idx])) 39 | else: 40 | tris = Image.open(opj(self.data_dir, 'trimaps', self.file_names[idx].replace('jpeg','png').replace('jpg','png'))) 41 | imgs = Image.open(opj(self.data_dir, 'merged', self.file_names[idx])) 42 | 43 | sample = {} 44 | 45 | sample['trimap'] = F.to_tensor(tris)[0:1, :, :] 46 | sample['image'] = F.to_tensor(imgs) 47 | sample['image_name'] = self.file_names[idx] 48 | return sample 49 | 50 | 51 | #model and output 52 | def matting_inference( 53 | config_dir='', 54 | checkpoint_dir='', 55 | inference_dir='', 56 | data_dir='', 57 | rank=None, 58 | max_number_token = 18500, 59 | ): 60 | 61 | #initializing model 62 | cfg = LazyConfig.load(config_dir) 63 | cfg.model.teacher_backbone = None 64 | cfg.model.backbone.max_number_token = max_number_token 65 | model = instantiate(cfg.model) 66 | model.to(cfg.train.device if rank is None else rank) 67 | model.eval() 68 | DetectionCheckpointer(model).load(checkpoint_dir) 69 | 70 | #initializing dataset 71 | composition_1k_dataloader = DataLoader( 72 | dataset = Composition_1k( 73 | data_dir = data_dir, 74 | ), 75 | shuffle = False, 76 | batch_size = 1, 77 | ) 78 | 79 | #inferencing 80 | os.makedirs(inference_dir, exist_ok=True) 81 | 82 | for data in tqdm(composition_1k_dataloader): 83 | with torch.no_grad(): 84 | for k in data.keys(): 85 | if k == 'image_name': 86 | continue 87 | else: 88 | data[k].to(model.device) 89 | 90 | output, _, _ = model(data, patch_decoder=True) 91 | output = output['phas'].flatten(0, 2) 92 | trimap = data['trimap'].squeeze(0).squeeze(0) 93 | output[trimap == 0] = 0 94 | output[trimap == 1] = 1 95 | output = F.to_pil_image(output) 96 | output.save(opj(inference_dir, data['image_name'][0].replace('.jpg', '.png'))) 97 | torch.cuda.empty_cache() 98 | 99 | if __name__ == '__main__': 100 | #add argument we need: 101 | parser = default_argument_parser() 102 | parser.add_argument('--config-dir', type=str, required=True) 103 | parser.add_argument('--checkpoint-dir', type=str, required=True) 104 | parser.add_argument('--inference-dir', type=str, required=True) 105 | parser.add_argument('--data-dir', type=str, required=True) 106 | parser.add_argument('--max-number-token', type=int, required=True, default=18500) 107 | 108 | args = parser.parse_args() 109 | matting_inference( 110 | config_dir = args.config_dir, 111 | checkpoint_dir = args.checkpoint_dir, 112 | inference_dir = args.inference_dir, 113 | data_dir = args.data_dir, 114 | max_number_token = args.max_number_token 115 | ) 116 | -------------------------------------------------------------------------------- /modeling/decoder/detail_capture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Basic_Conv3x3(nn.Module): 6 | """ 7 | Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. 8 | """ 9 | def __init__( 10 | self, 11 | in_chans, 12 | out_chans, 13 | stride=2, 14 | padding=1, 15 | ): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) 18 | self.bn = nn.BatchNorm2d(out_chans) 19 | self.relu = nn.ReLU(True) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | x = self.bn(x) 24 | x = self.relu(x) 25 | 26 | return x 27 | 28 | class ConvStream(nn.Module): 29 | """ 30 | Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. 31 | """ 32 | def __init__( 33 | self, 34 | in_chans = 4, 35 | out_chans = [48, 96, 192], 36 | ): 37 | super().__init__() 38 | self.convs = nn.ModuleList() 39 | 40 | self.conv_chans = out_chans.copy() 41 | self.conv_chans.insert(0, in_chans) 42 | 43 | for i in range(len(self.conv_chans)-1): 44 | in_chan_ = self.conv_chans[i] 45 | out_chan_ = self.conv_chans[i+1] 46 | self.convs.append( 47 | Basic_Conv3x3(in_chan_, out_chan_) 48 | ) 49 | 50 | def forward(self, x): 51 | out_dict = {'D0': x} 52 | for i in range(len(self.convs)): 53 | x = self.convs[i](x) 54 | name_ = 'D'+str(i+1) 55 | out_dict[name_] = x 56 | 57 | return out_dict 58 | 59 | class Fusion_Block(nn.Module): 60 | """ 61 | Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. 62 | """ 63 | def __init__( 64 | self, 65 | in_chans, 66 | out_chans, 67 | ): 68 | super().__init__() 69 | self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) 70 | 71 | def forward(self, x, D): 72 | F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 73 | out = torch.cat([D, F_up], dim=1) 74 | out = self.conv(out) 75 | 76 | return out 77 | 78 | class Matting_Head(nn.Module): 79 | """ 80 | Simple Matting Head, containing only conv3x3 and conv1x1 layers. 81 | """ 82 | def __init__( 83 | self, 84 | in_chans = 32, 85 | mid_chans = 16, 86 | ): 87 | super().__init__() 88 | self.matting_convs = nn.Sequential( 89 | nn.Conv2d(in_chans, mid_chans, 3, 1, 1), 90 | nn.BatchNorm2d(mid_chans), 91 | nn.ReLU(True), 92 | nn.Conv2d(mid_chans, 1, 1, 1, 0) 93 | ) 94 | 95 | def forward(self, x): 96 | x = self.matting_convs(x) 97 | 98 | return x 99 | 100 | class Detail_Capture(nn.Module): 101 | """ 102 | Simple and Lightweight Detail Capture Module for ViT Matting. 103 | """ 104 | def __init__( 105 | self, 106 | in_chans = 384, 107 | img_chans=4, 108 | convstream_out = [48, 96, 192], 109 | fusion_out = [256, 128, 64, 32], 110 | ): 111 | super().__init__() 112 | assert len(fusion_out) == len(convstream_out) + 1 113 | 114 | self.convstream = ConvStream(in_chans = img_chans) 115 | self.conv_chans = self.convstream.conv_chans 116 | 117 | self.fusion_blks = nn.ModuleList() 118 | self.fus_channs = fusion_out.copy() 119 | self.fus_channs.insert(0, in_chans) 120 | for i in range(len(self.fus_channs)-1): 121 | self.fusion_blks.append( 122 | Fusion_Block( 123 | in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], 124 | out_chans = self.fus_channs[i+1], 125 | ) 126 | ) 127 | 128 | self.matting_head = Matting_Head( 129 | in_chans = fusion_out[-1], 130 | ) 131 | 132 | def forward(self, features, images): 133 | detail_features = self.convstream(images) 134 | for i in range(len(self.fusion_blks)): 135 | d_name_ = 'D'+str(len(self.fusion_blks)-i-1) 136 | features = self.fusion_blks[i](features, detail_features[d_name_]) 137 | 138 | phas = torch.sigmoid(self.matting_head(features)) 139 | 140 | return {'phas': phas} 141 | 142 | if __name__ == '__main__': 143 | detail_capture = Detail_Capture() 144 | features = torch.rand(2, 384, 32, 32) 145 | inputs = torch.randn(2, 4, 512, 512) 146 | print(detail_capture) 147 | output = detail_capture(features, inputs) 148 | print(output.shape) -------------------------------------------------------------------------------- /modeling/criterion/matting_criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MattingCriterion(nn.Module): 6 | def __init__(self, 7 | *, 8 | losses, 9 | ): 10 | super(MattingCriterion, self).__init__() 11 | self.losses = losses 12 | 13 | def loss_gradient_penalty(self, sample_map ,preds, targets): 14 | preds = preds['phas'] 15 | targets = targets['phas'] 16 | 17 | #sample_map for unknown area 18 | scale = sample_map.shape[0]*262144/torch.sum(sample_map) 19 | 20 | #gradient in x 21 | sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type()) 22 | delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1) 23 | delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1) 24 | 25 | #gradient in y 26 | sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type()) 27 | delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1) 28 | delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1) 29 | 30 | #loss 31 | loss = (F.l1_loss(delta_pred_x*sample_map, delta_gt_x*sample_map)* scale + \ 32 | F.l1_loss(delta_pred_y*sample_map, delta_gt_y*sample_map)* scale + \ 33 | 0.01 * torch.mean(torch.abs(delta_pred_x*sample_map))* scale + \ 34 | 0.01 * torch.mean(torch.abs(delta_pred_y*sample_map))* scale) 35 | 36 | return dict(loss_gradient_penalty=loss) 37 | 38 | def loss_pha_laplacian(self, preds, targets): 39 | assert 'phas' in preds and 'phas' in targets 40 | loss = laplacian_loss(preds['phas'], targets['phas']) 41 | 42 | return dict(loss_pha_laplacian=loss) 43 | 44 | def unknown_l1_loss(self, sample_map, preds, targets): 45 | 46 | scale = sample_map.shape[0]*262144/torch.sum(sample_map) 47 | # scale = 1 48 | 49 | loss = F.l1_loss(preds['phas']*sample_map, targets['phas']*sample_map)*scale 50 | return dict(unknown_l1_loss=loss) 51 | 52 | def known_l1_loss(self, sample_map, preds, targets): 53 | new_sample_map = torch.zeros_like(sample_map) 54 | new_sample_map[sample_map==0] = 1 55 | 56 | if torch.sum(new_sample_map) == 0: 57 | scale = 0 58 | else: 59 | scale = new_sample_map.shape[0]*262144/torch.sum(new_sample_map) 60 | # scale = 1 61 | 62 | loss = F.l1_loss(preds['phas']*new_sample_map, targets['phas']*new_sample_map)*scale 63 | return dict(known_l1_loss=loss) 64 | 65 | 66 | def forward(self, sample_map, preds, targets): 67 | losses = dict() 68 | for k in self.losses: 69 | if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty': 70 | losses.update(getattr(self, k)(sample_map, preds, targets)) 71 | else: 72 | losses.update(getattr(self, k)(preds, targets)) 73 | return losses 74 | 75 | 76 | #-----------------Laplacian Loss-------------------------# 77 | def laplacian_loss(pred, true, max_levels=5): 78 | kernel = gauss_kernel(device=pred.device, dtype=pred.dtype) 79 | pred_pyramid = laplacian_pyramid(pred, kernel, max_levels) 80 | true_pyramid = laplacian_pyramid(true, kernel, max_levels) 81 | loss = 0 82 | for level in range(max_levels): 83 | loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level]) 84 | return loss / max_levels 85 | 86 | def laplacian_pyramid(img, kernel, max_levels): 87 | current = img 88 | pyramid = [] 89 | for _ in range(max_levels): 90 | current = crop_to_even_size(current) 91 | down = downsample(current, kernel) 92 | up = upsample(down, kernel) 93 | diff = current - up 94 | pyramid.append(diff) 95 | current = down 96 | return pyramid 97 | 98 | def gauss_kernel(device='cpu', dtype=torch.float32): 99 | kernel = torch.tensor([[1, 4, 6, 4, 1], 100 | [4, 16, 24, 16, 4], 101 | [6, 24, 36, 24, 6], 102 | [4, 16, 24, 16, 4], 103 | [1, 4, 6, 4, 1]], device=device, dtype=dtype) 104 | kernel /= 256 105 | kernel = kernel[None, None, :, :] 106 | return kernel 107 | 108 | def gauss_convolution(img, kernel): 109 | B, C, H, W = img.shape 110 | img = img.reshape(B * C, 1, H, W) 111 | img = F.pad(img, (2, 2, 2, 2), mode='reflect') 112 | img = F.conv2d(img, kernel) 113 | img = img.reshape(B, C, H, W) 114 | return img 115 | 116 | def downsample(img, kernel): 117 | img = gauss_convolution(img, kernel) 118 | img = img[:, :, ::2, ::2] 119 | return img 120 | 121 | def upsample(img, kernel): 122 | B, C, H, W = img.shape 123 | out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype) 124 | out[:, :, ::2, ::2] = img * 4 125 | out = gauss_convolution(out, kernel) 126 | return out 127 | 128 | def crop_to_even_size(img): 129 | H, W = img.shape[2:] 130 | H = H - H % 2 131 | W = W - W % 2 132 | return img[:, :, :H, :W] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | """ 4 | Training script using the new "LazyConfig" python config files. 5 | 6 | This scripts reads a given python config file and runs the training or evaluation. 7 | It can be used to train any models or dataset as long as they can be 8 | instantiated by the recursive construction defined in the given config file. 9 | 10 | Besides lazy construction of models, dataloader, etc., this scripts expects a 11 | few common configuration parameters currently defined in "configs/common/train.py". 12 | To add more complicated training logic, you can easily add other configs 13 | in the config file and implement a new train_net.py to handle them. 14 | """ 15 | import logging 16 | 17 | from detectron2.checkpoint import DetectionCheckpointer 18 | from detectron2.config import LazyConfig, instantiate 19 | from detectron2.engine import ( 20 | AMPTrainer, 21 | SimpleTrainer, 22 | default_argument_parser, 23 | default_setup, 24 | default_writers, 25 | hooks, 26 | launch, 27 | ) 28 | from detectron2.engine.defaults import create_ddp_model 29 | from detectron2.evaluation import inference_on_dataset, print_csv_format 30 | from detectron2.utils import comm 31 | from utils.logger import power_default_writers 32 | 33 | from engine import MattingTrainer 34 | 35 | #running without warnings 36 | import warnings 37 | warnings.filterwarnings('ignore') 38 | 39 | logger = logging.getLogger("detectron2") 40 | 41 | 42 | def do_test(cfg, model): 43 | if "evaluator" in cfg.dataloader: 44 | ret = inference_on_dataset( 45 | model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) 46 | ) 47 | print_csv_format(ret) 48 | return ret 49 | 50 | 51 | def do_train(args, cfg): 52 | """ 53 | Args: 54 | cfg: an object with the following attributes: 55 | model: instantiate to a module 56 | dataloader.{train,test}: instantiate to dataloaders 57 | dataloader.evaluator: instantiate to evaluator for test set 58 | optimizer: instantaite to an optimizer 59 | lr_multiplier: instantiate to a fvcore scheduler 60 | train: other misc config defined in `configs/common/train.py`, including: 61 | output_dir (str) 62 | init_checkpoint (str) 63 | amp.enabled (bool) 64 | max_iter (int) 65 | eval_period, log_period (int) 66 | device (str) 67 | checkpointer (dict) 68 | ddp (dict) 69 | """ 70 | model = instantiate(cfg.model) 71 | logger = logging.getLogger("detectron2") 72 | logger.info("Model:\n{}".format(model)) 73 | model.to(cfg.train.device) 74 | 75 | for name, param in model.named_parameters(): 76 | if "teacher" in name: 77 | param.requires_grad = False 78 | 79 | cfg.optimizer.params.model = model 80 | optim = instantiate(cfg.optimizer) 81 | 82 | train_dataset = instantiate(cfg.train_dataset) 83 | cfg.dataloader.train.dataset = train_dataset 84 | cfg.dataloader.train.sampler.dataset = train_dataset 85 | train_loader = instantiate(cfg.dataloader.train) 86 | 87 | model = create_ddp_model(model, **cfg.train.ddp) 88 | trainer = MattingTrainer( 89 | model = model, 90 | data_loader = train_loader, 91 | optimizer = optim, 92 | log_image_period = cfg.train.log_image_period 93 | ) 94 | checkpointer = DetectionCheckpointer( 95 | model, 96 | cfg.train.output_dir, 97 | trainer=trainer, 98 | ) 99 | trainer.register_hooks( 100 | [ 101 | hooks.IterationTimer(), 102 | hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), 103 | hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) 104 | if comm.is_main_process() 105 | else None, 106 | hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), 107 | hooks.PeriodicWriter( 108 | power_default_writers(cfg.train.output_dir, cfg.train.max_iter), 109 | period=cfg.train.log_loss_period, 110 | ) 111 | if comm.is_main_process() 112 | else None, 113 | ] 114 | ) 115 | 116 | checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) 117 | if args.resume and checkpointer.has_checkpoint(): 118 | # The checkpoint stores the training iteration that just finished, thus we start 119 | # at the next iteration 120 | start_iter = trainer.iter + 1 121 | else: 122 | start_iter = 0 123 | trainer.train(start_iter, cfg.train.max_iter) 124 | 125 | 126 | def main(args): 127 | cfg = LazyConfig.load(args.config_file) 128 | cfg = LazyConfig.apply_overrides(cfg, args.opts) 129 | default_setup(cfg, args) 130 | 131 | if args.eval_only: 132 | model = instantiate(cfg.model) 133 | model.to(cfg.train.device) 134 | model = create_ddp_model(model) 135 | DetectionCheckpointer(model).load(cfg.train.init_checkpoint) 136 | print(do_test(cfg, model)) 137 | else: 138 | do_train(args, cfg) 139 | 140 | 141 | if __name__ == "__main__": 142 | args = default_argument_parser().parse_args() 143 | launch( 144 | main, 145 | args.num_gpus, 146 | num_machines=args.num_machines, 147 | machine_rank=args.machine_rank, 148 | dist_url=args.dist_url, 149 | args=(args,), 150 | ) 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Memory Efficient Matting with Adaptive Token Routing

4 | 5 | Yiheng Lin, Yihan Hu, Chenyi Zhang, Ting Liu, Xiaochao Qu, Luoqi Liu, Yao Zhao, Yunchao Wei 6 | 7 | Institute of Information Science, Beijing Jiaotong University 8 | Visual Intelligence + X International Joint Laboratory of the Ministry of Education 9 | Pengcheng Laboratory, Shenzhen, China 10 | MT Lab, Meitu Inc 11 | 12 |

13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | HuggingFace 21 | 22 |

23 |
24 | 25 | ## 📮 News 26 | - [2025.07] Release our interactive image matting model - [MattePro](https://github.com/ChenyiZhang007/MattePro). 27 | 28 | ## Introduction 29 | Transformer-based models have recently achieved outstanding performance in image matting. However, their application to high-resolution images remains challenging due to the quadratic complexity of global self-attention. To address this issue, we propose MEMatte, a memory-efficient matting framework for processing high-resolution images. MEMatte incorporates a router before each global attention block, directing informative tokens to the global attention while routing other tokens to a Lightweight Token Refinement Module (LTRM). Specifically, the router employs a local-global strategy to predict the routing probability of each token, and the LTRM utilizes efficient modules to simulate global attention. Additionally, we introduce a Batch-constrained Adaptive Token Routing (BATR) mechanism, which allows each router to dynamically route tokens based on image content and the stages of attention block in the network. 30 | 31 | ## Dataset 32 | Our proposed ultra high-resolution image matting datasets: 33 | [`huggingface: dafbgd/UHRIM`](https://huggingface.co/datasets/dafbgd/UHRIM) 34 | 35 | 36 | ## Quick Installation 37 | Run the following command to install required packages. 38 | ``` 39 | pip install -r requirements.txt 40 | ``` 41 | Install [detectron2](https://github.com/facebookresearch/detectron2) please following its [document](https://detectron2.readthedocs.io/en/latest/), you can also run following command 42 | ``` 43 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' 44 | ``` 45 | 46 | ## Results 47 | Quantitative Results on [Composition-1k](https://paperswithcode.com/dataset/composition-1k) 48 | | Model | SAD | MSE | Grad | Conn | checkpoints | 49 | | ---------- | ----- | --- | ---- | ----- | ----------- | 50 | | MEMatte-ViTS | 21.90 | 3.37 | 7.43 | 16.77 | [GoogleDrive](https://drive.google.com/file/d/122p3sdhJVb7vg4IXELeC9C3HEG9Mlh5z/view?usp=sharing) | 51 | | MEMatte-ViTB | 21.06 | 3.11 | 6.70 | 15.71 | [GoogleDrive](https://drive.google.com/file/d/1NOV64zMSFtoKPASqvEvxQKI_PRY9m5IA/view?usp=sharing) | 52 | 53 | We also train a robust model for real-world images AIM-500 using mixed data: 54 | | Model | SAD | MSE | Grad | Conn | checkpoints | 55 | | ---------- | ----- | --- | ---- | ----- | ----------- | 56 | | MEMatte-ViTS | 13.90 | 11.17 | 10.94 | 12.78 | [GoogleDrive](https://drive.google.com/file/d/1R5NbgIpOudKjvLz1V9M9SxXr1ovAmu3u/view?usp=drive_link) | 57 | 58 | ## Train 59 | 1. Download the official checkpoints of ViTMatte ([ViTMatte_S_Com.pth](https://drive.google.com/file/d/12VKhSwE_miF9lWQQCgK7mv83rJIls3Xe/view), [ViTMatte_B_Com.pth](https://drive.google.com/file/d/1mOO5MMU4kwhNX96AlfpwjAoMM4V5w3k-/view?pli=1)), and then process the checkpoints using `pretrained/preprocess.py`. 60 | 2. Set `train.init_checkpoint` in the configs to specify the processed checkpoint. 61 | 3. Train the model with the following command: 62 | ``` 63 | python main.py \ 64 | --config-file configs/MEMatte_S_topk0.25_win_global_long.py \ 65 | --num-gpus 2 66 | ``` 67 | 68 | 69 | ## Inference 70 | ``` 71 | python inference.py \ 72 | --config-dir ./configs/CONFIG.py \ 73 | --checkpoint-dir ./CHECKPOINT_PATH \ 74 | --inference-dir ./SAVE_DIR \ 75 | --data-dir /DataDir \ 76 | --max-number-token Max_number_token 77 | ``` 78 | For example: 79 | ``` 80 | python inference.py \ 81 | --config-dir ./configs/MEMatte_S_topk0.25_win_global_long.py \ 82 | --checkpoint-dir ./checkpoints/MEMatte_ViTS_DIM.pth \ 83 | --inference-dir ./predAlpha/test_aim500 \ 84 | --data-dir ./Datasets/AIM-500 \ 85 | --max-number-token 18000 86 | # Reducing the maximum number of tokens lowers memory usage. 87 | ``` 88 | 89 | ## ToDo 90 | - [x] release UHRIM dataset 91 | - [x] release code and checkpoint 92 | 93 | ## Citation 94 | If you have any questions, please feel free to open an issue. If you find our method or dataset helpful, we would appreciate it if you could give our project a star ⭐️ on GitHub and cite our paper: 95 | ```bibtex 96 | @inproceedings{lin2025memory, 97 | title={Memory Efficient Matting with Adaptive Token Routing}, 98 | author={Lin, Yiheng and Hu, Yihan and Zhang, Chenyi and Liu, Ting and Qu, Xiaochao and Liu, Luoqi and Zhao, Yao and Wei, Yunchao}, 99 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 100 | volume={39}, 101 | number={5}, 102 | pages={5298--5306}, 103 | year={2025} 104 | } 105 | ``` 106 | 107 | ## License 108 | The code is released under the MIT License. It is a short, permissive software license. Basically, you can do whatever you want as long as you include the original copyright and license notice in any copy of the software/source. 109 | 110 | ## Acknowledgement 111 | Our project is developed based on [ViTMatte](https://github.com/hustvl/ViTMatte), [DynamicViT](https://github.com/raoyongming/DynamicViT), [Matteformer](https://github.com/webtoon/matteformer), [ToMe](https://github.com/facebookresearch/ToMe), [EViT](https://github.com/youweiliang/evit). Thanks for their wonderful work!
112 | 113 | -------------------------------------------------------------------------------- /modeling/meta_arch/mematte.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import os 6 | 7 | from detectron2.structures import ImageList 8 | 9 | class MEMatte(nn.Module): 10 | def __init__(self, 11 | *, 12 | teacher_backbone, 13 | backbone, 14 | criterion, 15 | pixel_mean, 16 | pixel_std, 17 | input_format, 18 | size_divisibility, 19 | decoder, 20 | distill = True, 21 | distill_loss_ratio = 1., 22 | token_loss_ratio = 1., 23 | balance_loss = "MSE", 24 | ): 25 | super(MEMatte, self).__init__() 26 | self.teacher_backbone = teacher_backbone 27 | self.backbone = backbone 28 | self.criterion = criterion 29 | self.input_format = input_format 30 | self.balance_loss = balance_loss 31 | self.size_divisibility = size_divisibility 32 | self.decoder = decoder 33 | self.distill = distill 34 | self.distill_loss_ratio = distill_loss_ratio 35 | self.token_loss_ratio = token_loss_ratio 36 | self.register_buffer( 37 | "pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False 38 | ) 39 | self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) 40 | assert ( 41 | self.pixel_mean.shape == self.pixel_std.shape 42 | ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" 43 | 44 | @property 45 | def device(self): 46 | return self.pixel_mean.device 47 | 48 | def forward(self, batched_inputs, patch_decoder=True): 49 | images, targets, H, W = self.preprocess_inputs(batched_inputs) 50 | 51 | if self.training: 52 | if self.distill == True: 53 | self.teacher_backbone.eval() 54 | features, out_pred_prob = self.backbone(images) 55 | teacher_features = self.teacher_backbone(images) 56 | distill_loss = F.mse_loss(features, teacher_features) 57 | else: 58 | features, out_pred_prob = self.backbone(images) 59 | outputs = self.decoder(features, images) 60 | assert targets is not None 61 | trimap = images[:, 3:4] 62 | sample_map = torch.zeros_like(trimap) 63 | sample_map[trimap==0.5] = 1 # 2*B*dim*(2*ratio*hw*dim + ratio*hw*ratio*hw) 64 | losses = self.criterion(sample_map ,outputs, targets) 65 | if self.distill: 66 | losses['distill_loss'] = distill_loss * self.distill_loss_ratio 67 | total_ratio = sum([p.sum() / p.numel() for p in out_pred_prob]) / len(out_pred_prob) 68 | losses['mse_ratio_loss'] = self.token_loss_ratio * F.mse_loss(total_ratio, torch.tensor(self.backbone.topk).cuda()) 69 | 70 | tensorboard_images = dict() 71 | tensorboard_images['pred_alpha'] = outputs['phas'] 72 | return losses, tensorboard_images, out_pred_prob 73 | else: 74 | features, out_pred_prob, out_hard_keep_decision = self.backbone(images) 75 | if patch_decoder: 76 | outputs = self.patch_inference(features=features, images=images) 77 | else: 78 | outputs = self.decoder(features, images) 79 | 80 | outputs['phas'] = outputs['phas'][:,:,:H,:W] 81 | 82 | return outputs, out_pred_prob, out_hard_keep_decision 83 | 84 | """测试flops""" 85 | # features = self.backbone(images) 86 | # # outputs = self.decoder(features, images) 87 | # # outputs['phas'] = outputs['phas'][:,:,:H,:W] 88 | # return features 89 | 90 | """测试decoder flops""" 91 | 92 | 93 | 94 | 95 | def patch_inference(self, features, images): 96 | patch_size = 512 97 | overlap = 64 98 | image_size = patch_size + 2 * overlap 99 | feature_patch_size = patch_size // 16 100 | feature_overlap = overlap // 16 101 | features_size = feature_patch_size + 2 * feature_overlap 102 | B, C, H, W = images.shape 103 | pad_h = (patch_size - H % patch_size) % patch_size 104 | pad_w = (patch_size - W % patch_size) % patch_size 105 | pad_images = F.pad(images.permute(0,2,3,1), (0,0,0,pad_w,0,pad_h)).permute(0,3,1,2) 106 | _, _, pad_H, pad_W = pad_images.shape 107 | 108 | _, _, H_fea, W_fea = features.shape 109 | pad_fea_h = (feature_patch_size - H_fea % feature_patch_size) % feature_patch_size 110 | pad_fea_w = (feature_patch_size - W_fea % feature_patch_size) % feature_patch_size 111 | pad_features = F.pad(features.permute(0,2,3,1), (0,0,0,pad_fea_w,0,pad_fea_h)).permute(0,3,1,2) 112 | _, _, pad_fea_H, pad_fea_W = pad_features.shape 113 | 114 | h_patch_num = pad_images.shape[2] // patch_size 115 | w_patch_num = pad_images.shape[3] // patch_size 116 | 117 | outputs = torch.zeros_like(pad_images[:,0:1,:,:]) 118 | 119 | for i in range(h_patch_num): 120 | for j in range(w_patch_num): 121 | start_top = i * patch_size 122 | end_bottom = start_top + patch_size 123 | start_left = j*patch_size 124 | end_right = start_left + patch_size 125 | coor_top = start_top if (start_top - overlap) < 0 else (start_top - overlap) 126 | coor_bottom = end_bottom if (end_bottom + overlap) > pad_H else (end_bottom + overlap) 127 | coor_left = start_left if (start_left - overlap) < 0 else (start_left - overlap) 128 | coor_right = end_right if (end_right + overlap) > pad_W else (end_right + overlap) 129 | selected_images = pad_images[:,:,coor_top:coor_bottom, coor_left:coor_right] 130 | 131 | fea_start_top = i * feature_patch_size 132 | fea_end_bottom = fea_start_top + feature_patch_size 133 | fea_start_left = j*feature_patch_size 134 | fea_end_right = fea_start_left + feature_patch_size 135 | coor_top_fea = fea_start_top if (fea_start_top - feature_overlap) < 0 else (fea_start_top - feature_overlap) 136 | coor_bottom_fea = fea_end_bottom if (fea_end_bottom + feature_overlap) > pad_fea_H else (fea_end_bottom + feature_overlap) 137 | coor_left_fea = fea_start_left if (fea_start_left - feature_overlap) < 0 else (fea_start_left - feature_overlap) 138 | coor_right_fea = fea_end_right if (fea_end_right + feature_overlap) > pad_fea_W else (fea_end_right + feature_overlap) 139 | selected_fea = pad_features[:,:,coor_top_fea:coor_bottom_fea, coor_left_fea:coor_right_fea] 140 | 141 | 142 | 143 | outputs_patch = self.decoder(selected_fea, selected_images) 144 | 145 | coor_top = start_top if (start_top - overlap) < 0 else (coor_top + overlap) 146 | coor_bottom = coor_top + patch_size 147 | coor_left = start_left if (start_left - overlap) < 0 else (coor_left + overlap) 148 | coor_right = coor_left + patch_size 149 | 150 | coor_out_top = 0 if (start_top - overlap) < 0 else overlap 151 | coor_out_bottom = coor_out_top + patch_size 152 | coor_out_left = 0 if (start_left - overlap) < 0 else overlap 153 | coor_out_right = coor_out_left + patch_size 154 | 155 | outputs[:, :, coor_top:coor_bottom, coor_left:coor_right] = outputs_patch['phas'][:,:,coor_out_top:coor_out_bottom,coor_out_left:coor_out_right] 156 | 157 | outputs = outputs[:,:,:H, :W] 158 | return {'phas':outputs} 159 | 160 | def preprocess_inputs(self, batched_inputs): 161 | """ 162 | Normalize, pad and batch the input images. 163 | """ 164 | images = batched_inputs["image"].to(self.device) 165 | trimap = batched_inputs['trimap'].to(self.device) 166 | images = (images - self.pixel_mean) / self.pixel_std 167 | 168 | if 'fg' in batched_inputs.keys(): 169 | trimap[trimap < 85] = 0 170 | trimap[trimap >= 170] = 1 171 | trimap[trimap >= 85] = 0.5 172 | 173 | images = torch.cat((images, trimap), dim=1) 174 | 175 | B, C, H, W = images.shape 176 | if images.shape[-1]%32!=0 or images.shape[-2]%32!=0: 177 | new_H = (32-images.shape[-2]%32) + H 178 | new_W = (32-images.shape[-1]%32) + W 179 | new_images = torch.zeros((images.shape[0], images.shape[1], new_H, new_W)).to(self.device) 180 | new_images[:,:,:H,:W] = images[:,:,:,:] 181 | del images 182 | images = new_images 183 | 184 | if "alpha" in batched_inputs: 185 | phas = batched_inputs["alpha"].to(self.device) 186 | else: 187 | phas = None 188 | 189 | return images, dict(phas=phas), H, W 190 | 191 | -------------------------------------------------------------------------------- /modeling/backbone/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, repeat, pack, unpack 8 | from functools import partial 9 | from collections import namedtuple 10 | from torch import Tensor, nn, einsum 11 | import numpy as np 12 | from torch.cuda.amp import autocast 13 | 14 | __all__ = [ 15 | "window_partition", 16 | "window_unpartition", 17 | "add_decomposed_rel_pos", 18 | "get_abs_pos", 19 | "PatchEmbed", 20 | "Router", 21 | ] 22 | 23 | RouterReturn = namedtuple('RouterReturn', ['indices', 'scores', 'routed_tokens', 'routed_mask']) 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | def default(val, d): 29 | return val if exists(val) else d 30 | 31 | def divisible_by(numer, denom): 32 | return (numer % denom) == 0 33 | 34 | def pack_one(t, pattern): 35 | return pack([t], pattern) 36 | 37 | def unpack_one(t, ps, pattern): 38 | return unpack(t, ps, pattern)[0] 39 | 40 | def pad_to_multiple(tensor, multiple, dim=-1, value=0): 41 | seq_len = tensor.shape[dim] 42 | m = seq_len / multiple 43 | if m.is_integer(): 44 | return tensor, seq_len 45 | 46 | remainder = math.ceil(m) * multiple - seq_len 47 | pad_offset = (0,) * (-1 - dim) * 2 48 | padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value = value) 49 | return padded_tensor, seq_len 50 | 51 | def batched_gather(x, indices): 52 | batch_range = create_batch_range(indices, indices.ndim - 1) 53 | return x[batch_range, indices] 54 | 55 | def identity(t): 56 | return t 57 | 58 | def l2norm(t): 59 | return F.normalize(t, dim = -1) 60 | 61 | # tensor helpers 62 | 63 | def create_batch_range(t, right_pad_dims = 1): 64 | b, device = t.shape[0], t.device 65 | batch_range = torch.arange(b, device = device) 66 | pad_dims = ((1,) * right_pad_dims) 67 | return batch_range.reshape(-1, *pad_dims) 68 | 69 | 70 | def window_partition(x, window_size): 71 | """ 72 | Partition into non-overlapping windows with padding if needed. 73 | Args: 74 | x (tensor): input tokens with [B, H, W, C]. 75 | window_size (int): window size. 76 | 77 | Returns: 78 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 79 | (Hp, Wp): padded height and width before partition 80 | """ 81 | B, H, W, C = x.shape 82 | 83 | pad_h = (window_size - H % window_size) % window_size 84 | pad_w = (window_size - W % window_size) % window_size 85 | if pad_h > 0 or pad_w > 0: 86 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 87 | Hp, Wp = H + pad_h, W + pad_w 88 | 89 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 90 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 91 | return windows, (Hp, Wp) 92 | 93 | 94 | def window_unpartition(windows, window_size, pad_hw, hw): 95 | """ 96 | Window unpartition into original sequences and removing padding. 97 | Args: 98 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 99 | window_size (int): window size. 100 | pad_hw (Tuple): padded height and width (Hp, Wp). 101 | hw (Tuple): original height and width (H, W) before padding. 102 | 103 | Returns: 104 | x: unpartitioned sequences with [B, H, W, C]. 105 | """ 106 | Hp, Wp = pad_hw 107 | H, W = hw 108 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 109 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 110 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 111 | 112 | if Hp > H or Wp > W: 113 | x = x[:, :H, :W, :].contiguous() 114 | return x 115 | 116 | 117 | def get_rel_pos(q_size, k_size, rel_pos): 118 | """ 119 | Get relative positional embeddings according to the relative positions of 120 | query and key sizes. 121 | Args: 122 | q_size (int): size of query q. 123 | k_size (int): size of key k. 124 | rel_pos (Tensor): relative position embeddings (L, C). 125 | 126 | Returns: 127 | Extracted positional embeddings according to relative positions. 128 | """ 129 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 130 | # Interpolate rel pos if needed. 131 | if rel_pos.shape[0] != max_rel_dist: 132 | # Interpolate rel pos. 133 | rel_pos_resized = F.interpolate( 134 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 135 | size=max_rel_dist, 136 | mode="linear", 137 | ) 138 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 139 | else: 140 | rel_pos_resized = rel_pos 141 | 142 | # Scale the coords with short length if shapes for q and k are different. 143 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 144 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 145 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 146 | 147 | return rel_pos_resized[relative_coords.long()] 148 | 149 | 150 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): 151 | """ 152 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 153 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 154 | Args: 155 | attn (Tensor): attention map. 156 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 157 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 158 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 159 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 160 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 161 | 162 | Returns: 163 | attn (Tensor): attention map with added relative positional embeddings. 164 | """ 165 | q_h, q_w = q_size # 80, 120 166 | k_h, k_w = k_size # 80, 120 167 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) # torch.Size([80, 80, 64]) 168 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) # torch.Size([120, 120, 64]) 169 | 170 | B, _, dim = q.shape 171 | r_q = q.reshape(B, q_h, q_w, dim) 172 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) # torch.Size([6, 80, 120, 80]) 173 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) # torch.Size([6, 80, 120, 120]) 174 | 175 | attn = ( # 2048*2048这里会爆, rel_h: [6, 128, 128, 128], rel_w:[6, 128, 128, 128], None扩充后都乘128:6*128*128*128*128*4 = 6,442,450,944 176 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 177 | ).view(B, q_h * q_w, k_h * k_w) 178 | 179 | return attn 180 | 181 | 182 | def get_abs_pos(abs_pos, has_cls_token, hw): 183 | """ 184 | Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token 185 | dimension for the original embeddings. 186 | Args: 187 | abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). 188 | has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. 189 | hw (Tuple): size of input image tokens. 190 | 191 | Returns: 192 | Absolute positional embeddings after processing with shape (1, H, W, C) 193 | """ 194 | h, w = hw 195 | if has_cls_token: 196 | abs_pos = abs_pos[:, 1:] 197 | xy_num = abs_pos.shape[1] 198 | size = int(math.sqrt(xy_num)) 199 | assert size * size == xy_num 200 | 201 | if size != h or size != w: 202 | new_abs_pos = F.interpolate( 203 | abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), 204 | size=(h, w), 205 | mode="bicubic", 206 | align_corners=False, 207 | ) 208 | 209 | return new_abs_pos.permute(0, 2, 3, 1) 210 | else: 211 | return abs_pos.reshape(1, h, w, -1) 212 | 213 | 214 | class PatchEmbed(nn.Module): 215 | """ 216 | Image to Patch Embedding. 217 | """ 218 | 219 | def __init__( 220 | self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 221 | ): 222 | """ 223 | Args: 224 | kernel_size (Tuple): kernel size of the projection layer. 225 | stride (Tuple): stride of the projection layer. 226 | padding (Tuple): padding size of the projection layer. 227 | in_chans (int): Number of input image channels. 228 | embed_dim (int): embed_dim (int): Patch embedding dimension. 229 | """ 230 | super().__init__() 231 | 232 | self.proj = nn.Conv2d( 233 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 234 | ) 235 | 236 | def forward(self, x): 237 | x = self.proj(x) 238 | # B C H W -> B H W C 239 | x = x.permute(0, 2, 3, 1) 240 | return x 241 | 242 | class Router(nn.Module): 243 | def __init__(self, embed_dim=384): 244 | super().__init__() 245 | self.in_conv = nn.Sequential( 246 | nn.LayerNorm(embed_dim), 247 | nn.Linear(embed_dim, embed_dim), 248 | nn.GELU() 249 | ) 250 | 251 | self.out_conv = nn.Sequential( 252 | nn.Linear(embed_dim, embed_dim // 2), 253 | nn.GELU(), 254 | nn.Linear(embed_dim // 2, embed_dim // 4), 255 | nn.GELU(), 256 | nn.Linear(embed_dim // 4, 2), 257 | nn.LogSoftmax(dim=-1) 258 | ) 259 | 260 | def forward(self, x, policy=None): 261 | x = self.in_conv(x) # 16, 1024, 384 262 | B, N, C = x.size() 263 | local_x = x[:,:, :C//2] # 16, 1024, 192 264 | global_x = (x[:,:, C//2:]).sum(dim=1, keepdim=True) / N # 16, 1, 192 265 | x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1) # 16, 1024, 384 266 | return self.out_conv(x) # 16, 1024, 2 267 | -------------------------------------------------------------------------------- /modeling/backbone/vit_teacher.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import fvcore.nn.weight_init as weight_init 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from detectron2.layers import CNNBlockBase, Conv2d, get_norm 8 | from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous 9 | from fairscale.nn.checkpoint import checkpoint_wrapper 10 | from timm.models.layers import DropPath, Mlp, trunc_normal_ 11 | from .backbone import Backbone 12 | from .utils import ( 13 | PatchEmbed, 14 | add_decomposed_rel_pos, 15 | get_abs_pos, 16 | window_partition, 17 | window_unpartition, 18 | ) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | __all__ = ["ViT"] 24 | 25 | 26 | class Attention(nn.Module): 27 | """Multi-head Attention block with relative position embeddings.""" 28 | 29 | def __init__( 30 | self, 31 | dim, 32 | num_heads=8, 33 | qkv_bias=True, 34 | use_rel_pos=False, 35 | rel_pos_zero_init=True, 36 | input_size=None, 37 | ): 38 | """ 39 | Args: 40 | dim (int): Number of input channels. 41 | num_heads (int): Number of attention heads. 42 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 43 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 44 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 45 | input_size (int or None): Input resolution for calculating the relative positional 46 | parameter size. 47 | """ 48 | super().__init__() 49 | self.num_heads = num_heads 50 | head_dim = dim // num_heads 51 | self.scale = head_dim**-0.5 52 | 53 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 54 | self.proj = nn.Linear(dim, dim) 55 | 56 | self.use_rel_pos = use_rel_pos 57 | if self.use_rel_pos: 58 | # initialize relative positional embeddings 59 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 60 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 61 | 62 | if not rel_pos_zero_init: 63 | trunc_normal_(self.rel_pos_h, std=0.02) 64 | trunc_normal_(self.rel_pos_w, std=0.02) 65 | 66 | def forward(self, x): 67 | B, H, W, _ = x.shape 68 | # qkv with shape (3, B, nHead, H * W, C) 69 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 70 | # q, k, v with shape (B * nHead, H * W, C) 71 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 72 | 73 | attn = (q * self.scale) @ k.transpose(-2, -1) 74 | 75 | if self.use_rel_pos: 76 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 77 | 78 | attn = attn.softmax(dim=-1) 79 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 80 | x = self.proj(x) 81 | 82 | return x 83 | 84 | class LayerNorm(nn.Module): 85 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 86 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 87 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 88 | with shape (batch_size, channels, height, width). 89 | """ 90 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 91 | super().__init__() 92 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 93 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 94 | self.eps = eps 95 | self.data_format = data_format 96 | if self.data_format not in ["channels_last", "channels_first"]: 97 | raise NotImplementedError 98 | self.normalized_shape = (normalized_shape, ) 99 | 100 | def forward(self, x): 101 | if self.data_format == "channels_last": 102 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 103 | elif self.data_format == "channels_first": 104 | u = x.mean(1, keepdim=True) 105 | s = (x - u).pow(2).mean(1, keepdim=True) 106 | x = (x - u) / torch.sqrt(s + self.eps) 107 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 108 | return x 109 | 110 | class ResBottleneckBlock(CNNBlockBase): 111 | """ 112 | The standard bottleneck residual block without the last activation layer. 113 | It contains 3 conv layers with kernels 1x1, 3x3, 1x1. 114 | """ 115 | 116 | def __init__( 117 | self, 118 | in_channels, 119 | out_channels, 120 | bottleneck_channels, 121 | norm="LN", 122 | act_layer=nn.GELU, 123 | conv_kernels=3, 124 | conv_paddings=1, 125 | ): 126 | """ 127 | Args: 128 | in_channels (int): Number of input channels. 129 | out_channels (int): Number of output channels. 130 | bottleneck_channels (int): number of output channels for the 3x3 131 | "bottleneck" conv layers. 132 | norm (str or callable): normalization for all conv layers. 133 | See :func:`layers.get_norm` for supported format. 134 | act_layer (callable): activation for all conv layers. 135 | """ 136 | super().__init__(in_channels, out_channels, 1) 137 | 138 | self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) 139 | self.norm1 = get_norm(norm, bottleneck_channels) 140 | self.act1 = act_layer() 141 | 142 | self.conv2 = Conv2d( 143 | bottleneck_channels, 144 | bottleneck_channels, 145 | conv_kernels, 146 | padding=conv_paddings, 147 | bias=False, 148 | ) 149 | self.norm2 = get_norm(norm, bottleneck_channels) 150 | self.act2 = act_layer() 151 | 152 | self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) 153 | self.norm3 = get_norm(norm, out_channels) 154 | 155 | for layer in [self.conv1, self.conv2, self.conv3]: 156 | weight_init.c2_msra_fill(layer) 157 | for layer in [self.norm1, self.norm2]: 158 | layer.weight.data.fill_(1.0) 159 | layer.bias.data.zero_() 160 | # zero init last norm layer. 161 | self.norm3.weight.data.zero_() 162 | self.norm3.bias.data.zero_() 163 | 164 | def forward(self, x): 165 | out = x 166 | for layer in self.children(): 167 | out = layer(out) 168 | 169 | out = x + out 170 | return out 171 | 172 | 173 | class Block(nn.Module): 174 | """Transformer blocks with support of window attention and residual propagation blocks""" 175 | 176 | def __init__( 177 | self, 178 | dim, 179 | num_heads, 180 | mlp_ratio=4.0, 181 | qkv_bias=True, 182 | drop_path=0.0, 183 | norm_layer=nn.LayerNorm, 184 | act_layer=nn.GELU, 185 | use_rel_pos=False, 186 | rel_pos_zero_init=True, 187 | window_size=0, 188 | use_cc_attn = False, 189 | use_residual_block=False, 190 | use_convnext_block=False, 191 | input_size=None, 192 | res_conv_kernel_size=3, 193 | res_conv_padding=1, 194 | ): 195 | """ 196 | Args: 197 | dim (int): Number of input channels. 198 | num_heads (int): Number of attention heads in each ViT block. 199 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 200 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 201 | drop_path (float): Stochastic depth rate. 202 | norm_layer (nn.Module): Normalization layer. 203 | act_layer (nn.Module): Activation layer. 204 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 205 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 206 | window_size (int): Window size for window attention blocks. If it equals 0, then not 207 | use window attention. 208 | use_residual_block (bool): If True, use a residual block after the MLP block. 209 | input_size (int or None): Input resolution for calculating the relative positional 210 | parameter size. 211 | """ 212 | super().__init__() 213 | self.norm1 = norm_layer(dim) 214 | self.attn = Attention( 215 | dim, 216 | num_heads=num_heads, 217 | qkv_bias=qkv_bias, 218 | use_rel_pos=use_rel_pos, 219 | rel_pos_zero_init=rel_pos_zero_init, 220 | input_size=input_size if window_size == 0 else (window_size, window_size), 221 | ) 222 | 223 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 224 | self.norm2 = norm_layer(dim) 225 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) 226 | 227 | self.window_size = window_size 228 | 229 | self.use_residual_block = use_residual_block 230 | if use_residual_block: 231 | # Use a residual block with bottleneck channel as dim // 2 232 | self.residual = ResBottleneckBlock( 233 | in_channels=dim, 234 | out_channels=dim, 235 | bottleneck_channels=dim // 2, 236 | norm="LN", 237 | act_layer=act_layer, 238 | conv_kernels=res_conv_kernel_size, 239 | conv_paddings=res_conv_padding, 240 | ) 241 | self.use_convnext_block = use_convnext_block 242 | if use_convnext_block: 243 | self.convnext = ConvNextBlock(dim = dim) 244 | 245 | if use_cc_attn: 246 | self.attn = CrissCrossAttention(dim) 247 | 248 | 249 | def forward(self, x): 250 | shortcut = x 251 | x = self.norm1(x) 252 | # Window partition 253 | if self.window_size > 0: 254 | H, W = x.shape[1], x.shape[2] 255 | x, pad_hw = window_partition(x, self.window_size) 256 | 257 | x = self.attn(x) 258 | 259 | # Reverse window partition 260 | if self.window_size > 0: 261 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 262 | 263 | x = shortcut + self.drop_path(x) 264 | x = x + self.drop_path(self.mlp(self.norm2(x))) 265 | 266 | if self.use_residual_block: 267 | x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 268 | if self.use_convnext_block: 269 | x = self.convnext(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 270 | 271 | return x 272 | 273 | 274 | class ViT_Teacher(Backbone): 275 | """ 276 | This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. 277 | "Exploring Plain Vision Transformer Backbones for Object Detection", 278 | https://arxiv.org/abs/2203.16527 279 | """ 280 | 281 | def __init__( 282 | self, 283 | img_size=1024, 284 | patch_size=16, 285 | in_chans=3, 286 | embed_dim=768, 287 | depth=12, 288 | num_heads=12, 289 | mlp_ratio=4.0, 290 | qkv_bias=True, 291 | drop_path_rate=0.0, 292 | norm_layer=nn.LayerNorm, 293 | act_layer=nn.GELU, 294 | use_abs_pos=True, 295 | use_rel_pos=False, 296 | rel_pos_zero_init=True, 297 | window_size=0, 298 | window_block_indexes=(), 299 | residual_block_indexes=(), 300 | use_act_checkpoint=False, 301 | pretrain_img_size=224, 302 | pretrain_use_cls_token=True, 303 | out_feature="last_feat", 304 | res_conv_kernel_size=3, 305 | res_conv_padding=1, 306 | ): 307 | """ 308 | Args: 309 | img_size (int): Input image size. 310 | patch_size (int): Patch size. 311 | in_chans (int): Number of input image channels. 312 | embed_dim (int): Patch embedding dimension. 313 | depth (int): Depth of ViT. 314 | num_heads (int): Number of attention heads in each ViT block. 315 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 316 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 317 | drop_path_rate (float): Stochastic depth rate. 318 | norm_layer (nn.Module): Normalization layer. 319 | act_layer (nn.Module): Activation layer. 320 | use_abs_pos (bool): If True, use absolute positional embeddings. 321 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 322 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 323 | window_size (int): Window size for window attention blocks. 324 | window_block_indexes (list): Indexes for blocks using window attention. 325 | residual_block_indexes (list): Indexes for blocks using conv propagation. 326 | use_act_checkpoint (bool): If True, use activation checkpointing. 327 | pretrain_img_size (int): input image size for pretraining models. 328 | pretrain_use_cls_token (bool): If True, pretrainig models use class token. 329 | out_feature (str): name of the feature from the last block. 330 | """ 331 | super().__init__() 332 | self.pretrain_use_cls_token = pretrain_use_cls_token 333 | 334 | self.patch_embed = PatchEmbed( 335 | kernel_size=(patch_size, patch_size), 336 | stride=(patch_size, patch_size), 337 | in_chans=in_chans, 338 | embed_dim=embed_dim, 339 | ) 340 | 341 | if use_abs_pos: 342 | # Initialize absolute positional embedding with pretrain image size. 343 | num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) 344 | num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches 345 | self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) 346 | else: 347 | self.pos_embed = None 348 | 349 | # stochastic depth decay rule 350 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 351 | 352 | self.blocks = nn.ModuleList() 353 | for i in range(depth): 354 | block = Block( 355 | dim=embed_dim, 356 | num_heads=num_heads, 357 | mlp_ratio=mlp_ratio, 358 | qkv_bias=qkv_bias, 359 | drop_path=dpr[i], 360 | norm_layer=norm_layer, 361 | act_layer=act_layer, 362 | use_rel_pos=use_rel_pos, 363 | rel_pos_zero_init=rel_pos_zero_init, 364 | window_size=window_size if i in window_block_indexes else 0, 365 | use_residual_block=i in residual_block_indexes, 366 | input_size=(img_size // patch_size, img_size // patch_size), 367 | res_conv_kernel_size=res_conv_kernel_size, 368 | res_conv_padding=res_conv_padding, 369 | ) 370 | if use_act_checkpoint: 371 | block = checkpoint_wrapper(block) 372 | self.blocks.append(block) 373 | 374 | self._out_feature_channels = {out_feature: embed_dim} 375 | self._out_feature_strides = {out_feature: patch_size} 376 | self._out_features = [out_feature] 377 | 378 | if self.pos_embed is not None: 379 | trunc_normal_(self.pos_embed, std=0.02) 380 | 381 | self.apply(self._init_weights) 382 | 383 | def _init_weights(self, m): 384 | if isinstance(m, nn.Linear): 385 | trunc_normal_(m.weight, std=0.02) 386 | if isinstance(m, nn.Linear) and m.bias is not None: 387 | nn.init.constant_(m.bias, 0) 388 | elif isinstance(m, nn.LayerNorm): 389 | nn.init.constant_(m.bias, 0) 390 | nn.init.constant_(m.weight, 1.0) 391 | 392 | def forward(self, x): 393 | x = self.patch_embed(x) 394 | if self.pos_embed is not None: 395 | x = x + get_abs_pos( 396 | self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) 397 | ) 398 | 399 | 400 | for blk in self.blocks: # torch.Size([14, 32, 32, 384]) 401 | x = blk(x) 402 | 403 | outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)} 404 | 405 | return outputs['last_feat'] -------------------------------------------------------------------------------- /modeling/backbone/vit.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import fvcore.nn.weight_init as weight_init 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from detectron2.layers import CNNBlockBase, Conv2d, get_norm 8 | from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous 9 | from fairscale.nn.checkpoint import checkpoint_wrapper 10 | from timm.models.layers import DropPath, Mlp, trunc_normal_ 11 | from functools import partial 12 | from .backbone import Backbone 13 | from .utils import ( 14 | PatchEmbed, 15 | add_decomposed_rel_pos, 16 | get_abs_pos, 17 | window_partition, 18 | window_unpartition, 19 | Router 20 | ) 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | __all__ = ["ViT"] 26 | 27 | 28 | class Attention(nn.Module): 29 | """Multi-head Attention block with relative position embeddings.""" 30 | 31 | def __init__( 32 | self, 33 | dim, 34 | num_heads=8, 35 | qkv_bias=True, 36 | use_rel_pos=False, 37 | rel_pos_zero_init=True, 38 | input_size=None, 39 | ): 40 | """ 41 | Args: 42 | dim (int): Number of input channels. 43 | num_heads (int): Number of attention heads. 44 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 45 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 46 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 47 | input_size (int or None): Input resolution for calculating the relative positional 48 | parameter size. 49 | """ 50 | super().__init__() 51 | self.num_heads = num_heads 52 | head_dim = dim // num_heads 53 | self.scale = head_dim**-0.5 54 | 55 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 56 | self.proj = nn.Linear(dim, dim) 57 | 58 | self.use_rel_pos = use_rel_pos 59 | if self.use_rel_pos: 60 | # initialize relative positional embeddings 61 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 62 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 63 | 64 | if not rel_pos_zero_init: 65 | trunc_normal_(self.rel_pos_h, std=0.02) 66 | trunc_normal_(self.rel_pos_w, std=0.02) 67 | 68 | def softmax_with_policy(self, attn, policy, eps=1e-6): 69 | B, N, _ = policy.size() # 128, 197, 1 70 | B, H, N, N = attn.size() 71 | attn_policy = policy.reshape(B, 1, 1, N) # * policy.reshape(B, 1, N, 1) 72 | eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N) 73 | attn_policy = attn_policy + (1.0 - attn_policy) * eye # 目的是将对角线上的token计算attention 74 | max_att = torch.max(attn, dim=-1, keepdim=True)[0] 75 | attn = attn - max_att 76 | # attn = attn.exp_() * attn_policy 77 | # return attn / attn.sum(dim=-1, keepdim=True) 78 | 79 | # for stable training 80 | attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32) 81 | attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps) 82 | return attn.type_as(max_att) 83 | 84 | 85 | def forward(self, x, policy, H, W): # total FLOPs: 4 * B * hw * dim * dim + 2 * B * hw * hw * dim 86 | if x.ndim == 4: 87 | B, H, W, _ = x.shape 88 | N = H*W 89 | else: 90 | B, N, _ = x.shape 91 | 92 | 93 | # qkv with shape (3, B, nHead, H * W, C) self.qkv.flops: b * hw * dim * dim * 3 94 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 6.341787648 reshape和permute没有FLOPs 95 | # q, k, v with shape (B * nHead, H * W, C) 96 | q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0) 97 | 98 | attn = (q * self.scale) @ k.transpose(-2, -1) # 5.637144576 (B * hw * hw * dim) 14 * 1024*1024*384 99 | 100 | if self.use_rel_pos: 101 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 102 | 103 | if policy is None: 104 | attn = attn.softmax(dim=-1) 105 | else: 106 | attn = self.softmax_with_policy(attn.reshape(B, self.num_heads, N, N), policy).reshape(B*self.num_heads, N, N) 107 | 108 | # # 5.637144576 (B * hw * hw * dim) 109 | if x.ndim == 4: 110 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 111 | else: 112 | x = (attn @ v).view(B, self.num_heads, N, -1).permute(0, 2, 1, 3).reshape(B, N, -1) 113 | 114 | x = self.proj(x) # 2.113929216 (B * hw * dim * dim) 14 * 1024 * 384 * 384 115 | 116 | return x 117 | 118 | class LayerNorm(nn.Module): 119 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 120 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 121 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 122 | with shape (batch_size, channels, height, width). 123 | """ 124 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 125 | super().__init__() 126 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 127 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 128 | self.eps = eps 129 | self.data_format = data_format 130 | if self.data_format not in ["channels_last", "channels_first"]: 131 | raise NotImplementedError 132 | self.normalized_shape = (normalized_shape, ) 133 | 134 | def forward(self, x): 135 | if self.data_format == "channels_last": 136 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 137 | elif self.data_format == "channels_first": 138 | u = x.mean(1, keepdim=True) 139 | s = (x - u).pow(2).mean(1, keepdim=True) 140 | x = (x - u) / torch.sqrt(s + self.eps) 141 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 142 | return x 143 | 144 | class ResBottleneckBlock(CNNBlockBase): 145 | """ 146 | The standard bottleneck residual block without the last activation layer. 147 | It contains 3 conv layers with kernels 1x1, 3x3, 1x1. 148 | """ 149 | 150 | def __init__( 151 | self, 152 | in_channels, 153 | out_channels, 154 | bottleneck_channels, 155 | norm="LN", 156 | act_layer=nn.GELU, 157 | conv_kernels=3, 158 | conv_paddings=1, 159 | ): 160 | """ 161 | Args: 162 | in_channels (int): Number of input channels. 163 | out_channels (int): Number of output channels. 164 | bottleneck_channels (int): number of output channels for the 3x3 165 | "bottleneck" conv layers. 166 | norm (str or callable): normalization for all conv layers. 167 | See :func:`layers.get_norm` for supported format. 168 | act_layer (callable): activation for all conv layers. 169 | """ 170 | super().__init__(in_channels, out_channels, 1) 171 | 172 | self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) 173 | self.norm1 = get_norm(norm, bottleneck_channels) 174 | self.act1 = act_layer() 175 | 176 | self.conv2 = Conv2d( 177 | bottleneck_channels, 178 | bottleneck_channels, 179 | conv_kernels, 180 | padding=conv_paddings, 181 | bias=False, 182 | ) 183 | self.norm2 = get_norm(norm, bottleneck_channels) 184 | self.act2 = act_layer() 185 | 186 | self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) 187 | self.norm3 = get_norm(norm, out_channels) 188 | 189 | for layer in [self.conv1, self.conv2, self.conv3]: 190 | weight_init.c2_msra_fill(layer) 191 | for layer in [self.norm1, self.norm2]: 192 | layer.weight.data.fill_(1.0) 193 | layer.bias.data.zero_() 194 | # zero init last norm layer. 195 | self.norm3.weight.data.zero_() 196 | self.norm3.bias.data.zero_() 197 | 198 | def forward(self, x): 199 | out = x 200 | for layer in self.children(): 201 | out = layer(out) 202 | 203 | out = x + out 204 | return out 205 | 206 | class ECA(nn.Module): 207 | def __init__(self, channels, b=1, gamma=2): 208 | super(ECA, self).__init__() 209 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 210 | self.channels = channels 211 | self.b = b 212 | self.gamma = gamma 213 | self.conv = nn.Conv1d( 214 | 1, 215 | 1, 216 | kernel_size=self.kernel_size(), 217 | padding=(self.kernel_size() - 1) // 2, 218 | bias=False, 219 | ) 220 | self.sigmoid = nn.Sigmoid() 221 | 222 | def kernel_size(self): 223 | k = int(abs((math.log2(self.channels) / self.gamma) + self.b / self.gamma)) 224 | out = k if k % 2 else k + 1 225 | return out 226 | 227 | def forward(self, x): 228 | 229 | # feature descriptor on the global spatial information 230 | y = self.avg_pool(x) 231 | 232 | # Two different branches of ECA module 233 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 234 | 235 | # Multi-scale information fusion 236 | y = self.sigmoid(y) 237 | 238 | return x * y.expand_as(x) 239 | 240 | 241 | 242 | class LTRM(nn.Module): 243 | """Efficient Block to replace the Original Transformer Block""" 244 | def __init__(self, dim, expand_ratio = 2, kernel_size = 5): 245 | super(LTRM, self).__init__() 246 | self.fc1 = nn.Linear(dim, dim*expand_ratio) 247 | self.act1 = nn.GELU() 248 | self.dwconv = nn.Conv2d(dim*expand_ratio, dim*expand_ratio, kernel_size=(kernel_size, kernel_size), groups=dim*expand_ratio, padding=(kernel_size//2, kernel_size//2)) 249 | self.act2 = nn.GELU() 250 | self.fc2 = nn.Linear(dim*expand_ratio, dim) 251 | self.eca = ECA(dim) 252 | 253 | def forward(self, x, prev_msa=None): 254 | x = self.act1(self.fc1(x)) 255 | x = self.act2(self.dwconv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)) 256 | x = self.fc2(x) 257 | y = x + self.eca(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 258 | return y 259 | 260 | 261 | 262 | 263 | 264 | class EfficientBlock_ALR(nn.ModuleList): # from rething attention AAAI 265 | def __init__(self, model_dimension=128): 266 | super(EfficientBlock_ALR, self).__init__() 267 | # self.sentence_length=sentence_length 268 | self.model_dimension = model_dimension 269 | self.width = self.model_dimension 270 | self.layers=list() 271 | widths=[1,2,1] 272 | self.depth=len(widths)-1 273 | self.layers=nn.ModuleList() 274 | for i in range(self.depth): 275 | self.layers.extend([nn.LayerNorm(self.width * widths[i]),nn.Linear(self.width * widths[i], self.width * widths[i+1])]) 276 | if(i