├── .gitignore ├── LICENSE ├── README.md ├── assets ├── harmonization.jpg ├── matanyone_logo.png ├── pipeline.jpg ├── teaser.jpg └── teaser_demo.gif ├── hugging_face ├── app.py ├── matanyone_wrapper.py ├── requirements.txt └── tools │ ├── __init__.py │ ├── base_segmenter.py │ ├── download_util.py │ ├── interact_tools.py │ ├── mask_painter.py │ ├── misc.py │ └── painter.py ├── inference_hf.py ├── inference_matanyone.py ├── inputs ├── mask │ ├── test-sample0_1.png │ ├── test-sample0_2.png │ ├── test-sample1.png │ ├── test-sample2.png │ └── test-sample3.png └── video │ ├── test-sample0 │ ├── 0000.jpg │ ├── 0001.jpg │ ├── 0002.jpg │ ├── 0003.jpg │ ├── 0004.jpg │ ├── 0005.jpg │ ├── 0006.jpg │ ├── 0007.jpg │ ├── 0008.jpg │ ├── 0009.jpg │ ├── 0010.jpg │ ├── 0011.jpg │ ├── 0012.jpg │ ├── 0013.jpg │ ├── 0014.jpg │ ├── 0015.jpg │ ├── 0016.jpg │ ├── 0017.jpg │ ├── 0018.jpg │ ├── 0019.jpg │ ├── 0020.jpg │ ├── 0021.jpg │ ├── 0022.jpg │ ├── 0023.jpg │ ├── 0024.jpg │ ├── 0025.jpg │ ├── 0026.jpg │ ├── 0027.jpg │ ├── 0028.jpg │ ├── 0029.jpg │ ├── 0030.jpg │ ├── 0031.jpg │ ├── 0032.jpg │ ├── 0033.jpg │ ├── 0034.jpg │ ├── 0035.jpg │ ├── 0036.jpg │ ├── 0037.jpg │ ├── 0038.jpg │ ├── 0039.jpg │ ├── 0040.jpg │ ├── 0041.jpg │ ├── 0042.jpg │ ├── 0043.jpg │ ├── 0044.jpg │ ├── 0045.jpg │ ├── 0046.jpg │ ├── 0047.jpg │ ├── 0048.jpg │ ├── 0049.jpg │ ├── 0050.jpg │ ├── 0051.jpg │ ├── 0052.jpg │ ├── 0053.jpg │ ├── 0054.jpg │ ├── 0055.jpg │ ├── 0056.jpg │ ├── 0057.jpg │ ├── 0058.jpg │ ├── 0059.jpg │ ├── 0060.jpg │ ├── 0061.jpg │ ├── 0062.jpg │ ├── 0063.jpg │ ├── 0064.jpg │ ├── 0065.jpg │ ├── 0066.jpg │ ├── 0067.jpg │ ├── 0068.jpg │ ├── 0069.jpg │ ├── 0070.jpg │ └── 0071.jpg │ ├── test-sample1.mp4 │ ├── test-sample2.mp4 │ └── test-sample3.mp4 ├── matanyone ├── __init__.py ├── config │ ├── __init__.py │ ├── eval_matanyone_config.yaml │ ├── hydra │ │ └── job_logging │ │ │ ├── custom-no-rank.yaml │ │ │ └── custom.yaml │ └── model │ │ └── base.yaml ├── inference │ ├── __init__.py │ ├── image_feature_store.py │ ├── inference_core.py │ ├── kv_memory_store.py │ ├── memory_manager.py │ ├── object_info.py │ ├── object_manager.py │ └── utils │ │ ├── __init__.py │ │ └── args_utils.py ├── model │ ├── __init__.py │ ├── aux_modules.py │ ├── big_modules.py │ ├── channel_attn.py │ ├── group_modules.py │ ├── matanyone.py │ ├── modules.py │ ├── transformer │ │ ├── __init__.py │ │ ├── object_summarizer.py │ │ ├── object_transformer.py │ │ ├── positional_encoding.py │ │ └── transformer_layers.py │ └── utils │ │ ├── __init__.py │ │ ├── memory_utils.py │ │ ├── parameter_groups.py │ │ └── resnet.py └── utils │ ├── __init__.py │ ├── get_default_model.py │ ├── inference_utils.py │ └── tensor_utils.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | .DS_Store 4 | hugging_face/assets/ 5 | results/ 6 | test_sample/ 7 | pretrained_models/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # S-Lab License 1.0 2 | 3 | Copyright 2023 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. 36 | 37 | 38 | --- 39 | For inquiries permission for commercial use, please consult our team: 40 | Peiqing Yang (peiqingyang99@outlook.com), 41 | Dr. Shangchen Zhou (shangchenzhou@gmail.com), 42 | Prof. Chen Change Loy (ccloy@ntu.edu.sg) 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | MatAnyone Logo 4 |

Stable Video Matting with Consistent Memory Propagation

5 |
6 | 7 |
8 | Peiqing Yang1  9 | Shangchen Zhou1  10 | Jixin Zhao1  11 | Qingyi Tao2  12 | Chen Change Loy1 13 |
14 |
15 | 1S-Lab, Nanyang Technological University  16 | 2SenseTime Research, Singapore  17 |
18 | 19 | 20 |
21 |

22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 |

36 |
37 | 38 | MatAnyone is a practical human video matting framework supporting target assignment, with stable performance in both semantics of core regions and fine-grained boundary details. 39 | 40 |
41 | 42 |
43 | 44 | :movie_camera: For more visual results, go checkout our project page 45 | 46 | --- 47 |
48 | 49 | 50 | ## 📮 Update 51 | - [2025.03] Release our evaluation benchmark - [YouTubeMatte](https://github.com/pq-yang/MatAnyone?tab=readme-ov-file#-evaluation-benchmark). 52 | - [2025.03] Integrate MatAnyone with Hugging Face 🤗 53 | - [2025.02] Release inference codes and gradio demo. 54 | - [2025.02] This repo is created. 55 | 56 | ## 🔎 Overview 57 | ![overall_structure](assets/pipeline.jpg) 58 | 59 | ## 🔧 Installation 60 | 1. Clone Repo 61 | ```bash 62 | git clone https://github.com/pq-yang/MatAnyone 63 | cd MatAnyone 64 | ``` 65 | 66 | 2. Create Conda Environment and Install Dependencies 67 | ```bash 68 | # create new conda env 69 | conda create -n matanyone python=3.8 -y 70 | conda activate matanyone 71 | 72 | # install python dependencies 73 | pip install -e . 74 | # [optional] install python dependencies for gradio demo 75 | pip3 install -r hugging_face/requirements.txt 76 | ``` 77 | 78 | ## 🤗 Load from Hugging Face 79 | Alternatively, models can be directly loaded from [Hugging Face](https://huggingface.co/PeiqingYang/MatAnyone) to make inference. 80 | 81 | ```shell 82 | pip install -q git+https://github.com/pq-yang/MatAnyone 83 | ``` 84 | 85 | To extract the foreground and the alpha video you can directly run the following lines. Please refer to [inference_hf.py](https://github.com/pq-yang/MatAnyone/blob/main/inference_hf.py) for more arguments. 86 | ```python 87 | from matanyone import InferenceCore 88 | processor = InferenceCore("PeiqingYang/MatAnyone") 89 | 90 | foreground_path, alpha_path = processor.process_video( 91 | input_path = "inputs/video/test-sample1.mp4", 92 | mask_path = "inputs/mask/test-sample1.png", 93 | output_path = "outputs" 94 | ) 95 | ``` 96 | 97 | ## 🔥 Inference 98 | 99 | ### Download Model 100 | Download our pretrained model from [MatAnyone v1.0.0](https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth) to the `pretrained_models` folder (pretrained model can also be automatically downloaded during the first inference). 101 | 102 | The directory structure will be arranged as: 103 | ``` 104 | pretrained_models 105 | |- matanyone.pth 106 | ``` 107 | 108 | ### Quick Test 109 | We provide some examples in the [`inputs`](./inputs) folder. **For each run, we take a video and its first-frame segmenatation mask as input.** The segmenation mask could be obtained from interactive segmentation models such as [SAM2 demo](https://huggingface.co/spaces/fffiloni/SAM2-Image-Predictor). For example, the directory structure can be arranged as: 110 | ``` 111 | inputs 112 | |- video 113 | |- test-sample0 # folder containing all frames 114 | |- test-sample1.mp4 # .mp4, .mov, .avi 115 | |- mask 116 | |- test-sample0_1.png # mask for person 1 117 | |- test-sample0_2.png # mask for person 2 118 | |- test-sample1.png 119 | ``` 120 | Run the following command to try it out: 121 | 122 | ```shell 123 | ## single target 124 | # short video; 720p 125 | python inference_matanyone.py -i inputs/video/test-sample1.mp4 -m inputs/mask/test-sample1.png 126 | # short video; 1080p 127 | python inference_matanyone.py -i inputs/video/test-sample2.mp4 -m inputs/mask/test-sample2.png 128 | # long video; 1080p 129 | python inference_matanyone.py -i inputs/video/test-sample3.mp4 -m inputs/mask/test-sample3.png 130 | 131 | ## multiple targets (control by mask) 132 | # obtain matte for target 1 133 | python inference_matanyone.py -i inputs/video/test-sample0 -m inputs/mask/test-sample0_1.png --suffix target1 134 | # obtain matte for target 2 135 | python inference_matanyone.py -i inputs/video/test-sample0 -m inputs/mask/test-sample0_2.png --suffix target2 136 | ``` 137 | The results will be saved in the `results` folder, including the foreground output video and the alpha output video. 138 | - If you want to save the results as per-frame images, you can set `--save_image`. 139 | - If you want to set a limit for the maximum input resolution, you can set `--max_size`, and the video will be downsampled if min(w, h) exceeds. By default, we don't set the limit. 140 | 141 | ## 🎪 Interactive Demo 142 | To get rid of the preparation for first-frame segmentation mask, we prepare a gradio demo on [hugging face](https://huggingface.co/spaces/PeiqingYang/MatAnyone) and could also **launch locally**. Just drop your video/image, assign the target masks with a few clicks, and get the the matting results! 143 | ```shell 144 | cd hugging_face 145 | 146 | # install python dependencies 147 | pip3 install -r requirements.txt # FFmpeg required 148 | 149 | # launch the demo 150 | python app.py 151 | ``` 152 | 153 | By launching, an interactive interface will appear as follow: 154 | 155 | ![overall_teaser](assets/teaser_demo.gif) 156 | 157 | 158 | ## 📊 Evaluation Benchmark 159 | 160 | We provide a synthetic benchmark **[YouTubeMatte](https://drive.google.com/file/d/1IEH0RaimT_hSp38AWF6wuwNJzzNSHpJ4/view?usp=sharing)** to enlarge the commonly-used [VideoMatte240K-Test](https://github.com/PeterL1n/RobustVideoMatting/blob/master/documentation/training.md#evaluation). A comparison between them is summarized in the table below. 161 | 162 | | Dataset | #Foregrounds | Source | Harmonized | 163 | | :------------------ | :----------: | :----------------: | :--------: | 164 | | VideoMatte240K-Test | 5 | Purchased Footage | ❌ | 165 | | **YouTubeMatte** | **32** | **YouTube Videos** | ✅ | 166 | 167 | It is noteworthy that we applied **harmonization** (using [Harmonizer](https://github.com/ZHKKKe/Harmonizer)) when compositing the foreground on a background. Such an operation effectively makes YouTubeMatte a more *challenging* benchmark that is closer to the *real* distribution. As shown in the figure below, while [RVM](https://github.com/PeterL1n/RobustVideoMatting) is confused by the harmonized frame, our method still yields robust performance. 168 | 169 | ![harmonization](assets/harmonization.jpg) 170 | 171 | 172 | ## 📑 Citation 173 | 174 | If you find our repo useful for your research, please consider citing our paper: 175 | 176 | ```bibtex 177 | @inProceedings{yang2025matanyone, 178 | title = {{MatAnyone}: Stable Video Matting with Consistent Memory Propagation}, 179 | author = {Yang, Peiqing and Zhou, Shangchen and Zhao, Jixin and Tao, Qingyi and Loy, Chen Change}, 180 | booktitle = {CVPR}, 181 | year = {2025} 182 | } 183 | ``` 184 | 185 | ## 📝 License 186 | 187 | This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license. 188 | 189 | ## 👏 Acknowledgement 190 | 191 | This project is built upon [Cutie](https://github.com/hkchengrex/Cutie), with the interactive demo adapted from [ProPainter](https://github.com/sczhou/ProPainter), leveraging segmentation capabilities from [Segment Anything Model](https://github.com/facebookresearch/segment-anything) and [Segment Anything Model 2](https://github.com/facebookresearch/sam2). Thanks for their awesome works! 192 | 193 | --- 194 | 195 | This study is supported under the RIE2020 Industry Alignment Fund – Industry Collaboration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s). 196 | 197 | ## 📧 Contact 198 | 199 | If you have any questions, please feel free to reach us at `peiqingyang99@outlook.com`. 200 | -------------------------------------------------------------------------------- /assets/harmonization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/harmonization.jpg -------------------------------------------------------------------------------- /assets/matanyone_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/matanyone_logo.png -------------------------------------------------------------------------------- /assets/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/pipeline.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/teaser.jpg -------------------------------------------------------------------------------- /assets/teaser_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/assets/teaser_demo.gif -------------------------------------------------------------------------------- /hugging_face/matanyone_wrapper.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | from torchvision.transforms.functional import to_tensor 4 | import numpy as np 5 | import random 6 | import cv2 7 | 8 | def gen_dilate(alpha, min_kernel_size, max_kernel_size): 9 | kernel_size = random.randint(min_kernel_size, max_kernel_size) 10 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) 11 | fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) 12 | dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255 13 | return dilate.astype(np.float32) 14 | 15 | def gen_erosion(alpha, min_kernel_size, max_kernel_size): 16 | kernel_size = random.randint(min_kernel_size, max_kernel_size) 17 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) 18 | fg = np.array(np.equal(alpha, 255).astype(np.float32)) 19 | erode = cv2.erode(fg, kernel, iterations=1)*255 20 | return erode.astype(np.float32) 21 | 22 | @torch.inference_mode() 23 | @torch.amp.autocast("cuda") 24 | def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10): 25 | """ 26 | Args: 27 | frames_np: [(H,W,C)]*n, uint8 28 | mask: (H,W), uint8 29 | Outputs: 30 | com: [(H,W,C)]*n, uint8 31 | pha: [(H,W,C)]*n, uint8 32 | """ 33 | 34 | # print(f'===== [r_erode] {r_erode}; [r_dilate] {r_dilate} =====') 35 | bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3)) 36 | objects = [1] 37 | 38 | # [optional] erode & dilate on given seg mask 39 | if r_dilate > 0: 40 | mask = gen_dilate(mask, r_dilate, r_dilate) 41 | if r_erode > 0: 42 | mask = gen_erosion(mask, r_erode, r_erode) 43 | 44 | mask = torch.from_numpy(mask).cuda() 45 | 46 | frames_np = [frames_np[0]]* n_warmup + frames_np 47 | 48 | frames = [] 49 | phas = [] 50 | for ti, frame_single in tqdm.tqdm(enumerate(frames_np)): 51 | image = to_tensor(frame_single).cuda().float() 52 | 53 | if ti == 0: 54 | output_prob = processor.step(image, mask, objects=objects) # encode given mask 55 | output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames 56 | else: 57 | if ti <= n_warmup: 58 | output_prob = processor.step(image, first_frame_pred=True) # clear past memory for warmup frames 59 | else: 60 | output_prob = processor.step(image) 61 | 62 | # convert output probabilities to an object mask 63 | mask = processor.output_prob_to_mask(output_prob) 64 | 65 | pha = mask.unsqueeze(2).cpu().numpy() 66 | com_np = frame_single / 255. * pha + bgr * (1 - pha) 67 | 68 | # DONOT save the warmup frames 69 | if ti > (n_warmup-1): 70 | frames.append((com_np*255).astype(np.uint8)) 71 | phas.append((pha*255).astype(np.uint8)) 72 | 73 | return frames, phas -------------------------------------------------------------------------------- /hugging_face/requirements.txt: -------------------------------------------------------------------------------- 1 | progressbar2 2 | gdown >= 4.7.1 3 | gitpython >= 3.1 4 | git+https://github.com/cheind/py-thin-plate-spline 5 | hickle >= 5.0 6 | tensorboard >= 2.11 7 | numpy >= 1.21 8 | git+https://github.com/facebookresearch/segment-anything.git 9 | gradio==4.31.0 10 | fastapi==0.111.0 11 | pydantic==2.7.1 12 | opencv-python >= 4.8 13 | matplotlib 14 | pyyaml 15 | av >= 0.5.2 16 | openmim 17 | tqdm >= 4.66.1 18 | psutil 19 | ffmpeg-python 20 | cython 21 | Pillow >= 9.5 22 | scipy >= 1.7 23 | pycocotools >= 2.0.7 24 | einops >= 0.6 25 | hydra-core >= 1.3.2 26 | PySide6 >= 6.2.0 27 | charset-normalizer >= 3.1.0 28 | netifaces >= 0.11.0 29 | cchardet >= 2.1.7 30 | easydict 31 | requests 32 | pyqtdarktheme 33 | imageio == 2.25.0 34 | imageio[ffmpeg] 35 | ffmpeg-python -------------------------------------------------------------------------------- /hugging_face/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/hugging_face/tools/__init__.py -------------------------------------------------------------------------------- /hugging_face/tools/base_segmenter.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import cv2 4 | from PIL import Image, ImageDraw, ImageOps 5 | import numpy as np 6 | from typing import Union 7 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator 8 | import matplotlib.pyplot as plt 9 | import PIL 10 | from .mask_painter import mask_painter 11 | 12 | 13 | class BaseSegmenter: 14 | def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): 15 | """ 16 | device: model device 17 | SAM_checkpoint: path of SAM checkpoint 18 | model_type: vit_b, vit_l, vit_h 19 | """ 20 | print(f"Initializing BaseSegmenter to {device}") 21 | assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h' 22 | 23 | self.device = device 24 | self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 25 | self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint) 26 | self.model.to(device=self.device) 27 | self.predictor = SamPredictor(self.model) 28 | self.embedded = False 29 | 30 | @torch.no_grad() 31 | def set_image(self, image: np.ndarray): 32 | # PIL.open(image_path) 3channel: RGB 33 | # image embedding: avoid encode the same image multiple times 34 | self.orignal_image = image 35 | if self.embedded: 36 | print('repeat embedding, please reset_image.') 37 | return 38 | self.predictor.set_image(image) 39 | self.embedded = True 40 | return 41 | 42 | @torch.no_grad() 43 | def reset_image(self): 44 | # reset image embeding 45 | self.predictor.reset_image() 46 | self.embedded = False 47 | 48 | def predict(self, prompts, mode, multimask=True): 49 | """ 50 | image: numpy array, h, w, 3 51 | prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input' 52 | prompts['point_coords']: numpy array [N,2] 53 | prompts['point_labels']: numpy array [1,N] 54 | prompts['mask_input']: numpy array [1,256,256] 55 | mode: 'point' (points only), 'mask' (mask only), 'both' (consider both) 56 | mask_outputs: True (return 3 masks), False (return 1 mask only) 57 | whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :] 58 | """ 59 | assert self.embedded, 'prediction is called before set_image (feature embedding).' 60 | assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both' 61 | 62 | if mode == 'point': 63 | masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], 64 | point_labels=prompts['point_labels'], 65 | multimask_output=multimask) 66 | elif mode == 'mask': 67 | masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'], 68 | multimask_output=multimask) 69 | elif mode == 'both': # both 70 | masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], 71 | point_labels=prompts['point_labels'], 72 | mask_input=prompts['mask_input'], 73 | multimask_output=multimask) 74 | else: 75 | raise("Not implement now!") 76 | # masks (n, h, w), scores (n,), logits (n, 256, 256) 77 | return masks, scores, logits 78 | 79 | 80 | if __name__ == "__main__": 81 | # load and show an image 82 | image = cv2.imread('/hhd3/gaoshang/truck.jpg') 83 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # numpy array (h, w, 3) 84 | 85 | # initialise BaseSegmenter 86 | SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' 87 | model_type = 'vit_h' 88 | device = "cuda:4" 89 | base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device) 90 | 91 | # image embedding (once embedded, multiple prompts can be applied) 92 | base_segmenter.set_image(image) 93 | 94 | # examples 95 | # point only ------------------------ 96 | mode = 'point' 97 | prompts = { 98 | 'point_coords': np.array([[500, 375], [1125, 625]]), 99 | 'point_labels': np.array([1, 1]), 100 | } 101 | masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) # masks (n, h, w), scores (n,), logits (n, 256, 256) 102 | painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) 103 | painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) 104 | cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) 105 | 106 | # both ------------------------ 107 | mode = 'both' 108 | mask_input = logits[np.argmax(scores), :, :] 109 | prompts = {'mask_input': mask_input [None, :, :]} 110 | prompts = { 111 | 'point_coords': np.array([[500, 375], [1125, 625]]), 112 | 'point_labels': np.array([1, 0]), 113 | 'mask_input': mask_input[None, :, :] 114 | } 115 | masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) 116 | painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) 117 | painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) 118 | cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image) 119 | 120 | # mask only ------------------------ 121 | mode = 'mask' 122 | mask_input = logits[np.argmax(scores), :, :] 123 | 124 | prompts = {'mask_input': mask_input[None, :, :]} 125 | 126 | masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) # masks (n, h, w), scores (n,), logits (n, 256, 256) 127 | painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) 128 | painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) 129 | cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image) 130 | -------------------------------------------------------------------------------- /hugging_face/tools/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | def sizeof_fmt(size, suffix='B'): 9 | """Get human readable file size. 10 | 11 | Args: 12 | size (int): File size. 13 | suffix (str): Suffix. Default: 'B'. 14 | 15 | Return: 16 | str: Formated file siz. 17 | """ 18 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 19 | if abs(size) < 1024.0: 20 | return f'{size:3.1f} {unit}{suffix}' 21 | size /= 1024.0 22 | return f'{size:3.1f} Y{suffix}' 23 | 24 | 25 | def download_file_from_google_drive(file_id, save_path): 26 | """Download files from google drive. 27 | Ref: 28 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 29 | Args: 30 | file_id (str): File id. 31 | save_path (str): Save path. 32 | """ 33 | 34 | session = requests.Session() 35 | URL = 'https://docs.google.com/uc?export=download' 36 | params = {'id': file_id} 37 | 38 | response = session.get(URL, params=params, stream=True) 39 | token = get_confirm_token(response) 40 | if token: 41 | params['confirm'] = token 42 | response = session.get(URL, params=params, stream=True) 43 | 44 | # get file size 45 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 46 | print(response_file_size) 47 | if 'Content-Range' in response_file_size.headers: 48 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 49 | else: 50 | file_size = None 51 | 52 | save_response_content(response, save_path, file_size) 53 | 54 | 55 | def get_confirm_token(response): 56 | for key, value in response.cookies.items(): 57 | if key.startswith('download_warning'): 58 | return value 59 | return None 60 | 61 | 62 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 63 | if file_size is not None: 64 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 65 | 66 | readable_file_size = sizeof_fmt(file_size) 67 | else: 68 | pbar = None 69 | 70 | with open(destination, 'wb') as f: 71 | downloaded_size = 0 72 | for chunk in response.iter_content(chunk_size): 73 | downloaded_size += chunk_size 74 | if pbar is not None: 75 | pbar.update(1) 76 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 77 | if chunk: # filter out keep-alive new chunks 78 | f.write(chunk) 79 | if pbar is not None: 80 | pbar.close() 81 | 82 | 83 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 84 | """Load file form http url, will download models if necessary. 85 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 86 | Args: 87 | url (str): URL to be downloaded. 88 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 89 | Default: None. 90 | progress (bool): Whether to show the download progress. Default: True. 91 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 92 | Returns: 93 | str: The path to the downloaded file. 94 | """ 95 | if model_dir is None: # use the pytorch hub_dir 96 | hub_dir = get_dir() 97 | model_dir = os.path.join(hub_dir, 'checkpoints') 98 | 99 | os.makedirs(model_dir, exist_ok=True) 100 | 101 | parts = urlparse(url) 102 | filename = os.path.basename(parts.path) 103 | if file_name is not None: 104 | filename = file_name 105 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 106 | if not os.path.exists(cached_file): 107 | print(f'Downloading: "{url}" to {cached_file}\n') 108 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 109 | return cached_file -------------------------------------------------------------------------------- /hugging_face/tools/interact_tools.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import cv2 4 | from PIL import Image, ImageDraw, ImageOps 5 | import numpy as np 6 | from typing import Union 7 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator 8 | import matplotlib.pyplot as plt 9 | import PIL 10 | from .mask_painter import mask_painter as mask_painter2 11 | from .base_segmenter import BaseSegmenter 12 | from .painter import mask_painter, point_painter 13 | import os 14 | import requests 15 | import sys 16 | 17 | 18 | mask_color = 3 19 | mask_alpha = 0.7 20 | contour_color = 1 21 | contour_width = 5 22 | point_color_ne = 8 23 | point_color_ps = 50 24 | point_alpha = 0.9 25 | point_radius = 15 26 | contour_color = 2 27 | contour_width = 5 28 | 29 | 30 | class SamControler(): 31 | def __init__(self, SAM_checkpoint, model_type, device): 32 | ''' 33 | initialize sam controler 34 | ''' 35 | self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) 36 | 37 | 38 | # def seg_again(self, image: np.ndarray): 39 | # ''' 40 | # it is used when interact in video 41 | # ''' 42 | # self.sam_controler.reset_image() 43 | # self.sam_controler.set_image(image) 44 | # return 45 | 46 | 47 | def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3): 48 | ''' 49 | it is used in first frame in video 50 | return: mask, logit, painted image(mask+point) 51 | ''' 52 | # self.sam_controler.set_image(image) 53 | origal_image = self.sam_controler.orignal_image 54 | neg_flag = labels[-1] 55 | if neg_flag==1: 56 | #find neg 57 | prompts = { 58 | 'point_coords': points, 59 | 'point_labels': labels, 60 | } 61 | masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) 62 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] 63 | prompts = { 64 | 'point_coords': points, 65 | 'point_labels': labels, 66 | 'mask_input': logit[None, :, :] 67 | } 68 | masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) 69 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] 70 | else: 71 | #find positive 72 | prompts = { 73 | 'point_coords': points, 74 | 'point_labels': labels, 75 | } 76 | masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) 77 | mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] 78 | 79 | 80 | assert len(points)==len(labels) 81 | 82 | painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) 83 | painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) 84 | painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) 85 | painted_image = Image.fromarray(painted_image) 86 | 87 | return mask, logit, painted_image 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /hugging_face/tools/mask_painter.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import copy 6 | import time 7 | 8 | 9 | def colormap(rgb=True): 10 | color_list = np.array( 11 | [ 12 | 0.000, 0.000, 0.000, 13 | 1.000, 1.000, 1.000, 14 | 1.000, 0.498, 0.313, 15 | 0.392, 0.581, 0.929, 16 | 0.000, 0.447, 0.741, 17 | 0.850, 0.325, 0.098, 18 | 0.929, 0.694, 0.125, 19 | 0.494, 0.184, 0.556, 20 | 0.466, 0.674, 0.188, 21 | 0.301, 0.745, 0.933, 22 | 0.635, 0.078, 0.184, 23 | 0.300, 0.300, 0.300, 24 | 0.600, 0.600, 0.600, 25 | 1.000, 0.000, 0.000, 26 | 1.000, 0.500, 0.000, 27 | 0.749, 0.749, 0.000, 28 | 0.000, 1.000, 0.000, 29 | 0.000, 0.000, 1.000, 30 | 0.667, 0.000, 1.000, 31 | 0.333, 0.333, 0.000, 32 | 0.333, 0.667, 0.000, 33 | 0.333, 1.000, 0.000, 34 | 0.667, 0.333, 0.000, 35 | 0.667, 0.667, 0.000, 36 | 0.667, 1.000, 0.000, 37 | 1.000, 0.333, 0.000, 38 | 1.000, 0.667, 0.000, 39 | 1.000, 1.000, 0.000, 40 | 0.000, 0.333, 0.500, 41 | 0.000, 0.667, 0.500, 42 | 0.000, 1.000, 0.500, 43 | 0.333, 0.000, 0.500, 44 | 0.333, 0.333, 0.500, 45 | 0.333, 0.667, 0.500, 46 | 0.333, 1.000, 0.500, 47 | 0.667, 0.000, 0.500, 48 | 0.667, 0.333, 0.500, 49 | 0.667, 0.667, 0.500, 50 | 0.667, 1.000, 0.500, 51 | 1.000, 0.000, 0.500, 52 | 1.000, 0.333, 0.500, 53 | 1.000, 0.667, 0.500, 54 | 1.000, 1.000, 0.500, 55 | 0.000, 0.333, 1.000, 56 | 0.000, 0.667, 1.000, 57 | 0.000, 1.000, 1.000, 58 | 0.333, 0.000, 1.000, 59 | 0.333, 0.333, 1.000, 60 | 0.333, 0.667, 1.000, 61 | 0.333, 1.000, 1.000, 62 | 0.667, 0.000, 1.000, 63 | 0.667, 0.333, 1.000, 64 | 0.667, 0.667, 1.000, 65 | 0.667, 1.000, 1.000, 66 | 1.000, 0.000, 1.000, 67 | 1.000, 0.333, 1.000, 68 | 1.000, 0.667, 1.000, 69 | 0.167, 0.000, 0.000, 70 | 0.333, 0.000, 0.000, 71 | 0.500, 0.000, 0.000, 72 | 0.667, 0.000, 0.000, 73 | 0.833, 0.000, 0.000, 74 | 1.000, 0.000, 0.000, 75 | 0.000, 0.167, 0.000, 76 | 0.000, 0.333, 0.000, 77 | 0.000, 0.500, 0.000, 78 | 0.000, 0.667, 0.000, 79 | 0.000, 0.833, 0.000, 80 | 0.000, 1.000, 0.000, 81 | 0.000, 0.000, 0.167, 82 | 0.000, 0.000, 0.333, 83 | 0.000, 0.000, 0.500, 84 | 0.000, 0.000, 0.667, 85 | 0.000, 0.000, 0.833, 86 | 0.000, 0.000, 1.000, 87 | 0.143, 0.143, 0.143, 88 | 0.286, 0.286, 0.286, 89 | 0.429, 0.429, 0.429, 90 | 0.571, 0.571, 0.571, 91 | 0.714, 0.714, 0.714, 92 | 0.857, 0.857, 0.857 93 | ] 94 | ).astype(np.float32) 95 | color_list = color_list.reshape((-1, 3)) * 255 96 | if not rgb: 97 | color_list = color_list[:, ::-1] 98 | return color_list 99 | 100 | 101 | color_list = colormap() 102 | color_list = color_list.astype('uint8').tolist() 103 | 104 | 105 | def vis_add_mask(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha): 106 | background_color = np.array(background_color) 107 | contour_color = np.array(contour_color) 108 | 109 | # background_mask = 1 - background_mask 110 | # contour_mask = 1 - contour_mask 111 | 112 | for i in range(3): 113 | image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \ 114 | + background_color[i] * (background_alpha-background_mask*background_alpha) 115 | 116 | image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \ 117 | + contour_color[i] * (contour_alpha-contour_mask*contour_alpha) 118 | 119 | return image.astype('uint8') 120 | 121 | 122 | def mask_generator_00(mask, background_radius, contour_radius): 123 | # no background width when '00' 124 | # distance map 125 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) 126 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) 127 | dist_map = dist_transform_fore - dist_transform_back 128 | # ...:::!!!:::... 129 | contour_radius += 2 130 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) 131 | contour_mask = contour_mask / np.max(contour_mask) 132 | contour_mask[contour_mask>0.5] = 1. 133 | 134 | return mask, contour_mask 135 | 136 | 137 | def mask_generator_01(mask, background_radius, contour_radius): 138 | # no background width when '00' 139 | # distance map 140 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) 141 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) 142 | dist_map = dist_transform_fore - dist_transform_back 143 | # ...:::!!!:::... 144 | contour_radius += 2 145 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) 146 | contour_mask = contour_mask / np.max(contour_mask) 147 | return mask, contour_mask 148 | 149 | 150 | def mask_generator_10(mask, background_radius, contour_radius): 151 | # distance map 152 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) 153 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) 154 | dist_map = dist_transform_fore - dist_transform_back 155 | # .....:::::!!!!! 156 | background_mask = np.clip(dist_map, -background_radius, background_radius) 157 | background_mask = (background_mask - np.min(background_mask)) 158 | background_mask = background_mask / np.max(background_mask) 159 | # ...:::!!!:::... 160 | contour_radius += 2 161 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) 162 | contour_mask = contour_mask / np.max(contour_mask) 163 | contour_mask[contour_mask>0.5] = 1. 164 | return background_mask, contour_mask 165 | 166 | 167 | def mask_generator_11(mask, background_radius, contour_radius): 168 | # distance map 169 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) 170 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) 171 | dist_map = dist_transform_fore - dist_transform_back 172 | # .....:::::!!!!! 173 | background_mask = np.clip(dist_map, -background_radius, background_radius) 174 | background_mask = (background_mask - np.min(background_mask)) 175 | background_mask = background_mask / np.max(background_mask) 176 | # ...:::!!!:::... 177 | contour_radius += 2 178 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) 179 | contour_mask = contour_mask / np.max(contour_mask) 180 | return background_mask, contour_mask 181 | 182 | 183 | def mask_painter(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'): 184 | """ 185 | Input: 186 | input_image: numpy array 187 | input_mask: numpy array 188 | background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing 189 | background_blur_radius: radius of background blur, must be odd number 190 | contour_width: width of mask contour, must be odd number 191 | contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others 192 | contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted 193 | mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both 194 | 195 | Output: 196 | painted_image: numpy array 197 | """ 198 | assert input_image.shape[:2] == input_mask.shape, 'different shape' 199 | assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' 200 | assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11' 201 | 202 | # downsample input image and mask 203 | width, height = input_image.shape[0], input_image.shape[1] 204 | res = 1024 205 | ratio = min(1.0 * res / max(width, height), 1.0) 206 | input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio))) 207 | input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio))) 208 | 209 | # 0: background, 1: foreground 210 | msk = np.clip(input_mask, 0, 1) 211 | 212 | # generate masks for background and contour pixels 213 | background_radius = (background_blur_radius - 1) // 2 214 | contour_radius = (contour_width - 1) // 2 215 | generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11} 216 | background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius) 217 | 218 | # paint 219 | painted_image = vis_add_mask\ 220 | (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background 221 | 222 | return painted_image 223 | 224 | 225 | if __name__ == '__main__': 226 | 227 | background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing 228 | background_blur_radius = 31 # radius of background blur, must be odd number 229 | contour_width = 11 # contour width, must be odd number 230 | contour_color = 3 # id in color map, 0: black, 1: white, >1: others 231 | contour_alpha = 1 # transparency of background, 0: no contour highlighted 232 | 233 | # load input image and mask 234 | input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB')) 235 | input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P')) 236 | 237 | # paint 238 | overall_time_1 = 0 239 | overall_time_2 = 0 240 | overall_time_3 = 0 241 | overall_time_4 = 0 242 | overall_time_5 = 0 243 | 244 | for i in range(50): 245 | t2 = time.time() 246 | painted_image_00 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00') 247 | e2 = time.time() 248 | 249 | t3 = time.time() 250 | painted_image_10 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10') 251 | e3 = time.time() 252 | 253 | t1 = time.time() 254 | painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha) 255 | e1 = time.time() 256 | 257 | t4 = time.time() 258 | painted_image_01 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01') 259 | e4 = time.time() 260 | 261 | t5 = time.time() 262 | painted_image_11 = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11') 263 | e5 = time.time() 264 | 265 | overall_time_1 += (e1 - t1) 266 | overall_time_2 += (e2 - t2) 267 | overall_time_3 += (e3 - t3) 268 | overall_time_4 += (e4 - t4) 269 | overall_time_5 += (e5 - t5) 270 | 271 | print(f'average time w gaussian: {overall_time_1/50}') 272 | print(f'average time w/o gaussian00: {overall_time_2/50}') 273 | print(f'average time w/o gaussian10: {overall_time_3/50}') 274 | print(f'average time w/o gaussian01: {overall_time_4/50}') 275 | print(f'average time w/o gaussian11: {overall_time_5/50}') 276 | 277 | # save 278 | painted_image_00 = Image.fromarray(painted_image_00) 279 | painted_image_00.save('./test_img/painter_output_image_00.png') 280 | 281 | painted_image_10 = Image.fromarray(painted_image_10) 282 | painted_image_10.save('./test_img/painter_output_image_10.png') 283 | 284 | painted_image_01 = Image.fromarray(painted_image_01) 285 | painted_image_01.save('./test_img/painter_output_image_01.png') 286 | 287 | painted_image_11 = Image.fromarray(painted_image_11) 288 | painted_image_11.save('./test_img/painter_output_image_11.png') 289 | -------------------------------------------------------------------------------- /hugging_face/tools/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import logging 8 | import numpy as np 9 | from os import path as osp 10 | 11 | def constant_init(module, val, bias=0): 12 | if hasattr(module, 'weight') and module.weight is not None: 13 | nn.init.constant_(module.weight, val) 14 | if hasattr(module, 'bias') and module.bias is not None: 15 | nn.init.constant_(module.bias, bias) 16 | 17 | initialized_logger = {} 18 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 19 | """Get the root logger. 20 | The logger will be initialized if it has not been initialized. By default a 21 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 22 | also be added. 23 | Args: 24 | logger_name (str): root logger name. Default: 'basicsr'. 25 | log_file (str | None): The log filename. If specified, a FileHandler 26 | will be added to the root logger. 27 | log_level (int): The root logger level. Note that only the process of 28 | rank 0 is affected, while other processes will set the level to 29 | "Error" and be silent most of the time. 30 | Returns: 31 | logging.Logger: The root logger. 32 | """ 33 | logger = logging.getLogger(logger_name) 34 | # if the logger has been initialized, just return it 35 | if logger_name in initialized_logger: 36 | return logger 37 | 38 | format_str = '%(asctime)s %(levelname)s: %(message)s' 39 | stream_handler = logging.StreamHandler() 40 | stream_handler.setFormatter(logging.Formatter(format_str)) 41 | logger.addHandler(stream_handler) 42 | logger.propagate = False 43 | 44 | if log_file is not None: 45 | logger.setLevel(log_level) 46 | # add file handler 47 | # file_handler = logging.FileHandler(log_file, 'w') 48 | file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log 49 | file_handler.setFormatter(logging.Formatter(format_str)) 50 | file_handler.setLevel(log_level) 51 | logger.addHandler(file_handler) 52 | initialized_logger[logger_name] = True 53 | return logger 54 | 55 | 56 | IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ 57 | torch.__version__)[0][:3])] >= [1, 12, 0] 58 | 59 | def gpu_is_available(): 60 | if IS_HIGH_VERSION: 61 | if torch.backends.mps.is_available(): 62 | return True 63 | return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False 64 | 65 | def get_device(gpu_id=None): 66 | if gpu_id is None: 67 | gpu_str = '' 68 | elif isinstance(gpu_id, int): 69 | gpu_str = f':{gpu_id}' 70 | else: 71 | raise TypeError('Input should be int value.') 72 | 73 | if IS_HIGH_VERSION: 74 | if torch.backends.mps.is_available(): 75 | return torch.device('mps'+gpu_str) 76 | return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') 77 | 78 | 79 | def set_random_seed(seed): 80 | """Set random seeds.""" 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | torch.manual_seed(seed) 84 | torch.cuda.manual_seed(seed) 85 | torch.cuda.manual_seed_all(seed) 86 | 87 | 88 | def get_time_str(): 89 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 90 | 91 | 92 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 93 | """Scan a directory to find the interested files. 94 | 95 | Args: 96 | dir_path (str): Path of the directory. 97 | suffix (str | tuple(str), optional): File suffix that we are 98 | interested in. Default: None. 99 | recursive (bool, optional): If set to True, recursively scan the 100 | directory. Default: False. 101 | full_path (bool, optional): If set to True, include the dir_path. 102 | Default: False. 103 | 104 | Returns: 105 | A generator for all the interested files with relative pathes. 106 | """ 107 | 108 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 109 | raise TypeError('"suffix" must be a string or tuple of strings') 110 | 111 | root = dir_path 112 | 113 | def _scandir(dir_path, suffix, recursive): 114 | for entry in os.scandir(dir_path): 115 | if not entry.name.startswith('.') and entry.is_file(): 116 | if full_path: 117 | return_path = entry.path 118 | else: 119 | return_path = osp.relpath(entry.path, root) 120 | 121 | if suffix is None: 122 | yield return_path 123 | elif return_path.endswith(suffix): 124 | yield return_path 125 | else: 126 | if recursive: 127 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 128 | else: 129 | continue 130 | 131 | return _scandir(dir_path, suffix=suffix, recursive=recursive) -------------------------------------------------------------------------------- /hugging_face/tools/painter.py: -------------------------------------------------------------------------------- 1 | # paint masks, contours, or points on images, with specified colors 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import copy 7 | import time 8 | 9 | 10 | def colormap(rgb=True): 11 | color_list = np.array( 12 | [ 13 | 0.000, 0.000, 0.000, 14 | 1.000, 1.000, 1.000, 15 | 1.000, 0.498, 0.313, 16 | 0.392, 0.581, 0.929, 17 | 0.000, 0.447, 0.741, 18 | 0.850, 0.325, 0.098, 19 | 0.929, 0.694, 0.125, 20 | 0.494, 0.184, 0.556, 21 | 0.466, 0.674, 0.188, 22 | 0.301, 0.745, 0.933, 23 | 0.635, 0.078, 0.184, 24 | 0.300, 0.300, 0.300, 25 | 0.600, 0.600, 0.600, 26 | 1.000, 0.000, 0.000, 27 | 1.000, 0.500, 0.000, 28 | 0.749, 0.749, 0.000, 29 | 0.000, 1.000, 0.000, 30 | 0.000, 0.000, 1.000, 31 | 0.667, 0.000, 1.000, 32 | 0.333, 0.333, 0.000, 33 | 0.333, 0.667, 0.000, 34 | 0.333, 1.000, 0.000, 35 | 0.667, 0.333, 0.000, 36 | 0.667, 0.667, 0.000, 37 | 0.667, 1.000, 0.000, 38 | 1.000, 0.333, 0.000, 39 | 1.000, 0.667, 0.000, 40 | 1.000, 1.000, 0.000, 41 | 0.000, 0.333, 0.500, 42 | 0.000, 0.667, 0.500, 43 | 0.000, 1.000, 0.500, 44 | 0.333, 0.000, 0.500, 45 | 0.333, 0.333, 0.500, 46 | 0.333, 0.667, 0.500, 47 | 0.333, 1.000, 0.500, 48 | 0.667, 0.000, 0.500, 49 | 0.667, 0.333, 0.500, 50 | 0.667, 0.667, 0.500, 51 | 0.667, 1.000, 0.500, 52 | 1.000, 0.000, 0.500, 53 | 1.000, 0.333, 0.500, 54 | 1.000, 0.667, 0.500, 55 | 1.000, 1.000, 0.500, 56 | 0.000, 0.333, 1.000, 57 | 0.000, 0.667, 1.000, 58 | 0.000, 1.000, 1.000, 59 | 0.333, 0.000, 1.000, 60 | 0.333, 0.333, 1.000, 61 | 0.333, 0.667, 1.000, 62 | 0.333, 1.000, 1.000, 63 | 0.667, 0.000, 1.000, 64 | 0.667, 0.333, 1.000, 65 | 0.667, 0.667, 1.000, 66 | 0.667, 1.000, 1.000, 67 | 1.000, 0.000, 1.000, 68 | 1.000, 0.333, 1.000, 69 | 1.000, 0.667, 1.000, 70 | 0.167, 0.000, 0.000, 71 | 0.333, 0.000, 0.000, 72 | 0.500, 0.000, 0.000, 73 | 0.667, 0.000, 0.000, 74 | 0.833, 0.000, 0.000, 75 | 1.000, 0.000, 0.000, 76 | 0.000, 0.167, 0.000, 77 | 0.000, 0.333, 0.000, 78 | 0.000, 0.500, 0.000, 79 | 0.000, 0.667, 0.000, 80 | 0.000, 0.833, 0.000, 81 | 0.000, 1.000, 0.000, 82 | 0.000, 0.000, 0.167, 83 | 0.000, 0.000, 0.333, 84 | 0.000, 0.000, 0.500, 85 | 0.000, 0.000, 0.667, 86 | 0.000, 0.000, 0.833, 87 | 0.000, 0.000, 1.000, 88 | 0.143, 0.143, 0.143, 89 | 0.286, 0.286, 0.286, 90 | 0.429, 0.429, 0.429, 91 | 0.571, 0.571, 0.571, 92 | 0.714, 0.714, 0.714, 93 | 0.857, 0.857, 0.857 94 | ] 95 | ).astype(np.float32) 96 | color_list = color_list.reshape((-1, 3)) * 255 97 | if not rgb: 98 | color_list = color_list[:, ::-1] 99 | return color_list 100 | 101 | 102 | color_list = colormap() 103 | color_list = color_list.astype('uint8').tolist() 104 | 105 | 106 | def vis_add_mask(image, mask, color, alpha): 107 | color = np.array(color_list[color]) 108 | mask = mask > 0.5 109 | image[mask] = image[mask] * (1-alpha) + color * alpha 110 | return image.astype('uint8') 111 | 112 | def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, point_radius=15, contour_color=2, contour_width=5): 113 | h, w = input_image.shape[:2] 114 | point_mask = np.zeros((h, w)).astype('uint8') 115 | for point in input_points: 116 | point_mask[point[1], point[0]] = 1 117 | 118 | kernel = cv2.getStructuringElement(2, (point_radius, point_radius)) 119 | point_mask = cv2.dilate(point_mask, kernel) 120 | 121 | contour_radius = (contour_width - 1) // 2 122 | dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3) 123 | dist_transform_back = cv2.distanceTransform(1-point_mask, cv2.DIST_L2, 3) 124 | dist_map = dist_transform_fore - dist_transform_back 125 | # ...:::!!!:::... 126 | contour_radius += 2 127 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) 128 | contour_mask = contour_mask / np.max(contour_mask) 129 | contour_mask[contour_mask>0.5] = 1. 130 | 131 | # paint mask 132 | painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha) 133 | # paint contour 134 | painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) 135 | return painted_image 136 | 137 | def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3): 138 | assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' 139 | # 0: background, 1: foreground 140 | mask = np.clip(input_mask, 0, 1) 141 | contour_radius = (contour_width - 1) // 2 142 | 143 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) 144 | dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3) 145 | dist_map = dist_transform_fore - dist_transform_back 146 | # ...:::!!!:::... 147 | contour_radius += 2 148 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) 149 | contour_mask = contour_mask / np.max(contour_mask) 150 | contour_mask[contour_mask>0.5] = 1. 151 | 152 | # paint mask 153 | painted_image = vis_add_mask(input_image.copy(), mask.copy(), mask_color, mask_alpha) 154 | # paint contour 155 | painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) 156 | 157 | return painted_image 158 | 159 | def background_remover(input_image, input_mask): 160 | """ 161 | input_image: H, W, 3, np.array 162 | input_mask: H, W, np.array 163 | 164 | image_wo_background: PIL.Image 165 | """ 166 | assert input_image.shape[:2] == input_mask.shape, 'different shape between image and mask' 167 | # 0: background, 1: foreground 168 | mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2)*255 169 | image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4 170 | image_wo_background = Image.fromarray(image_wo_background).convert('RGBA') 171 | 172 | return image_wo_background 173 | 174 | if __name__ == '__main__': 175 | input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) 176 | input_mask = np.array(Image.open('images/painter_input_mask.jpg').convert('P')) 177 | 178 | # example of mask painter 179 | mask_color = 3 180 | mask_alpha = 0.7 181 | contour_color = 1 182 | contour_width = 5 183 | 184 | # save 185 | painted_image = Image.fromarray(input_image) 186 | painted_image.save('images/original.png') 187 | 188 | painted_image = mask_painter(input_image, input_mask, mask_color, mask_alpha, contour_color, contour_width) 189 | # save 190 | painted_image = Image.fromarray(input_image) 191 | painted_image.save('images/original1.png') 192 | 193 | # example of point painter 194 | input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) 195 | input_points = np.array([[500, 375], [70, 600]]) # x, y 196 | point_color = 5 197 | point_alpha = 0.9 198 | point_radius = 15 199 | contour_color = 2 200 | contour_width = 5 201 | painted_image_1 = point_painter(input_image, input_points, point_color, point_alpha, point_radius, contour_color, contour_width) 202 | # save 203 | painted_image = Image.fromarray(painted_image_1) 204 | painted_image.save('images/point_painter_1.png') 205 | 206 | input_image = np.array(Image.open('images/painter_input_image.jpg').convert('RGB')) 207 | painted_image_2 = point_painter(input_image, input_points, point_color=9, point_radius=20, contour_color=29) 208 | # save 209 | painted_image = Image.fromarray(painted_image_2) 210 | painted_image.save('images/point_painter_2.png') 211 | 212 | # example of background remover 213 | input_image = np.array(Image.open('images/original.png').convert('RGB')) 214 | image_wo_background = background_remover(input_image, input_mask) # return PIL.Image 215 | image_wo_background.save('images/image_wo_background.png') 216 | -------------------------------------------------------------------------------- /inference_hf.py: -------------------------------------------------------------------------------- 1 | from matanyone import InferenceCore 2 | 3 | 4 | def main( 5 | input_path, 6 | mask_path, 7 | output_path, 8 | n_warmup=10, 9 | r_erode=10, 10 | r_dilate=10, 11 | suffix="", 12 | save_image=False, 13 | max_size=-1, 14 | ): 15 | processor = InferenceCore("PeiqingYang/MatAnyone") 16 | fgr, alpha = processor.process_video( 17 | input_path=input_path, 18 | mask_path=mask_path, 19 | output_path=output_path, 20 | n_warmup=n_warmup, 21 | r_erode=r_erode, 22 | r_dilate=r_dilate, 23 | suffix=suffix, 24 | save_image=save_image, 25 | max_size=max_size, 26 | ) 27 | return fgr, alpha -------------------------------------------------------------------------------- /inference_matanyone.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import tqdm 4 | import imageio 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from hugging_face.tools.download_util import load_file_from_url 12 | from matanyone.utils.inference_utils import gen_dilate, gen_erosion, read_frame_from_videos 13 | 14 | from matanyone.inference.inference_core import InferenceCore 15 | from matanyone.utils.get_default_model import get_matanyone_model 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | @torch.inference_mode() 21 | @torch.amp.autocast("cuda") 22 | def main(input_path, mask_path, output_path, ckpt_path, n_warmup=10, r_erode=10, r_dilate=10, suffix="", save_image=False, max_size=-1): 23 | # download ckpt for the first inference 24 | pretrain_model_url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth" 25 | ckpt_path = load_file_from_url(pretrain_model_url, 'pretrained_models') 26 | 27 | # load MatAnyone model 28 | matanyone = get_matanyone_model(ckpt_path) 29 | 30 | # init inference processor 31 | processor = InferenceCore(matanyone, cfg=matanyone.cfg) 32 | 33 | # inference parameters 34 | r_erode = int(r_erode) 35 | r_dilate = int(r_dilate) 36 | n_warmup = int(n_warmup) 37 | max_size = int(max_size) 38 | 39 | # load input frames 40 | vframes, fps, length, video_name = read_frame_from_videos(input_path) 41 | repeated_frames = vframes[0].unsqueeze(0).repeat(n_warmup, 1, 1, 1) # repeat the first frame for warmup 42 | vframes = torch.cat([repeated_frames, vframes], dim=0).float() 43 | length += n_warmup # update length 44 | 45 | # resize if needed 46 | if max_size > 0: 47 | h, w = vframes.shape[-2:] 48 | min_side = min(h, w) 49 | if min_side > max_size: 50 | new_h = int(h / min_side * max_size) 51 | new_w = int(w / min_side * max_size) 52 | 53 | vframes = F.interpolate(vframes, size=(new_h, new_w), mode="area") 54 | 55 | # set output paths 56 | os.makedirs(output_path, exist_ok=True) 57 | if suffix != "": 58 | video_name = f'{video_name}_{suffix}' 59 | if save_image: 60 | os.makedirs(f'{output_path}/{video_name}', exist_ok=True) 61 | os.makedirs(f'{output_path}/{video_name}/pha', exist_ok=True) 62 | os.makedirs(f'{output_path}/{video_name}/fgr', exist_ok=True) 63 | 64 | # load the first-frame mask 65 | mask = Image.open(mask_path).convert('L') 66 | mask = np.array(mask) 67 | 68 | bgr = (np.array([120, 255, 155], dtype=np.float32)/255).reshape((1, 1, 3)) # green screen to paste fgr 69 | objects = [1] 70 | 71 | # [optional] erode & dilate 72 | if r_dilate > 0: 73 | mask = gen_dilate(mask, r_dilate, r_dilate) 74 | if r_erode > 0: 75 | mask = gen_erosion(mask, r_erode, r_erode) 76 | 77 | mask = torch.from_numpy(mask).cuda() 78 | 79 | if max_size > 0: # resize needed 80 | mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(new_h, new_w), mode="nearest") 81 | mask = mask[0,0] 82 | 83 | # inference start 84 | phas = [] 85 | fgrs = [] 86 | for ti in tqdm.tqdm(range(length)): 87 | # load the image as RGB; normalization is done within the model 88 | image = vframes[ti] 89 | 90 | image_np = np.array(image.permute(1,2,0)) # for output visualize 91 | image = (image / 255.).cuda().float() # for network input 92 | 93 | if ti == 0: 94 | output_prob = processor.step(image, mask, objects=objects) # encode given mask 95 | output_prob = processor.step(image, first_frame_pred=True) # first frame for prediction 96 | else: 97 | if ti <= n_warmup: 98 | output_prob = processor.step(image, first_frame_pred=True) # reinit as the first frame for prediction 99 | else: 100 | output_prob = processor.step(image) 101 | 102 | # convert output probabilities to alpha matte 103 | mask = processor.output_prob_to_mask(output_prob) 104 | 105 | # visualize prediction 106 | pha = mask.unsqueeze(2).cpu().numpy() 107 | com_np = image_np / 255. * pha + bgr * (1 - pha) 108 | 109 | # DONOT save the warmup frame 110 | if ti > (n_warmup-1): 111 | com_np = (com_np*255).astype(np.uint8) 112 | pha = (pha*255).astype(np.uint8) 113 | fgrs.append(com_np) 114 | phas.append(pha) 115 | if save_image: 116 | cv2.imwrite(f'{output_path}/{video_name}/pha/{str(ti-n_warmup).zfill(5)}.png', pha) 117 | cv2.imwrite(f'{output_path}/{video_name}/fgr/{str(ti-n_warmup).zfill(5)}.png', com_np[...,[2,1,0]]) 118 | 119 | phas = np.array(phas) 120 | fgrs = np.array(fgrs) 121 | 122 | imageio.mimwrite(f'{output_path}/{video_name}_fgr.mp4', fgrs, fps=fps, quality=7) 123 | imageio.mimwrite(f'{output_path}/{video_name}_pha.mp4', phas, fps=fps, quality=7) 124 | 125 | if __name__ == '__main__': 126 | import argparse 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('-i', '--input_path', type=str, default="inputs/video/test-sample1.mp4", help='Path of the input video or frame folder.') 129 | parser.add_argument('-m', '--mask_path', type=str, default="inputs/mask/test-sample1.png", help='Path of the first-frame segmentation mask.') 130 | parser.add_argument('-o', '--output_path', type=str, default="results/", help='Output folder. Default: results') 131 | parser.add_argument('-c', '--ckpt_path', type=str, default="pretrained_models/matanyone.pth", help='Path of the MatAnyone model.') 132 | parser.add_argument('-w', '--warmup', type=str, default="10", help='Number of warmup iterations for the first frame alpha prediction.') 133 | parser.add_argument('-e', '--erode_kernel', type=str, default="10", help='Erosion kernel on the input mask.') 134 | parser.add_argument('-d', '--dilate_kernel', type=str, default="10", help='Dilation kernel on the input mask.') 135 | parser.add_argument('--suffix', type=str, default="", help='Suffix to specify different target when saving, e.g., target1.') 136 | parser.add_argument('--save_image', action='store_true', default=False, help='Save output frames. Default: False') 137 | parser.add_argument('--max_size', type=str, default="-1", help='When positive, the video will be downsampled if min(w, h) exceeds. Default: -1 (means no limit)') 138 | 139 | 140 | args = parser.parse_args() 141 | 142 | main(input_path=args.input_path, \ 143 | mask_path=args.mask_path, \ 144 | output_path=args.output_path, \ 145 | ckpt_path=args.ckpt_path, \ 146 | n_warmup=args.warmup, \ 147 | r_erode=args.erode_kernel, \ 148 | r_dilate=args.dilate_kernel, \ 149 | suffix=args.suffix, \ 150 | save_image=args.save_image, \ 151 | max_size=args.max_size) 152 | -------------------------------------------------------------------------------- /inputs/mask/test-sample0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample0_1.png -------------------------------------------------------------------------------- /inputs/mask/test-sample0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample0_2.png -------------------------------------------------------------------------------- /inputs/mask/test-sample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample1.png -------------------------------------------------------------------------------- /inputs/mask/test-sample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample2.png -------------------------------------------------------------------------------- /inputs/mask/test-sample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/mask/test-sample3.png -------------------------------------------------------------------------------- /inputs/video/test-sample0/0000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0000.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0001.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0002.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0003.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0004.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0005.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0006.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0007.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0008.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0009.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0010.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0011.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0012.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0013.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0014.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0015.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0016.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0017.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0018.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0019.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0020.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0021.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0022.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0022.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0023.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0024.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0025.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0026.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0026.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0027.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0028.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0029.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0030.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0031.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0032.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0033.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0034.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0035.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0035.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0036.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0037.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0037.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0038.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0038.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0039.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0040.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0041.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0042.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0043.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0044.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0045.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0045.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0046.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0046.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0047.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0047.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0048.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0049.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0049.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0050.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0051.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0052.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0052.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0053.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0053.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0054.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0054.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0055.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0055.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0056.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0056.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0057.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0057.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0058.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0058.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0059.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0059.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0060.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0061.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0062.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0062.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0063.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0064.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0064.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0065.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0065.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0066.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0066.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0067.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0067.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0068.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0069.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0069.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0070.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample0/0071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample0/0071.jpg -------------------------------------------------------------------------------- /inputs/video/test-sample1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample1.mp4 -------------------------------------------------------------------------------- /inputs/video/test-sample2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample2.mp4 -------------------------------------------------------------------------------- /inputs/video/test-sample3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/inputs/video/test-sample3.mp4 -------------------------------------------------------------------------------- /matanyone/__init__.py: -------------------------------------------------------------------------------- 1 | from matanyone.inference.inference_core import InferenceCore 2 | from matanyone.model.matanyone import MatAnyone 3 | -------------------------------------------------------------------------------- /matanyone/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/config/__init__.py -------------------------------------------------------------------------------- /matanyone/config/eval_matanyone_config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - model: base 4 | - override hydra/job_logging: custom-no-rank.yaml 5 | 6 | hydra: 7 | run: 8 | dir: ../output/${exp_id}/${dataset} 9 | output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra 10 | 11 | amp: False 12 | weights: pretrained_models/matanyone.pth # default (can be modified from outside) 13 | output_dir: null # defaults to run_dir; specify this to override 14 | flip_aug: False 15 | 16 | 17 | # maximum shortest side of the input; -1 means no resizing 18 | # With eval_vos.py, we usually just use the dataset's size (resizing done in dataloader) 19 | # this parameter is added for the sole purpose for the GUI in the current codebase 20 | # InferenceCore will downsize the input and restore the output to the original size if needed 21 | # if you are using this code for some other project, you can also utilize this parameter 22 | max_internal_size: -1 23 | 24 | # these parameters, when set, override the dataset's default; useful for debugging 25 | save_all: True 26 | use_all_masks: False 27 | use_long_term: False 28 | mem_every: 5 29 | 30 | # only relevant when long_term is not enabled 31 | max_mem_frames: 5 32 | 33 | # only relevant when long_term is enabled 34 | long_term: 35 | count_usage: True 36 | max_mem_frames: 10 37 | min_mem_frames: 5 38 | num_prototypes: 128 39 | max_num_tokens: 10000 40 | buffer_tokens: 2000 41 | 42 | top_k: 30 43 | stagger_updates: 5 44 | chunk_size: -1 # number of objects to process in parallel; -1 means unlimited 45 | save_scores: False 46 | save_aux: False 47 | visualize: False 48 | -------------------------------------------------------------------------------- /matanyone/config/hydra/job_logging/custom-no-rank.yaml: -------------------------------------------------------------------------------- 1 | # python logging configuration for tasks 2 | version: 1 3 | formatters: 4 | simple: 5 | format: '[%(asctime)s][%(levelname)s] - %(message)s' 6 | datefmt: '%Y-%m-%d %H:%M:%S' 7 | handlers: 8 | console: 9 | class: logging.StreamHandler 10 | formatter: simple 11 | stream: ext://sys.stdout 12 | file: 13 | class: logging.FileHandler 14 | formatter: simple 15 | # absolute file path 16 | filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log 17 | mode: w 18 | root: 19 | level: INFO 20 | handlers: [console, file] 21 | 22 | disable_existing_loggers: false -------------------------------------------------------------------------------- /matanyone/config/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | # python logging configuration for tasks 2 | version: 1 3 | formatters: 4 | simple: 5 | format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s' 6 | datefmt: '%Y-%m-%d %H:%M:%S' 7 | handlers: 8 | console: 9 | class: logging.StreamHandler 10 | formatter: simple 11 | stream: ext://sys.stdout 12 | file: 13 | class: logging.FileHandler 14 | formatter: simple 15 | # absolute file path 16 | filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log 17 | mode: w 18 | root: 19 | level: INFO 20 | handlers: [console, file] 21 | 22 | disable_existing_loggers: false -------------------------------------------------------------------------------- /matanyone/config/model/base.yaml: -------------------------------------------------------------------------------- 1 | pixel_mean: [0.485, 0.456, 0.406] 2 | pixel_std: [0.229, 0.224, 0.225] 3 | 4 | pixel_dim: 256 5 | key_dim: 64 6 | value_dim: 256 7 | sensory_dim: 256 8 | embed_dim: 256 9 | 10 | pixel_encoder: 11 | type: resnet50 12 | ms_dims: [1024, 512, 256, 64, 3] # f16, f8, f4, f2, f1 13 | 14 | mask_encoder: 15 | type: resnet18 16 | final_dim: 256 17 | 18 | pixel_pe_scale: 32 19 | pixel_pe_temperature: 128 20 | 21 | object_transformer: 22 | embed_dim: ${model.embed_dim} 23 | ff_dim: 2048 24 | num_heads: 8 25 | num_blocks: 3 26 | num_queries: 16 27 | read_from_pixel: 28 | input_norm: False 29 | input_add_pe: False 30 | add_pe_to_qkv: [True, True, False] 31 | read_from_past: 32 | add_pe_to_qkv: [True, True, False] 33 | read_from_memory: 34 | add_pe_to_qkv: [True, True, False] 35 | read_from_query: 36 | add_pe_to_qkv: [True, True, False] 37 | output_norm: False 38 | query_self_attention: 39 | add_pe_to_qkv: [True, True, False] 40 | pixel_self_attention: 41 | add_pe_to_qkv: [True, True, False] 42 | 43 | object_summarizer: 44 | embed_dim: ${model.object_transformer.embed_dim} 45 | num_summaries: ${model.object_transformer.num_queries} 46 | add_pe: True 47 | 48 | aux_loss: 49 | sensory: 50 | enabled: True 51 | weight: 0.01 52 | query: 53 | enabled: True 54 | weight: 0.01 55 | 56 | mask_decoder: 57 | # first value must equal embed_dim 58 | up_dims: [256, 128, 128, 64, 16] 59 | -------------------------------------------------------------------------------- /matanyone/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/inference/__init__.py -------------------------------------------------------------------------------- /matanyone/inference/image_feature_store.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Iterable 3 | import torch 4 | from matanyone.model.matanyone import MatAnyone 5 | 6 | 7 | class ImageFeatureStore: 8 | """ 9 | A cache for image features. 10 | These features might be reused at different parts of the inference pipeline. 11 | This class provide an interface for reusing these features. 12 | It is the user's responsibility to delete redundant features. 13 | 14 | Feature of a frame should be associated with a unique index -- typically the frame id. 15 | """ 16 | def __init__(self, network: MatAnyone, no_warning: bool = False): 17 | self.network = network 18 | self._store = {} 19 | self.no_warning = no_warning 20 | 21 | def _encode_feature(self, index: int, image: torch.Tensor, last_feats=None) -> None: 22 | ms_features, pix_feat = self.network.encode_image(image, last_feats=last_feats) 23 | key, shrinkage, selection = self.network.transform_key(ms_features[0]) 24 | self._store[index] = (ms_features, pix_feat, key, shrinkage, selection) 25 | 26 | def get_all_features(self, images: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): 27 | seq_length = images.shape[0] 28 | ms_features, pix_feat = self.network.encode_image(images, seq_length) 29 | key, shrinkage, selection = self.network.transform_key(ms_features[0]) 30 | for index in range(seq_length): 31 | self._store[index] = ([f[index].unsqueeze(0) for f in ms_features], pix_feat[index].unsqueeze(0), key[index].unsqueeze(0), shrinkage[index].unsqueeze(0), selection[index].unsqueeze(0)) 32 | 33 | def get_features(self, index: int, 34 | image: torch.Tensor, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): 35 | if index not in self._store: 36 | self._encode_feature(index, image, last_feats) 37 | 38 | return self._store[index][:2] 39 | 40 | def get_key(self, index: int, 41 | image: torch.Tensor, last_feats=None) -> (torch.Tensor, torch.Tensor, torch.Tensor): 42 | if index not in self._store: 43 | self._encode_feature(index, image, last_feats) 44 | 45 | return self._store[index][2:] 46 | 47 | def delete(self, index: int) -> None: 48 | if index in self._store: 49 | del self._store[index] 50 | 51 | def __len__(self): 52 | return len(self._store) 53 | 54 | def __del__(self): 55 | if len(self._store) > 0 and not self.no_warning: 56 | warnings.warn(f'Leaking {self._store.keys()} in the image feature store') 57 | -------------------------------------------------------------------------------- /matanyone/inference/kv_memory_store.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Literal 2 | from collections import defaultdict 3 | import torch 4 | 5 | 6 | def _add_last_dim(dictionary, key, new_value, prepend=False): 7 | # append/prepend a new value to the last dimension of a tensor in a dictionary 8 | # if the key does not exist, put the new value in 9 | # append by default 10 | if key in dictionary: 11 | dictionary[key] = torch.cat([dictionary[key], new_value], -1) 12 | else: 13 | dictionary[key] = new_value 14 | 15 | 16 | class KeyValueMemoryStore: 17 | """ 18 | Works for key/value pairs type storage 19 | e.g., working and long-term memory 20 | """ 21 | def __init__(self, save_selection: bool = False, save_usage: bool = False): 22 | """ 23 | We store keys and values of objects that first appear in the same frame in a bucket. 24 | Each bucket contains a set of object ids. 25 | Each bucket is associated with a single key tensor 26 | and a dictionary of value tensors indexed by object id. 27 | 28 | The keys and values are stored as the concatenation of a permanent part and a temporary part. 29 | """ 30 | self.save_selection = save_selection 31 | self.save_usage = save_usage 32 | 33 | self.global_bucket_id = 0 # does not reduce even if buckets are removed 34 | self.buckets: Dict[int, List[int]] = {} # indexed by bucket id 35 | self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id 36 | self.v: Dict[int, torch.Tensor] = {} # indexed by object id 37 | 38 | # indexed by bucket id; the end point of permanent memory 39 | self.perm_end_pt: Dict[int, int] = defaultdict(int) 40 | 41 | # shrinkage and selection are just like the keys 42 | self.s = {} 43 | if self.save_selection: 44 | self.e = {} # does not contain the permanent memory part 45 | 46 | # usage 47 | if self.save_usage: 48 | self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part 49 | self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part 50 | 51 | def add(self, 52 | key: torch.Tensor, 53 | values: Dict[int, torch.Tensor], 54 | shrinkage: torch.Tensor, 55 | selection: torch.Tensor, 56 | supposed_bucket_id: int = -1, 57 | as_permanent: Literal['no', 'first', 'all'] = 'no') -> None: 58 | """ 59 | key: (1/2)*C*N 60 | values: dict of values ((1/2)*C*N), object ids are used as keys 61 | shrinkage: (1/2)*1*N 62 | selection: (1/2)*C*N 63 | 64 | supposed_bucket_id: used to sync the bucket id between working and long-term memory 65 | if provided, the input should all be in a single bucket indexed by this id 66 | as_permanent: whether to store the input as permanent memory 67 | 'no': don't 68 | 'first': only store it as permanent memory if the bucket is empty 69 | 'all': always store it as permanent memory 70 | """ 71 | bs = key.shape[0] 72 | ne = key.shape[-1] 73 | assert len(key.shape) == 3 74 | assert len(shrinkage.shape) == 3 75 | assert not self.save_selection or len(selection.shape) == 3 76 | assert as_permanent in ['no', 'first', 'all'] 77 | 78 | # add the value and create new buckets if necessary 79 | if supposed_bucket_id >= 0: 80 | enabled_buckets = [supposed_bucket_id] 81 | bucket_exist = supposed_bucket_id in self.buckets 82 | for obj, value in values.items(): 83 | if bucket_exist: 84 | assert obj in self.v 85 | assert obj in self.buckets[supposed_bucket_id] 86 | _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) 87 | else: 88 | assert obj not in self.v 89 | self.v[obj] = value 90 | self.buckets[supposed_bucket_id] = list(values.keys()) 91 | else: 92 | new_bucket_id = None 93 | enabled_buckets = set() 94 | for obj, value in values.items(): 95 | assert len(value.shape) == 3 96 | if obj in self.v: 97 | _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) 98 | bucket_used = [ 99 | bucket_id for bucket_id, object_ids in self.buckets.items() 100 | if obj in object_ids 101 | ] 102 | assert len(bucket_used) == 1 # each object should only be in one bucket 103 | enabled_buckets.add(bucket_used[0]) 104 | else: 105 | self.v[obj] = value 106 | if new_bucket_id is None: 107 | # create new bucket 108 | new_bucket_id = self.global_bucket_id 109 | self.global_bucket_id += 1 110 | self.buckets[new_bucket_id] = [] 111 | # put the new object into the corresponding bucket 112 | self.buckets[new_bucket_id].append(obj) 113 | enabled_buckets.add(new_bucket_id) 114 | 115 | # increment the permanent size if necessary 116 | add_as_permanent = {} # indexed by bucket id 117 | for bucket_id in enabled_buckets: 118 | add_as_permanent[bucket_id] = False 119 | if as_permanent == 'all': 120 | self.perm_end_pt[bucket_id] += ne 121 | add_as_permanent[bucket_id] = True 122 | elif as_permanent == 'first': 123 | if self.perm_end_pt[bucket_id] == 0: 124 | self.perm_end_pt[bucket_id] = ne 125 | add_as_permanent[bucket_id] = True 126 | 127 | # create new counters for usage if necessary 128 | if self.save_usage and as_permanent != 'all': 129 | new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) 130 | new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7 131 | 132 | # add the key to every bucket 133 | for bucket_id in self.buckets: 134 | if bucket_id not in enabled_buckets: 135 | # if we are not adding new values to a bucket, we should skip it 136 | continue 137 | 138 | _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id]) 139 | _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id]) 140 | if not add_as_permanent[bucket_id]: 141 | if self.save_selection: 142 | _add_last_dim(self.e, bucket_id, selection) 143 | if self.save_usage: 144 | _add_last_dim(self.use_cnt, bucket_id, new_count) 145 | _add_last_dim(self.life_cnt, bucket_id, new_life) 146 | 147 | def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None: 148 | # increase all life count by 1 149 | # increase use of indexed elements 150 | if not self.save_usage: 151 | return 152 | 153 | usage = usage[:, self.perm_end_pt[bucket_id]:] 154 | if usage.shape[-1] == 0: 155 | # if there is no temporary memory, we don't need to update 156 | return 157 | self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id]) 158 | self.life_cnt[bucket_id] += 1 159 | 160 | def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None: 161 | # keep only the temporary elements *outside* of this range (with some boundary conditions) 162 | # the permanent elements are ignored in this computation 163 | # i.e., concat (a[:start], a[end:]) 164 | # bucket with size <= min_size are not modified 165 | 166 | assert start >= 0 167 | assert end <= 0 168 | 169 | object_ids = self.buckets[bucket_id] 170 | bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id] 171 | if bucket_num_elements <= min_size: 172 | return 173 | 174 | if end == 0: 175 | # negative 0 would not work as the end index! 176 | # effectively make the second part an empty slice 177 | end = self.k[bucket_id].shape[-1] + 1 178 | 179 | p_size = self.perm_end_pt[bucket_id] 180 | start = start + p_size 181 | 182 | k = self.k[bucket_id] 183 | s = self.s[bucket_id] 184 | if self.save_selection: 185 | e = self.e[bucket_id] 186 | if self.save_usage: 187 | use_cnt = self.use_cnt[bucket_id] 188 | life_cnt = self.life_cnt[bucket_id] 189 | 190 | self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1) 191 | self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1) 192 | if self.save_selection: 193 | self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1) 194 | if self.save_usage: 195 | self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1) 196 | self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]], 197 | -1) 198 | for obj_id in object_ids: 199 | v = self.v[obj_id] 200 | self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1) 201 | 202 | def remove_old_memory(self, bucket_id: int, max_len: int) -> None: 203 | self.sieve_by_range(bucket_id, 0, -max_len, max_len) 204 | 205 | def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None: 206 | # for long-term memory only 207 | object_ids = self.buckets[bucket_id] 208 | 209 | assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory 210 | 211 | # normalize with life duration 212 | usage = self.get_usage(bucket_id) 213 | bs = usage.shape[0] 214 | 215 | survivals = [] 216 | 217 | for bi in range(bs): 218 | _, survived = torch.topk(usage[bi], k=max_size) 219 | survivals.append(survived.flatten()) 220 | assert survived.shape[-1] == survivals[0].shape[-1] 221 | 222 | self.k[bucket_id] = torch.stack( 223 | [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) 224 | self.s[bucket_id] = torch.stack( 225 | [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) 226 | 227 | if self.save_selection: 228 | # Long-term memory does not store selection so this should not be needed 229 | self.e[bucket_id] = torch.stack( 230 | [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) 231 | for obj_id in object_ids: 232 | self.v[obj_id] = torch.stack( 233 | [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) 234 | 235 | self.use_cnt[bucket_id] = torch.stack( 236 | [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) 237 | self.life_cnt[bucket_id] = torch.stack( 238 | [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) 239 | 240 | def get_usage(self, bucket_id: int) -> torch.Tensor: 241 | # return normalized usage 242 | if not self.save_usage: 243 | raise RuntimeError('I did not count usage!') 244 | else: 245 | usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id] 246 | return usage 247 | 248 | def get_all_sliced( 249 | self, bucket_id: int, start: int, end: int 250 | ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): 251 | # return k, sk, ek, value, normalized usage in order, sliced by start and end 252 | # this only queries the temporary memory 253 | 254 | assert start >= 0 255 | assert end <= 0 256 | 257 | p_size = self.perm_end_pt[bucket_id] 258 | start = start + p_size 259 | 260 | if end == 0: 261 | # negative 0 would not work as the end index! 262 | k = self.k[bucket_id][:, :, start:] 263 | sk = self.s[bucket_id][:, :, start:] 264 | ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None 265 | value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]} 266 | usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None 267 | else: 268 | k = self.k[bucket_id][:, :, start:end] 269 | sk = self.s[bucket_id][:, :, start:end] 270 | ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None 271 | value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]} 272 | usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None 273 | 274 | return k, sk, ek, value, usage 275 | 276 | def purge_except(self, obj_keep_idx: List[int]): 277 | # purge certain objects from the memory except the one listed 278 | obj_keep_idx = set(obj_keep_idx) 279 | 280 | # remove objects that are not in the keep list from the buckets 281 | buckets_to_remove = [] 282 | for bucket_id, object_ids in self.buckets.items(): 283 | self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx] 284 | if len(self.buckets[bucket_id]) == 0: 285 | buckets_to_remove.append(bucket_id) 286 | 287 | # remove object values that are not in the keep list 288 | self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx} 289 | 290 | # remove buckets that are empty 291 | for bucket_id in buckets_to_remove: 292 | del self.buckets[bucket_id] 293 | del self.k[bucket_id] 294 | del self.s[bucket_id] 295 | if self.save_selection: 296 | del self.e[bucket_id] 297 | if self.save_usage: 298 | del self.use_cnt[bucket_id] 299 | del self.life_cnt[bucket_id] 300 | 301 | def clear_non_permanent_memory(self): 302 | # clear all non-permanent memory 303 | for bucket_id in self.buckets: 304 | self.sieve_by_range(bucket_id, 0, 0, 0) 305 | 306 | def get_v_size(self, obj_id: int) -> int: 307 | return self.v[obj_id].shape[-1] 308 | 309 | def size(self, bucket_id: int) -> int: 310 | if bucket_id not in self.k: 311 | return 0 312 | else: 313 | return self.k[bucket_id].shape[-1] 314 | 315 | def perm_size(self, bucket_id: int) -> int: 316 | return self.perm_end_pt[bucket_id] 317 | 318 | def non_perm_size(self, bucket_id: int) -> int: 319 | return self.size(bucket_id) - self.perm_size(bucket_id) 320 | 321 | def engaged(self, bucket_id: Optional[int] = None) -> bool: 322 | if bucket_id is None: 323 | return len(self.buckets) > 0 324 | else: 325 | return bucket_id in self.buckets 326 | 327 | @property 328 | def num_objects(self) -> int: 329 | return len(self.v) 330 | 331 | @property 332 | def key(self) -> Dict[int, torch.Tensor]: 333 | return self.k 334 | 335 | @property 336 | def value(self) -> Dict[int, torch.Tensor]: 337 | return self.v 338 | 339 | @property 340 | def shrinkage(self) -> Dict[int, torch.Tensor]: 341 | return self.s 342 | 343 | @property 344 | def selection(self) -> Dict[int, torch.Tensor]: 345 | return self.e 346 | 347 | def __contains__(self, key): 348 | return key in self.v 349 | -------------------------------------------------------------------------------- /matanyone/inference/object_info.py: -------------------------------------------------------------------------------- 1 | class ObjectInfo: 2 | """ 3 | Store meta information for an object 4 | """ 5 | def __init__(self, id: int): 6 | self.id = id 7 | self.poke_count = 0 # count number of detections missed 8 | 9 | def poke(self) -> None: 10 | self.poke_count += 1 11 | 12 | def unpoke(self) -> None: 13 | self.poke_count = 0 14 | 15 | def __hash__(self): 16 | return hash(self.id) 17 | 18 | def __eq__(self, other): 19 | if type(other) == int: 20 | return self.id == other 21 | return self.id == other.id 22 | 23 | def __repr__(self): 24 | return f'(ID: {self.id})' 25 | -------------------------------------------------------------------------------- /matanyone/inference/object_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Dict 2 | 3 | import torch 4 | from matanyone.inference.object_info import ObjectInfo 5 | 6 | 7 | class ObjectManager: 8 | """ 9 | Object IDs are immutable. The same ID always represent the same object. 10 | Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. 11 | Temporary IDs start from 1. 12 | """ 13 | 14 | def __init__(self): 15 | self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} 16 | self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} 17 | self.obj_id_to_obj: Dict[int, ObjectInfo] = {} 18 | 19 | self.all_historical_object_ids: List[int] = [] 20 | 21 | def _recompute_obj_id_to_obj_mapping(self) -> None: 22 | self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} 23 | 24 | def add_new_objects( 25 | self, objects: Union[List[ObjectInfo], ObjectInfo, 26 | List[int]]) -> (List[int], List[int]): 27 | if not isinstance(objects, list): 28 | objects = [objects] 29 | 30 | corresponding_tmp_ids = [] 31 | corresponding_obj_ids = [] 32 | for obj in objects: 33 | if isinstance(obj, int): 34 | obj = ObjectInfo(id=obj) 35 | 36 | if obj in self.obj_to_tmp_id: 37 | # old object 38 | corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) 39 | corresponding_obj_ids.append(obj.id) 40 | else: 41 | # new object 42 | new_obj = ObjectInfo(id=obj.id) 43 | 44 | # new object 45 | new_tmp_id = len(self.obj_to_tmp_id) + 1 46 | self.obj_to_tmp_id[new_obj] = new_tmp_id 47 | self.tmp_id_to_obj[new_tmp_id] = new_obj 48 | self.all_historical_object_ids.append(new_obj.id) 49 | corresponding_tmp_ids.append(new_tmp_id) 50 | corresponding_obj_ids.append(new_obj.id) 51 | 52 | self._recompute_obj_id_to_obj_mapping() 53 | assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) 54 | return corresponding_tmp_ids, corresponding_obj_ids 55 | 56 | def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None: 57 | # delete an object or a list of objects 58 | # re-sort the tmp ids 59 | if isinstance(obj_ids_to_remove, int): 60 | obj_ids_to_remove = [obj_ids_to_remove] 61 | 62 | new_tmp_id = 1 63 | total_num_id = len(self.obj_to_tmp_id) 64 | 65 | local_obj_to_tmp_id = {} 66 | local_tmp_to_obj_id = {} 67 | 68 | for tmp_iter in range(1, total_num_id + 1): 69 | obj = self.tmp_id_to_obj[tmp_iter] 70 | if obj.id not in obj_ids_to_remove: 71 | local_obj_to_tmp_id[obj] = new_tmp_id 72 | local_tmp_to_obj_id[new_tmp_id] = obj 73 | new_tmp_id += 1 74 | 75 | self.obj_to_tmp_id = local_obj_to_tmp_id 76 | self.tmp_id_to_obj = local_tmp_to_obj_id 77 | self._recompute_obj_id_to_obj_mapping() 78 | 79 | def purge_inactive_objects(self, 80 | max_missed_detection_count: int) -> (bool, List[int], List[int]): 81 | # remove tmp ids of objects that are removed 82 | obj_id_to_be_deleted = [] 83 | tmp_id_to_be_deleted = [] 84 | tmp_id_to_keep = [] 85 | obj_id_to_keep = [] 86 | 87 | for obj in self.obj_to_tmp_id: 88 | if obj.poke_count > max_missed_detection_count: 89 | obj_id_to_be_deleted.append(obj.id) 90 | tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) 91 | else: 92 | tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) 93 | obj_id_to_keep.append(obj.id) 94 | 95 | purge_activated = len(obj_id_to_be_deleted) > 0 96 | if purge_activated: 97 | self.delete_objects(obj_id_to_be_deleted) 98 | return purge_activated, tmp_id_to_keep, obj_id_to_keep 99 | 100 | def tmp_to_obj_cls(self, mask) -> torch.Tensor: 101 | # remap tmp id cls representation to the true object id representation 102 | new_mask = torch.zeros_like(mask) 103 | for tmp_id, obj in self.tmp_id_to_obj.items(): 104 | new_mask[mask == tmp_id] = obj.id 105 | return new_mask 106 | 107 | def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: 108 | # returns the mapping in a dict format for saving it with pickle 109 | return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} 110 | 111 | def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: 112 | # turns a dict indexed by obj id into a tensor, ordered by tmp IDs 113 | output = [] 114 | for _, obj in self.tmp_id_to_obj.items(): 115 | if obj.id not in obj_dict: 116 | raise NotImplementedError 117 | output.append(obj_dict[obj.id]) 118 | output = torch.stack(output, dim=dim) 119 | return output 120 | 121 | def make_one_hot(self, cls_mask) -> torch.Tensor: 122 | output = [] 123 | for _, obj in self.tmp_id_to_obj.items(): 124 | output.append(cls_mask == obj.id) 125 | if len(output) == 0: 126 | output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) 127 | else: 128 | output = torch.stack(output, dim=0) 129 | return output 130 | 131 | @property 132 | def all_obj_ids(self) -> List[int]: 133 | return [k.id for k in self.obj_to_tmp_id] 134 | 135 | @property 136 | def num_obj(self) -> int: 137 | return len(self.obj_to_tmp_id) 138 | 139 | def has_all(self, objects: List[int]) -> bool: 140 | for obj in objects: 141 | if obj not in self.obj_to_tmp_id: 142 | return False 143 | return True 144 | 145 | def find_object_by_id(self, obj_id) -> ObjectInfo: 146 | return self.obj_id_to_obj[obj_id] 147 | 148 | def find_tmp_by_id(self, obj_id) -> int: 149 | return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] 150 | -------------------------------------------------------------------------------- /matanyone/inference/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/inference/utils/__init__.py -------------------------------------------------------------------------------- /matanyone/inference/utils/args_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from omegaconf import DictConfig 3 | 4 | log = logging.getLogger() 5 | 6 | 7 | def get_dataset_cfg(cfg: DictConfig): 8 | dataset_name = cfg.dataset 9 | data_cfg = cfg.datasets[dataset_name] 10 | 11 | potential_overrides = [ 12 | 'image_directory', 13 | 'mask_directory', 14 | 'json_directory', 15 | 'size', 16 | 'save_all', 17 | 'use_all_masks', 18 | 'use_long_term', 19 | 'mem_every', 20 | ] 21 | 22 | for override in potential_overrides: 23 | if cfg[override] is not None: 24 | log.info(f'Overriding config {override} from {data_cfg[override]} to {cfg[override]}') 25 | data_cfg[override] = cfg[override] 26 | # escalte all potential overrides to the top-level config 27 | if override in data_cfg: 28 | cfg[override] = data_cfg[override] 29 | 30 | return data_cfg 31 | -------------------------------------------------------------------------------- /matanyone/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/model/__init__.py -------------------------------------------------------------------------------- /matanyone/model/aux_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | For computing auxiliary outputs for auxiliary losses 3 | """ 4 | from typing import Dict 5 | from omegaconf import DictConfig 6 | import torch 7 | import torch.nn as nn 8 | 9 | from matanyone.model.group_modules import GConv2d 10 | from matanyone.utils.tensor_utils import aggregate 11 | 12 | 13 | class LinearPredictor(nn.Module): 14 | def __init__(self, x_dim: int, pix_dim: int): 15 | super().__init__() 16 | self.projection = GConv2d(x_dim, pix_dim + 1, kernel_size=1) 17 | 18 | def forward(self, pix_feat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 19 | # pixel_feat: B*pix_dim*H*W 20 | # x: B*num_objects*x_dim*H*W 21 | num_objects = x.shape[1] 22 | x = self.projection(x) 23 | 24 | pix_feat = pix_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) 25 | logits = (pix_feat * x[:, :, :-1]).sum(dim=2) + x[:, :, -1] 26 | return logits 27 | 28 | 29 | class DirectPredictor(nn.Module): 30 | def __init__(self, x_dim: int): 31 | super().__init__() 32 | self.projection = GConv2d(x_dim, 1, kernel_size=1) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | # x: B*num_objects*x_dim*H*W 36 | logits = self.projection(x).squeeze(2) 37 | return logits 38 | 39 | 40 | class AuxComputer(nn.Module): 41 | def __init__(self, cfg: DictConfig): 42 | super().__init__() 43 | 44 | use_sensory_aux = cfg.model.aux_loss.sensory.enabled 45 | self.use_query_aux = cfg.model.aux_loss.query.enabled 46 | self.use_sensory_aux = use_sensory_aux 47 | 48 | sensory_dim = cfg.model.sensory_dim 49 | embed_dim = cfg.model.embed_dim 50 | 51 | if use_sensory_aux: 52 | self.sensory_aux = LinearPredictor(sensory_dim, embed_dim) 53 | 54 | def _aggregate_with_selector(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.Tensor: 55 | prob = torch.sigmoid(logits) 56 | if selector is not None: 57 | prob = prob * selector 58 | logits = aggregate(prob, dim=1) 59 | return logits 60 | 61 | def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor], 62 | selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]: 63 | sensory = aux_input['sensory'] 64 | q_logits = aux_input['q_logits'] 65 | 66 | aux_output = {} 67 | aux_output['attn_mask'] = aux_input['attn_mask'] 68 | 69 | if self.use_sensory_aux: 70 | # B*num_objects*H*W 71 | logits = self.sensory_aux(pix_feat, sensory) 72 | aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector) 73 | if self.use_query_aux: 74 | # B*num_objects*num_levels*H*W 75 | aux_output['q_logits'] = self._aggregate_with_selector( 76 | torch.stack(q_logits, dim=2), 77 | selector.unsqueeze(2) if selector is not None else None) 78 | 79 | return aux_output 80 | 81 | def compute_mask(self, aux_input: Dict[str, torch.Tensor], 82 | selector: torch.Tensor) -> Dict[str, torch.Tensor]: 83 | # sensory = aux_input['sensory'] 84 | q_logits = aux_input['q_logits'] 85 | 86 | aux_output = {} 87 | 88 | # B*num_objects*num_levels*H*W 89 | aux_output['q_logits'] = self._aggregate_with_selector( 90 | torch.stack(q_logits, dim=2), 91 | selector.unsqueeze(2) if selector is not None else None) 92 | 93 | return aux_output -------------------------------------------------------------------------------- /matanyone/model/big_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | big_modules.py - This file stores higher-level network blocks. 3 | 4 | x - usually denotes features that are shared between objects. 5 | g - usually denotes features that are not shared between objects 6 | with an extra "num_objects" dimension (batch_size * num_objects * num_channels * H * W). 7 | 8 | The trailing number of a variable usually denotes the stride 9 | """ 10 | 11 | from typing import Iterable 12 | from omegaconf import DictConfig 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from matanyone.model.group_modules import MainToGroupDistributor, GroupFeatureFusionBlock, GConv2d 18 | from matanyone.model.utils import resnet 19 | from matanyone.model.modules import SensoryDeepUpdater, SensoryUpdater_fullscale, DecoderFeatureProcessor, MaskUpsampleBlock 20 | 21 | class UncertPred(nn.Module): 22 | def __init__(self, model_cfg: DictConfig): 23 | super().__init__() 24 | self.conv1x1_v2 = nn.Conv2d(model_cfg.pixel_dim*2 + 1 + model_cfg.value_dim, 64, kernel_size=1, stride=1, bias=False) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv3x3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) 28 | self.bn2 = nn.BatchNorm2d(32) 29 | self.conv3x3_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, groups=1, bias=False, dilation=1) 30 | 31 | def forward(self, last_frame_feat: torch.Tensor, cur_frame_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor): 32 | last_mask = F.interpolate(last_mask, size=last_frame_feat.shape[-2:], mode='area') 33 | x = torch.cat([last_frame_feat, cur_frame_feat, last_mask, mem_val_diff], dim=1) 34 | x = self.conv1x1_v2(x) 35 | x = self.bn1(x) 36 | x = self.relu(x) 37 | x = self.conv3x3(x) 38 | x = self.bn2(x) 39 | x = self.relu(x) 40 | x = self.conv3x3_out(x) 41 | return x 42 | 43 | # override the default train() to freeze BN statistics 44 | def train(self, mode=True): 45 | self.training = False 46 | for module in self.children(): 47 | module.train(False) 48 | return self 49 | 50 | class PixelEncoder(nn.Module): 51 | def __init__(self, model_cfg: DictConfig): 52 | super().__init__() 53 | 54 | self.is_resnet = 'resnet' in model_cfg.pixel_encoder.type 55 | # if model_cfg.pretrained_resnet is set in the model_cfg we get the value 56 | # else default to True 57 | is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) 58 | if self.is_resnet: 59 | if model_cfg.pixel_encoder.type == 'resnet18': 60 | network = resnet.resnet18(pretrained=is_pretrained_resnet) 61 | elif model_cfg.pixel_encoder.type == 'resnet50': 62 | network = resnet.resnet50(pretrained=is_pretrained_resnet) 63 | else: 64 | raise NotImplementedError 65 | self.conv1 = network.conv1 66 | self.bn1 = network.bn1 67 | self.relu = network.relu 68 | self.maxpool = network.maxpool 69 | 70 | self.res2 = network.layer1 71 | self.layer2 = network.layer2 72 | self.layer3 = network.layer3 73 | else: 74 | raise NotImplementedError 75 | 76 | def forward(self, x: torch.Tensor, seq_length=None) -> (torch.Tensor, torch.Tensor, torch.Tensor): 77 | f1 = x 78 | x = self.conv1(x) 79 | x = self.bn1(x) 80 | x = self.relu(x) 81 | f2 = x 82 | x = self.maxpool(x) 83 | f4 = self.res2(x) 84 | f8 = self.layer2(f4) 85 | f16 = self.layer3(f8) 86 | 87 | return f16, f8, f4, f2, f1 88 | 89 | # override the default train() to freeze BN statistics 90 | def train(self, mode=True): 91 | self.training = False 92 | for module in self.children(): 93 | module.train(False) 94 | return self 95 | 96 | 97 | class KeyProjection(nn.Module): 98 | def __init__(self, model_cfg: DictConfig): 99 | super().__init__() 100 | in_dim = model_cfg.pixel_encoder.ms_dims[0] 101 | mid_dim = model_cfg.pixel_dim 102 | key_dim = model_cfg.key_dim 103 | 104 | self.pix_feat_proj = nn.Conv2d(in_dim, mid_dim, kernel_size=1) 105 | self.key_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) 106 | # shrinkage 107 | self.d_proj = nn.Conv2d(mid_dim, 1, kernel_size=3, padding=1) 108 | # selection 109 | self.e_proj = nn.Conv2d(mid_dim, key_dim, kernel_size=3, padding=1) 110 | 111 | nn.init.orthogonal_(self.key_proj.weight.data) 112 | nn.init.zeros_(self.key_proj.bias.data) 113 | 114 | def forward(self, x: torch.Tensor, *, need_s: bool, 115 | need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): 116 | x = self.pix_feat_proj(x) 117 | shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None 118 | selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None 119 | 120 | return self.key_proj(x), shrinkage, selection 121 | 122 | 123 | class MaskEncoder(nn.Module): 124 | def __init__(self, model_cfg: DictConfig, single_object=False): 125 | super().__init__() 126 | pixel_dim = model_cfg.pixel_dim 127 | value_dim = model_cfg.value_dim 128 | sensory_dim = model_cfg.sensory_dim 129 | final_dim = model_cfg.mask_encoder.final_dim 130 | 131 | self.single_object = single_object 132 | extra_dim = 1 if single_object else 2 133 | 134 | # if model_cfg.pretrained_resnet is set in the model_cfg we get the value 135 | # else default to True 136 | is_pretrained_resnet = getattr(model_cfg,"pretrained_resnet",True) 137 | if model_cfg.mask_encoder.type == 'resnet18': 138 | network = resnet.resnet18(pretrained=is_pretrained_resnet, extra_dim=extra_dim) 139 | elif model_cfg.mask_encoder.type == 'resnet50': 140 | network = resnet.resnet50(pretrained=is_pretrained_resnet, extra_dim=extra_dim) 141 | else: 142 | raise NotImplementedError 143 | self.conv1 = network.conv1 144 | self.bn1 = network.bn1 145 | self.relu = network.relu 146 | self.maxpool = network.maxpool 147 | 148 | self.layer1 = network.layer1 149 | self.layer2 = network.layer2 150 | self.layer3 = network.layer3 151 | 152 | self.distributor = MainToGroupDistributor() 153 | self.fuser = GroupFeatureFusionBlock(pixel_dim, final_dim, value_dim) 154 | 155 | self.sensory_update = SensoryDeepUpdater(value_dim, sensory_dim) 156 | 157 | def forward(self, 158 | image: torch.Tensor, 159 | pix_feat: torch.Tensor, 160 | sensory: torch.Tensor, 161 | masks: torch.Tensor, 162 | others: torch.Tensor, 163 | *, 164 | deep_update: bool = True, 165 | chunk_size: int = -1) -> (torch.Tensor, torch.Tensor): 166 | # ms_features are from the key encoder 167 | # we only use the first one (lowest resolution), following XMem 168 | if self.single_object: 169 | g = masks.unsqueeze(2) 170 | else: 171 | g = torch.stack([masks, others], dim=2) 172 | 173 | g = self.distributor(image, g) 174 | 175 | batch_size, num_objects = g.shape[:2] 176 | if chunk_size < 1 or chunk_size >= num_objects: 177 | chunk_size = num_objects 178 | fast_path = True 179 | new_sensory = sensory 180 | else: 181 | if deep_update: 182 | new_sensory = torch.empty_like(sensory) 183 | else: 184 | new_sensory = sensory 185 | fast_path = False 186 | 187 | # chunk-by-chunk inference 188 | all_g = [] 189 | for i in range(0, num_objects, chunk_size): 190 | if fast_path: 191 | g_chunk = g 192 | else: 193 | g_chunk = g[:, i:i + chunk_size] 194 | actual_chunk_size = g_chunk.shape[1] 195 | g_chunk = g_chunk.flatten(start_dim=0, end_dim=1) 196 | 197 | g_chunk = self.conv1(g_chunk) 198 | g_chunk = self.bn1(g_chunk) # 1/2, 64 199 | g_chunk = self.maxpool(g_chunk) # 1/4, 64 200 | g_chunk = self.relu(g_chunk) 201 | 202 | g_chunk = self.layer1(g_chunk) # 1/4 203 | g_chunk = self.layer2(g_chunk) # 1/8 204 | g_chunk = self.layer3(g_chunk) # 1/16 205 | 206 | g_chunk = g_chunk.view(batch_size, actual_chunk_size, *g_chunk.shape[1:]) 207 | g_chunk = self.fuser(pix_feat, g_chunk) 208 | all_g.append(g_chunk) 209 | if deep_update: 210 | if fast_path: 211 | new_sensory = self.sensory_update(g_chunk, sensory) 212 | else: 213 | new_sensory[:, i:i + chunk_size] = self.sensory_update( 214 | g_chunk, sensory[:, i:i + chunk_size]) 215 | g = torch.cat(all_g, dim=1) 216 | 217 | return g, new_sensory 218 | 219 | # override the default train() to freeze BN statistics 220 | def train(self, mode=True): 221 | self.training = False 222 | for module in self.children(): 223 | module.train(False) 224 | return self 225 | 226 | 227 | class PixelFeatureFuser(nn.Module): 228 | def __init__(self, model_cfg: DictConfig, single_object=False): 229 | super().__init__() 230 | value_dim = model_cfg.value_dim 231 | sensory_dim = model_cfg.sensory_dim 232 | pixel_dim = model_cfg.pixel_dim 233 | embed_dim = model_cfg.embed_dim 234 | self.single_object = single_object 235 | 236 | self.fuser = GroupFeatureFusionBlock(pixel_dim, value_dim, embed_dim) 237 | if self.single_object: 238 | self.sensory_compress = GConv2d(sensory_dim + 1, value_dim, kernel_size=1) 239 | else: 240 | self.sensory_compress = GConv2d(sensory_dim + 2, value_dim, kernel_size=1) 241 | 242 | def forward(self, 243 | pix_feat: torch.Tensor, 244 | pixel_memory: torch.Tensor, 245 | sensory_memory: torch.Tensor, 246 | last_mask: torch.Tensor, 247 | last_others: torch.Tensor, 248 | *, 249 | chunk_size: int = -1) -> torch.Tensor: 250 | batch_size, num_objects = pixel_memory.shape[:2] 251 | 252 | if self.single_object: 253 | last_mask = last_mask.unsqueeze(2) 254 | else: 255 | last_mask = torch.stack([last_mask, last_others], dim=2) 256 | 257 | if chunk_size < 1: 258 | chunk_size = num_objects 259 | 260 | # chunk-by-chunk inference 261 | all_p16 = [] 262 | for i in range(0, num_objects, chunk_size): 263 | sensory_readout = self.sensory_compress( 264 | torch.cat([sensory_memory[:, i:i + chunk_size], last_mask[:, i:i + chunk_size]], 2)) 265 | p16 = pixel_memory[:, i:i + chunk_size] + sensory_readout 266 | p16 = self.fuser(pix_feat, p16) 267 | all_p16.append(p16) 268 | p16 = torch.cat(all_p16, dim=1) 269 | 270 | return p16 271 | 272 | 273 | class MaskDecoder(nn.Module): 274 | def __init__(self, model_cfg: DictConfig): 275 | super().__init__() 276 | embed_dim = model_cfg.embed_dim 277 | sensory_dim = model_cfg.sensory_dim 278 | ms_image_dims = model_cfg.pixel_encoder.ms_dims 279 | up_dims = model_cfg.mask_decoder.up_dims 280 | 281 | assert embed_dim == up_dims[0] 282 | 283 | self.sensory_update = SensoryUpdater_fullscale([up_dims[0], up_dims[1], up_dims[2], up_dims[3], up_dims[4] + 1], sensory_dim, 284 | sensory_dim) 285 | 286 | self.decoder_feat_proc = DecoderFeatureProcessor(ms_image_dims[1:], up_dims[:-1]) 287 | self.up_16_8 = MaskUpsampleBlock(up_dims[0], up_dims[1]) 288 | self.up_8_4 = MaskUpsampleBlock(up_dims[1], up_dims[2]) 289 | # newly add for alpha matte 290 | self.up_4_2 = MaskUpsampleBlock(up_dims[2], up_dims[3]) 291 | self.up_2_1 = MaskUpsampleBlock(up_dims[3], up_dims[4]) 292 | 293 | self.pred_seg = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) 294 | self.pred_mat = nn.Conv2d(up_dims[-1], 1, kernel_size=3, padding=1) 295 | 296 | def forward(self, 297 | ms_image_feat: Iterable[torch.Tensor], 298 | memory_readout: torch.Tensor, 299 | sensory: torch.Tensor, 300 | *, 301 | chunk_size: int = -1, 302 | update_sensory: bool = True, 303 | seg_pass: bool = False, 304 | last_mask=None, 305 | sigmoid_residual=False) -> (torch.Tensor, torch.Tensor): 306 | 307 | batch_size, num_objects = memory_readout.shape[:2] 308 | f8, f4, f2, f1 = self.decoder_feat_proc(ms_image_feat[1:]) 309 | if chunk_size < 1 or chunk_size >= num_objects: 310 | chunk_size = num_objects 311 | fast_path = True 312 | new_sensory = sensory 313 | else: 314 | if update_sensory: 315 | new_sensory = torch.empty_like(sensory) 316 | else: 317 | new_sensory = sensory 318 | fast_path = False 319 | 320 | # chunk-by-chunk inference 321 | all_logits = [] 322 | for i in range(0, num_objects, chunk_size): 323 | if fast_path: 324 | p16 = memory_readout 325 | else: 326 | p16 = memory_readout[:, i:i + chunk_size] 327 | actual_chunk_size = p16.shape[1] 328 | 329 | p8 = self.up_16_8(p16, f8) 330 | p4 = self.up_8_4(p8, f4) 331 | p2 = self.up_4_2(p4, f2) 332 | p1 = self.up_2_1(p2, f1) 333 | with torch.amp.autocast("cuda",enabled=False): 334 | if seg_pass: 335 | if last_mask is not None: 336 | res = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) 337 | if sigmoid_residual: 338 | res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask 339 | logits = last_mask + res 340 | else: 341 | logits = self.pred_seg(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) 342 | else: 343 | if last_mask is not None: 344 | res = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) 345 | if sigmoid_residual: 346 | res = (torch.sigmoid(res) - 0.5) * 2 # regularization: (-1, 1) change on last mask 347 | logits = last_mask + res 348 | else: 349 | logits = self.pred_mat(F.relu(p1.flatten(start_dim=0, end_dim=1).float())) 350 | ## SensoryUpdater_fullscale 351 | if update_sensory: 352 | p1 = torch.cat( 353 | [p1, logits.view(batch_size, actual_chunk_size, 1, *logits.shape[-2:])], 2) 354 | if fast_path: 355 | new_sensory = self.sensory_update([p16, p8, p4, p2, p1], sensory) 356 | else: 357 | new_sensory[:, 358 | i:i + chunk_size] = self.sensory_update([p16, p8, p4, p2, p1], 359 | sensory[:, 360 | i:i + chunk_size]) 361 | all_logits.append(logits) 362 | logits = torch.cat(all_logits, dim=0) 363 | logits = logits.view(batch_size, num_objects, *logits.shape[-2:]) 364 | 365 | return new_sensory, logits 366 | -------------------------------------------------------------------------------- /matanyone/model/channel_attn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class CAResBlock(nn.Module): 8 | def __init__(self, in_dim: int, out_dim: int, residual: bool = True): 9 | super().__init__() 10 | self.residual = residual 11 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) 12 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) 13 | 14 | t = int((abs(math.log2(out_dim)) + 1) // 2) 15 | k = t if t % 2 else t + 1 16 | self.pool = nn.AdaptiveAvgPool2d(1) 17 | self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False) 18 | 19 | if self.residual: 20 | if in_dim == out_dim: 21 | self.downsample = nn.Identity() 22 | else: 23 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | r = x 27 | x = self.conv1(F.relu(x)) 28 | x = self.conv2(F.relu(x)) 29 | 30 | b, c = x.shape[:2] 31 | w = self.pool(x).view(b, 1, c) 32 | w = self.conv(w).transpose(-1, -2).unsqueeze(-1).sigmoid() # B*C*1*1 33 | 34 | if self.residual: 35 | x = x * w + self.downsample(r) 36 | else: 37 | x = x * w 38 | 39 | return x 40 | -------------------------------------------------------------------------------- /matanyone/model/group_modules.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from matanyone.model.channel_attn import CAResBlock 6 | 7 | def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, 8 | align_corners: bool) -> torch.Tensor: 9 | batch_size, num_objects = g.shape[:2] 10 | g = F.interpolate(g.flatten(start_dim=0, end_dim=1), 11 | scale_factor=ratio, 12 | mode=mode, 13 | align_corners=align_corners) 14 | g = g.view(batch_size, num_objects, *g.shape[1:]) 15 | return g 16 | 17 | 18 | def upsample_groups(g: torch.Tensor, 19 | ratio: float = 2, 20 | mode: str = 'bilinear', 21 | align_corners: bool = False) -> torch.Tensor: 22 | return interpolate_groups(g, ratio, mode, align_corners) 23 | 24 | 25 | def downsample_groups(g: torch.Tensor, 26 | ratio: float = 1 / 2, 27 | mode: str = 'area', 28 | align_corners: bool = None) -> torch.Tensor: 29 | return interpolate_groups(g, ratio, mode, align_corners) 30 | 31 | 32 | class GConv2d(nn.Conv2d): 33 | def forward(self, g: torch.Tensor) -> torch.Tensor: 34 | batch_size, num_objects = g.shape[:2] 35 | g = super().forward(g.flatten(start_dim=0, end_dim=1)) 36 | return g.view(batch_size, num_objects, *g.shape[1:]) 37 | 38 | 39 | class GroupResBlock(nn.Module): 40 | def __init__(self, in_dim: int, out_dim: int): 41 | super().__init__() 42 | 43 | if in_dim == out_dim: 44 | self.downsample = nn.Identity() 45 | else: 46 | self.downsample = GConv2d(in_dim, out_dim, kernel_size=1) 47 | 48 | self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1) 49 | self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1) 50 | 51 | def forward(self, g: torch.Tensor) -> torch.Tensor: 52 | out_g = self.conv1(F.relu(g)) 53 | out_g = self.conv2(F.relu(out_g)) 54 | 55 | g = self.downsample(g) 56 | 57 | return out_g + g 58 | 59 | 60 | class MainToGroupDistributor(nn.Module): 61 | def __init__(self, 62 | x_transform: Optional[nn.Module] = None, 63 | g_transform: Optional[nn.Module] = None, 64 | method: str = 'cat', 65 | reverse_order: bool = False): 66 | super().__init__() 67 | 68 | self.x_transform = x_transform 69 | self.g_transform = g_transform 70 | self.method = method 71 | self.reverse_order = reverse_order 72 | 73 | def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: 74 | num_objects = g.shape[1] 75 | 76 | if self.x_transform is not None: 77 | x = self.x_transform(x) 78 | 79 | if self.g_transform is not None: 80 | g = self.g_transform(g) 81 | 82 | if not skip_expand: 83 | x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) 84 | if self.method == 'cat': 85 | if self.reverse_order: 86 | g = torch.cat([g, x], 2) 87 | else: 88 | g = torch.cat([x, g], 2) 89 | elif self.method == 'add': 90 | g = x + g 91 | elif self.method == 'mulcat': 92 | g = torch.cat([x * g, g], dim=2) 93 | elif self.method == 'muladd': 94 | g = x * g + g 95 | else: 96 | raise NotImplementedError 97 | 98 | return g 99 | 100 | 101 | class GroupFeatureFusionBlock(nn.Module): 102 | def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int): 103 | super().__init__() 104 | 105 | x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1) 106 | g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1) 107 | 108 | self.distributor = MainToGroupDistributor(x_transform=x_transform, 109 | g_transform=g_transform, 110 | method='add') 111 | self.block1 = CAResBlock(out_dim, out_dim) 112 | self.block2 = CAResBlock(out_dim, out_dim) 113 | 114 | def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: 115 | batch_size, num_objects = g.shape[:2] 116 | 117 | g = self.distributor(x, g) 118 | 119 | g = g.flatten(start_dim=0, end_dim=1) 120 | 121 | g = self.block1(g) 122 | g = self.block2(g) 123 | 124 | g = g.view(batch_size, num_objects, *g.shape[1:]) 125 | 126 | return g -------------------------------------------------------------------------------- /matanyone/model/matanyone.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Iterable, Tuple 2 | import logging 3 | from omegaconf import DictConfig 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from omegaconf import OmegaConf 8 | from huggingface_hub import PyTorchModelHubMixin 9 | 10 | from matanyone.model.big_modules import PixelEncoder, UncertPred, KeyProjection, MaskEncoder, PixelFeatureFuser, MaskDecoder 11 | from matanyone.model.aux_modules import AuxComputer 12 | from matanyone.model.utils.memory_utils import get_affinity, readout 13 | from matanyone.model.transformer.object_transformer import QueryTransformer 14 | from matanyone.model.transformer.object_summarizer import ObjectSummarizer 15 | from matanyone.utils.tensor_utils import aggregate 16 | 17 | log = logging.getLogger() 18 | class MatAnyone(nn.Module, 19 | PyTorchModelHubMixin, 20 | library_name="matanyone", 21 | repo_url="https://github.com/pq-yang/MatAnyone", 22 | coders={ 23 | DictConfig: ( 24 | lambda x: OmegaConf.to_container(x), 25 | lambda data: OmegaConf.create(data), 26 | ) 27 | }, 28 | ): 29 | 30 | def __init__(self, cfg: DictConfig, *, single_object=False): 31 | super().__init__() 32 | self.cfg = cfg 33 | model_cfg = cfg.model 34 | self.ms_dims = model_cfg.pixel_encoder.ms_dims 35 | self.key_dim = model_cfg.key_dim 36 | self.value_dim = model_cfg.value_dim 37 | self.sensory_dim = model_cfg.sensory_dim 38 | self.pixel_dim = model_cfg.pixel_dim 39 | self.embed_dim = model_cfg.embed_dim 40 | self.single_object = single_object 41 | 42 | log.info(f'Single object: {self.single_object}') 43 | 44 | self.pixel_encoder = PixelEncoder(model_cfg) 45 | self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1) 46 | self.key_proj = KeyProjection(model_cfg) 47 | self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object) 48 | self.mask_decoder = MaskDecoder(model_cfg) 49 | self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object) 50 | self.object_transformer = QueryTransformer(model_cfg) 51 | self.object_summarizer = ObjectSummarizer(model_cfg) 52 | self.aux_computer = AuxComputer(cfg) 53 | self.temp_sparity = UncertPred(model_cfg) 54 | 55 | self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False) 56 | self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False) 57 | 58 | def _get_others(self, masks: torch.Tensor) -> torch.Tensor: 59 | # for each object, return the sum of masks of all other objects 60 | if self.single_object: 61 | return None 62 | 63 | num_objects = masks.shape[1] 64 | if num_objects >= 1: 65 | others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1) 66 | else: 67 | others = torch.zeros_like(masks) 68 | return others 69 | 70 | def pred_uncertainty(self, last_pix_feat: torch.Tensor, cur_pix_feat: torch.Tensor, last_mask: torch.Tensor, mem_val_diff:torch.Tensor): 71 | logits = self.temp_sparity(last_frame_feat=last_pix_feat, 72 | cur_frame_feat=cur_pix_feat, 73 | last_mask=last_mask, 74 | mem_val_diff=mem_val_diff) 75 | 76 | prob = torch.sigmoid(logits) 77 | mask = (prob > 0) + 0 78 | 79 | uncert_output = {"logits": logits, 80 | "prob": prob, 81 | "mask": mask} 82 | 83 | return uncert_output 84 | 85 | def encode_image(self, image: torch.Tensor, seq_length=None, last_feats=None) -> (Iterable[torch.Tensor], torch.Tensor): # type: ignore 86 | image = (image - self.pixel_mean) / self.pixel_std 87 | ms_image_feat = self.pixel_encoder(image, seq_length) # f16, f8, f4, f2, f1 88 | return ms_image_feat, self.pix_feat_proj(ms_image_feat[0]) 89 | 90 | def encode_mask( 91 | self, 92 | image: torch.Tensor, 93 | ms_features: List[torch.Tensor], 94 | sensory: torch.Tensor, 95 | masks: torch.Tensor, 96 | *, 97 | deep_update: bool = True, 98 | chunk_size: int = -1, 99 | need_weights: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 100 | image = (image - self.pixel_mean) / self.pixel_std 101 | others = self._get_others(masks) 102 | mask_value, new_sensory = self.mask_encoder(image, 103 | ms_features, 104 | sensory, 105 | masks, 106 | others, 107 | deep_update=deep_update, 108 | chunk_size=chunk_size) 109 | object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights) 110 | return mask_value, new_sensory, object_summaries, object_logits 111 | 112 | def transform_key(self, 113 | final_pix_feat: torch.Tensor, 114 | *, 115 | need_sk: bool = True, 116 | need_ek: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 117 | key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek) 118 | return key, shrinkage, selection 119 | 120 | # Used in training only. 121 | # This step is replaced by MemoryManager in test time 122 | def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor, 123 | memory_key: torch.Tensor, memory_shrinkage: torch.Tensor, 124 | msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor, 125 | sensory: torch.Tensor, last_mask: torch.Tensor, 126 | selector: torch.Tensor, uncert_output=None, seg_pass=False, 127 | last_pix_feat=None, last_pred_mask=None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 128 | """ 129 | query_key : B * CK * H * W 130 | query_selection : B * CK * H * W 131 | memory_key : B * CK * T * H * W 132 | memory_shrinkage: B * 1 * T * H * W 133 | msk_value : B * num_objects * CV * T * H * W 134 | obj_memory : B * num_objects * T * num_summaries * C 135 | pixel_feature : B * C * H * W 136 | """ 137 | batch_size, num_objects = msk_value.shape[:2] 138 | 139 | uncert_mask = uncert_output["mask"] if uncert_output is not None else None 140 | 141 | # read using visual attention 142 | with torch.amp.autocast("cuda",enabled=False): 143 | affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(), 144 | query_selection.float(), uncert_mask=uncert_mask) 145 | 146 | msk_value = msk_value.flatten(start_dim=1, end_dim=2).float() 147 | 148 | # B * (num_objects*CV) * H * W 149 | pixel_readout = readout(affinity, msk_value, uncert_mask) 150 | pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim, 151 | *pixel_readout.shape[-2:]) 152 | 153 | uncert_output = self.pred_uncertainty(last_pix_feat, pix_feat, last_pred_mask, pixel_readout[:,0]-msk_value[:,:,-1]) 154 | uncert_prob = uncert_output["prob"].unsqueeze(1) # b n 1 h w 155 | pixel_readout = pixel_readout*uncert_prob + msk_value[:,:,-1].unsqueeze(1)*(1-uncert_prob) 156 | 157 | pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) 158 | 159 | 160 | # read from query transformer 161 | mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass) 162 | 163 | aux_output = { 164 | 'sensory': sensory, 165 | 'q_logits': aux_features['logits'] if aux_features else None, 166 | 'attn_mask': aux_features['attn_mask'] if aux_features else None, 167 | } 168 | 169 | return mem_readout, aux_output, uncert_output 170 | 171 | def read_first_frame_memory(self, pixel_readout, 172 | obj_memory: torch.Tensor, pix_feat: torch.Tensor, 173 | sensory: torch.Tensor, last_mask: torch.Tensor, 174 | selector: torch.Tensor, seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 175 | """ 176 | query_key : B * CK * H * W 177 | query_selection : B * CK * H * W 178 | memory_key : B * CK * T * H * W 179 | memory_shrinkage: B * 1 * T * H * W 180 | msk_value : B * num_objects * CV * T * H * W 181 | obj_memory : B * num_objects * T * num_summaries * C 182 | pixel_feature : B * C * H * W 183 | """ 184 | 185 | pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) 186 | 187 | # read from query transformer 188 | mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector, seg_pass=seg_pass) 189 | 190 | aux_output = { 191 | 'sensory': sensory, 192 | 'q_logits': aux_features['logits'] if aux_features else None, 193 | 'attn_mask': aux_features['attn_mask'] if aux_features else None, 194 | } 195 | 196 | return mem_readout, aux_output 197 | 198 | def pixel_fusion(self, 199 | pix_feat: torch.Tensor, 200 | pixel: torch.Tensor, 201 | sensory: torch.Tensor, 202 | last_mask: torch.Tensor, 203 | *, 204 | chunk_size: int = -1) -> torch.Tensor: 205 | last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area') 206 | last_others = self._get_others(last_mask) 207 | fused = self.pixel_fuser(pix_feat, 208 | pixel, 209 | sensory, 210 | last_mask, 211 | last_others, 212 | chunk_size=chunk_size) 213 | return fused 214 | 215 | def readout_query(self, 216 | pixel_readout, 217 | obj_memory, 218 | *, 219 | selector=None, 220 | need_weights=False, 221 | seg_pass=False) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 222 | return self.object_transformer(pixel_readout, 223 | obj_memory, 224 | selector=selector, 225 | need_weights=need_weights, 226 | seg_pass=seg_pass) 227 | 228 | def segment(self, 229 | ms_image_feat: List[torch.Tensor], 230 | memory_readout: torch.Tensor, 231 | sensory: torch.Tensor, 232 | *, 233 | selector: bool = None, 234 | chunk_size: int = -1, 235 | update_sensory: bool = True, 236 | seg_pass: bool = False, 237 | clamp_mat: bool = True, 238 | last_mask=None, 239 | sigmoid_residual=False, 240 | seg_mat=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 241 | """ 242 | multi_scale_features is from the key encoder for skip-connection 243 | memory_readout is from working/long-term memory 244 | sensory is the sensory memory 245 | last_mask is the mask from the last frame, supplementing sensory memory 246 | selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects 247 | during training. 248 | """ 249 | #### use mat head for seg data 250 | if seg_mat: 251 | assert seg_pass 252 | seg_pass = False 253 | #### 254 | sensory, logits = self.mask_decoder(ms_image_feat, 255 | memory_readout, 256 | sensory, 257 | chunk_size=chunk_size, 258 | update_sensory=update_sensory, 259 | seg_pass = seg_pass, 260 | last_mask=last_mask, 261 | sigmoid_residual=sigmoid_residual) 262 | if seg_pass: 263 | prob = torch.sigmoid(logits) 264 | if selector is not None: 265 | prob = prob * selector 266 | 267 | # Softmax over all objects[] 268 | logits = aggregate(prob, dim=1) 269 | prob = F.softmax(logits, dim=1) 270 | else: 271 | if clamp_mat: 272 | logits = logits.clamp(0.0, 1.0) 273 | logits = torch.cat([torch.prod(1 - logits, dim=1, keepdim=True), logits], 1) 274 | prob = logits 275 | 276 | return sensory, logits, prob 277 | 278 | def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor], 279 | selector: torch.Tensor, seg_pass=False) -> Dict[str, torch.Tensor]: 280 | return self.aux_computer(pix_feat, aux_inputs, selector, seg_pass=seg_pass) 281 | 282 | def forward(self, *args, **kwargs): 283 | raise NotImplementedError 284 | 285 | def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None: 286 | if not self.single_object: 287 | # Map single-object weight to multi-object weight (4->5 out channels in conv1) 288 | for k in list(src_dict.keys()): 289 | if k == 'mask_encoder.conv1.weight': 290 | if src_dict[k].shape[1] == 4: 291 | log.info(f'Converting {k} from single object to multiple objects.') 292 | pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) 293 | if not init_as_zero_if_needed: 294 | nn.init.orthogonal_(pads) 295 | log.info(f'Randomly initialized padding for {k}.') 296 | else: 297 | log.info(f'Zero-initialized padding for {k}.') 298 | src_dict[k] = torch.cat([src_dict[k], pads], 1) 299 | elif k == 'pixel_fuser.sensory_compress.weight': 300 | if src_dict[k].shape[1] == self.sensory_dim + 1: 301 | log.info(f'Converting {k} from single object to multiple objects.') 302 | pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device) 303 | if not init_as_zero_if_needed: 304 | nn.init.orthogonal_(pads) 305 | log.info(f'Randomly initialized padding for {k}.') 306 | else: 307 | log.info(f'Zero-initialized padding for {k}.') 308 | src_dict[k] = torch.cat([src_dict[k], pads], 1) 309 | elif self.single_object: 310 | """ 311 | If the model is multiple-object and we are training in single-object, 312 | we strip the last channel of conv1. 313 | This is not supposed to happen in standard training except when users are trying to 314 | finetune a trained model with single object datasets. 315 | """ 316 | if src_dict['mask_encoder.conv1.weight'].shape[1] == 5: 317 | log.warning('Converting mask_encoder.conv1.weight from multiple objects to single object.' 318 | 'This is not supposed to happen in standard training.') 319 | src_dict['mask_encoder.conv1.weight'] = src_dict['mask_encoder.conv1.weight'][:, :-1] 320 | src_dict['pixel_fuser.sensory_compress.weight'] = src_dict['pixel_fuser.sensory_compress.weight'][:, :-1] 321 | 322 | for k in src_dict: 323 | if k not in self.state_dict(): 324 | log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!') 325 | for k in self.state_dict(): 326 | if k not in src_dict: 327 | log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!') 328 | 329 | self.load_state_dict(src_dict, strict=False) 330 | 331 | @property 332 | def device(self) -> torch.device: 333 | return self.pixel_mean.device 334 | -------------------------------------------------------------------------------- /matanyone/model/modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterable 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from matanyone.model.group_modules import MainToGroupDistributor, GroupResBlock, upsample_groups, GConv2d, downsample_groups 7 | 8 | 9 | class UpsampleBlock(nn.Module): 10 | def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): 11 | super().__init__() 12 | self.out_conv = ResBlock(in_dim, out_dim) 13 | self.scale_factor = scale_factor 14 | 15 | def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: 16 | g = F.interpolate(in_g, 17 | scale_factor=self.scale_factor, 18 | mode='bilinear') 19 | g = self.out_conv(g) 20 | g = g + skip_f 21 | return g 22 | 23 | class MaskUpsampleBlock(nn.Module): 24 | def __init__(self, in_dim: int, out_dim: int, scale_factor: int = 2): 25 | super().__init__() 26 | self.distributor = MainToGroupDistributor(method='add') 27 | self.out_conv = GroupResBlock(in_dim, out_dim) 28 | self.scale_factor = scale_factor 29 | 30 | def forward(self, in_g: torch.Tensor, skip_f: torch.Tensor) -> torch.Tensor: 31 | g = upsample_groups(in_g, ratio=self.scale_factor) 32 | g = self.distributor(skip_f, g) 33 | g = self.out_conv(g) 34 | return g 35 | 36 | 37 | class DecoderFeatureProcessor(nn.Module): 38 | def __init__(self, decoder_dims: List[int], out_dims: List[int]): 39 | super().__init__() 40 | self.transforms = nn.ModuleList([ 41 | nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims) 42 | ]) 43 | 44 | def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]: 45 | outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)] 46 | return outputs 47 | 48 | 49 | # @torch.jit.script 50 | def _recurrent_update(h: torch.Tensor, values: torch.Tensor) -> torch.Tensor: 51 | # h: batch_size * num_objects * hidden_dim * h * w 52 | # values: batch_size * num_objects * (hidden_dim*3) * h * w 53 | dim = values.shape[2] // 3 54 | forget_gate = torch.sigmoid(values[:, :, :dim]) 55 | update_gate = torch.sigmoid(values[:, :, dim:dim * 2]) 56 | new_value = torch.tanh(values[:, :, dim * 2:]) 57 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value 58 | return new_h 59 | 60 | 61 | class SensoryUpdater_fullscale(nn.Module): 62 | # Used in the decoder, multi-scale feature + GRU 63 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): 64 | super().__init__() 65 | self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) 66 | self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) 67 | self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) 68 | self.g2_conv = GConv2d(g_dims[3], mid_dim, kernel_size=1) 69 | self.g1_conv = GConv2d(g_dims[4], mid_dim, kernel_size=1) 70 | 71 | self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) 72 | 73 | nn.init.xavier_normal_(self.transform.weight) 74 | 75 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: 76 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ 77 | self.g4_conv(downsample_groups(g[2], ratio=1/4)) + \ 78 | self.g2_conv(downsample_groups(g[3], ratio=1/8)) + \ 79 | self.g1_conv(downsample_groups(g[4], ratio=1/16)) 80 | 81 | with torch.amp.autocast("cuda",enabled=False): 82 | g = g.float() 83 | h = h.float() 84 | values = self.transform(torch.cat([g, h], dim=2)) 85 | new_h = _recurrent_update(h, values) 86 | 87 | return new_h 88 | 89 | class SensoryUpdater(nn.Module): 90 | # Used in the decoder, multi-scale feature + GRU 91 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): 92 | super().__init__() 93 | self.g16_conv = GConv2d(g_dims[0], mid_dim, kernel_size=1) 94 | self.g8_conv = GConv2d(g_dims[1], mid_dim, kernel_size=1) 95 | self.g4_conv = GConv2d(g_dims[2], mid_dim, kernel_size=1) 96 | 97 | self.transform = GConv2d(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) 98 | 99 | nn.init.xavier_normal_(self.transform.weight) 100 | 101 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: 102 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ 103 | self.g4_conv(downsample_groups(g[2], ratio=1/4)) 104 | 105 | with torch.amp.autocast("cuda",enabled=False): 106 | g = g.float() 107 | h = h.float() 108 | values = self.transform(torch.cat([g, h], dim=2)) 109 | new_h = _recurrent_update(h, values) 110 | 111 | return new_h 112 | 113 | 114 | class SensoryDeepUpdater(nn.Module): 115 | def __init__(self, f_dim: int, sensory_dim: int): 116 | super().__init__() 117 | self.transform = GConv2d(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) 118 | 119 | nn.init.xavier_normal_(self.transform.weight) 120 | 121 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: 122 | with torch.amp.autocast("cuda",enabled=False): 123 | g = g.float() 124 | h = h.float() 125 | values = self.transform(torch.cat([g, h], dim=2)) 126 | new_h = _recurrent_update(h, values) 127 | 128 | return new_h 129 | 130 | 131 | class ResBlock(nn.Module): 132 | def __init__(self, in_dim: int, out_dim: int): 133 | super().__init__() 134 | 135 | if in_dim == out_dim: 136 | self.downsample = nn.Identity() 137 | else: 138 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=1) 139 | 140 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) 141 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) 142 | 143 | def forward(self, g: torch.Tensor) -> torch.Tensor: 144 | out_g = self.conv1(F.relu(g)) 145 | out_g = self.conv2(F.relu(out_g)) 146 | 147 | g = self.downsample(g) 148 | 149 | return out_g + g -------------------------------------------------------------------------------- /matanyone/model/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/model/transformer/__init__.py -------------------------------------------------------------------------------- /matanyone/model/transformer/object_summarizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from omegaconf import DictConfig 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from matanyone.model.transformer.positional_encoding import PositionalEncoding 8 | 9 | 10 | # @torch.jit.script 11 | def _weighted_pooling(masks: torch.Tensor, value: torch.Tensor, 12 | logits: torch.Tensor) -> (torch.Tensor, torch.Tensor): 13 | # value: B*num_objects*H*W*value_dim 14 | # logits: B*num_objects*H*W*num_summaries 15 | # masks: B*num_objects*H*W*num_summaries: 1 if allowed 16 | weights = logits.sigmoid() * masks 17 | # B*num_objects*num_summaries*value_dim 18 | sums = torch.einsum('bkhwq,bkhwc->bkqc', weights, value) 19 | # B*num_objects*H*W*num_summaries -> B*num_objects*num_summaries*1 20 | area = weights.flatten(start_dim=2, end_dim=3).sum(2).unsqueeze(-1) 21 | 22 | # B*num_objects*num_summaries*value_dim 23 | return sums, area 24 | 25 | 26 | class ObjectSummarizer(nn.Module): 27 | def __init__(self, model_cfg: DictConfig): 28 | super().__init__() 29 | 30 | this_cfg = model_cfg.object_summarizer 31 | self.value_dim = model_cfg.value_dim 32 | self.embed_dim = this_cfg.embed_dim 33 | self.num_summaries = this_cfg.num_summaries 34 | self.add_pe = this_cfg.add_pe 35 | self.pixel_pe_scale = model_cfg.pixel_pe_scale 36 | self.pixel_pe_temperature = model_cfg.pixel_pe_temperature 37 | 38 | if self.add_pe: 39 | self.pos_enc = PositionalEncoding(self.embed_dim, 40 | scale=self.pixel_pe_scale, 41 | temperature=self.pixel_pe_temperature) 42 | 43 | self.input_proj = nn.Linear(self.value_dim, self.embed_dim) 44 | self.feature_pred = nn.Sequential( 45 | nn.Linear(self.embed_dim, self.embed_dim), 46 | nn.ReLU(inplace=True), 47 | nn.Linear(self.embed_dim, self.embed_dim), 48 | ) 49 | self.weights_pred = nn.Sequential( 50 | nn.Linear(self.embed_dim, self.embed_dim), 51 | nn.ReLU(inplace=True), 52 | nn.Linear(self.embed_dim, self.num_summaries), 53 | ) 54 | 55 | def forward(self, 56 | masks: torch.Tensor, 57 | value: torch.Tensor, 58 | need_weights: bool = False) -> (torch.Tensor, Optional[torch.Tensor]): 59 | # masks: B*num_objects*(H0)*(W0) 60 | # value: B*num_objects*value_dim*H*W 61 | # -> B*num_objects*H*W*value_dim 62 | h, w = value.shape[-2:] 63 | masks = F.interpolate(masks, size=(h, w), mode='area') 64 | masks = masks.unsqueeze(-1) 65 | inv_masks = 1 - masks 66 | repeated_masks = torch.cat([ 67 | masks.expand(-1, -1, -1, -1, self.num_summaries // 2), 68 | inv_masks.expand(-1, -1, -1, -1, self.num_summaries // 2), 69 | ], 70 | dim=-1) 71 | 72 | value = value.permute(0, 1, 3, 4, 2) 73 | value = self.input_proj(value) 74 | if self.add_pe: 75 | pe = self.pos_enc(value) 76 | value = value + pe 77 | 78 | with torch.amp.autocast("cuda",enabled=False): 79 | value = value.float() 80 | feature = self.feature_pred(value) 81 | logits = self.weights_pred(value) 82 | sums, area = _weighted_pooling(repeated_masks, feature, logits) 83 | 84 | summaries = torch.cat([sums, area], dim=-1) 85 | 86 | if need_weights: 87 | return summaries, logits 88 | else: 89 | return summaries, None -------------------------------------------------------------------------------- /matanyone/model/transformer/object_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | from omegaconf import DictConfig 3 | 4 | import torch 5 | import torch.nn as nn 6 | from matanyone.model.group_modules import GConv2d 7 | from matanyone.utils.tensor_utils import aggregate 8 | from matanyone.model.transformer.positional_encoding import PositionalEncoding 9 | from matanyone.model.transformer.transformer_layers import CrossAttention, SelfAttention, FFN, PixelFFN 10 | 11 | 12 | class QueryTransformerBlock(nn.Module): 13 | def __init__(self, model_cfg: DictConfig): 14 | super().__init__() 15 | 16 | this_cfg = model_cfg.object_transformer 17 | self.embed_dim = this_cfg.embed_dim 18 | self.num_heads = this_cfg.num_heads 19 | self.num_queries = this_cfg.num_queries 20 | self.ff_dim = this_cfg.ff_dim 21 | 22 | self.read_from_pixel = CrossAttention(self.embed_dim, 23 | self.num_heads, 24 | add_pe_to_qkv=this_cfg.read_from_pixel.add_pe_to_qkv) 25 | self.self_attn = SelfAttention(self.embed_dim, 26 | self.num_heads, 27 | add_pe_to_qkv=this_cfg.query_self_attention.add_pe_to_qkv) 28 | self.ffn = FFN(self.embed_dim, self.ff_dim) 29 | self.read_from_query = CrossAttention(self.embed_dim, 30 | self.num_heads, 31 | add_pe_to_qkv=this_cfg.read_from_query.add_pe_to_qkv, 32 | norm=this_cfg.read_from_query.output_norm) 33 | self.pixel_ffn = PixelFFN(self.embed_dim) 34 | 35 | def forward( 36 | self, 37 | x: torch.Tensor, 38 | pixel: torch.Tensor, 39 | query_pe: torch.Tensor, 40 | pixel_pe: torch.Tensor, 41 | attn_mask: torch.Tensor, 42 | need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): 43 | # x: (bs*num_objects)*num_queries*embed_dim 44 | # pixel: bs*num_objects*C*H*W 45 | # query_pe: (bs*num_objects)*num_queries*embed_dim 46 | # pixel_pe: (bs*num_objects)*(H*W)*C 47 | # attn_mask: (bs*num_objects*num_heads)*num_queries*(H*W) 48 | 49 | # bs*num_objects*C*H*W -> (bs*num_objects)*(H*W)*C 50 | pixel_flat = pixel.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() 51 | x, q_weights = self.read_from_pixel(x, 52 | pixel_flat, 53 | query_pe, 54 | pixel_pe, 55 | attn_mask=attn_mask, 56 | need_weights=need_weights) 57 | x = self.self_attn(x, query_pe) 58 | x = self.ffn(x) 59 | 60 | pixel_flat, p_weights = self.read_from_query(pixel_flat, 61 | x, 62 | pixel_pe, 63 | query_pe, 64 | need_weights=need_weights) 65 | pixel = self.pixel_ffn(pixel, pixel_flat) 66 | 67 | if need_weights: 68 | bs, num_objects, _, h, w = pixel.shape 69 | q_weights = q_weights.view(bs, num_objects, self.num_heads, self.num_queries, h, w) 70 | p_weights = p_weights.transpose(2, 3).view(bs, num_objects, self.num_heads, 71 | self.num_queries, h, w) 72 | 73 | return x, pixel, q_weights, p_weights 74 | 75 | 76 | class QueryTransformer(nn.Module): 77 | def __init__(self, model_cfg: DictConfig): 78 | super().__init__() 79 | 80 | this_cfg = model_cfg.object_transformer 81 | self.value_dim = model_cfg.value_dim 82 | self.embed_dim = this_cfg.embed_dim 83 | self.num_heads = this_cfg.num_heads 84 | self.num_queries = this_cfg.num_queries 85 | 86 | # query initialization and embedding 87 | self.query_init = nn.Embedding(self.num_queries, self.embed_dim) 88 | self.query_emb = nn.Embedding(self.num_queries, self.embed_dim) 89 | 90 | # projection from object summaries to query initialization and embedding 91 | self.summary_to_query_init = nn.Linear(self.embed_dim, self.embed_dim) 92 | self.summary_to_query_emb = nn.Linear(self.embed_dim, self.embed_dim) 93 | 94 | self.pixel_pe_scale = model_cfg.pixel_pe_scale 95 | self.pixel_pe_temperature = model_cfg.pixel_pe_temperature 96 | self.pixel_init_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) 97 | self.pixel_emb_proj = GConv2d(self.embed_dim, self.embed_dim, kernel_size=1) 98 | self.spatial_pe = PositionalEncoding(self.embed_dim, 99 | scale=self.pixel_pe_scale, 100 | temperature=self.pixel_pe_temperature, 101 | channel_last=False, 102 | transpose_output=True) 103 | 104 | # transformer blocks 105 | self.num_blocks = this_cfg.num_blocks 106 | self.blocks = nn.ModuleList( 107 | QueryTransformerBlock(model_cfg) for _ in range(self.num_blocks)) 108 | self.mask_pred = nn.ModuleList( 109 | nn.Sequential(nn.ReLU(), GConv2d(self.embed_dim, 1, kernel_size=1)) 110 | for _ in range(self.num_blocks + 1)) 111 | 112 | self.act = nn.ReLU(inplace=True) 113 | 114 | def forward(self, 115 | pixel: torch.Tensor, 116 | obj_summaries: torch.Tensor, 117 | selector: Optional[torch.Tensor] = None, 118 | need_weights: bool = False, 119 | seg_pass=False) -> (torch.Tensor, Dict[str, torch.Tensor]): 120 | # pixel: B*num_objects*embed_dim*H*W 121 | # obj_summaries: B*num_objects*T*num_queries*embed_dim 122 | T = obj_summaries.shape[2] 123 | bs, num_objects, _, H, W = pixel.shape 124 | 125 | # normalize object values 126 | # the last channel is the cumulative area of the object 127 | obj_summaries = obj_summaries.view(bs * num_objects, T, self.num_queries, 128 | self.embed_dim + 1) 129 | # sum over time 130 | # during inference, T=1 as we already did streaming average in memory_manager 131 | obj_sums = obj_summaries[:, :, :, :-1].sum(dim=1) 132 | obj_area = obj_summaries[:, :, :, -1:].sum(dim=1) 133 | obj_values = obj_sums / (obj_area + 1e-4) 134 | obj_init = self.summary_to_query_init(obj_values) 135 | obj_emb = self.summary_to_query_emb(obj_values) 136 | 137 | # positional embeddings for object queries 138 | query = self.query_init.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_init 139 | query_emb = self.query_emb.weight.unsqueeze(0).expand(bs * num_objects, -1, -1) + obj_emb 140 | 141 | # positional embeddings for pixel features 142 | pixel_init = self.pixel_init_proj(pixel) 143 | pixel_emb = self.pixel_emb_proj(pixel) 144 | pixel_pe = self.spatial_pe(pixel.flatten(0, 1)) 145 | pixel_emb = pixel_emb.flatten(3, 4).flatten(0, 1).transpose(1, 2).contiguous() 146 | pixel_pe = pixel_pe.flatten(1, 2) + pixel_emb 147 | 148 | pixel = pixel_init 149 | 150 | # run the transformer 151 | aux_features = {'logits': []} 152 | 153 | # first aux output 154 | aux_logits = self.mask_pred[0](pixel).squeeze(2) 155 | attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) 156 | aux_features['logits'].append(aux_logits) 157 | for i in range(self.num_blocks): 158 | query, pixel, q_weights, p_weights = self.blocks[i](query, 159 | pixel, 160 | query_emb, 161 | pixel_pe, 162 | attn_mask, 163 | need_weights=need_weights) 164 | 165 | if self.training or i <= self.num_blocks - 1 or need_weights: 166 | aux_logits = self.mask_pred[i + 1](pixel).squeeze(2) 167 | attn_mask = self._get_aux_mask(aux_logits, selector, seg_pass=seg_pass) 168 | aux_features['logits'].append(aux_logits) 169 | 170 | aux_features['q_weights'] = q_weights # last layer only 171 | aux_features['p_weights'] = p_weights # last layer only 172 | 173 | if self.training: 174 | # no need to save all heads 175 | aux_features['attn_mask'] = attn_mask.view(bs, num_objects, self.num_heads, 176 | self.num_queries, H, W)[:, :, 0] 177 | 178 | return pixel, aux_features 179 | 180 | def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor, seg_pass=False) -> torch.Tensor: 181 | # logits: batch_size*num_objects*H*W 182 | # selector: batch_size*num_objects*1*1 183 | # returns a mask of shape (batch_size*num_objects*num_heads)*num_queries*(H*W) 184 | # where True means the attention is blocked 185 | 186 | if selector is None: 187 | prob = logits.sigmoid() 188 | else: 189 | prob = logits.sigmoid() * selector 190 | logits = aggregate(prob, dim=1) 191 | 192 | is_foreground = (logits[:, 1:] >= logits.max(dim=1, keepdim=True)[0]) 193 | foreground_mask = is_foreground.bool().flatten(start_dim=2) 194 | inv_foreground_mask = ~foreground_mask 195 | inv_background_mask = foreground_mask 196 | 197 | aux_foreground_mask = inv_foreground_mask.unsqueeze(2).unsqueeze(2).repeat( 198 | 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) 199 | aux_background_mask = inv_background_mask.unsqueeze(2).unsqueeze(2).repeat( 200 | 1, 1, self.num_heads, self.num_queries // 2, 1).flatten(start_dim=0, end_dim=2) 201 | 202 | aux_mask = torch.cat([aux_foreground_mask, aux_background_mask], dim=1) 203 | 204 | aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False 205 | 206 | return aux_mask -------------------------------------------------------------------------------- /matanyone/model/transformer/positional_encoding.py: -------------------------------------------------------------------------------- 1 | # Reference: 2 | # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/transformer_decoder/position_encoding.py 3 | # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py 4 | 5 | import math 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | 12 | def get_emb(sin_inp: torch.Tensor) -> torch.Tensor: 13 | """ 14 | Gets a base embedding for one dimension with sin and cos intertwined 15 | """ 16 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 17 | return torch.flatten(emb, -2, -1) 18 | 19 | 20 | class PositionalEncoding(nn.Module): 21 | def __init__(self, 22 | dim: int, 23 | scale: float = math.pi * 2, 24 | temperature: float = 10000, 25 | normalize: bool = True, 26 | channel_last: bool = True, 27 | transpose_output: bool = False): 28 | super().__init__() 29 | dim = int(np.ceil(dim / 4) * 2) 30 | self.dim = dim 31 | inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim)) 32 | self.register_buffer("inv_freq", inv_freq) 33 | self.normalize = normalize 34 | self.scale = scale 35 | self.eps = 1e-6 36 | self.channel_last = channel_last 37 | self.transpose_output = transpose_output 38 | 39 | self.cached_penc = None # the cache is irrespective of the number of objects 40 | 41 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 42 | """ 43 | :param tensor: A 4/5d tensor of size 44 | channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c) 45 | channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w) 46 | :return: positional encoding tensor that has the same shape as the input if the input is 4d 47 | if the input is 5d, the output is broadcastable along the k-dimension 48 | """ 49 | if len(tensor.shape) != 4 and len(tensor.shape) != 5: 50 | raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!') 51 | 52 | if len(tensor.shape) == 5: 53 | # take a sample from the k dimension 54 | num_objects = tensor.shape[1] 55 | tensor = tensor[:, 0] 56 | else: 57 | num_objects = None 58 | 59 | if self.channel_last: 60 | batch_size, h, w, c = tensor.shape 61 | else: 62 | batch_size, c, h, w = tensor.shape 63 | 64 | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: 65 | if num_objects is None: 66 | return self.cached_penc 67 | else: 68 | return self.cached_penc.unsqueeze(1) 69 | 70 | self.cached_penc = None 71 | 72 | pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype) 73 | pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype) 74 | if self.normalize: 75 | pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale 76 | pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale 77 | 78 | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) 79 | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) 80 | emb_y = get_emb(sin_inp_y).unsqueeze(1) 81 | emb_x = get_emb(sin_inp_x) 82 | 83 | emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype) 84 | emb[:, :, :self.dim] = emb_x 85 | emb[:, :, self.dim:] = emb_y 86 | 87 | if not self.channel_last and self.transpose_output: 88 | # cancelled out 89 | pass 90 | elif (not self.channel_last) or (self.transpose_output): 91 | emb = emb.permute(2, 0, 1) 92 | 93 | self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1) 94 | if num_objects is None: 95 | return self.cached_penc 96 | else: 97 | return self.cached_penc.unsqueeze(1) 98 | 99 | 100 | if __name__ == '__main__': 101 | pe = PositionalEncoding(8).cuda() 102 | input = torch.ones((1, 8, 8, 8)).cuda() 103 | output = pe(input) 104 | # print(output) 105 | print(output[0, :, 0, 0]) 106 | print(output[0, :, 0, 5]) 107 | print(output[0, 0, :, 0]) 108 | print(output[0, 0, 0, :]) 109 | -------------------------------------------------------------------------------- /matanyone/model/transformer/transformer_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from PyTorch nn.Transformer 2 | 3 | from typing import List, Callable 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from matanyone.model.channel_attn import CAResBlock 10 | 11 | 12 | class SelfAttention(nn.Module): 13 | def __init__(self, 14 | dim: int, 15 | nhead: int, 16 | dropout: float = 0.0, 17 | batch_first: bool = True, 18 | add_pe_to_qkv: List[bool] = [True, True, False]): 19 | super().__init__() 20 | self.self_attn = nn.MultiheadAttention(dim, nhead, dropout=dropout, batch_first=batch_first) 21 | self.norm = nn.LayerNorm(dim) 22 | self.dropout = nn.Dropout(dropout) 23 | self.add_pe_to_qkv = add_pe_to_qkv 24 | 25 | def forward(self, 26 | x: torch.Tensor, 27 | pe: torch.Tensor, 28 | attn_mask: bool = None, 29 | key_padding_mask: bool = None) -> torch.Tensor: 30 | x = self.norm(x) 31 | if any(self.add_pe_to_qkv): 32 | x_with_pe = x + pe 33 | q = x_with_pe if self.add_pe_to_qkv[0] else x 34 | k = x_with_pe if self.add_pe_to_qkv[1] else x 35 | v = x_with_pe if self.add_pe_to_qkv[2] else x 36 | else: 37 | q = k = v = x 38 | 39 | r = x 40 | x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] 41 | return r + self.dropout(x) 42 | 43 | 44 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention 45 | class CrossAttention(nn.Module): 46 | def __init__(self, 47 | dim: int, 48 | nhead: int, 49 | dropout: float = 0.0, 50 | batch_first: bool = True, 51 | add_pe_to_qkv: List[bool] = [True, True, False], 52 | residual: bool = True, 53 | norm: bool = True): 54 | super().__init__() 55 | self.cross_attn = nn.MultiheadAttention(dim, 56 | nhead, 57 | dropout=dropout, 58 | batch_first=batch_first) 59 | if norm: 60 | self.norm = nn.LayerNorm(dim) 61 | else: 62 | self.norm = nn.Identity() 63 | self.dropout = nn.Dropout(dropout) 64 | self.add_pe_to_qkv = add_pe_to_qkv 65 | self.residual = residual 66 | 67 | def forward(self, 68 | x: torch.Tensor, 69 | mem: torch.Tensor, 70 | x_pe: torch.Tensor, 71 | mem_pe: torch.Tensor, 72 | attn_mask: bool = None, 73 | *, 74 | need_weights: bool = False) -> (torch.Tensor, torch.Tensor): 75 | x = self.norm(x) 76 | if self.add_pe_to_qkv[0]: 77 | q = x + x_pe 78 | else: 79 | q = x 80 | 81 | if any(self.add_pe_to_qkv[1:]): 82 | mem_with_pe = mem + mem_pe 83 | k = mem_with_pe if self.add_pe_to_qkv[1] else mem 84 | v = mem_with_pe if self.add_pe_to_qkv[2] else mem 85 | else: 86 | k = v = mem 87 | r = x 88 | x, weights = self.cross_attn(q, 89 | k, 90 | v, 91 | attn_mask=attn_mask, 92 | need_weights=need_weights, 93 | average_attn_weights=False) 94 | 95 | if self.residual: 96 | return r + self.dropout(x), weights 97 | else: 98 | return self.dropout(x), weights 99 | 100 | 101 | class FFN(nn.Module): 102 | def __init__(self, dim_in: int, dim_ff: int, activation=F.relu): 103 | super().__init__() 104 | self.linear1 = nn.Linear(dim_in, dim_ff) 105 | self.linear2 = nn.Linear(dim_ff, dim_in) 106 | self.norm = nn.LayerNorm(dim_in) 107 | 108 | if isinstance(activation, str): 109 | self.activation = _get_activation_fn(activation) 110 | else: 111 | self.activation = activation 112 | 113 | def forward(self, x: torch.Tensor) -> torch.Tensor: 114 | r = x 115 | x = self.norm(x) 116 | x = self.linear2(self.activation(self.linear1(x))) 117 | x = r + x 118 | return x 119 | 120 | 121 | class PixelFFN(nn.Module): 122 | def __init__(self, dim: int): 123 | super().__init__() 124 | self.dim = dim 125 | self.conv = CAResBlock(dim, dim) 126 | 127 | def forward(self, pixel: torch.Tensor, pixel_flat: torch.Tensor) -> torch.Tensor: 128 | # pixel: batch_size * num_objects * dim * H * W 129 | # pixel_flat: (batch_size*num_objects) * (H*W) * dim 130 | bs, num_objects, _, h, w = pixel.shape 131 | pixel_flat = pixel_flat.view(bs * num_objects, h, w, self.dim) 132 | pixel_flat = pixel_flat.permute(0, 3, 1, 2).contiguous() 133 | 134 | x = self.conv(pixel_flat) 135 | x = x.view(bs, num_objects, self.dim, h, w) 136 | return x 137 | 138 | 139 | class OutputFFN(nn.Module): 140 | def __init__(self, dim_in: int, dim_out: int, activation=F.relu): 141 | super().__init__() 142 | self.linear1 = nn.Linear(dim_in, dim_out) 143 | self.linear2 = nn.Linear(dim_out, dim_out) 144 | 145 | if isinstance(activation, str): 146 | self.activation = _get_activation_fn(activation) 147 | else: 148 | self.activation = activation 149 | 150 | def forward(self, x: torch.Tensor) -> torch.Tensor: 151 | x = self.linear2(self.activation(self.linear1(x))) 152 | return x 153 | 154 | 155 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: 156 | if activation == "relu": 157 | return F.relu 158 | elif activation == "gelu": 159 | return F.gelu 160 | 161 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 162 | -------------------------------------------------------------------------------- /matanyone/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/model/utils/__init__.py -------------------------------------------------------------------------------- /matanyone/model/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Optional, Union, Tuple 4 | 5 | 6 | # @torch.jit.script 7 | def get_similarity(mk: torch.Tensor, 8 | ms: torch.Tensor, 9 | qk: torch.Tensor, 10 | qe: torch.Tensor, 11 | add_batch_dim: bool = False, 12 | uncert_mask = None) -> torch.Tensor: 13 | # used for training/inference and memory reading/memory potentiation 14 | # mk: B x CK x [N] - Memory keys 15 | # ms: B x 1 x [N] - Memory shrinkage 16 | # qk: B x CK x [HW/P] - Query keys 17 | # qe: B x CK x [HW/P] - Query selection 18 | # Dimensions in [] are flattened 19 | # Return: B*N*HW 20 | if add_batch_dim: 21 | mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) 22 | qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) 23 | 24 | CK = mk.shape[1] 25 | 26 | mk = mk.flatten(start_dim=2) 27 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None 28 | qk = qk.flatten(start_dim=2) 29 | qe = qe.flatten(start_dim=2) if qe is not None else None 30 | 31 | # query token selection based on temporal sparsity 32 | if uncert_mask is not None: 33 | uncert_mask = uncert_mask.flatten(start_dim=2) 34 | uncert_mask = uncert_mask.expand(-1, 64, -1) 35 | qk = qk * uncert_mask 36 | qe = qe * uncert_mask 37 | 38 | if qe is not None: 39 | # See XMem's appendix for derivation 40 | mk = mk.transpose(1, 2) 41 | a_sq = (mk.pow(2) @ qe) 42 | two_ab = 2 * (mk @ (qk * qe)) 43 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) 44 | similarity = (-a_sq + two_ab - b_sq) 45 | else: 46 | # similar to STCN if we don't have the selection term 47 | a_sq = mk.pow(2).sum(1).unsqueeze(2) 48 | two_ab = 2 * (mk.transpose(1, 2) @ qk) 49 | similarity = (-a_sq + two_ab) 50 | 51 | if ms is not None: 52 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW 53 | else: 54 | similarity = similarity / math.sqrt(CK) # B*N*HW 55 | 56 | return similarity 57 | 58 | 59 | def do_softmax( 60 | similarity: torch.Tensor, 61 | top_k: Optional[int] = None, 62 | inplace: bool = False, 63 | return_usage: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 64 | # normalize similarity with top-k softmax 65 | # similarity: B x N x [HW/P] 66 | # use inplace with care 67 | if top_k is not None: 68 | values, indices = torch.topk(similarity, k=top_k, dim=1) 69 | 70 | x_exp = values.exp_() 71 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True) 72 | if inplace: 73 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW 74 | affinity = similarity 75 | else: 76 | affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW 77 | else: 78 | maxes = torch.max(similarity, dim=1, keepdim=True)[0] 79 | x_exp = torch.exp(similarity - maxes) 80 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) 81 | affinity = x_exp / x_exp_sum 82 | indices = None 83 | 84 | if return_usage: 85 | return affinity, affinity.sum(dim=2) 86 | 87 | return affinity 88 | 89 | 90 | def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, 91 | qe: torch.Tensor, uncert_mask = None) -> torch.Tensor: 92 | # shorthand used in training with no top-k 93 | similarity = get_similarity(mk, ms, qk, qe, uncert_mask=uncert_mask) 94 | affinity = do_softmax(similarity) 95 | return affinity 96 | 97 | def readout(affinity: torch.Tensor, mv: torch.Tensor, uncert_mask: torch.Tensor=None) -> torch.Tensor: 98 | B, CV, T, H, W = mv.shape 99 | 100 | mo = mv.view(B, CV, T * H * W) 101 | mem = torch.bmm(mo, affinity) 102 | if uncert_mask is not None: 103 | uncert_mask = uncert_mask.flatten(start_dim=2).expand(-1, CV, -1) 104 | mem = mem * uncert_mask 105 | mem = mem.view(B, CV, H, W) 106 | 107 | return mem 108 | -------------------------------------------------------------------------------- /matanyone/model/utils/parameter_groups.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | log = logging.getLogger() 4 | 5 | 6 | def get_parameter_groups(model, stage_cfg, print_log=False): 7 | """ 8 | Assign different weight decays and learning rates to different parameters. 9 | Returns a parameter group which can be passed to the optimizer. 10 | """ 11 | weight_decay = stage_cfg.weight_decay 12 | embed_weight_decay = stage_cfg.embed_weight_decay 13 | backbone_lr_ratio = stage_cfg.backbone_lr_ratio 14 | base_lr = stage_cfg.learning_rate 15 | 16 | backbone_params = [] 17 | embed_params = [] 18 | other_params = [] 19 | 20 | embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] 21 | embedding_names = [e + '.weight' for e in embedding_names] 22 | 23 | # inspired by detectron2 24 | memo = set() 25 | for name, param in model.named_parameters(): 26 | if not param.requires_grad: 27 | continue 28 | # Avoid duplicating parameters 29 | if param in memo: 30 | continue 31 | memo.add(param) 32 | 33 | if name.startswith('module'): 34 | name = name[7:] 35 | 36 | inserted = False 37 | if name.startswith('pixel_encoder.'): 38 | backbone_params.append(param) 39 | inserted = True 40 | if print_log: 41 | log.info(f'{name} counted as a backbone parameter.') 42 | else: 43 | for e in embedding_names: 44 | if name.endswith(e): 45 | embed_params.append(param) 46 | inserted = True 47 | if print_log: 48 | log.info(f'{name} counted as an embedding parameter.') 49 | break 50 | 51 | if not inserted: 52 | other_params.append(param) 53 | 54 | parameter_groups = [ 55 | { 56 | 'params': backbone_params, 57 | 'lr': base_lr * backbone_lr_ratio, 58 | 'weight_decay': weight_decay 59 | }, 60 | { 61 | 'params': embed_params, 62 | 'lr': base_lr, 63 | 'weight_decay': embed_weight_decay 64 | }, 65 | { 66 | 'params': other_params, 67 | 'lr': base_lr, 68 | 'weight_decay': weight_decay 69 | }, 70 | ] 71 | 72 | return parameter_groups -------------------------------------------------------------------------------- /matanyone/model/utils/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | resnet.py - A modified ResNet structure 3 | We append extra channels to the first conv by some network surgery 4 | """ 5 | 6 | from collections import OrderedDict 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils import model_zoo 12 | 13 | 14 | def load_weights_add_extra_dim(target, source_state, extra_dim=1): 15 | new_dict = OrderedDict() 16 | 17 | for k1, v1 in target.state_dict().items(): 18 | if 'num_batches_tracked' not in k1: 19 | if k1 in source_state: 20 | tar_v = source_state[k1] 21 | 22 | if v1.shape != tar_v.shape: 23 | # Init the new segmentation channel with zeros 24 | # print(v1.shape, tar_v.shape) 25 | c, _, w, h = v1.shape 26 | pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device) 27 | nn.init.orthogonal_(pads) 28 | tar_v = torch.cat([tar_v, pads], 1) 29 | 30 | new_dict[k1] = tar_v 31 | 32 | target.load_state_dict(new_dict) 33 | 34 | 35 | model_urls = { 36 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 37 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 38 | } 39 | 40 | 41 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 42 | return nn.Conv2d(in_planes, 43 | out_planes, 44 | kernel_size=3, 45 | stride=stride, 46 | padding=dilation, 47 | dilation=dilation, 48 | bias=False) 49 | 50 | 51 | class BasicBlock(nn.Module): 52 | expansion = 1 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 55 | super(BasicBlock, self).__init__() 56 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | 74 | if self.downsample is not None: 75 | residual = self.downsample(x) 76 | 77 | out += residual 78 | out = self.relu(out) 79 | 80 | return out 81 | 82 | 83 | class Bottleneck(nn.Module): 84 | expansion = 4 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 87 | super(Bottleneck, self).__init__() 88 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 89 | self.bn1 = nn.BatchNorm2d(planes) 90 | self.conv2 = nn.Conv2d(planes, 91 | planes, 92 | kernel_size=3, 93 | stride=stride, 94 | dilation=dilation, 95 | padding=dilation, 96 | bias=False) 97 | self.bn2 = nn.BatchNorm2d(planes) 98 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(planes * 4) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | residual = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | residual = self.downsample(x) 120 | 121 | out += residual 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): 129 | self.inplanes = 64 130 | super(ResNet, self).__init__() 131 | self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) 132 | self.bn1 = nn.BatchNorm2d(64) 133 | self.relu = nn.ReLU(inplace=True) 134 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 135 | self.layer1 = self._make_layer(block, 64, layers[0]) 136 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 137 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 143 | m.weight.data.normal_(0, math.sqrt(2. / n)) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | m.weight.data.fill_(1) 146 | m.bias.data.zero_() 147 | 148 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 149 | downsample = None 150 | if stride != 1 or self.inplanes != planes * block.expansion: 151 | downsample = nn.Sequential( 152 | nn.Conv2d(self.inplanes, 153 | planes * block.expansion, 154 | kernel_size=1, 155 | stride=stride, 156 | bias=False), 157 | nn.BatchNorm2d(planes * block.expansion), 158 | ) 159 | 160 | layers = [block(self.inplanes, planes, stride, downsample)] 161 | self.inplanes = planes * block.expansion 162 | for i in range(1, blocks): 163 | layers.append(block(self.inplanes, planes, dilation=dilation)) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def resnet18(pretrained=True, extra_dim=0): 169 | model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) 170 | if pretrained: 171 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) 172 | return model 173 | 174 | 175 | def resnet50(pretrained=True, extra_dim=0): 176 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) 177 | if pretrained: 178 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) 179 | return model 180 | -------------------------------------------------------------------------------- /matanyone/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pq-yang/MatAnyone/6472a66d4495e15cc3dd1d6e6b1c69c8c016796f/matanyone/utils/__init__.py -------------------------------------------------------------------------------- /matanyone/utils/get_default_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | A helper function to get a default model for quick testing 3 | """ 4 | from omegaconf import open_dict 5 | from hydra import compose, initialize 6 | 7 | import torch 8 | from matanyone.model.matanyone import MatAnyone 9 | 10 | def get_matanyone_model(ckpt_path, device=None) -> MatAnyone: 11 | initialize(version_base='1.3.2', config_path="../config", job_name="eval_our_config") 12 | cfg = compose(config_name="eval_matanyone_config") 13 | 14 | with open_dict(cfg): 15 | cfg['weights'] = ckpt_path 16 | 17 | # Load the network weights 18 | if device is not None: 19 | matanyone = MatAnyone(cfg, single_object=True).to(device).eval() 20 | model_weights = torch.load(cfg.weights, map_location=device) 21 | else: # if device is not specified, `.cuda()` by default 22 | matanyone = MatAnyone(cfg, single_object=True).cuda().eval() 23 | model_weights = torch.load(cfg.weights) 24 | 25 | matanyone.load_weights(model_weights) 26 | 27 | return matanyone 28 | -------------------------------------------------------------------------------- /matanyone/utils/inference_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torchvision 8 | 9 | IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG') 10 | VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi', '.MP4', '.MOV', '.AVI') 11 | 12 | def read_frame_from_videos(frame_root): 13 | if frame_root.endswith(VIDEO_EXTENSIONS): # Video file path 14 | video_name = os.path.basename(frame_root)[:-4] 15 | frames, _, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', output_format='TCHW') # RGB 16 | fps = info['video_fps'] 17 | else: 18 | video_name = os.path.basename(frame_root) 19 | frames = [] 20 | fr_lst = sorted(os.listdir(frame_root)) 21 | for fr in fr_lst: 22 | frame = cv2.imread(os.path.join(frame_root, fr))[...,[2,1,0]] # RGB, HWC 23 | frames.append(frame) 24 | fps = 24 # default 25 | frames = torch.Tensor(np.array(frames)).permute(0, 3, 1, 2).contiguous() # TCHW 26 | 27 | length = frames.shape[0] 28 | 29 | return frames, fps, length, video_name 30 | 31 | def get_video_paths(input_root): 32 | video_paths = [] 33 | for root, _, files in os.walk(input_root): 34 | for file in files: 35 | if file.lower().endswith(VIDEO_EXTENSIONS): 36 | video_paths.append(os.path.join(root, file)) 37 | return sorted(video_paths) 38 | 39 | def str_to_list(value): 40 | return list(map(int, value.split(','))) 41 | 42 | def gen_dilate(alpha, min_kernel_size, max_kernel_size): 43 | kernel_size = random.randint(min_kernel_size, max_kernel_size) 44 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) 45 | fg_and_unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) 46 | dilate = cv2.dilate(fg_and_unknown, kernel, iterations=1)*255 47 | return dilate.astype(np.float32) 48 | 49 | def gen_erosion(alpha, min_kernel_size, max_kernel_size): 50 | kernel_size = random.randint(min_kernel_size, max_kernel_size) 51 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size,kernel_size)) 52 | fg = np.array(np.equal(alpha, 255).astype(np.float32)) 53 | erode = cv2.erode(fg, kernel, iterations=1)*255 54 | return erode.astype(np.float32) -------------------------------------------------------------------------------- /matanyone/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterable 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | # STM 7 | def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): 8 | h, w = in_img.shape[-2:] 9 | 10 | if h % d > 0: 11 | new_h = h + d - h % d 12 | else: 13 | new_h = h 14 | if w % d > 0: 15 | new_w = w + d - w % d 16 | else: 17 | new_w = w 18 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) 19 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) 20 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 21 | out = F.pad(in_img, pad_array) 22 | return out, pad_array 23 | 24 | 25 | def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: 26 | if len(img.shape) == 4: 27 | if pad[2] + pad[3] > 0: 28 | img = img[:, :, pad[2]:-pad[3], :] 29 | if pad[0] + pad[1] > 0: 30 | img = img[:, :, :, pad[0]:-pad[1]] 31 | elif len(img.shape) == 3: 32 | if pad[2] + pad[3] > 0: 33 | img = img[:, pad[2]:-pad[3], :] 34 | if pad[0] + pad[1] > 0: 35 | img = img[:, :, pad[0]:-pad[1]] 36 | elif len(img.shape) == 5: 37 | if pad[2] + pad[3] > 0: 38 | img = img[:, :, :, pad[2]:-pad[3], :] 39 | if pad[0] + pad[1] > 0: 40 | img = img[:, :, :, :, pad[0]:-pad[1]] 41 | else: 42 | raise NotImplementedError 43 | return img 44 | 45 | 46 | # @torch.jit.script 47 | def aggregate(prob: torch.Tensor, dim: int) -> torch.Tensor: 48 | with torch.amp.autocast("cuda",enabled=False): 49 | prob = prob.float() 50 | new_prob = torch.cat([torch.prod(1 - prob, dim=dim, keepdim=True), prob], 51 | dim).clamp(1e-7, 1 - 1e-7) 52 | logits = torch.log((new_prob / (1 - new_prob))) # (0, 1) --> (-inf, inf) 53 | 54 | return logits 55 | 56 | 57 | # @torch.jit.script 58 | def cls_to_one_hot(cls_gt: torch.Tensor, num_objects: int) -> torch.Tensor: 59 | # cls_gt: B*1*H*W 60 | B, _, H, W = cls_gt.shape 61 | one_hot = torch.zeros(B, num_objects + 1, H, W, device=cls_gt.device).scatter_(1, cls_gt, 1) 62 | return one_hot -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.metadata] 6 | allow-direct-references = true 7 | 8 | [tool.yapf] 9 | based_on_style = "pep8" 10 | indent_width = 4 11 | column_limit = 100 12 | 13 | [project] 14 | name = "matanyone" 15 | version = "1.0.0" 16 | authors = [{ name = "Peiqing Yang", email = "peiqingyang99@outlook.com" }] 17 | description = "" 18 | readme = "README.md" 19 | requires-python = ">=3.8" 20 | classifiers = [ 21 | "Programming Language :: Python :: 3", 22 | "Operating System :: OS Independent", 23 | ] 24 | dependencies = [ 25 | 'cython', 26 | 'gitpython >= 3.1', 27 | 'thinplate@git+https://github.com/cheind/py-thin-plate-spline', 28 | 'hickle >= 5.0', 29 | 'tensorboard >= 2.11', 30 | 'numpy >= 1.21', 31 | 'Pillow >= 9.5', 32 | 'opencv-python >= 4.8', 33 | 'scipy >= 1.7', 34 | 'pycocotools >= 2.0.7', 35 | 'tqdm >= 4.66.1', 36 | 'gradio >= 3.34', 37 | 'gdown >= 4.7.1', 38 | 'einops >= 0.6', 39 | 'hydra-core >= 1.3.2', 40 | 'PySide6 >= 6.2.0', 41 | 'charset-normalizer >= 3.1.0', 42 | 'netifaces >= 0.11.0', 43 | 'cchardet >= 2.1.7', 44 | 'easydict', 45 | 'av >= 0.5.2', 46 | 'requests', 47 | 'pyqtdarktheme', 48 | 'imageio == 2.25.0', 49 | 'imageio[ffmpeg]', 50 | 'huggingface_hub', 51 | 'safetensors', 52 | ] 53 | 54 | [tool.hatch.build.targets.wheel] 55 | packages = ["matanyone"] 56 | 57 | [project.urls] 58 | "Homepage" = "https://github.com/pq-yang/MatAnyone" 59 | "Bug Tracker" = "https://github.com/pq-yang/MatAnyone/issues" 60 | 61 | [tool.setuptools] 62 | package-dir = {"" = "."} 63 | --------------------------------------------------------------------------------